├── .github └── workflows │ └── publish.yaml ├── .gitignore ├── .readthedocs.yaml ├── CITATION.cff ├── LICENSE ├── MANIFEST.in ├── README-zh.md ├── README.md ├── assets ├── Framework-Features-en.png ├── Framework-Features-zh.png ├── Framework-FlexRAGv3-en.png ├── Framework-FlexRAGv3-zh.png ├── ModularAssistant.png ├── Retrievers.png ├── WebRetriever.png ├── flexrag-wide.png ├── flexrag.png ├── gui_static.png └── parse_files.png ├── benchmarks ├── README.md └── singlehop_qa.md ├── docs ├── Makefile ├── make.bat ├── requirements.docs.txt └── source │ ├── conf.py │ ├── getting_started │ ├── installation.md │ ├── quickstart1.md │ └── quickstart2.md │ ├── index.rst │ ├── locales │ └── zh_CN │ │ └── LC_MESSAGES │ │ ├── getting_started │ │ ├── installation.po │ │ ├── quickstart1.po │ │ └── quickstart2.po │ │ ├── index.po │ │ ├── reference │ │ ├── assistant.po │ │ ├── chunking.po │ │ ├── common_dataclass.po │ │ ├── datasets.po │ │ ├── document_parser.po │ │ ├── encoders.po │ │ ├── generators.po │ │ ├── metrics.po │ │ ├── prompt.po │ │ ├── rankers.po │ │ ├── refiner.po │ │ ├── retrievers.po │ │ ├── text_process.po │ │ ├── tokenizers.po │ │ └── utils.po │ │ └── tutorial │ │ ├── building_assistant.po │ │ ├── entrypoints.po │ │ ├── preparing_corpus.po │ │ ├── preparing_retriever.po │ │ ├── preparing_web_retriever.po │ │ └── using_register.po │ ├── reference │ ├── assistant.rst │ ├── chunking.rst │ ├── common_dataclass.rst │ ├── datasets.rst │ ├── document_parser.rst │ ├── encoders.rst │ ├── generators.rst │ ├── metrics.rst │ ├── prompt.rst │ ├── rankers.rst │ ├── refiner.rst │ ├── retrievers.rst │ ├── text_process.rst │ ├── tokenizers.rst │ └── utils.rst │ └── tutorial │ ├── building_assistant.md │ ├── entrypoints.md │ ├── preparing_corpus.md │ ├── preparing_retriever.md │ ├── preparing_web_retriever.md │ └── using_register.md ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── setup.py ├── src └── flexrag │ ├── __init__.py │ ├── assistant │ ├── __init__.py │ ├── assistant.py │ ├── assistant_prompts │ │ ├── longform_generate_prompt_with_context.json │ │ ├── longform_generate_prompt_without_context.json │ │ ├── shortform_generate_prompt_with_context.json │ │ └── shortform_generate_prompt_without_context.json │ ├── basic_assistant.py │ ├── chatqa_assistant.py │ ├── document_chat_assistant.py │ ├── modular_rag_assistant.py │ └── online_assistant.py │ ├── cache │ ├── __init__.py │ ├── backends.py │ ├── persistent_cache.py │ └── serializer.py │ ├── chunking │ ├── __init__.py │ ├── basic_chunkers.py │ ├── chunker_base.py │ ├── semantic_chunker.py │ └── sentence_splitter.py │ ├── common_dataclass.py │ ├── context_refine │ ├── __init__.py │ ├── arranger.py │ ├── refiner.py │ └── summarizer.py │ ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── hf_dataset.py │ ├── line_delimited_dataset.py │ ├── rag_dataset.py │ └── retrieval_dataset.py │ ├── document_parser │ ├── __init__.py │ ├── docling_parser.py │ ├── document_parser_base.py │ └── markitdown_parser.py │ ├── entrypoints │ ├── __init__.py │ ├── assets │ │ ├── flexrag-wide.png │ │ ├── flexrag.png │ │ ├── robot.png │ │ └── user.png │ ├── cache.py │ ├── combine_outputs.py │ ├── prepare_corpus.py │ ├── prepare_index.py │ ├── rebuild_index.py │ ├── run_assistant.py │ ├── run_interactive.py │ └── run_retriever.py │ ├── metrics │ ├── __init__.py │ ├── evaluator.py │ ├── generation_metrics.py │ ├── lib_rel.cpp │ ├── matching_metrics.py │ ├── metrics_base.py │ ├── retrieval_metrics.py │ ├── xfinder.py │ └── xfinder_utils.py │ ├── models │ ├── __init__.py │ ├── anthropic_model.py │ ├── cohere_model.py │ ├── hf_model.py │ ├── jina_model.py │ ├── llamacpp_model.py │ ├── model_base.py │ ├── ollama_model.py │ ├── openai_model.py │ ├── sentence_transformers_model.py │ ├── tokenizer.py │ ├── utils.py │ └── vllm_model.py │ ├── prompt │ ├── __init__.py │ ├── prompt_base.py │ └── template.py │ ├── ranker │ ├── __init__.py │ ├── cohere_ranker.py │ ├── gpt_ranker.py │ ├── hf_ranker.py │ ├── jina_ranker.py │ ├── mixedbread_ranker.py │ ├── ranker.py │ ├── ranker_prompts │ │ └── rankgpt_prompt.json │ └── voyage_ranker.py │ ├── retriever │ ├── __init__.py │ ├── bm25s_retriever.py │ ├── dense_retriever.py │ ├── elastic_retriever.py │ ├── hyde_retriever.py │ ├── index │ │ ├── __init__.py │ │ ├── annoy_index.py │ │ ├── faiss_index.py │ │ ├── index_base.py │ │ └── scann_index.py │ ├── retriever_base.py │ ├── typesense_retriever.py │ └── web_retrievers │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── web_downloader.py │ │ ├── web_reader.py │ │ ├── web_retriever.py │ │ ├── web_seeker.py │ │ └── wikipedia_retriever.py │ ├── text_process │ ├── __init__.py │ ├── basic_filters.py │ ├── basic_processors.py │ ├── normalize_tokens.py │ ├── pipeline.py │ ├── processor.py │ └── utils.py │ └── utils.py ├── tested_models.md └── tests ├── test_assistant.py ├── test_cache.py ├── test_chunker.py ├── test_data.py ├── test_model.py ├── test_ranker.py ├── test_retriver.py └── testcorp └── testcorp.jsonl /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | build_and_publish: 10 | name: Build wheels on ${{ matrix.os }} 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, windows-latest] 15 | 16 | permissions: 17 | contents: read 18 | id-token: write 19 | 20 | steps: 21 | - name: Check out repository 22 | uses: actions/checkout@v4 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: 3.11 28 | 29 | - name: Install dependencies 30 | run: pip install setuptools wheel twine cibuildwheel 31 | 32 | - name: Build wheels 33 | run: python -m cibuildwheel --output-dir wheelhouse 34 | 35 | - name: Publish to PyPI 36 | env: 37 | TWINE_USERNAME: __token__ 38 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 39 | run: twine upload wheelhouse/*.whl 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .idea 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | applications/DeepSpeed-Chat/data 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | 134 | # vscode 135 | .vscode 136 | 137 | 138 | # third party models 139 | yala/model/third_party_models 140 | 141 | # aim 142 | .aim 143 | 144 | # test files 145 | _test*.py 146 | _test*.ipynb 147 | 148 | # experimental configs 149 | experimental_configs/ 150 | 151 | # hydra logs 152 | outputs/ 153 | 154 | # pytest configs 155 | tests/configs/ 156 | 157 | # cibuildwheel 158 | wheelhouse/ -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version, and other tools you might need 8 | build: 9 | os: ubuntu-24.04 10 | tools: 11 | python: "3.13" 12 | 13 | # Build documentation in the "docs/" directory with Sphinx 14 | sphinx: 15 | configuration: docs/source/conf.py 16 | # Optionally, but recommended, 17 | # declare the Python requirements required to build your documentation 18 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 19 | python: 20 | install: 21 | - requirements: docs/requirements.docs.txt 22 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Zhang" 5 | given-names: "Zhuocheng" 6 | orcid: "https://orcid.org/0009-0004-4131-7968" 7 | - family-names: "Feng" 8 | given-names: "Yang" 9 | - family-names: "Zhang" 10 | given-names: "Min" 11 | title: "FlexRAG" 12 | doi: 10.5281/zenodo.14306983 13 | url: "https://github.com/ictnlp/FlexRAG" 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 zhangzhuocheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include CITATION.cff 4 | recursive-include tests * 5 | recursive-include assets * -------------------------------------------------------------------------------- /README-zh.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | ![Language](https://img.shields.io/badge/language-python-brightgreen) 6 | [![Code Style](https://img.shields.io/badge/code%20style-black-black)](https://github.com/psf/black) 7 | [![Imports: isort](https://img.shields.io/badge/imports-isort-blue)](https://pycqa.github.io/isort/) 8 | [![github license](https://img.shields.io/github/license/ictnlp/FlexRAG)](LICENSE) 9 | [![Read the Docs](https://img.shields.io/badge/docs-English-green)](https://flexrag.readthedocs.io/en/latest/) 10 | [![Read the Docs](https://img.shields.io/badge/docs-Chinese-yellow)](https://flexrag.readthedocs.io/zh-cn/latest/) 11 | [![PyPI - Version](https://img.shields.io/pypi/v/flexrag)](https://pypi.org/project/flexrag/) 12 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.14306983.svg)](https://doi.org/10.5281/zenodo.14306983) 13 | 14 |

15 | | 16 | 介绍视频 | 17 | README (english) | 18 | 文档 | 19 | 检索器 | 20 | 示例 21 | | 22 |

23 | 24 | FlexRAG 是一个创新的开源框架,旨在简化 RAG(检索增强生成)系统的快速复现、开发和评估。它全面支持多种 RAG 场景,包括 **基于文本的、多模态的以及可通过 Web 访问的 RAG** 。借助从数据准备到系统评估的**端到端流水线**,FlexRAG 能够帮助研究人员高效地与社区共享他们的工作,并快速基于自己的算法开发演示原型。 25 | 26 | # 📖 目录 27 | - [📖 目录](#-目录) 28 | - [✨ 框架特色](#-框架特色) 29 | - [📢 最新消息](#-最新消息) 30 | - [🚀 框架入门](#-框架入门) 31 | - [🏗️ FlexRAG 架构](#️-flexrag-架构) 32 | - [📊 基准测试](#-基准测试) 33 | - [🏷️ 许可证](#️-许可证) 34 | - [🖋️ 引用](#️-引用) 35 | - [❤️ 致谢](#️-致谢) 36 | 37 | 38 | # ✨ 框架特色 39 |

40 | 41 |

42 | 43 | # 📢 最新消息 44 | - **2025-03-24**: 中文文档上线啦!请访问 [文档](https://flexrag.readthedocs.io/zh-cn/latest/) 查看。 45 | - **2025-02-25**: FlexRAG 的 LocalRetriever 现在支持从 [HuggingFace Hub](https://huggingface.co/collections/ICTNLP/flexrag-retrievers-67b5373b70123669108a2e59) 上加载啦! 46 | - **2025-01-22**: 新的命令行入口 `run_retriever` 以及大量新的信息检索指标(如 `RetrievalMAP` )现已上线,请阅读[文档](https://flexrag.readthedocs.io/en/latest/)以获取更多信息。 47 | - **2025-01-08**: FlexRAG 现已支持 Windows 系统,您可以直接通过 `pip install flexrag` 来安装。 48 | - **2025-01-08**: FlexRAG 在单跳QA数据集上的基准测试现已公开,详情请参考 [benchmarks](benchmarks/README.md) 页面。 49 | - **2025-01-05**: FlexRAG 的[文档](https://flexrag.readthedocs.io/en/latest/)现已上线。 50 | 51 | # 🚀 框架入门 52 | 从 `pip` 安装 FlexRAG: 53 | ```bash 54 | pip install flexrag 55 | ``` 56 | 57 | 访问我们的[文档](https://flexrag.readthedocs.io/zh-cn/latest/)以了解更多信息。 58 | - [安装](https://flexrag.readthedocs.io/zh-cn/latest/getting_started/installation.html) 59 | - [快速入门](https://flexrag.readthedocs.io/zh-cn/latest/getting_started/quickstart1.html) 60 | - [教程](https://flexrag.readthedocs.io/zh-cn/latest/tutorial/preparing_corpus.html) 61 | 62 | # 🏗️ FlexRAG 架构 63 | FlexRAG 采用**模块化**架构设计,让您可以轻松定制和扩展框架以满足您的特定需求。下图说明了 FlexRAG 的架构: 64 |

65 | 66 |

67 | 68 | # 📊 基准测试 69 | 我们利用 FlexRAG 进行了大量的基准测试,详情请参考 [benchmarks](benchmarks/README.md) 页面。 70 | 71 | # 🏷️ 许可证 72 | 本仓库采用 **MIT License** 开源协议. 详情请参考 [LICENSE](LICENSE) 文件。 73 | 74 | # 🖋️ 引用 75 | 如果您在研究中使用了 FlexRAG,请引用我们的项目: 76 | ```bibtex 77 | @software{Zhang_FlexRAG_2025, 78 | author = {Zhang, Zhuocheng and Feng, Yang and Zhang, Min}, 79 | doi = {10.5281/zenodo.14593327}, 80 | month = jan, 81 | title = {{FlexRAG}}, 82 | url = {https://github.com/ictnlp/FlexRAG}, 83 | year = {2025} 84 | } 85 | ``` 86 | 87 | 88 | # ❤️ 致谢 89 | 下面的开源项目对本项目有所帮助: 90 | - [Faiss](https://github.com/facebookresearch/faiss) 91 | - [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG) 92 | - [LanceDB](https://github.com/lancedb/lancedb) 93 | - [ANN Benchmarks](https://github.com/erikbern/ann-benchmarks) 94 | - [Chonkie](https://github.com/chonkie-ai/chonkie) 95 | - [rerankers](https://github.com/AnswerDotAI/rerankers) 96 | -------------------------------------------------------------------------------- /assets/Framework-Features-en.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Framework-Features-en.png -------------------------------------------------------------------------------- /assets/Framework-Features-zh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Framework-Features-zh.png -------------------------------------------------------------------------------- /assets/Framework-FlexRAGv3-en.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Framework-FlexRAGv3-en.png -------------------------------------------------------------------------------- /assets/Framework-FlexRAGv3-zh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Framework-FlexRAGv3-zh.png -------------------------------------------------------------------------------- /assets/ModularAssistant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/ModularAssistant.png -------------------------------------------------------------------------------- /assets/Retrievers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Retrievers.png -------------------------------------------------------------------------------- /assets/WebRetriever.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/WebRetriever.png -------------------------------------------------------------------------------- /assets/flexrag-wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/flexrag-wide.png -------------------------------------------------------------------------------- /assets/flexrag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/flexrag.png -------------------------------------------------------------------------------- /assets/gui_static.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/gui_static.png -------------------------------------------------------------------------------- /assets/parse_files.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/parse_files.png -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | This directory contains benchmarks for the FlexRAG framework. We conduct these experiments for the following reasons: 3 | 1. We hope to help users gain a deeper understanding of the various components in RAG and determine their impact on the overall RAG system. 4 | 2. We aim to test the FlexRAG framework, which will help us reduce potential bugs. 5 | 6 | ## Directory Structure 7 | - [`singlehop_qa.md`](singlehop_qa.md): This file contains the benchmark results for the single-hop QA task. -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.docs.txt: -------------------------------------------------------------------------------- 1 | # Requirements for Sphinx Documentation 2 | sphinx==8.1.3 3 | myst-parser==4.0.0 4 | sphinx-book-theme==1.1.3 5 | sphinx-copybutton==0.5.2 6 | # Requirements for Loading FlexRAG 7 | numpy<2.0.0 8 | tenacity 9 | hydra-core>=1.3 10 | omegaconf>=2.3.0 11 | pillow 12 | accelerate>=0.26.0 13 | rouge 14 | sacrebleu>=2.4.2 15 | openai>=1.30.1 16 | transformers>=4.44.0 17 | lmdb 18 | unidecode 19 | sacremoses 20 | pandas 21 | pylance 22 | bm25s 23 | elasticsearch>=8.14.0 24 | -f https://download.pytorch.org/whl/cpu 25 | torch>=2.3.0 26 | beautifulsoup4 27 | datasets 28 | pytrec_eval 29 | colorama 30 | # gradio>=5.8.0 -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | import re 9 | import pathlib 10 | import sys 11 | 12 | sys.path.insert(0, str(pathlib.Path(__file__).parents[2] / "src")) 13 | 14 | 15 | def get_version() -> str: 16 | version_string_path = pathlib.Path(__file__).parents[2] / "src/flexrag/utils.py" 17 | with open(version_string_path, encoding="utf-8") as f: 18 | version = re.search(r"__VERSION__ = \"(.*?)\"", f.read()).group(1) 19 | return version 20 | 21 | 22 | project = "FlexRAG Documentation" 23 | html_short_title = "FlexRAG Documentation" 24 | copyright = "2025, ZhuochengZhang" 25 | author = "ZhuochengZhang" 26 | release = get_version() 27 | 28 | # -- General configuration --------------------------------------------------- 29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 30 | 31 | extensions = [ 32 | "sphinx.ext.napoleon", 33 | "sphinx.ext.autodoc", 34 | "sphinx.ext.viewcode", 35 | "sphinx.ext.autosectionlabel", 36 | "sphinx_copybutton", 37 | "myst_parser", 38 | ] 39 | 40 | templates_path = ["_templates"] 41 | exclude_patterns = [] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 46 | 47 | html_theme = "sphinx_book_theme" 48 | html_static_path = ["_static", "../../assets"] 49 | html_theme_options = { 50 | "path_to_docs": "docs/source", 51 | "repository_url": "https://github.com/ictnlp/flexrag", 52 | "use_repository_button": True, 53 | } 54 | 55 | # -- Options for autodoc ----------------------------------------------------- 56 | 57 | autodoc_mock_imports = [ 58 | "gradio", # as gradio has a lot of dependencies, we mock it to speed up building the docs. 59 | ] 60 | 61 | 62 | # -- Options for copybutton -------------------------------------------------- 63 | copybutton_prompt_text = r">>> |\.\.\. " 64 | copybutton_prompt_is_regexp = True 65 | 66 | # -- Options for autosectionlabel -------------------------------------------- 67 | autosectionlabel_prefix_document = False 68 | 69 | # -- Options for multi-language ------------------------------------------------- 70 | language = "en" 71 | locale_dirs = ["./locales/"] 72 | gettext_compact = False 73 | gettext_uuid = True 74 | -------------------------------------------------------------------------------- /docs/source/getting_started/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | FlexRAG is a Python package that can be installed via `pip` or from source. 3 | 4 | ```{eval-rst} 5 | .. important:: 6 | FlexRAG requires Python 3.11 or later. 7 | ``` 8 | 9 | ## Installation via `pip` 10 | To install FlexRAG via pip, run the following command: 11 | 12 | ```bash 13 | pip install flexrag 14 | ``` 15 | 16 | ## Installation from source 17 | Alternatively, to install FlexRAG from the source, follow the steps below: 18 | ```bash 19 | pip install pybind11 20 | 21 | git clone https://github.com/ictnlp/FlexRAG.git 22 | cd flexrag 23 | pip install ./ 24 | ``` 25 | You can also install the FlexRAG in *editable* mode with the `-e` flag. 26 | 27 | ## Installation flags 28 | FlexRAG can be installed with additional flags to enable specific features. The following flags are available: 29 | 30 | | Flag | pip install command | Description | 31 | | ---------- | ------------------------------- | --------------------------------------------------- | 32 | | scann | pip install flexrag[scann] | Install FlexRAG with the ScaNN index. | 33 | | annoy | pip install flexrag[annoy] | Install FlexRAG with the Annoy index. | 34 | | llamacpp | pip install flexrag[llamacpp] | Install FlexRAG with the LlamaCpp Generator. | 35 | | minference | pip install flexrag[minference] | Install FlexRAG with the Minference. | 36 | | web | pip install flexrag[web] | Install FlexRAG with the Web Retrievers. | 37 | | docs | pip install flexrag[docs] | Install FlexRAG with the Document Parser. | 38 | | all | pip install flexrag[all] | Install FlexRAG with most features. | 39 | | dev | pip install flexrag[dev] | Install FlexRAG with the libraries for development. | 40 | -------------------------------------------------------------------------------- /docs/source/getting_started/quickstart2.md: -------------------------------------------------------------------------------- 1 | # Quickstart: Building your own RAG application 2 | Besides using RAG assistant, you can also import FlexRAG as a library to develop your own RAG applications. FlexRAG provides a flexible and modular API that allows you to customize your RAG application with ease. For example, you can use the following code to build a simple command line RAG QA system: 3 | 4 | ```python 5 | from flexrag.models import OpenAIGenerator, OpenAIGeneratorConfig 6 | from flexrag.retriever import LocalRetriever 7 | 8 | 9 | def main(): 10 | # load the retriever 11 | retriever = LocalRetriever.load_from_hub("FlexRAG/wiki2021_atlas_bm25s") 12 | 13 | # load the generator 14 | generator = OpenAIGenerator( 15 | OpenAIGeneratorConfig( 16 | model_name="Qwen2-7B-Instruct", 17 | base_url="http://10.28.0.148:8000/v1", 18 | ) 19 | ) 20 | 21 | # build a QA loop 22 | while True: 23 | query = input("Please input your question (type /bye to quit): ") 24 | if query == "/bye": 25 | break 26 | # retrieve the contexts 27 | contexts = retriever.search(query, top_k=3)[0] 28 | # construct the prompt 29 | user_prompt = ( 30 | "Please answer the following question based on the given contexts.\n" 31 | f"Question: {query}\n" 32 | ) 33 | for i, ctx in enumerate(contexts): 34 | user_prompt += f"Context {i+1}: {ctx.data['text']}\n" 35 | # generate the response 36 | response = generator.chat([{"role": "user", "content": user_prompt}])[0][0] 37 | print(response) 38 | 39 | return 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | ``` -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | | 2 | 3 | .. image:: ../../assets/flexrag-wide.png 4 | :alt: FlexRAG 5 | :align: center 6 | 7 | | 8 | 9 | Welecome to FlexRAG Documentation 10 | ================================= 11 | 12 | FlexRAG is a highly reproducible, easy-to-use, and high-performance RAG framework designed for both research and application scenarios. It supports **text**, **multimodal**, and **web-based** RAG, providing a **complete RAG pipeline and evaluation process**. With built-in **asynchronous** processing and **persistent caching**, it ensures efficiency and scalability. Easily load retrievers from Hugging Face and quickly build powerful RAG solutions out of the box. 13 | 14 | .. note:: 15 | FlexRAG is under active development and is currently in the **alpha** stage. We welcome contributions from the community and are open to feedback and suggestions. 16 | 17 | .. Check out the :doc:`tutorial` section for further information, including how to 18 | .. :ref:`install ` the project. 19 | 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | :caption: Getting Started: 24 | 25 | getting_started/installation 26 | getting_started/quickstart1 27 | getting_started/quickstart2 28 | 29 | .. toctree:: 30 | :maxdepth: 1 31 | :caption: Tutorial: 32 | 33 | tutorial/preparing_corpus 34 | tutorial/preparing_retriever 35 | tutorial/building_assistant 36 | tutorial/entrypoints 37 | tutorial/using_register 38 | tutorial/preparing_web_retriever 39 | 40 | .. toctree:: 41 | :maxdepth: 1 42 | :caption: API Reference Manual: 43 | 44 | reference/assistant 45 | reference/chunking 46 | reference/common_dataclass 47 | reference/refiner 48 | reference/datasets 49 | reference/document_parser 50 | reference/encoders 51 | reference/generators 52 | reference/metrics 53 | reference/prompt 54 | reference/retrievers 55 | reference/rankers 56 | reference/tokenizers 57 | reference/text_process 58 | reference/utils 59 | -------------------------------------------------------------------------------- /docs/source/locales/zh_CN/LC_MESSAGES/getting_started/quickstart2.po: -------------------------------------------------------------------------------- 1 | # SOME DESCRIPTIVE TITLE. 2 | # Copyright (C) 2025, ZhuochengZhang 3 | # This file is distributed under the same license as the FlexRAG 4 | # Documentation package. 5 | # FIRST AUTHOR , 2025. 6 | # 7 | #, fuzzy 8 | msgid "" 9 | msgstr "" 10 | "Project-Id-Version: FlexRAG Documentation \n" 11 | "Report-Msgid-Bugs-To: \n" 12 | "POT-Creation-Date: 2025-03-21 20:53+0800\n" 13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" 14 | "Last-Translator: FULL NAME \n" 15 | "Language: zh_CN\n" 16 | "Language-Team: zh_CN \n" 17 | "Plural-Forms: nplurals=1; plural=0;\n" 18 | "MIME-Version: 1.0\n" 19 | "Content-Type: text/plain; charset=utf-8\n" 20 | "Content-Transfer-Encoding: 8bit\n" 21 | "Generated-By: Babel 2.16.0\n" 22 | 23 | #: ../../source/getting_started/quickstart2.md:1 24 | #: 24f56c005d924c62801a30eefe48f5d7 25 | msgid "Quickstart: Building your own RAG application" 26 | msgstr "快速入门:构建您自己的 RAG 应用" 27 | 28 | #: ../../source/getting_started/quickstart2.md:2 29 | #: e3b83077d47f4da4bcee4584d9354d13 30 | msgid "" 31 | "Besides using RAG assistant, you can also import FlexRAG as a library to " 32 | "develop your own RAG applications. FlexRAG provides a flexible and " 33 | "modular API that allows you to customize your RAG application with ease. " 34 | "For example, you can use the following code to build a simple command " 35 | "line RAG QA system:" 36 | msgstr "" 37 | "除去使用 RAG 助手这一概念,您也可以将 FlexRAG 作为库来开发您的 RAG 应用。" 38 | "FlexRAG 提供了一个灵活且模块化的 API 来帮助您构建 RAG 应用。" 39 | "下面的代码向您展示了如何构建一个简单的命令行检索式 QA 系统:" 40 | 41 | -------------------------------------------------------------------------------- /docs/source/locales/zh_CN/LC_MESSAGES/index.po: -------------------------------------------------------------------------------- 1 | # SOME DESCRIPTIVE TITLE. 2 | # Copyright (C) 2025, ZhuochengZhang 3 | # This file is distributed under the same license as the FlexRAG 4 | # Documentation package. 5 | # FIRST AUTHOR , 2025. 6 | # 7 | #, fuzzy 8 | msgid "" 9 | msgstr "" 10 | "Project-Id-Version: FlexRAG Documentation \n" 11 | "Report-Msgid-Bugs-To: \n" 12 | "POT-Creation-Date: 2025-03-25 10:28+0800\n" 13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" 14 | "Last-Translator: FULL NAME \n" 15 | "Language: zh_CN\n" 16 | "Language-Team: zh_CN \n" 17 | "Plural-Forms: nplurals=1; plural=0;\n" 18 | "MIME-Version: 1.0\n" 19 | "Content-Type: text/plain; charset=utf-8\n" 20 | "Content-Transfer-Encoding: 8bit\n" 21 | "Generated-By: Babel 2.16.0\n" 22 | 23 | #: ../../source/index.rst:21 24 | msgid "Getting Started:" 25 | msgstr "入门:" 26 | 27 | #: ../../source/index.rst:29 28 | msgid "Tutorial:" 29 | msgstr "教程:" 30 | 31 | #: ../../source/index.rst:39 32 | msgid "API Reference Manual:" 33 | msgstr "API参考文档:" 34 | 35 | #: ../../source/index.rst:3 036d9484788041d39f31231b7144b122 36 | msgid "FlexRAG" 37 | msgstr "FlexRAG" 38 | 39 | #: ../../source/index.rst:10 1fd4801699a34995aa1755e9f089e498 40 | msgid "Welecome to FlexRAG Documentation" 41 | msgstr "欢迎访问 FlexRAG 文档" 42 | 43 | #: ../../source/index.rst:12 cd90c2b11e6f411fafdc5d885e302d25 44 | msgid "" 45 | "FlexRAG is a highly reproducible, easy-to-use, and high-performance RAG " 46 | "framework designed for both research and application scenarios. It " 47 | "supports **text**, **multimodal**, and **web-based** RAG, providing a " 48 | "**complete RAG pipeline and evaluation process**. With built-in " 49 | "**asynchronous** processing and **persistent caching**, it ensures " 50 | "efficiency and scalability. Easily load retrievers from Hugging Face and " 51 | "quickly build powerful RAG solutions out of the box." 52 | msgstr "" 53 | "FlexRAG 是一个高可复现、易上手、性能优越的 RAG 框架,专为科研与应用场景设计。它支持 **文本、多模态以及网络 " 54 | "RAG** ,提供 **完整的 RAG 流水线与评估流程** ,开箱即用,同时具备高效的 **异步处理与持久化缓存** 能力。通过 Hugging Face " 55 | "轻松加载 Retriever,助力快速搭建强大的 RAG 解决方案。" 56 | 57 | #: ../../source/index.rst:15 3152a79ea0b34b70b433632a7280d1b1 58 | msgid "" 59 | "FlexRAG is under active development and is currently in the **alpha** " 60 | "stage. We welcome contributions from the community and are open to " 61 | "feedback and suggestions." 62 | msgstr "FlexRAG 目前正处于活跃开发的 **alpha** 阶段。我们非常欢迎来自于社区的贡献、反馈或建议。" 63 | 64 | #~ msgid "" 65 | #~ "FlexRAG is a flexible and high-" 66 | #~ "performance framework designed for " 67 | #~ "Retrieval-Augmented Generation (RAG) tasks, " 68 | #~ "offering support for multimodal data, " 69 | #~ "seamless configuration management, and out-" 70 | #~ "of-the-box performance for both " 71 | #~ "research and prototyping." 72 | #~ msgstr "FlexRAG 是一个灵活的高性能检索增强生成 (RAG) 框架,支持多模态数据以及无缝的配置管理,为研究和原型系统设计提供开箱即用的支持。" 73 | 74 | -------------------------------------------------------------------------------- /docs/source/locales/zh_CN/LC_MESSAGES/reference/common_dataclass.po: -------------------------------------------------------------------------------- 1 | # SOME DESCRIPTIVE TITLE. 2 | # Copyright (C) 2025, ZhuochengZhang 3 | # This file is distributed under the same license as the FlexRAG 4 | # Documentation package. 5 | # FIRST AUTHOR , 2025. 6 | # 7 | #, fuzzy 8 | msgid "" 9 | msgstr "" 10 | "Project-Id-Version: FlexRAG Documentation \n" 11 | "Report-Msgid-Bugs-To: \n" 12 | "POT-Creation-Date: 2025-03-19 16:54+0800\n" 13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" 14 | "Last-Translator: FULL NAME \n" 15 | "Language: zh_CN\n" 16 | "Language-Team: zh_CN \n" 17 | "Plural-Forms: nplurals=1; plural=0;\n" 18 | "MIME-Version: 1.0\n" 19 | "Content-Type: text/plain; charset=utf-8\n" 20 | "Content-Transfer-Encoding: 8bit\n" 21 | "Generated-By: Babel 2.16.0\n" 22 | 23 | #: ../../source/reference/common_dataclass.rst:2 24 | #: a19baf5fcc2c4dd6ab490be232845b4a 25 | msgid "Common Dataclass" 26 | msgstr "" 27 | 28 | #: ../../source/reference/common_dataclass.rst:3 29 | #: 1c2ff1e4446e4df8934f948945708ba4 30 | msgid "" 31 | "This module provides several pre-defined dataclasses that are commonly " 32 | "used in the project." 33 | msgstr "" 34 | 35 | #: 6f7d2329532f44668bc7396446318953 bb1caf9a3761461ca518abb0f868280c 36 | #: flexrag.common_dataclass.Context:1 37 | #: flexrag.common_dataclass.RetrievedContext:1 of 38 | msgid "The dataclass for retrieved context." 39 | msgstr "" 40 | 41 | #: ../../source/reference/common_dataclass.rst 3a6ac042dba747e79cf02646da47ed08 42 | #: 451a02ea6b084d1dbefe5b42ce700e75 a7299aa836b54388a9c3b5bf05a10e7e 43 | #: d06c09ba34ae4c8f97186e56162c3cfd 44 | msgid "Parameters" 45 | msgstr "" 46 | 47 | #: c8ba7a884918444abca911a60c350250 flexrag.common_dataclass.Context:3 of 48 | msgid "The unique identifier of the context. Default: None." 49 | msgstr "" 50 | 51 | #: 0f42874e10a549d1a8c7e53515818e70 flexrag.common_dataclass.Context:5 of 52 | msgid "The context data. Default: {}." 53 | msgstr "" 54 | 55 | #: c75ed32b080545568c648a9a94a76f22 flexrag.common_dataclass.Context:7 of 56 | msgid "The source of the retrieved data. Default: None." 57 | msgstr "" 58 | 59 | #: 3098971d8305433ab2eae75a0481dd28 flexrag.common_dataclass.Context:9 of 60 | msgid "The metadata of the context. Default: {}." 61 | msgstr "" 62 | 63 | #: a25525af6c5547eea661cf145a3d3e7b flexrag.common_dataclass.RetrievedContext:3 64 | #: of 65 | msgid "The name of the retriever. Required." 66 | msgstr "" 67 | 68 | #: e6b598a94438440c8ea864c936636baf flexrag.common_dataclass.RetrievedContext:5 69 | #: of 70 | msgid "The query for retrieval. Required." 71 | msgstr "" 72 | 73 | #: f027bec0202942e895b6849aed67ecd3 flexrag.common_dataclass.RetrievedContext:7 74 | #: of 75 | msgid "The relevance score of the retrieved data. Default: 0.0." 76 | msgstr "" 77 | 78 | #: 6e64ac7b12d64287878456cf58eb4802 flexrag.common_dataclass.RAGEvalData:1 of 79 | msgid "The dataclass for RAG evaluation data." 80 | msgstr "" 81 | 82 | #: 45774b4d2e7747b48630e143844e4ee8 bc41b1dfd24148a9aeafbe7f9323332c 83 | #: flexrag.common_dataclass.IREvalData:3 flexrag.common_dataclass.RAGEvalData:3 84 | #: of 85 | msgid "The question for evaluation. Required." 86 | msgstr "" 87 | 88 | #: 5ed64fae32744097b1a22a1b48b42283 61a97b8a7bb343079faf21e67c2597c2 89 | #: flexrag.common_dataclass.IREvalData:5 flexrag.common_dataclass.RAGEvalData:5 90 | #: of 91 | msgid "The contexts related to the question. Default: None." 92 | msgstr "" 93 | 94 | #: 253237efa5e14e83838d4dd43f4524eb flexrag.common_dataclass.RAGEvalData:7 of 95 | msgid "The golden answers for the question. Default: None." 96 | msgstr "" 97 | 98 | #: 74f3665080794970b5d8e6431f9db399 f96337eee3a84a29ae2511771a095513 99 | #: flexrag.common_dataclass.IREvalData:7 flexrag.common_dataclass.RAGEvalData:9 100 | #: of 101 | msgid "The metadata of the evaluation data. Default: {}." 102 | msgstr "" 103 | 104 | #: 8665df545a604c22a992a078e3a45e94 flexrag.common_dataclass.IREvalData:1 of 105 | msgid "The dataclass for Information Retrieval evaluation data." 106 | msgstr "" 107 | 108 | -------------------------------------------------------------------------------- /docs/source/locales/zh_CN/LC_MESSAGES/reference/document_parser.po: -------------------------------------------------------------------------------- 1 | # SOME DESCRIPTIVE TITLE. 2 | # Copyright (C) 2025, ZhuochengZhang 3 | # This file is distributed under the same license as the FlexRAG 4 | # Documentation package. 5 | # FIRST AUTHOR , 2025. 6 | # 7 | #, fuzzy 8 | msgid "" 9 | msgstr "" 10 | "Project-Id-Version: FlexRAG Documentation \n" 11 | "Report-Msgid-Bugs-To: \n" 12 | "POT-Creation-Date: 2025-03-19 16:54+0800\n" 13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" 14 | "Last-Translator: FULL NAME \n" 15 | "Language: zh_CN\n" 16 | "Language-Team: zh_CN \n" 17 | "Plural-Forms: nplurals=1; plural=0;\n" 18 | "MIME-Version: 1.0\n" 19 | "Content-Type: text/plain; charset=utf-8\n" 20 | "Content-Transfer-Encoding: 8bit\n" 21 | "Generated-By: Babel 2.16.0\n" 22 | 23 | #: ../../source/reference/document_parser.rst:2 24 | #: f5192c1f7fdb435c93337dc462ddfea2 25 | msgid "Document Parser" 26 | msgstr "" 27 | 28 | #: ../../source/reference/document_parser.rst:3 29 | #: a800ad80627a4cc197c8f82929ee88a3 30 | msgid "" 31 | "This module provides a set of classes and functions for parsing a " 32 | "formated document (such as PDF, Word, etc.) into a structured format." 33 | msgstr "" 34 | 35 | #: 9054257c79d94d72b8201d4dd3b6aed3 36 | #: flexrag.document_parser.document_parser_base.Document:1 of 37 | msgid "A document parsed by a DocumentParser." 38 | msgstr "" 39 | 40 | #: 48bb0cf6c9fb48a89550231171ea40bb 41 | #: flexrag.document_parser.document_parser_base.DocumentParserBase.parse:1 of 42 | msgid "Parse the document at the given path." 43 | msgstr "" 44 | 45 | #: ../../source/reference/document_parser.rst 054f7ec99e0c4dbc94b396574d62f3ee 46 | msgid "Parameters" 47 | msgstr "" 48 | 49 | #: 3fb9e3c019b84636b78dbf15fa78adcb 50 | #: flexrag.document_parser.document_parser_base.DocumentParserBase.parse:3 of 51 | msgid "The path to the document to parse." 52 | msgstr "" 53 | 54 | #: ../../source/reference/document_parser.rst 13ada7ac5c144e8e859b78bd0b28273a 55 | msgid "Returns" 56 | msgstr "" 57 | 58 | #: 6b907db4872841a6bfa98da0f6f6d29e 59 | #: flexrag.document_parser.document_parser_base.DocumentParserBase.parse:5 of 60 | msgid "The parsed document." 61 | msgstr "" 62 | 63 | #: ../../source/reference/document_parser.rst 91a749f7d3d44ff08e7cfee65cc270e3 64 | msgid "Return type" 65 | msgstr "" 66 | 67 | #: 88cffbd67e54483e930e18ef9389d0f4 d65e8fd3edb54480a8172de2bc088151 68 | #: flexrag.document_parser.docling_parser.DoclingParser:1 69 | #: flexrag.document_parser.markitdown_parser.MarkItDownParser:1 of 70 | msgid "" 71 | "Bases: " 72 | ":py:class:`~flexrag.document_parser.document_parser_base.DocumentParserBase`" 73 | msgstr "" 74 | 75 | -------------------------------------------------------------------------------- /docs/source/locales/zh_CN/LC_MESSAGES/reference/prompt.po: -------------------------------------------------------------------------------- 1 | # SOME DESCRIPTIVE TITLE. 2 | # Copyright (C) 2025, ZhuochengZhang 3 | # This file is distributed under the same license as the FlexRAG 4 | # Documentation package. 5 | # FIRST AUTHOR , 2025. 6 | # 7 | #, fuzzy 8 | msgid "" 9 | msgstr "" 10 | "Project-Id-Version: FlexRAG Documentation \n" 11 | "Report-Msgid-Bugs-To: \n" 12 | "POT-Creation-Date: 2025-03-19 16:54+0800\n" 13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" 14 | "Last-Translator: FULL NAME \n" 15 | "Language: zh_CN\n" 16 | "Language-Team: zh_CN \n" 17 | "Plural-Forms: nplurals=1; plural=0;\n" 18 | "MIME-Version: 1.0\n" 19 | "Content-Type: text/plain; charset=utf-8\n" 20 | "Content-Transfer-Encoding: 8bit\n" 21 | "Generated-By: Babel 2.16.0\n" 22 | 23 | #: ../../source/reference/prompt.rst:2 66bc65f470714a2ebb8ce73e6ecb1663 24 | msgid "Prompt" 25 | msgstr "" 26 | 27 | #: ../../source/reference/prompt.rst:3 b05bc7889b6c43ec853dab6ac39faaa4 28 | msgid "" 29 | "This module provides two classes namely `ChatPrompt` and `ChatTemplate`. " 30 | "The `ChatPrompt` is used to store the system prompt, chat history, and " 31 | "demonstrations used to interact with the `Generator`. The `ChatTemplate` " 32 | "is used to convert the `ChatPrompt` into a string or a list of tokens " 33 | "that can be used by the model." 34 | msgstr "" 35 | 36 | #: ../../source/reference/prompt.rst:6 d7ec38b189c844d5a8aaa543cc9d5f8d 37 | msgid "Chat Prompt" 38 | msgstr "" 39 | 40 | #: 2962809d7ac14f2c97c440883b21f7e7 41 | #: flexrag.prompt.prompt_base.MultiModelChatPrompt:1 of 42 | msgid "" 43 | "This class shares almost all the methods with ChatPrompt. However, the " 44 | "Generics in Python does not support calling the TypeVar's classmethod. So" 45 | " we have to duplicate the code here." 46 | msgstr "" 47 | 48 | #: ../../source/reference/prompt.rst:26 ffdd017ff3be4a70af01cec9c4c4f828 49 | msgid "Template" 50 | msgstr "" 51 | 52 | #: d8172e9b7fd64f8193b875942ef3e901 flexrag.prompt.template.HFTemplate:1 of 53 | msgid "Bases: :py:class:`~flexrag.prompt.template.ChatTemplate`" 54 | msgstr "" 55 | 56 | #: a83d04b2aa2c49579f0ada3e332eae0a flexrag.prompt.template.load_template:1 of 57 | msgid "" 58 | "Load ChatTemplate for different models. If model_name is not provided, " 59 | "the default template in the Tokenizer will be used." 60 | msgstr "" 61 | 62 | #: ../../source/reference/prompt.rst ce3e83f32e27441680ec2e04227fc204 63 | msgid "Parameters" 64 | msgstr "" 65 | 66 | #: ebcd3a792ddd4b6c8b3d4002fa1b81d0 flexrag.prompt.template.load_template:3 of 67 | msgid "The tokenizer used to encode the prompt." 68 | msgstr "" 69 | 70 | #: 8462d283d9a148fbba18869b2486d3ea flexrag.prompt.template.load_template:4 of 71 | msgid "The name of the model. Default is None." 72 | msgstr "" 73 | 74 | #: ../../source/reference/prompt.rst e763f1c2a17245aea63e6c059e72dcb9 75 | msgid "Returns" 76 | msgstr "" 77 | 78 | #: ebce22c7ffb54b5294101bc34713bc15 flexrag.prompt.template.load_template:7 of 79 | msgid "The loaded ChatTemplate" 80 | msgstr "" 81 | 82 | #: ../../source/reference/prompt.rst b225b7df4078422ca91df70a3c4b0fc6 83 | msgid "Return type" 84 | msgstr "" 85 | 86 | -------------------------------------------------------------------------------- /docs/source/locales/zh_CN/LC_MESSAGES/reference/text_process.po: -------------------------------------------------------------------------------- 1 | # SOME DESCRIPTIVE TITLE. 2 | # Copyright (C) 2025, ZhuochengZhang 3 | # This file is distributed under the same license as the FlexRAG 4 | # Documentation package. 5 | # FIRST AUTHOR , 2025. 6 | # 7 | #, fuzzy 8 | msgid "" 9 | msgstr "" 10 | "Project-Id-Version: FlexRAG Documentation \n" 11 | "Report-Msgid-Bugs-To: \n" 12 | "POT-Creation-Date: 2025-03-19 16:54+0800\n" 13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" 14 | "Last-Translator: FULL NAME \n" 15 | "Language: zh_CN\n" 16 | "Language-Team: zh_CN \n" 17 | "Plural-Forms: nplurals=1; plural=0;\n" 18 | "MIME-Version: 1.0\n" 19 | "Content-Type: text/plain; charset=utf-8\n" 20 | "Content-Transfer-Encoding: 8bit\n" 21 | "Generated-By: Babel 2.16.0\n" 22 | 23 | #: ../../source/reference/text_process.rst:2 bcafd7856e7f47549f7854330be2a356 24 | msgid "Text Processing" 25 | msgstr "" 26 | 27 | #: ../../source/reference/text_process.rst:3 6e070e6188914335a88aba4c0859e75a 28 | msgid "" 29 | "This module provides a set of classes and functions for preprocessing and" 30 | " filtering texts, including normalization, length filtering, etc." 31 | msgstr "" 32 | 33 | #: 917b95e3e8e944ac92ed02d7b0e4398d 34 | #: flexrag.text_process.processor.Processor.__call__:1 of 35 | msgid "" 36 | "Process the input text. If the processor has been filtered, the reserved " 37 | "flag of the input TextUnit will be set to False." 38 | msgstr "" 39 | 40 | #: ../../source/reference/text_process.rst c66af0aa305d4d899831600d8aae2abf 41 | #: cdc6caf16dfc4972add779030801799b 42 | msgid "Parameters" 43 | msgstr "" 44 | 45 | #: 941f41a569eb4a1990b87bfd5dc43b79 46 | #: flexrag.text_process.processor.Processor.__call__:4 of 47 | msgid "The input text to process." 48 | msgstr "" 49 | 50 | #: ../../source/reference/text_process.rst 89a47ae85377414db31cd68e160a560d 51 | msgid "Returns" 52 | msgstr "" 53 | 54 | #: 3bd6f2ef3122417894a12434795eb3fc 55 | #: flexrag.text_process.processor.Processor.__call__:6 of 56 | msgid "The processed text." 57 | msgstr "" 58 | 59 | #: ../../source/reference/text_process.rst 282fe9bb2d4d43f89d1e2121f3485155 60 | msgid "Return type" 61 | msgstr "" 62 | 63 | #: 9e8d7b2691e24a01bb225cbe77a1f7d1 of types.TextProcessPipelineConfig:1 64 | msgid "" 65 | "Configuration class for processor (name: TextProcessPipelineConfig, " 66 | "default: ???)." 67 | msgstr "" 68 | 69 | #: bb2876a2f865424b924c65545a9c6cc2 of types.TextProcessPipelineConfig:3 70 | msgid "The processor type to use." 71 | msgstr "" 72 | 73 | #: 8cfe929bef594bcdba3e43eea71b21df of types.TextProcessPipelineConfig:5 74 | msgid "The config for LengthFilter." 75 | msgstr "" 76 | 77 | #: 51df3dd1e06b4c9183e6c122db090a16 of types.TextProcessPipelineConfig:7 78 | msgid "The config for TokenNormalizer." 79 | msgstr "" 80 | 81 | #: 59d68c7cf96342119bbff9cb855ba3ed of types.TextProcessPipelineConfig:9 82 | msgid "The config for Truncator." 83 | msgstr "" 84 | 85 | #: 14b63f7efb154a029c233a74a1b7aedf 2c1227cc8ac74dbf8e00e1779f917223 86 | #: 3c75a8b2130941a3a996d9257f580ec3 4b23451799734d3c8259b93dc1fc695f 87 | #: 833fd71b2baa41cd89a4ff8bda7db581 89c0c8b6f0f044f49b519e08069613ba 88 | #: e4a8cc0a00ce498d81029cd80bd400c8 89 | #: flexrag.text_process.basic_filters.ExactDeduplicate:1 90 | #: flexrag.text_process.basic_processors.AnswerSimplifier:1 91 | #: flexrag.text_process.basic_processors.ChineseSimplifier:1 92 | #: flexrag.text_process.basic_processors.Lowercase:1 93 | #: flexrag.text_process.basic_processors.TokenNormalizer:1 94 | #: flexrag.text_process.basic_processors.Truncator:1 95 | #: flexrag.text_process.basic_processors.Unifier:1 of 96 | msgid "Bases: :py:class:`~flexrag.text_process.processor.Processor`" 97 | msgstr "" 98 | 99 | -------------------------------------------------------------------------------- /docs/source/reference/assistant.rst: -------------------------------------------------------------------------------- 1 | Assistant 2 | ========= 3 | 4 | The ``Assistant`` class serves as an abstraction for Retrieval-Augmented Generation (RAG) behavior. It takes the user's query as input and returns an appropriate response. This class provides a flexible interface for defining how the assistant handles queries, including whether a retrieval step is required, how the retrieval should be conducted, and how the assistant generates the response based on the retrieved information. 5 | 6 | The Assistant Interface 7 | ----------------------- 8 | ``AssistantBase`` is the base class for all assistants. It provides a simple interface for answering a user query. The answering process is controlled by a configuration object that is passed to the assistant's constructor. 9 | 10 | .. autoclass:: flexrag.assistant.AssistantBase 11 | :members: 12 | :inherited-members: 13 | 14 | FlexRAG Assistants 15 | ------------------ 16 | FlexRAG provides several assistant implementations that can be used out of the box. These implementations are designed to be flexible and extensible, allowing users to customize the assistant's behavior by providing their own retrieval and generation components. 17 | 18 | .. autoclass:: flexrag.assistant.BasicAssistantConfig 19 | :members: 20 | :show-inheritance: 21 | 22 | .. autoclass:: flexrag.assistant.BasicAssistant 23 | :members: 24 | :show-inheritance: 25 | :exclude-members: answer 26 | 27 | .. autoclass:: flexrag.assistant.ModularAssistantConfig 28 | :members: 29 | :show-inheritance: 30 | 31 | .. autoclass:: flexrag.assistant.ModularAssistant 32 | :members: 33 | :show-inheritance: 34 | :exclude-members: answer, search, answer_with_contexts 35 | 36 | .. autoclass:: flexrag.assistant.ChatQAAssistant 37 | :members: 38 | :show-inheritance: 39 | :exclude-members: answer, get_formatted_input, answer_with_contexts -------------------------------------------------------------------------------- /docs/source/reference/chunking.rst: -------------------------------------------------------------------------------- 1 | Chunking 2 | ======== 3 | This module provides a set of classes for chunking a long text into smaller chunks. 4 | 5 | 6 | The Chunker Interface 7 | --------------------- 8 | `ChunkerBase` is the base class for all chunkers. 9 | It provides a simple interface for chunking a text into smaller chunks. 10 | The chunking process is controlled by a configuration object that is passed to the chunker's constructor. 11 | 12 | .. autoclass:: flexrag.chunking.ChunkerBase 13 | :members: 14 | :inherited-members: 15 | 16 | 17 | Chunkers 18 | -------- 19 | 20 | .. autoclass:: flexrag.chunking.CharChunkerConfig 21 | :members: 22 | :inherited-members: 23 | 24 | .. autoclass:: flexrag.chunking.CharChunker 25 | :members: 26 | :show-inheritance: 27 | 28 | .. autoclass:: flexrag.chunking.TokenChunkerConfig 29 | :members: 30 | :inherited-members: 31 | :show-inheritance: 32 | 33 | .. autoclass:: flexrag.chunking.TokenChunker 34 | :members: 35 | :show-inheritance: 36 | 37 | .. autoclass:: flexrag.chunking.RecursiveChunkerConfig 38 | :members: 39 | :inherited-members: 40 | :show-inheritance: 41 | 42 | .. autoclass:: flexrag.chunking.RecursiveChunker 43 | :members: 44 | :show-inheritance: 45 | 46 | .. autoclass:: flexrag.chunking.SentenceChunkerConfig 47 | :members: 48 | :inherited-members: 49 | :show-inheritance: 50 | 51 | .. autoclass:: flexrag.chunking.SentenceChunker 52 | :members: 53 | :show-inheritance: 54 | 55 | .. autoclass:: flexrag.chunking.SemanticChunkerConfig 56 | :members: 57 | :inherited-members: 58 | :show-inheritance: 59 | 60 | .. autoclass:: flexrag.chunking.SemanticChunker 61 | :members: 62 | :show-inheritance: 63 | 64 | 65 | Sentence Splitters 66 | ------------------ 67 | This submodule provides a set of useful tools for splitting a text into sentences. 68 | 69 | .. autoclass:: flexrag.chunking.sentence_splitter.SentenceSplitterBase 70 | :members: 71 | :inherited-members: 72 | 73 | .. autoclass:: flexrag.chunking.sentence_splitter.NLTKSentenceSplitterConfig 74 | :members: 75 | :inherited-members: 76 | 77 | .. autoclass:: flexrag.chunking.sentence_splitter.NLTKSentenceSplitter 78 | :members: 79 | :show-inheritance: 80 | 81 | .. autoclass:: flexrag.chunking.sentence_splitter.RegexSplitterConfig 82 | :members: 83 | :inherited-members: 84 | 85 | .. autoclass:: flexrag.chunking.sentence_splitter.RegexSplitter 86 | :members: 87 | :show-inheritance: 88 | 89 | .. autoclass:: flexrag.chunking.sentence_splitter.SpacySentenceSplitterConfig 90 | :members: 91 | :inherited-members: 92 | 93 | .. autoclass:: flexrag.chunking.sentence_splitter.SpacySentenceSplitter 94 | :members: 95 | :show-inheritance: 96 | 97 | .. autoattribute:: flexrag.chunking.sentence_splitter.PREDEFINED_SPLIT_PATTERNS 98 | 99 | A dictionary of predefined sentence splitting patterns. 100 | The keys are the names of the patterns, and the values are the corresponding regular expressions. 101 | Currently, ``FlexRAG`` provides 2 sets of predefined patterns: "en" for English and "zh" for Chinese. 102 | Please refer to the source code for more details. 103 | 104 | General Configuration 105 | --------------------- 106 | The configuration provides a general interface for loading and configurate the chunker or the sentence splitter. 107 | 108 | .. autoclass:: flexrag.chunking.ChunkerConfig 109 | :members: 110 | :inherited-members: 111 | 112 | .. autoclass:: flexrag.chunking.sentence_splitter.SentenceSplitterConfig 113 | :members: 114 | :inherited-members: 115 | -------------------------------------------------------------------------------- /docs/source/reference/common_dataclass.rst: -------------------------------------------------------------------------------- 1 | Common Dataclass 2 | ================ 3 | This module provides several pre-defined dataclasses that are commonly used in the project. 4 | 5 | .. autoclass:: flexrag.common_dataclass.Context 6 | :members: 7 | :inherited-members: 8 | 9 | .. autoclass:: flexrag.common_dataclass.RetrievedContext 10 | :members: 11 | :inherited-members: 12 | 13 | .. autoclass:: flexrag.common_dataclass.RAGEvalData 14 | :members: 15 | :inherited-members: 16 | 17 | .. autoclass:: flexrag.common_dataclass.IREvalData 18 | :members: 19 | :inherited-members: 20 | -------------------------------------------------------------------------------- /docs/source/reference/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ======== 3 | This module provides a set of classes and functions for loading and processing datasets. 4 | 5 | .. BaseClasses 6 | .. autoclass:: flexrag.datasets.IterableDataset 7 | :members: 8 | :inherited-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: flexrag.datasets.MappingDataset 12 | :members: 13 | :inherited-members: 14 | :show-inheritance: 15 | 16 | .. autoclass:: flexrag.datasets.ChainDataset 17 | :members: 18 | :inherited-members: 19 | :show-inheritance: 20 | 21 | .. autoclass:: flexrag.datasets.ConcatDataset 22 | :members: 23 | :inherited-members: 24 | :show-inheritance: 25 | 26 | .. HuggingFace Datasets 27 | .. autoclass:: flexrag.datasets.HFDatasetConfig 28 | :members: 29 | :inherited-members: 30 | :show-inheritance: 31 | 32 | .. autoclass:: flexrag.datasets.HFDataset 33 | :members: 34 | :inherited-members: 35 | :show-inheritance: 36 | 37 | .. Line Delimited Dataset 38 | .. autoclass:: flexrag.datasets.LineDelimitedDatasetConfig 39 | :members: 40 | :inherited-members: 41 | :show-inheritance: 42 | 43 | .. autoclass:: flexrag.datasets.LineDelimitedDataset 44 | :members: 45 | :inherited-members: 46 | :show-inheritance: 47 | 48 | .. RAG Datasets 49 | .. _RAGEvalDatasetConfig: 50 | 51 | .. autoclass:: flexrag.datasets.RAGEvalDatasetConfig 52 | :members: 53 | :inherited-members: 54 | :show-inheritance: 55 | 56 | .. autoclass:: flexrag.datasets.RAGEvalDataset 57 | :members: 58 | :inherited-members: 59 | :show-inheritance: 60 | 61 | .. autoclass:: flexrag.datasets.RAGCorpusDatasetConfig 62 | :members: 63 | :inherited-members: 64 | :show-inheritance: 65 | 66 | .. autoclass:: flexrag.datasets.RAGCorpusDataset 67 | :members: 68 | :inherited-members: 69 | :show-inheritance: 70 | 71 | .. Retrieval Datasets 72 | .. autoclass:: flexrag.datasets.MTEBDatasetConfig 73 | :members: 74 | :inherited-members: 75 | :show-inheritance: 76 | 77 | .. autoclass:: flexrag.datasets.MTEBDataset 78 | :members: 79 | :inherited-members: 80 | :show-inheritance: 81 | -------------------------------------------------------------------------------- /docs/source/reference/document_parser.rst: -------------------------------------------------------------------------------- 1 | Document Parser 2 | =============== 3 | This module provides a set of classes and functions for parsing a formated document (such as PDF, Word, etc.) into a structured format. 4 | 5 | .. autoclass:: flexrag.document_parser.Document 6 | :members: 7 | :inherited-members: 8 | 9 | .. autoclass:: flexrag.document_parser.DocumentParserBase 10 | :members: 11 | :inherited-members: 12 | 13 | .. autoclass:: flexrag.document_parser.DoclingConfig 14 | :members: 15 | :inherited-members: 16 | 17 | .. autoclass:: flexrag.document_parser.DoclingParser 18 | :members: 19 | :show-inheritance: 20 | :exclude-members: parse 21 | 22 | .. autoclass:: flexrag.document_parser.MarkItDownParser 23 | :members: 24 | :show-inheritance: 25 | :exclude-members: parse -------------------------------------------------------------------------------- /docs/source/reference/encoders.rst: -------------------------------------------------------------------------------- 1 | Encoders 2 | ======== 3 | 4 | .. autoclass:: flexrag.models.EncoderBase 5 | :members: 6 | :inherited-members: 7 | 8 | .. autoclass:: flexrag.models.EncoderConfig 9 | :members: 10 | :inherited-members: 11 | 12 | 13 | Local Encoders 14 | -------------- 15 | 16 | .. Huggingface Encoders 17 | .. autoclass:: flexrag.models.HFEncoderConfig 18 | :members: 19 | :inherited-members: 20 | 21 | .. autoclass:: flexrag.models.HFEncoder 22 | :members: 23 | :show-inheritance: 24 | :exclude-members: async_encode, encode 25 | 26 | 27 | .. Huggingface Clip Encoders 28 | .. autoclass:: flexrag.models.HFClipEncoderConfig 29 | :members: 30 | :inherited-members: 31 | 32 | .. autoclass:: flexrag.models.HFClipEncoder 33 | :members: 34 | :show-inheritance: 35 | :exclude-members: async_encode, encode 36 | 37 | 38 | .. Ollama Encoders 39 | .. autoclass:: flexrag.models.OllamaEncoderConfig 40 | :members: 41 | :inherited-members: 42 | 43 | .. autoclass:: flexrag.models.OllamaEncoder 44 | :members: 45 | :show-inheritance: 46 | :exclude-members: async_encode, encode 47 | 48 | 49 | .. Sentence Transformers Encoders 50 | .. autoclass:: flexrag.models.SentenceTransformerEncoderConfig 51 | :members: 52 | :inherited-members: 53 | 54 | .. autoclass:: flexrag.models.SentenceTransformerEncoder 55 | :members: 56 | :show-inheritance: 57 | :exclude-members: async_encode, encode 58 | 59 | 60 | Oneline Encoders 61 | ---------------- 62 | 63 | .. Coherence Encoders 64 | .. autoclass:: flexrag.models.CohereEncoderConfig 65 | :members: 66 | :inherited-members: 67 | 68 | .. autoclass:: flexrag.models.CohereEncoder 69 | :members: 70 | :show-inheritance: 71 | :exclude-members: async_encode, encode 72 | 73 | 74 | .. JinaAI Encoders 75 | .. autoclass:: flexrag.models.JinaEncoderConfig 76 | :members: 77 | :inherited-members: 78 | 79 | .. autoclass:: flexrag.models.JinaEncoder 80 | :members: 81 | :show-inheritance: 82 | :exclude-members: async_encode, encode 83 | 84 | 85 | .. OpenAI Encoders 86 | .. autoclass:: flexrag.models.OpenAIEncoderConfig 87 | :members: 88 | :show-inheritance: 89 | :inherited-members: 90 | 91 | .. autoclass:: flexrag.models.OpenAIEncoder 92 | :members: 93 | :show-inheritance: 94 | :exclude-members: async_encode, encode -------------------------------------------------------------------------------- /docs/source/reference/generators.rst: -------------------------------------------------------------------------------- 1 | Generators 2 | ========== 3 | 4 | .. autoclass:: flexrag.models.GeneratorBase 5 | :members: 6 | :inherited-members: 7 | 8 | 9 | .. autoclass:: flexrag.models.GenerationConfig 10 | :members: 11 | :inherited-members: 12 | 13 | .. autoclass:: flexrag.models.GeneratorConfig 14 | :members: 15 | :inherited-members: 16 | 17 | 18 | Local Generators 19 | ---------------- 20 | 21 | .. Hugging Face Generators 22 | .. autoclass:: flexrag.models.HFModelConfig 23 | :members: 24 | :inherited-members: 25 | 26 | 27 | .. autoclass:: flexrag.models.HFGeneratorConfig 28 | :members: 29 | :show-inheritance: 30 | 31 | .. autoclass:: flexrag.models.HFGenerator 32 | :members: 33 | :show-inheritance: 34 | :exclude-members: async_chat, async_generate, chat, generate 35 | 36 | 37 | .. Llamacpp Generators 38 | .. autoclass:: flexrag.models.LlamacppGeneratorConfig 39 | :members: 40 | :inherited-members: 41 | 42 | .. autoclass:: flexrag.models.LlamacppGenerator 43 | :members: 44 | :show-inheritance: 45 | :exclude-members: async_chat, async_generate, chat, generate 46 | 47 | 48 | .. Ollama Generators 49 | .. autoclass:: flexrag.models.OllamaGeneratorConfig 50 | :members: 51 | :inherited-members: 52 | 53 | .. autoclass:: flexrag.models.OllamaGenerator 54 | :members: 55 | :show-inheritance: 56 | :exclude-members: async_chat, async_generate, chat, generate 57 | 58 | .. VLLM Generators 59 | .. autoclass:: flexrag.models.VLLMGeneratorConfig 60 | :members: 61 | :inherited-members: 62 | 63 | .. autoclass:: flexrag.models.VLLMGenerator 64 | :members: 65 | :show-inheritance: 66 | :exclude-members: async_chat, async_generate, chat, generate 67 | 68 | 69 | Online Generators 70 | ----------------- 71 | 72 | .. Anthropic Generators 73 | .. autoclass:: flexrag.models.AnthropicGeneratorConfig 74 | :members: 75 | :inherited-members: 76 | 77 | .. autoclass:: flexrag.models.AnthropicGenerator 78 | :members: 79 | :show-inheritance: 80 | :exclude-members: async_chat, async_generate, chat, generate 81 | 82 | .. OpenAI Generators 83 | .. autoclass:: flexrag.models.OpenAIConfig 84 | :members: 85 | :show-inheritance: 86 | :inherited-members: 87 | 88 | .. autoclass:: flexrag.models.OpenAIGeneratorConfig 89 | :members: 90 | :show-inheritance: 91 | 92 | .. autoclass:: flexrag.models.OpenAIGenerator 93 | :members: 94 | :show-inheritance: 95 | :exclude-members: async_chat, async_generate, chat, generate 96 | 97 | 98 | Visual Language Model Generators 99 | -------------------------------- 100 | 101 | .. autoclass:: flexrag.models.VLMGeneratorBase 102 | :members: 103 | :inherited-members: 104 | :show-inheritance: 105 | 106 | .. HF VLM Generators 107 | .. autoclass:: flexrag.models.HFVLMGeneratorConfig 108 | :members: 109 | :show-inheritance: 110 | :inherited-members: 111 | 112 | .. autoclass:: flexrag.models.HFVLMGenerator 113 | :members: 114 | :show-inheritance: 115 | :exclude-members: chat, generate -------------------------------------------------------------------------------- /docs/source/reference/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ======= 3 | 4 | This module contains functions for evaluating the performance of a RAG assistant or a retriever. 5 | 6 | .. autoclass:: flexrag.metrics.MetricsBase 7 | :members: 8 | :inherited-members: 9 | 10 | 11 | Helper Class 12 | ------------ 13 | The RAGEvaluator takes a list of metrics and evaluates the performance of a RAG assistant or a retriever. 14 | 15 | .. autoclass:: flexrag.metrics.EvaluatorConfig 16 | :members: 17 | :inherited-members: 18 | 19 | .. autoclass:: flexrag.metrics.Evaluator 20 | :members: 21 | :show-inheritance: 22 | 23 | 24 | RAG Generation Metrics 25 | ---------------------- 26 | 27 | .. autoclass:: flexrag.metrics.BLEUConfig 28 | :members: 29 | :inherited-members: 30 | 31 | .. autoclass:: flexrag.metrics.BLEU 32 | :members: 33 | :show-inheritance: 34 | :exclude-members: compute 35 | 36 | .. autoclass:: flexrag.metrics.Rouge 37 | :members: 38 | :show-inheritance: 39 | 40 | .. autoclass:: flexrag.metrics.chrFConfig 41 | :members: 42 | :inherited-members: 43 | 44 | .. autoclass:: flexrag.metrics.chrF 45 | :members: 46 | :show-inheritance: 47 | :exclude-members: compute 48 | 49 | .. autoclass:: flexrag.metrics.F1 50 | :members: 51 | :show-inheritance: 52 | 53 | .. autoclass:: flexrag.metrics.Accuracy 54 | :members: 55 | :show-inheritance: 56 | 57 | .. autoclass:: flexrag.metrics.ExactMatch 58 | :members: 59 | :show-inheritance: 60 | 61 | .. autoclass:: flexrag.metrics.Precision 62 | :members: 63 | :show-inheritance: 64 | 65 | .. autoclass:: flexrag.metrics.Recall 66 | :members: 67 | :show-inheritance: 68 | 69 | Information Retrieval Metrics 70 | ----------------------------- 71 | 72 | .. autoclass:: flexrag.metrics.SuccessRateConfig 73 | :members: 74 | :inherited-members: 75 | 76 | .. autoclass:: flexrag.metrics.SuccessRate 77 | :members: 78 | :show-inheritance: 79 | :exclude-members: compute 80 | 81 | .. autoclass:: flexrag.metrics.RetrievalRecallConfig 82 | :members: 83 | :inherited-members: 84 | 85 | .. autoclass:: flexrag.metrics.RetrievalRecall 86 | :members: 87 | :show-inheritance: 88 | :exclude-members: compute 89 | 90 | .. autoclass:: flexrag.metrics.RetrievalPrecisionConfig 91 | :members: 92 | :inherited-members: 93 | 94 | .. autoclass:: flexrag.metrics.RetrievalPrecision 95 | :members: 96 | :show-inheritance: 97 | :exclude-members: compute 98 | 99 | .. autoclass:: flexrag.metrics.RetrievalMAPConfig 100 | :members: 101 | :inherited-members: 102 | 103 | .. autoclass:: flexrag.metrics.RetrievalMAP 104 | :members: 105 | :show-inheritance: 106 | :exclude-members: compute 107 | 108 | .. autoclass:: flexrag.metrics.RetrievalNDCGConfig 109 | :members: 110 | :inherited-members: 111 | 112 | .. autoclass:: flexrag.metrics.RetrievalNDCG 113 | :members: 114 | :show-inheritance: 115 | :exclude-members: compute 116 | -------------------------------------------------------------------------------- /docs/source/reference/prompt.rst: -------------------------------------------------------------------------------- 1 | Prompt 2 | ====== 3 | This module provides two classes namely `ChatPrompt` and `ChatTemplate`. The `ChatPrompt` is used to store the system prompt, chat history, and demonstrations used to interact with the `Generator`. The `ChatTemplate` is used to convert the `ChatPrompt` into a string or a list of tokens that can be used by the model. 4 | 5 | Chat Prompt 6 | ----------- 7 | 8 | .. autoclass:: flexrag.prompt.ChatTurn 9 | :members: 10 | :inherited-members: 11 | 12 | .. autoclass:: flexrag.prompt.ChatPrompt 13 | :members: 14 | :inherited-members: 15 | 16 | .. autoclass:: flexrag.prompt.MultiModelChatTurn 17 | :members: 18 | :inherited-members: 19 | 20 | .. autoclass:: flexrag.prompt.MultiModelChatPrompt 21 | :members: 22 | :inherited-members: 23 | 24 | 25 | Template 26 | -------- 27 | 28 | .. autoclass:: flexrag.prompt.ChatTemplate 29 | :members: 30 | :inherited-members: 31 | 32 | .. autoclass:: flexrag.prompt.HFTemplate 33 | :members: 34 | :show-inheritance: 35 | 36 | .. autofunction:: flexrag.prompt.load_template 37 | -------------------------------------------------------------------------------- /docs/source/reference/rankers.rst: -------------------------------------------------------------------------------- 1 | Rankers 2 | ======= 3 | 4 | The ranker is the component that determines the order of the results returned by the retriever. FlexRAG provides several rankers that can be used to sort the results based on various criteria. 5 | 6 | .. autoclass:: flexrag.ranker.RankerBaseConfig 7 | :members: 8 | :inherited-members: 9 | 10 | .. autoclass:: flexrag.ranker.RankerBase 11 | :members: 12 | :inherited-members: 13 | 14 | .. autoclass:: flexrag.ranker.RankingResult 15 | :members: 16 | :inherited-members: 17 | 18 | .. autoclass:: flexrag.ranker.RankerConfig 19 | :members: 20 | :inherited-members: 21 | 22 | 23 | Local Ranker 24 | ------------ 25 | .. HF Cross Encoder Ranker 26 | .. autoclass:: flexrag.ranker.HFCrossEncoderRankerConfig 27 | :members: 28 | :inherited-members: 29 | :show-inheritance: 30 | 31 | .. autoclass:: flexrag.ranker.HFCrossEncoderRanker 32 | :members: 33 | :show-inheritance: 34 | 35 | 36 | .. HF Cross Seq2Seq Ranker 37 | .. autoclass:: flexrag.ranker.HFSeq2SeqRankerConfig 38 | :members: 39 | :inherited-members: 40 | :show-inheritance: 41 | 42 | .. autoclass:: flexrag.ranker.HFSeq2SeqRanker 43 | :members: 44 | :show-inheritance: 45 | 46 | 47 | .. HF Cross ColBERT Ranker 48 | .. autoclass:: flexrag.ranker.HFColBertRankerConfig 49 | :members: 50 | :inherited-members: 51 | :show-inheritance: 52 | 53 | .. autoclass:: flexrag.ranker.HFColBertRanker 54 | :members: 55 | :show-inheritance: 56 | 57 | 58 | .. RankGPT Ranker 59 | .. autoclass:: flexrag.ranker.RankGPTRankerConfig 60 | :members: 61 | :inherited-members: 62 | :show-inheritance: 63 | 64 | .. autoclass:: flexrag.ranker.RankGPTRanker 65 | :members: 66 | :show-inheritance: 67 | 68 | 69 | Oneline Ranker 70 | -------------- 71 | .. Cohere Ranker 72 | .. autoclass:: flexrag.ranker.CohereRankerConfig 73 | :members: 74 | :inherited-members: 75 | :show-inheritance: 76 | 77 | .. autoclass:: flexrag.ranker.CohereRanker 78 | :members: 79 | :show-inheritance: 80 | 81 | 82 | .. Jina Ranker 83 | .. autoclass:: flexrag.ranker.JinaRankerConfig 84 | :members: 85 | :inherited-members: 86 | :show-inheritance: 87 | 88 | .. autoclass:: flexrag.ranker.JinaRanker 89 | :members: 90 | :show-inheritance: 91 | 92 | 93 | .. Mixedbread Ranker 94 | .. autoclass:: flexrag.ranker.MixedbreadRankerConfig 95 | :members: 96 | :inherited-members: 97 | :show-inheritance: 98 | 99 | .. autoclass:: flexrag.ranker.MixedbreadRanker 100 | :members: 101 | :show-inheritance: 102 | 103 | 104 | .. Voyage Ranker 105 | .. autoclass:: flexrag.ranker.VoyageRankerConfig 106 | :members: 107 | :inherited-members: 108 | :show-inheritance: 109 | 110 | .. autoclass:: flexrag.ranker.VoyageRanker 111 | :members: 112 | :show-inheritance: 113 | -------------------------------------------------------------------------------- /docs/source/reference/refiner.rst: -------------------------------------------------------------------------------- 1 | Context Refiner 2 | =============== 3 | The context refiner is responsible for refining the contexts retrieved by the retriever. 4 | It can be used to rearrange the contexts, summarize them, or extract the most relevant information from them. 5 | 6 | The Context Refiner Interface 7 | ----------------------------- 8 | The `RefinerBase` is the base class for all refiners. 9 | It provides the basic interface for refining the contexts retrieved by the retriever. 10 | 11 | .. autoclass:: flexrag.context_refine.RefinerBase 12 | :members: 13 | :inherited-members: 14 | 15 | Refiners 16 | -------- 17 | FlexRAG provides several refiners that can be used to refine the contexts retrieved by the retriever. 18 | 19 | .. autoclass:: flexrag.context_refine.ContextArrangerConfig 20 | :members: 21 | :inherited-members: 22 | 23 | .. autoclass:: flexrag.context_refine.ContextArranger 24 | :members: 25 | :show-inheritance: 26 | 27 | .. autoclass:: flexrag.context_refine.AbstractiveSummarizerConfig 28 | :members: 29 | :inherited-members: 30 | 31 | .. autoclass:: flexrag.context_refine.AbstractiveSummarizer 32 | :members: 33 | :show-inheritance: 34 | 35 | .. autoclass:: flexrag.context_refine.RecompExtractiveSummarizerConfig 36 | :members: 37 | :inherited-members: 38 | 39 | .. autoclass:: flexrag.context_refine.RecompExtractiveSummarizer 40 | :members: 41 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/reference/text_process.rst: -------------------------------------------------------------------------------- 1 | Text Processing 2 | =============== 3 | This module provides a set of classes and functions for preprocessing and filtering texts, including normalization, length filtering, etc. 4 | 5 | .. autoclass:: flexrag.text_process.TextUnit 6 | :members: 7 | :inherited-members: 8 | 9 | .. autoclass:: flexrag.text_process.Processor 10 | :members: 11 | :inherited-members: 12 | :special-members: __call__ 13 | 14 | .. autoclass:: flexrag.text_process.TextProcessPipelineConfig 15 | :members: 16 | :inherited-members: 17 | 18 | .. autoclass:: flexrag.text_process.TextProcessPipeline 19 | :members: 20 | :inherited-members: 21 | 22 | .. autoclass:: flexrag.text_process.TokenNormalizerConfig 23 | :members: 24 | :inherited-members: 25 | 26 | .. autoclass:: flexrag.text_process.TokenNormalizer 27 | :members: 28 | :show-inheritance: 29 | 30 | .. autoclass:: flexrag.text_process.ChineseSimplifier 31 | :members: 32 | :show-inheritance: 33 | 34 | .. autoclass:: flexrag.text_process.Lowercase 35 | :members: 36 | :show-inheritance: 37 | 38 | .. autoclass:: flexrag.text_process.Unifier 39 | :members: 40 | :show-inheritance: 41 | 42 | .. autoclass:: flexrag.text_process.TruncatorConfig 43 | :members: 44 | :inherited-members: 45 | 46 | .. autoclass:: flexrag.text_process.Truncator 47 | :members: 48 | :show-inheritance: 49 | 50 | .. autoclass:: flexrag.text_process.AnswerSimplifier 51 | :members: 52 | :show-inheritance: 53 | 54 | .. autoclass:: flexrag.text_process.ExactDeduplicate 55 | :members: 56 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/reference/tokenizers.rst: -------------------------------------------------------------------------------- 1 | Tokenizer 2 | ========= 3 | This module is a simple wrapper around other tokenizers. 4 | It provides a simple and consistent interface for tokenizing a text into tokens (maybe string or int). 5 | 6 | The Tokenizer Interface 7 | ----------------------- 8 | ``TokenizerBase`` is the base class for all tokenizers. 9 | 10 | .. autoclass:: flexrag.models.tokenizer.TokenizerBase 11 | :members: 12 | :inherited-members: 13 | 14 | 15 | Tokenizers 16 | ---------- 17 | The wrapped tokenizers. 18 | 19 | .. autoclass:: flexrag.models.tokenizer.HuggingFaceTokenizerConfig 20 | :members: 21 | :inherited-members: 22 | 23 | .. autoclass:: flexrag.models.tokenizer.HuggingFaceTokenizer 24 | :members: 25 | :show-inheritance: 26 | 27 | .. autoclass:: flexrag.models.tokenizer.TikTokenTokenizerConfig 28 | :members: 29 | :inherited-members: 30 | 31 | .. autoclass:: flexrag.models.tokenizer.TikTokenTokenizer 32 | :members: 33 | :show-inheritance: 34 | 35 | .. autoclass:: flexrag.models.tokenizer.MosesTokenizerConfig 36 | :members: 37 | :inherited-members: 38 | 39 | .. autoclass:: flexrag.models.tokenizer.MosesTokenizer 40 | :members: 41 | :show-inheritance: 42 | 43 | .. autoclass:: flexrag.models.tokenizer.NLTKTokenizerConfig 44 | :members: 45 | :inherited-members: 46 | 47 | .. autoclass:: flexrag.models.tokenizer.NLTKTokenizer 48 | :members: 49 | :show-inheritance: 50 | 51 | .. autoclass:: flexrag.models.tokenizer.JiebaTokenizerConfig 52 | :members: 53 | :inherited-members: 54 | 55 | .. autoclass:: flexrag.models.tokenizer.JiebaTokenizer 56 | :members: 57 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/reference/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | This module contains useful functions that are used throughout the codebase. 4 | 5 | Cache 6 | ----- 7 | .. automodule:: flexrag.cache 8 | :members: 9 | 10 | Other Utils 11 | ----------- 12 | .. autoclass:: flexrag.utils.Register 13 | :members: -------------------------------------------------------------------------------- /docs/source/tutorial/building_assistant.md: -------------------------------------------------------------------------------- 1 | # Building your own RAG Assistant 2 | FlexRAG provides a flexible and modularized framework for building RAG assistants. You can build your own RAG assistant by defining your own `Assistant` class and registering it with the `ASSISTANTS` decorator. 3 | 4 | ## Define the Assistant Class 5 | To build your RAG assistant, you can create a Python script file and import the necessary FlexRAG modules. Below is an example of how to construct a RAG assistant. In this example, we define a RAG assistant named `SimpleAssistant` by inheriting from the `AssistantBase` class. This assistant includes a dense retriever (`DenseRetriever`) and a generator (`OpenAIGenerator`). Whenever a user asks a question, `SimpleAssistant` uses `DenseRetriever` to retrieve relevant documents from the database, then concatenates these documents into the prompt and utilizes `OpenAIGenerator` to generate the final response. 6 | 7 | ```python 8 | from dataclasses import dataclass 9 | 10 | from flexrag.assistant import ASSISTANTS, AssistantBase 11 | from flexrag.models import OpenAIGenerator, OpenAIGeneratorConfig 12 | from flexrag.prompt import ChatPrompt, ChatTurn 13 | from flexrag.retriever import DenseRetriever, DenseRetrieverConfig 14 | 15 | 16 | @dataclass 17 | class SimpleAssistantConfig(DenseRetrieverConfig, OpenAIGeneratorConfig): ... 18 | 19 | 20 | @ASSISTANTS("simple", config_class=SimpleAssistantConfig) 21 | class SimpleAssistant(AssistantBase): 22 | def __init__(self, config: SimpleAssistantConfig): 23 | self.retriever = DenseRetriever(config) 24 | self.generator = OpenAIGenerator(config) 25 | return 26 | 27 | def answer(self, question: str) -> str: 28 | prompt = ChatPrompt() 29 | context = self.retriever.search(question)[0] 30 | prompt_str = "Please answer the following question based on the given text.\n\n" 31 | prompt_str += f"Question: {question}\n\n" 32 | for n, ctx in enumerate(context): 33 | prompt_str += f"Context {n}: {ctx.data['text']}\n" 34 | prompt.update(ChatTurn(role="user", content=prompt_str)) 35 | response = self.generator.chat([prompt])[0][0] 36 | prompt.update(ChatTurn(role="assistant", content=response)) 37 | return response 38 | ``` 39 | 40 | 41 | ### Running your own RAG Application 42 | After defining the `SimpleAssistant` class and registering it with the `ASSISTANTS` decorator, you can evaluate your assistant using FlexRAG's entrypoints by adding the `user_module=` argument to the command. 43 | 44 | For example, you can evaluate your assistant on the *Natural Questions* dataset using the following command: 45 | 46 | ```bash 47 | DB_PATH= 48 | OPENAI_KEY= 49 | MODULE_PATH= 50 | 51 | python -m flexrag.entrypoints.run_assistant \ 52 | user_module=${MODULE_PATH} \ 53 | name=nq \ 54 | split=test \ 55 | assistant_type=simple \ 56 | simple_config.model_name='gpt-4o-mini' \ 57 | simple_config.api_key=${OPENAI_KEY} \ 58 | simple_config.database_path=${DB_PATH} \ 59 | simple_config.index_type=faiss \ 60 | simple_config.query_encoder_config.encoder_type=hf \ 61 | simple_config.query_encoder_config.hf_config.model_path='facebook/contriever-msmarco' \ 62 | simple_config.query_encoder_config.hf_config.device_id=[0] \ 63 | eval_config.metrics_type=[retrieval_success_rate,generation_f1,generation_em] \ 64 | eval_config.retrieval_success_rate_config.eval_field=text \ 65 | eval_config.response_preprocess.processor_type=[simplify_answer] \ 66 | log_interval=10 67 | ``` 68 | 69 | In [FlexRAG_Examples](https://github.com/ictnlp/FlexRAG_Examples) repository, we provide several detailed examples of how to build a RAG assistant. 70 | -------------------------------------------------------------------------------- /docs/source/tutorial/preparing_corpus.md: -------------------------------------------------------------------------------- 1 | # Preparing the Knowledge Base 2 | In the real world, various types of knowledge are typically stored in documents such as PDFs, Word files, and PPTs. However, this semi-structured data cannot be parsed by large language models (LLMs) and is not suitable for building a knowledge base. Therefore, we need to convert it into structured text data beforehand. In this tutorial, we will use a simple example to demonstrate how to convert a batch of PDF files into structured data. 3 | 4 | ```{tip} 5 | If you already have structured data, you can skip this tutorial. 6 | ``` 7 | 8 | ## Parse Files using FlexRAG's Command-Line Tool 9 | FlexRAG provides a command-line tool `prepare_corpus` to help users parse various files into structured data. In this tutorial, we will use a paper from Arxiv as an example to demonstrate how to parse a PDF file using the built-in command-line tool of FlexRAG. 10 | 11 | Run the following command to download a paper from Arxiv: 12 | 13 | ```bash 14 | wget https://arxiv.org/pdf/2502.18139.pdf 15 | ``` 16 | 17 | You can then run the following command to parse this paper into structured knowledge base data: 18 | 19 | ```bash 20 | python -m flexrag.entrypoints.prepare_corpus \ 21 | document_paths=[2502.18139.pdf] \ 22 | output_path=knowledge.jsonl \ 23 | document_parser_type=markitdown \ 24 | chunker_type=sentence_chunker \ 25 | sentence_chunker_config.max_tokens=512 \ 26 | sentence_chunker_config.tokenizer_type=tiktoken \ 27 | sentence_chunker_config.tiktoken_config.model_name='gpt-4o' 28 | ``` 29 | 30 | In this command, we specify the following parameters: 31 | - `document_paths`:a list of file paths to be parsed. Here we only parse one paper; 32 | - `output_path`:the output path of the parsed results. The path should end with `.jsonl`, `.csv`, or `.tsv`; 33 | - `document_parser_type`:the type of document parser. Here we use `markitdown`; 34 | - `chunker_type`:the type of text chunker. Here we use `sentence_chunker`; 35 | - `sentence_chunker_config.max_tokens`:the maximum length of the text chunker. Here we set it to 512; 36 | - `sentence_chunker_config.tokenizer_type`:the type of tokenizer used by the text chunker. Here we use `tiktoken`, which is provided by OpenAI; 37 | - `sentence_chunker_config.tiktoken_config.model_name`:the model name used by the tokenizer. Here we use `gpt-4o`. 38 | 39 | After executing the above command, you will see that the PDF file has been parsed into a JSONL file. As shown in the figure below, FlexRAG executed three steps in this process: 40 | 1. **Parsing**: parsing the file into structured data; 41 | 2. **Chunking**: chunking long text paragraphs in the structured data into short text paragraphs suitable for processing; 42 | 3. **Preprocessing**: preprocessing and filtering the chunked text paragraphs. 43 | 44 | ```{eval-rst} 45 | .. image:: ../../../assets/parse_files.png 46 | :alt: Parse File 47 | :align: center 48 | :width: 80% 49 | ``` 50 | 51 | ```{tip} 52 | You can check the [FlexRAG Entrypoints](./entrypoints.md) documentation for more information about the `prepare_corpus` command. 53 | ``` 54 | 55 | ```{tip} 56 | You can check the [Preparing the Retriever](./preparing_retriever.md) documentation for how to build a retriever for your knowledge base. 57 | ``` 58 | -------------------------------------------------------------------------------- /docs/source/tutorial/using_register.md: -------------------------------------------------------------------------------- 1 | # Using Registers 2 | The `Register` class is an important component in the FlexRAG that integrates configuration files and loads various RAG components. The registrar can gather multiple components of the same type and generate a unified configuration structure to help you configure and use these components. This tutorial will show you how to use the registrar in FlexRAG. 3 | 4 | ## Using FlexRAG Registers 5 | FlexRAG provides a set of predefined registers for different components. These registers can be used to register and retrieve components of the respective type. The following registers are available in FlexRAG: 6 | 7 | - ASSISTANTS 8 | - REFINERS 9 | - CHUNKERS 10 | - DOCUMENTPARSERS 11 | - PROCESSORS 12 | - METRICS 13 | - GENERATORS 14 | - ENCODERS 15 | - RANKERS 16 | - DENSE_INDEX 17 | - RETRIEVERS 18 | - WEB_DOWNLOADERS 19 | - WEB_READERS 20 | 21 | ```{note} 22 | If you wish to develop your project by modifying the FlexRAG source code, all registrars can be used as decorators to register new components. However, if you use the `run_assistant` or `run_interactive` entrypoints of FlexRAG, **only** the `ASSISTANTS` registrar can be used to register new components. 23 | ``` 24 | 25 | ### Registering a New Component 26 | To register a new component, simply decorate the component class with the corresponding register. For example, to register a new `Assistant` component, you can use the `ASSISTANTS` register as shown below: 27 | 28 | ```python 29 | from dataclasses import dataclass 30 | from flexrag.assistant import AssistantBase, ASSISTANTS 31 | 32 | @dataclass 33 | class MyAssistantConfig: 34 | # Define your assistant configuration here 35 | pass 36 | 37 | @ASSISTANTS("my_assistant", config_class=MyAssistantConfig) 38 | class MyAssistant(AssistantBase): 39 | # Define your assistant here 40 | def answer(self, question: str) -> str: 41 | return "MyAssistant: " + question 42 | ``` 43 | 44 | The register takes the following arguments, namely `shortnames` and `config_class`. 45 | - The `shortnames` argument is a list of shortnames of the component, which serve as simplified names for the component, making it easier to reference when loading. 46 | - The `config_class` argument is the configuration class for the component. This parameter is optional—if not provided, the component will not use any configuration. 47 | 48 | ### Generating the Configuration 49 | After registering the component, you can generate the configuration `dataclass` for all the registered components using the `make_config` function. For example, to generate the configuration for all the registered `Assistant` components, you can use the `make_config` function as shown below: 50 | 51 | ```python 52 | AssistantConfig = ASSISTANTS.make_config() 53 | ``` 54 | 55 | The generated `AssistantConfig` class will have the following structure: 56 | 57 | ```python 58 | from dataclasses import dataclass 59 | 60 | @dataclass 61 | class AssistantConfig: 62 | # The shortname of the assistant 63 | assistant_type: str 64 | # The name of the configuration is the first shortname + "_config" 65 | my_assistant_config: MyAssistantConfig 66 | modular_config: ModularAssistantConfig 67 | # Other registered assistant configurations 68 | ... 69 | ``` 70 | 71 | ```{tip} 72 | In the FlexRAG entrypoints, many configurations are generated in this way. This allows us to flexibly modify the components and their configurations in the workflow through configuration files. 73 | ``` 74 | 75 | ### Loading the Component 76 | To load the component using the configuration, you can use the `load` function of the register. For example, to load the `MyAssistant` component using the configuration, you can use the `load` function as shown below: 77 | 78 | ```python 79 | AssistantConfig.assistant_type = "my_assistant" 80 | my_assistant = ASSISTANTS.load(AssistantConfig) 81 | ``` 82 | 83 | ## Defining a New Register 84 | The `Register` class can be extended to define a new register for a specific component. For example, to define a new register for a `Searcher` component, you can simply create a new instance of the `Register` class as shown below: 85 | 86 | ```python 87 | from flexrag.utils import Register 88 | 89 | SEARCHERS = Register("searcher") 90 | ``` 91 | 92 | ### Utilizing Type Hints 93 | As the `Register` class is a generic class, you can utilize type hints to specify the type of the component that the register is managing. For example, to define a register for a `Searcher` component, you can specify the type hint as follows: 94 | 95 | ```python 96 | from abc import ABC 97 | from flexrag.utils import Register 98 | 99 | class Searcher(ABC): 100 | pass 101 | 102 | SEARCHERS = Register[Searcher]("searcher") 103 | ``` 104 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "wheel", "pybind11>=2.13"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_level = DEBUG 3 | asyncio_default_fixture_loop_scope = function -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # basic 2 | numpy<2.0.0 3 | tenacity 4 | hydra-core>=1.3 5 | omegaconf>=2.3.0 6 | pillow 7 | accelerate>=0.26.0 8 | colorama 9 | # metrics 10 | rouge 11 | sacrebleu>=2.4.2 12 | pytrec_eval>=0.5 13 | # datasets 14 | datasets>=3.2.0 15 | # models 16 | openai>=1.30.1 17 | anthropic 18 | cohere 19 | ollama 20 | vllm>=0.6.0 21 | sentence_transformers 22 | transformers>=4.44.0 23 | mixedbread-ai 24 | voyageai 25 | # cache 26 | lmdb 27 | cloudpickle 28 | # processors 29 | unidecode 30 | sacremoses 31 | opencc 32 | # retrievers 33 | pandas 34 | pylance 35 | bm25s 36 | elasticsearch>=8.14.0 37 | torch>=2.3.0 38 | beautifulsoup4 39 | typesense 40 | httpx 41 | scipy 42 | # gui 43 | gradio>=5.8.0 44 | # chunking 45 | regex 46 | nltk -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import logging 4 | 5 | import pybind11 6 | from setuptools import Extension, find_packages, setup 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | 11 | ext_modules = [ 12 | Extension( 13 | "flexrag.metrics.lib_rel", 14 | ["src/flexrag/metrics/lib_rel.cpp"], 15 | include_dirs=[pybind11.get_include()], 16 | language="c++", 17 | extra_compile_args=["-O3"], 18 | ), 19 | ] 20 | 21 | 22 | def get_requirements() -> list[str]: 23 | with open("requirements.txt", encoding="utf-8") as f: 24 | file_content = f.read() 25 | requirements = [ 26 | line.strip() 27 | for line in file_content.strip().split("\n") 28 | if not line.startswith("#") 29 | ] 30 | # as faiss may be installed using conda, we need to remove it from the requirements 31 | try: 32 | import faiss 33 | 34 | logging.info(f"Detected installed faiss: faiss {faiss.__version__}") 35 | except ImportError: 36 | requirements.append("faiss-cpu") 37 | return requirements 38 | 39 | 40 | def get_version() -> str: 41 | with open(os.path.join("src", "flexrag", "utils.py"), encoding="utf-8") as f: 42 | file_content = f.read() 43 | pattern = r"{}\W*=\W*\"([^\"]+)\"".format("__VERSION__") 44 | (version,) = re.findall(pattern, file_content) 45 | return version 46 | 47 | 48 | def get_long_description() -> str: 49 | with open("README.md", encoding="utf-8") as f: 50 | return f.read() 51 | 52 | 53 | setup( 54 | name="flexrag", 55 | version=get_version(), 56 | author="Zhuocheng Zhang", 57 | author_email="zhuocheng_zhang@outlook.com", 58 | description="A RAG Framework for Information Retrieval and Generation.", 59 | url="https://github.com/ictnlp/flexrag", 60 | license="MIT License", 61 | long_description=get_long_description(), 62 | long_description_content_type="text/markdown", 63 | packages=find_packages(where="src"), 64 | package_dir={"": "src"}, 65 | package_data={ 66 | "flexrag": [ 67 | "ranker/ranker_prompts/*.json", 68 | "assistant/assistant_prompts/*.json", 69 | "entrypoints/assets/*.png", 70 | ], 71 | }, 72 | include_package_data=True, 73 | python_requires=">=3.11", 74 | install_requires=get_requirements(), 75 | extras_require={ 76 | "scann": ["scann>=1.3.2"], 77 | "annoy": ["annoy>1.17.0"], 78 | "llamacpp": ["llama_cpp_python>=0.2.84"], 79 | "minference": ["minference>=0.1.5"], 80 | "web": ["duckduckgo_search", "serpapi", "pytest-playwright"], 81 | "docs": ["docling", "markitdown"], 82 | "all": [ 83 | "llama_cpp_python>=0.2.84", 84 | "minference>=0.1.5", 85 | "PySocks>=1.7.1", 86 | "duckduckgo_search", 87 | "serpapi", 88 | "docling", 89 | "markitdown", 90 | "annoy>1.17.0", 91 | ], 92 | "dev": [ 93 | "black", 94 | "pytest", 95 | "pytest-asyncio", 96 | "sphinx", 97 | "sphinx-autobuild", 98 | "myst-parser", 99 | ], 100 | }, 101 | classifiers=[ 102 | "Development Status :: 3 - Alpha", 103 | "Intended Audience :: Developers", 104 | "License :: OSI Approved :: MIT License", 105 | "Operating System :: OS Independent", 106 | "Programming Language :: Python :: 3", 107 | "Programming Language :: Python :: 3.11", 108 | "Programming Language :: Python :: 3.12", 109 | "Programming Language :: Python :: 3.13", 110 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 111 | ], 112 | ext_modules=ext_modules, 113 | ) 114 | -------------------------------------------------------------------------------- /src/flexrag/__init__.py: -------------------------------------------------------------------------------- 1 | from .retriever import RETRIEVERS 2 | from .assistant import ASSISTANTS 3 | from .ranker import RANKERS 4 | from .models import GENERATORS, ENCODERS 5 | from .utils import __VERSION__ 6 | 7 | 8 | __all__ = [ 9 | "RETRIEVERS", 10 | "ASSISTANTS", 11 | "RANKERS", 12 | "GENERATORS", 13 | "ENCODERS", 14 | "__VERSION__", 15 | ] 16 | -------------------------------------------------------------------------------- /src/flexrag/assistant/__init__.py: -------------------------------------------------------------------------------- 1 | from .assistant import ASSISTANTS, AssistantBase, SearchHistory, PREDEFINED_PROMPTS 2 | from .basic_assistant import BasicAssistant, BasicAssistantConfig 3 | from .modular_rag_assistant import ModularAssistant, ModularAssistantConfig 4 | from .chatqa_assistant import ChatQAAssistant 5 | from .online_assistant import ( 6 | JinaDeepSearch, 7 | JinaDeepSearchConfig, 8 | PerplexityAssistant, 9 | PerplexityAssistantConfig, 10 | ) 11 | 12 | __all__ = [ 13 | "ASSISTANTS", 14 | "AssistantBase", 15 | "SearchHistory", 16 | "PREDEFINED_PROMPTS", 17 | "BasicAssistant", 18 | "BasicAssistantConfig", 19 | "ModularAssistant", 20 | "ModularAssistantConfig", 21 | "ChatQAAssistant", 22 | "JinaDeepSearch", 23 | "JinaDeepSearchConfig", 24 | "PerplexityAssistant", 25 | "PerplexityAssistantConfig", 26 | ] 27 | -------------------------------------------------------------------------------- /src/flexrag/assistant/assistant.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | from flexrag.common_dataclass import RetrievedContext 7 | from flexrag.prompt import ChatPrompt 8 | from flexrag.utils import Register 9 | 10 | 11 | class AssistantBase(ABC): 12 | @abstractmethod 13 | def answer( 14 | self, question: str 15 | ) -> tuple[str, Optional[list[RetrievedContext]], Optional[dict]]: 16 | """Answer the given question. 17 | 18 | :param question: The question to answer. 19 | :type question: str 20 | :return: A tuple containing the following elements: 21 | - The response to the question. 22 | - The contexts used to answer the question. 23 | - The metadata of the assistant. 24 | :rtype: tuple[str, Optional[list[RetrievedContext]], Optional[dict]] 25 | """ 26 | return 27 | 28 | 29 | @dataclass 30 | class SearchHistory: 31 | query: str 32 | contexts: list[RetrievedContext] 33 | 34 | def to_dict(self) -> dict[str, str | list[dict]]: 35 | return { 36 | "query": self.query, 37 | "contexts": [ctx.to_dict() for ctx in self.contexts], 38 | } 39 | 40 | 41 | ASSISTANTS = Register[AssistantBase]("assistant") 42 | 43 | 44 | PREDEFINED_PROMPTS = { 45 | "shortform_with_context": ChatPrompt.from_json( 46 | os.path.join( 47 | os.path.dirname(__file__), 48 | "assistant_prompts", 49 | "shortform_generate_prompt_with_context.json", 50 | ) 51 | ), 52 | "shortform_without_context": ChatPrompt.from_json( 53 | os.path.join( 54 | os.path.dirname(__file__), 55 | "assistant_prompts", 56 | "shortform_generate_prompt_without_context.json", 57 | ) 58 | ), 59 | "longform_with_context": ChatPrompt.from_json( 60 | os.path.join( 61 | os.path.dirname(__file__), 62 | "assistant_prompts", 63 | "longform_generate_prompt_with_context.json", 64 | ) 65 | ), 66 | "longform_without_context": ChatPrompt.from_json( 67 | os.path.join( 68 | os.path.dirname(__file__), 69 | "assistant_prompts", 70 | "longform_generate_prompt_without_context.json", 71 | ) 72 | ), 73 | } 74 | -------------------------------------------------------------------------------- /src/flexrag/assistant/assistant_prompts/longform_generate_prompt_with_context.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": { 3 | "role": "system", 4 | "content": "Answer the question based on the given contexts. Note that the context might not always contain relevant information to answer the question." 5 | }, 6 | "history": [], 7 | "demonstrations": [] 8 | } -------------------------------------------------------------------------------- /src/flexrag/assistant/assistant_prompts/longform_generate_prompt_without_context.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": { 3 | "role": "system", 4 | "content": "Answer the following question." 5 | }, 6 | "history": [], 7 | "demonstrations": [] 8 | } -------------------------------------------------------------------------------- /src/flexrag/assistant/assistant_prompts/shortform_generate_prompt_with_context.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": { 3 | "role": "system", 4 | "content": "Answer the question based on the given contexts. Note that the context might not always contain relevant information to answer the question. Only give me the answer and do not output any other words." 5 | }, 6 | "history": [], 7 | "demonstrations": [] 8 | } -------------------------------------------------------------------------------- /src/flexrag/assistant/assistant_prompts/shortform_generate_prompt_without_context.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": { 3 | "role": "system", 4 | "content": "Answer the following question. Only give me the answer and do not output any other words." 5 | }, 6 | "history": [], 7 | "demonstrations": [] 8 | } -------------------------------------------------------------------------------- /src/flexrag/assistant/basic_assistant.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | from flexrag.models import GENERATORS, GenerationConfig, GeneratorConfig 6 | from flexrag.prompt import ChatPrompt, ChatTurn 7 | from flexrag.utils import LOGGER_MANAGER 8 | 9 | from .assistant import ASSISTANTS, AssistantBase 10 | 11 | logger = LOGGER_MANAGER.get_logger("flexrag.assistant") 12 | 13 | 14 | @dataclass 15 | class BasicAssistantConfig(GeneratorConfig, GenerationConfig): 16 | """The configuration for the basic assistant. 17 | 18 | :param prompt_path: The path to the prompt file. Defaults to None. 19 | :type prompt_path: str, optional 20 | :param use_history: Whether to save the chat history for multi-turn conversation. Defaults to False. 21 | :type use_history: bool, optional 22 | """ 23 | 24 | prompt_path: Optional[str] = None 25 | use_history: bool = False 26 | 27 | 28 | @ASSISTANTS("basic", config_class=BasicAssistantConfig) 29 | class BasicAssistant(AssistantBase): 30 | """A basic assistant that generates response without retrieval.""" 31 | 32 | def __init__(self, cfg: BasicAssistantConfig): 33 | # set basic args 34 | self.gen_cfg = cfg 35 | if self.gen_cfg.sample_num > 1: 36 | logger.warning("Sample num > 1 is not supported for Assistant") 37 | self.gen_cfg.sample_num = 1 38 | 39 | # load generator 40 | self.generator = GENERATORS.load(cfg) 41 | 42 | # load prompts 43 | if cfg.prompt_path is not None: 44 | self.prompt = ChatPrompt.from_json(cfg.prompt_path) 45 | else: 46 | self.prompt = ChatPrompt() 47 | if cfg.use_history: 48 | self.history_prompt = deepcopy(self.prompt) 49 | else: 50 | self.history_prompt = None 51 | return 52 | 53 | def answer(self, question: str) -> tuple[str, None, dict[str, ChatPrompt]]: 54 | # prepare system prompt 55 | if self.history_prompt is not None: 56 | prompt = deepcopy(self.history_prompt) 57 | else: 58 | prompt = deepcopy(self.prompt) 59 | 60 | prompt.update(ChatTurn(role="user", content=question)) 61 | 62 | # generate response 63 | response = self.generator.chat([prompt], generation_config=self.gen_cfg)[0][0] 64 | 65 | # update history prompt 66 | if self.history_prompt is not None: 67 | self.history_prompt.update(ChatTurn(role="user", content=question)) 68 | self.history_prompt.update(ChatTurn(role="assistant", content=response)) 69 | return response, None, {"prompt": prompt} 70 | 71 | def clear_history(self) -> None: 72 | if self.history_prompt is not None: 73 | self.history_prompt = deepcopy(self.prompt) 74 | return 75 | -------------------------------------------------------------------------------- /src/flexrag/assistant/chatqa_assistant.py: -------------------------------------------------------------------------------- 1 | from flexrag.common_dataclass import RetrievedContext 2 | from flexrag.utils import LOGGER_MANAGER 3 | 4 | from .assistant import ASSISTANTS 5 | from .modular_rag_assistant import ModularAssistant, ModularAssistantConfig 6 | 7 | 8 | logger = LOGGER_MANAGER.get_logger("flexrag.assistant.chatqa") 9 | 10 | 11 | @ASSISTANTS("chatqa", config_class=ModularAssistantConfig) 12 | class ChatQAAssistant(ModularAssistant): 13 | """The Modular assistant that employs the ChatQA model for response generation.""" 14 | 15 | sys_prompt = ( 16 | "System: This is a chat between a user and an artificial intelligence assistant. " 17 | "The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. " 18 | "The assistant should also indicate when the answer cannot be found in the context." 19 | ) 20 | instruction = "Please give a full and complete answer for the question." 21 | allowed_models = [ 22 | "nvidia/Llama3-ChatQA-2-8B", 23 | "nvidia/Llama3-ChatQA-2-70B", 24 | "nvidia/Llama3-ChatQA-1.5-8B", 25 | "nvidia/Llama3-ChatQA-1.5-70B", 26 | ] 27 | 28 | def __init__(self, cfg: ModularAssistantConfig): 29 | super().__init__(cfg) 30 | logger.warning( 31 | f"ChatQA Assistant expects the model to be one of {self.allowed_models}." 32 | ) 33 | return 34 | 35 | def answer_with_contexts( 36 | self, question: str, contexts: list[RetrievedContext] 37 | ) -> tuple[str, str]: 38 | prefix = self.get_formatted_input(question, contexts) 39 | response = self.generator.generate([prefix], generation_config=self.gen_cfg) 40 | return response[0][0], prefix 41 | 42 | def get_formatted_input( 43 | self, question: str, contexts: list[RetrievedContext] 44 | ) -> str: 45 | # prepare system prompts 46 | prefix = f"{self.sys_prompt}\n\n" 47 | 48 | # prepare context string 49 | for n, context in enumerate(contexts): 50 | if len(self.used_fields) == 0: 51 | ctx = "" 52 | for field_name, field_value in context.data.items(): 53 | ctx += f"{field_name}: {field_value}\n" 54 | elif len(self.used_fields) == 1: 55 | ctx = context.data[self.used_fields[0]] 56 | else: 57 | ctx = "" 58 | for field_name in self.used_fields: 59 | ctx += f"{field_name}: {context.data[field_name]}\n" 60 | prefix += f"Context {n + 1}: {ctx}\n\n" 61 | 62 | # prepare user instruction 63 | prefix += f"User: {self.instruction} {question}\n\nAssistant:" 64 | return prefix 65 | -------------------------------------------------------------------------------- /src/flexrag/assistant/document_chat_assistant.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | from flexrag.common_dataclass import RetrievedContext 5 | from flexrag.chunking import CHUNKERS, ChunkerConfig 6 | from flexrag.document_parser import DOCUMENTPARSERS, DocumentParserConfig 7 | from flexrag.models import GENERATORS, GenerationConfig, GeneratorConfig 8 | from flexrag.prompt import ChatPrompt, ChatTurn 9 | from flexrag.ranker import RANKERS, RankerConfig 10 | from flexrag.retriever import DenseRetriever, DenseRetrieverConfig 11 | from flexrag.utils import LOGGER_MANAGER 12 | 13 | from .assistant import ASSISTANTS, AssistantBase 14 | 15 | logger = LOGGER_MANAGER.get_logger("flexrag.assistant.modular") 16 | 17 | 18 | @dataclass 19 | class DocumentChatAssistantConfig( 20 | GeneratorConfig, 21 | GenerationConfig, 22 | DenseRetrieverConfig, 23 | RankerConfig, 24 | DocumentParserConfig, 25 | ChunkerConfig, 26 | ): ... 27 | 28 | 29 | @ASSISTANTS("document_chat", config_class=DocumentChatAssistantConfig) 30 | class DocumentChatAssistant(AssistantBase): 31 | def __init__(self, cfg: DocumentChatAssistantConfig): 32 | # set basic args 33 | self.gen_cfg = cfg 34 | if self.gen_cfg.sample_num > 1: 35 | logger.warning("Sample num > 1 is not supported for Assistant") 36 | self.gen_cfg.sample_num = 1 37 | 38 | # load generator 39 | self.generator = GENERATORS.load(cfg) 40 | 41 | # load retriever 42 | self.retriever = DenseRetriever(cfg) 43 | assert len(self.retriever) == 0, "Retriever is not empty." 44 | 45 | # load ranker 46 | self.reranker = RANKERS.load(cfg) 47 | 48 | # load parser 49 | self.parser = DOCUMENTPARSERS.load(cfg) 50 | 51 | # load chunker 52 | self.chunker = CHUNKERS.load(cfg) 53 | return 54 | 55 | def attach_document(self, document_path: str = None) -> None: 56 | if document_path is None: 57 | self.retriever.clean() 58 | return 59 | # parse document 60 | self.retriever.clean() 61 | document = self.parser.parse(document_path) 62 | if self.chunker is not None: 63 | chunks = self.chunker.chunk(document.text) 64 | else: 65 | chunks = [document.text] 66 | 67 | # build index 68 | self.retriever.add_passages(chunks) 69 | return 70 | 71 | def answer( 72 | self, question: str 73 | ) -> tuple[str, list[RetrievedContext], dict[str, Any]]: 74 | # answer without contexts 75 | if len(self.retriever) == 0: 76 | prompt = ChatPrompt() 77 | prompt.update(ChatTurn(role="user", content=question)) 78 | response = self.generator.chat([prompt], generation_config=self.gen_cfg) 79 | return response[0][0], [], {"prompt": prompt} 80 | 81 | # retrieve 82 | retrieved_contexts = self.retriever.search(question)[0] 83 | 84 | # rerank 85 | if self.reranker is not None: 86 | contexts = self.reranker.rank(question, retrieved_contexts).candidates 87 | else: 88 | contexts = retrieved_contexts 89 | 90 | # prepare prompt 91 | prompt = ChatPrompt( 92 | system="Answer the user question based on the given contexts." 93 | ) 94 | usr_prompt = "" 95 | for n, context in enumerate(contexts): 96 | ctx = "" 97 | for field_name, field_value in context.data.items(): 98 | ctx += f"{field_name}: {field_value}\n" 99 | usr_prompt += f"Context {n + 1}: {ctx}\n\n" 100 | usr_prompt += f"Question: {question}" 101 | prompt.update(ChatTurn(role="user", content=usr_prompt)) 102 | 103 | # generate response 104 | response = self.generator.chat([prompt], generation_config=self.gen_cfg)[0][0] 105 | return response, contexts, {"prompt": prompt} 106 | -------------------------------------------------------------------------------- /src/flexrag/cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .backends import ( 2 | STORAGEBACKENDS, 3 | DictBackend, 4 | LMDBBackend, 5 | LMDBBackendConfig, 6 | StorageBackendBase, 7 | StorageBackendConfig, 8 | ) 9 | from .persistent_cache import ( 10 | FIFOPersistentCache, 11 | LFUPersistentCache, 12 | LRUPersistentCache, 13 | PersistentCacheBase, 14 | PersistentCacheConfig, 15 | RandomPersistentCache, 16 | ) 17 | from .serializer import ( 18 | SERIALIZERS, 19 | CloudPickleSerializer, 20 | JsonSerializer, 21 | MsgpackSerializer, 22 | PickleSerializer, 23 | SerializerBase, 24 | SerializerConfig, 25 | ) 26 | 27 | __all__ = [ 28 | "LMDBBackend", 29 | "LMDBBackendConfig", 30 | "STORAGEBACKENDS", 31 | "StorageBackendConfig", 32 | "StorageBackendBase", 33 | "DictBackend", 34 | "SERIALIZERS", 35 | "SerializerConfig", 36 | "JsonSerializer", 37 | "MsgpackSerializer", 38 | "PickleSerializer", 39 | "CloudPickleSerializer", 40 | "SerializerBase", 41 | "PersistentCacheConfig", 42 | "PersistentCacheBase", 43 | "RandomPersistentCache", 44 | "LRUPersistentCache", 45 | "LFUPersistentCache", 46 | "FIFOPersistentCache", 47 | ] 48 | -------------------------------------------------------------------------------- /src/flexrag/cache/backends.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import os 3 | from collections import OrderedDict 4 | from dataclasses import dataclass 5 | from typing import MutableMapping 6 | 7 | import lmdb 8 | from omegaconf import MISSING 9 | 10 | from flexrag.utils import Register 11 | 12 | 13 | class StorageBackendBase(MutableMapping[bytes, bytes]): 14 | """The Binary Storage Backend For ``PersistentCache``. 15 | The backend should provide interfaces like ``MutableMapping``. 16 | Thus, The following methods should be implemented: 17 | 18 | >>> def __getitem__(self, key: bytes) -> bytes: 19 | ... pass 20 | >>> def __setitem__(self, key: bytes, value: bytes) -> None: 21 | ... pass 22 | >>> def __delitem__(self, key: bytes) -> None: 23 | ... pass 24 | >>> def __iter__(self) -> Iterable[bytes]: 25 | ... pass 26 | >>> def __len__(self) -> int: 27 | ... pass 28 | 29 | The following methods will be implemented automatically: 30 | 31 | >>> def __contains__(self, key: bytes) -> bool: 32 | ... pass 33 | >>> def keys(self) -> KeysView: 34 | ... pass 35 | >>> def values(self) -> ValuesView: 36 | ... pass 37 | >>> def items(self) -> ItemsView: 38 | ... pass 39 | >>> def get(self, key: bytes, default: Any = None) -> bytes | Any: 40 | ... pass 41 | >>> def __eq__(self, other: StorageBackend) -> bool: 42 | ... pass 43 | >>> def __ne__(self, other: StorageBackend) -> bool: 44 | ... pass 45 | >>> def pop(self, key: bytes, default: Any = None) -> bytes | Any: 46 | ... pass 47 | >>> def popitem(self) -> Tuple: 48 | ... pass 49 | >>> def clear(self) -> None: 50 | ... pass 51 | >>> def update(self, other: MutableMapping) -> None: 52 | ... pass 53 | >>> def setdefault(self, key: bytes, default: Any = None) -> Any: 54 | ... pass 55 | """ 56 | 57 | def __repr__(self) -> str: 58 | f"{self.__class__.__name__}(len={len(self)})" 59 | 60 | 61 | STORAGEBACKENDS = Register[StorageBackendBase]("storage_backend") 62 | 63 | 64 | @dataclass 65 | class LMDBBackendConfig: 66 | db_path: str = MISSING 67 | db_size: int = 1 << 30 # 2^30 bytes = 1GB 68 | 69 | 70 | @STORAGEBACKENDS("lmdb", config_class=LMDBBackendConfig) 71 | class LMDBBackend(StorageBackendBase): 72 | def __init__(self, cfg: LMDBBackendConfig) -> None: 73 | self.db_path = cfg.db_path 74 | if not os.path.exists(os.path.dirname(cfg.db_path)): 75 | os.makedirs(os.path.dirname(cfg.db_path), exist_ok=True) 76 | self.database = lmdb.open(cfg.db_path, map_size=cfg.db_size) 77 | atexit.register(self.database.close) 78 | return 79 | 80 | def __getitem__(self, key: bytes) -> bytes: 81 | with self.database.begin() as txn: 82 | data = txn.get(key) 83 | if data is None: 84 | raise KeyError(key) 85 | return data 86 | 87 | def __setitem__(self, key: bytes, value: bytes) -> None: 88 | with self.database.begin(write=True) as txn: 89 | txn.put(key, value) 90 | return 91 | 92 | def __delitem__(self, key: bytes) -> None: 93 | with self.database.begin(write=True) as txn: 94 | txn.delete(key) 95 | return 96 | 97 | def __len__(self) -> int: 98 | with self.database.begin() as txn: 99 | return txn.stat()["entries"] 100 | 101 | def __iter__(self): 102 | with self.database.begin() as txn: 103 | cursor = txn.cursor() 104 | for key, _ in cursor: 105 | yield key 106 | return 107 | 108 | def __repr__(self) -> str: 109 | return f"{self.__class__.__name__}(db_path={self.db_path}, len={len(self)})" 110 | 111 | 112 | @STORAGEBACKENDS("dict") 113 | class DictBackend(OrderedDict, StorageBackendBase): ... 114 | 115 | 116 | StorageBackendConfig = STORAGEBACKENDS.make_config(default="dict") 117 | -------------------------------------------------------------------------------- /src/flexrag/cache/serializer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from abc import ABC, abstractmethod 4 | from typing import Any 5 | 6 | from flexrag.utils import Register 7 | 8 | 9 | class SerializerBase(ABC): 10 | """A simple interface for serializing and deserializing python objects.""" 11 | 12 | @abstractmethod 13 | def serialize(self, obj: Any) -> bytes: 14 | """Serialize the object into bytes. 15 | 16 | :param obj: The object to serialize. 17 | :type obj: Any 18 | :return: The serialized object. 19 | :rtype: bytes 20 | """ 21 | return 22 | 23 | @abstractmethod 24 | def deserialize(self, data: bytes) -> Any: 25 | """Deserialize the bytes into an object. 26 | 27 | :param data: The serialized object. 28 | :type data: bytes 29 | :return: The deserialized object. 30 | :rtype: Any 31 | """ 32 | return 33 | 34 | 35 | SERIALIZERS = Register[SerializerBase]("serializer") 36 | 37 | 38 | @SERIALIZERS("pickle") 39 | class PickleSerializer(SerializerBase): 40 | """A serializer that uses the pickle module.""" 41 | 42 | def serialize(self, obj: Any) -> bytes: 43 | return pickle.dumps(obj) 44 | 45 | def deserialize(self, data: bytes) -> Any: 46 | return pickle.loads(data) 47 | 48 | 49 | @SERIALIZERS("cloudpickle") 50 | class CloudPickleSerializer(SerializerBase): 51 | """A serializer that uses the cloudpickle module.""" 52 | 53 | def __init__(self): 54 | try: 55 | import cloudpickle 56 | 57 | self.pickler = cloudpickle 58 | except: 59 | raise ImportError( 60 | "Please install cloudpickle using `pip install cloudpickle`." 61 | ) 62 | return 63 | 64 | def serialize(self, obj: Any) -> bytes: 65 | return self.pickler.dumps(obj) 66 | 67 | def deserialize(self, data: bytes) -> Any: 68 | return self.pickler.loads(data) 69 | 70 | 71 | @SERIALIZERS("json") 72 | class JsonSerializer(SerializerBase): 73 | """A serializer that uses the json module.""" 74 | 75 | def serialize(self, obj: Any) -> bytes: 76 | return json.dumps(obj).encode("utf-8") 77 | 78 | def deserialize(self, data: bytes) -> Any: 79 | return json.loads(data.decode("utf-8")) 80 | 81 | 82 | @SERIALIZERS("msgpack") 83 | class MsgpackSerializer(SerializerBase): 84 | """A serializer that uses the msgpack module.""" 85 | 86 | def __init__(self) -> None: 87 | try: 88 | import msgpack 89 | 90 | self.msgpack = msgpack 91 | except ImportError: 92 | raise ImportError("Please install msgpack using `pip install msgpack`.") 93 | return 94 | 95 | def serialize(self, obj: Any) -> bytes: 96 | return self.msgpack.packb(obj, use_bin_type=True) 97 | 98 | def deserialize(self, data: bytes) -> Any: 99 | return self.msgpack.unpackb(data, raw=False) 100 | 101 | 102 | SerializerConfig = SERIALIZERS.make_config(default="pickle") 103 | -------------------------------------------------------------------------------- /src/flexrag/chunking/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunker_base import ChunkerBase, CHUNKERS 2 | from .basic_chunkers import ( 3 | CharChunker, 4 | CharChunkerConfig, 5 | TokenChunker, 6 | TokenChunkerConfig, 7 | RecursiveChunker, 8 | RecursiveChunkerConfig, 9 | SentenceChunker, 10 | SentenceChunkerConfig, 11 | ) 12 | from .semantic_chunker import SemanticChunker, SemanticChunkerConfig 13 | 14 | 15 | ChunkerConfig = CHUNKERS.make_config( 16 | default="sentence_chunker", config_name="ChunkerConfig" 17 | ) 18 | 19 | 20 | __all__ = [ 21 | "ChunkerBase", 22 | "CHUNKERS", 23 | "ChunkerConfig", 24 | "CharChunker", 25 | "CharChunkerConfig", 26 | "TokenChunker", 27 | "TokenChunkerConfig", 28 | "RecursiveChunker", 29 | "RecursiveChunkerConfig", 30 | "SentenceChunker", 31 | "SentenceChunkerConfig", 32 | "SemanticChunker", 33 | "SemanticChunkerConfig", 34 | ] 35 | -------------------------------------------------------------------------------- /src/flexrag/chunking/chunker_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | from flexrag.utils import Register 6 | 7 | 8 | @dataclass 9 | class Chunk: 10 | """The dataclass for a chunk of text. 11 | 12 | :param text: The text of the chunk. 13 | :type text: str 14 | :param start: The start index of the chunk in the original text. 15 | :type start: Optional[int] 16 | :param end: The end index of the chunk in the original text. 17 | :type end: Optional[int] 18 | """ 19 | 20 | text: str 21 | start: Optional[int] = None 22 | end: Optional[int] = None 23 | 24 | 25 | class ChunkerBase(ABC): 26 | """Chunker that splits text into chunks of fixed size. 27 | This is an abstract class that defines the interface for all chunkers. 28 | The subclasses should implement the `chunk` method to split the text. 29 | """ 30 | 31 | @abstractmethod 32 | def chunk(self, text: str) -> list[Chunk]: 33 | """Chunk the given text into smaller chunks. 34 | 35 | :param text: The text to chunk. 36 | :type text: str 37 | :return: The chunks of the text. 38 | :rtype: list[Chunk] 39 | """ 40 | return 41 | 42 | 43 | CHUNKERS = Register[ChunkerBase]("chunker") 44 | -------------------------------------------------------------------------------- /src/flexrag/common_dataclass.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | from omegaconf import MISSING 5 | 6 | 7 | @dataclass 8 | class Context: 9 | """The dataclass for retrieved context. 10 | 11 | :param context_id: The unique identifier of the context. Default: None. 12 | :type context_id: Optional[str] 13 | :param data: The context data. Default: {}. 14 | :type data: dict 15 | :param source: The source of the retrieved data. Default: None. 16 | :type source: Optional[str] 17 | :param meta_data: The metadata of the context. Default: {}. 18 | :type meta_data: dict 19 | """ 20 | 21 | context_id: Optional[str] = None 22 | data: dict = field(default_factory=dict) 23 | source: Optional[str] = None 24 | meta_data: dict = field(default_factory=dict) 25 | 26 | def to_dict(self): 27 | return { 28 | "context_id": self.context_id, 29 | "source": self.source, 30 | "data": self.data, 31 | "meta_data": self.meta_data, 32 | } 33 | 34 | 35 | @dataclass 36 | class RetrievedContext(Context): 37 | """The dataclass for retrieved context. 38 | 39 | :param retriever: The name of the retriever. Required. 40 | :type retriever: str 41 | :param query: The query for retrieval. Required. 42 | :type query: str 43 | :param score: The relevance score of the retrieved data. Default: 0.0. 44 | :type score: float 45 | """ 46 | 47 | retriever: str = MISSING 48 | query: str = MISSING 49 | score: float = 0.0 50 | 51 | def to_dict(self): 52 | return { 53 | **super().to_dict(), 54 | "retriever": self.retriever, 55 | "query": self.query, 56 | "score": self.score, 57 | } 58 | 59 | 60 | @dataclass 61 | class RAGEvalData: 62 | """The dataclass for RAG evaluation data. 63 | 64 | :param question: The question for evaluation. Required. 65 | :type question: str 66 | :param golden_contexts: The contexts related to the question. Default: None. 67 | :type golden_contexts: Optional[list[Context]] 68 | :param golden_answers: The golden answers for the question. Default: None. 69 | :type golden_answers: Optional[list[str]] 70 | :param meta_data: The metadata of the evaluation data. Default: {}. 71 | :type meta_data: dict 72 | """ 73 | 74 | question: str = MISSING 75 | golden_contexts: Optional[list[Context]] = None 76 | golden_answers: Optional[list[str]] = None 77 | meta_data: dict = field(default_factory=dict) 78 | 79 | 80 | @dataclass 81 | class IREvalData: 82 | """The dataclass for Information Retrieval evaluation data. 83 | 84 | :param question: The question for evaluation. Required. 85 | :type question: str 86 | :param contexts: The contexts related to the question. Default: None. 87 | :type contexts: Optional[list[Context]] 88 | :param meta_data: The metadata of the evaluation data. Default: {}. 89 | :type meta_data: dict 90 | """ 91 | 92 | question: str 93 | contexts: Optional[list[Context]] = None 94 | meta_data: dict = field(default_factory=dict) 95 | -------------------------------------------------------------------------------- /src/flexrag/context_refine/__init__.py: -------------------------------------------------------------------------------- 1 | from .arranger import ContextArranger, ContextArrangerConfig 2 | from .summarizer import ( 3 | RecompExtractiveSummarizer, 4 | RecompExtractiveSummarizerConfig, 5 | AbstractiveSummarizer, 6 | AbstractiveSummarizerConfig, 7 | ) 8 | from .refiner import RefinerBase, REFINERS 9 | 10 | 11 | RefinerConfig = REFINERS.make_config( 12 | allow_multiple=True, default=None, config_name="RefinerConfig" 13 | ) 14 | 15 | 16 | __all__ = [ 17 | "ContextArranger", 18 | "ContextArrangerConfig", 19 | "RecompExtractiveSummarizer", 20 | "RecompExtractiveSummarizerConfig", 21 | "AbstractiveSummarizer", 22 | "AbstractiveSummarizerConfig", 23 | "RefinerBase", 24 | "REFINERS", 25 | "RefinerConfig", 26 | ] 27 | -------------------------------------------------------------------------------- /src/flexrag/context_refine/arranger.py: -------------------------------------------------------------------------------- 1 | import random as rd 2 | from dataclasses import dataclass 3 | 4 | from flexrag.common_dataclass import RetrievedContext 5 | from flexrag.utils import Choices, TIME_METER 6 | 7 | from .refiner import REFINERS, RefinerBase 8 | 9 | 10 | @dataclass 11 | class ContextArrangerConfig: 12 | """The configuration for the ``ContextArranger``. 13 | 14 | :param order: The order to arrange the contexts. Defaults to "ascending". 15 | available choices: "ascending", "descending", "side", "random". 16 | :type order: str 17 | """ 18 | 19 | order: Choices(["ascending", "descending", "side", "random"]) = "ascending" # type: ignore 20 | 21 | 22 | @REFINERS("context_arranger", config_class=ContextArrangerConfig) 23 | class ContextArranger(RefinerBase): 24 | """The ``ContextArranger`` arranges the contexts based on the given order. 25 | 26 | As the `lost-in-the-middle` problem encountered by the LLMs, the order of the contexts may affect the performance. 27 | This refiner helps to arrange the contexts in a specific order. 28 | """ 29 | 30 | def __init__(self, config: ContextArrangerConfig): 31 | self.order = config.order 32 | return 33 | 34 | @TIME_METER("repack") 35 | def refine(self, contexts: list[RetrievedContext]) -> list[RetrievedContext]: 36 | match self.order: 37 | case "ascending": 38 | contexts = sorted(contexts, key=lambda x: x.score) 39 | case "descending": 40 | contexts = sorted(contexts, key=lambda x: x.score, reverse=True) 41 | case "random": 42 | indices = list(range(len(contexts))) 43 | rd.shuffle(indices) 44 | contexts = [contexts[i] for i in indices] 45 | case "side": 46 | sort_ctxs = sorted(contexts, key=lambda x: x.score, reverse=True) 47 | contexts_left = [] 48 | contexts_right = [] 49 | for i in range(0, len(sort_ctxs), 2): 50 | contexts_left.append(sort_ctxs[i]) 51 | for i in range(1, len(sort_ctxs), 2): 52 | contexts_right.append(sort_ctxs[i]) 53 | contexts = contexts_left + contexts_right[::-1] 54 | case _: 55 | raise ValueError(f"Invalid order: {self.order}") 56 | return contexts 57 | -------------------------------------------------------------------------------- /src/flexrag/context_refine/refiner.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from flexrag.common_dataclass import RetrievedContext 4 | from flexrag.utils import Register 5 | 6 | 7 | class RefinerBase(ABC): 8 | """The base class for context refiners. 9 | The subclasses should implement the ``refine`` method. 10 | """ 11 | 12 | @abstractmethod 13 | def refine(self, contexts: list[RetrievedContext]) -> list[RetrievedContext]: 14 | """Refine the contexts. 15 | 16 | :param contexts: The retrieved contexts to refine. 17 | :type contexts: list[RetrievedContext] 18 | :return: The refined contexts. 19 | :rtype: list[RetrievedContext] 20 | """ 21 | return 22 | 23 | 24 | REFINERS = Register[RefinerBase]("refiner") 25 | -------------------------------------------------------------------------------- /src/flexrag/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # datasets 2 | from .dataset import ChainDataset, ConcatDataset, IterableDataset, MappingDataset 3 | from .hf_dataset import HFDataset, HFDatasetConfig 4 | from .line_delimited_dataset import LineDelimitedDataset, LineDelimitedDatasetConfig 5 | from .rag_dataset import ( 6 | RAGCorpusDataset, 7 | RAGCorpusDatasetConfig, 8 | RAGEvalDataset, 9 | RAGEvalDatasetConfig, 10 | ) 11 | from .retrieval_dataset import MTEBDataset, MTEBDatasetConfig 12 | 13 | __all__ = [ 14 | "ChainDataset", 15 | "IterableDataset", 16 | "MappingDataset", 17 | "ConcatDataset", 18 | "HFDataset", 19 | "HFDatasetConfig", 20 | "LineDelimitedDataset", 21 | "LineDelimitedDatasetConfig", 22 | "RAGEvalDatasetConfig", 23 | "RAGEvalDataset", 24 | "RAGCorpusDatasetConfig", 25 | "RAGCorpusDataset", 26 | "MTEBDataset", 27 | "MTEBDatasetConfig", 28 | ] 29 | -------------------------------------------------------------------------------- /src/flexrag/datasets/hf_dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | from datasets import Dataset as _Dataset 5 | from datasets import DatasetDict as _DatasetDict 6 | from datasets import load_dataset 7 | 8 | from .dataset import MappingDataset 9 | 10 | 11 | @dataclass 12 | class HFDatasetConfig: 13 | """The configuration for the ``HFDataset``. 14 | The ``HFDataset`` is a wrapper class that employs the ``load_dataset`` method in HuggingFace ``datasets`` library to load the dataset. 15 | 16 | :param path: Path or name of the dataset. 17 | :type path: str 18 | :param name: Defining the name of the dataset configuration. 19 | :type name: Optional[str] 20 | :param data_dir: Defining the ``data_dir`` of the dataset configuration. 21 | :type data_dir: Optional[str] 22 | :param data_files: Paths to source data files. 23 | :type data_files: list[str] 24 | :param split: Which split of the data to load. 25 | :type split: Optional[str] 26 | :param cache_dir: Directory to read/write data. 27 | :type cache_dir: Optional[str] 28 | :param token: Optional string or boolean to use as Bearer token for remote files on the Datasets Hub. 29 | :type token: Optional[str] 30 | :param trust_remote_code: Whether or not to allow for datasets defined on the Hub using a dataset script. 31 | :type trust_remote_code: bool 32 | 33 | For example, you can load the dataset from the HuggingFace by running the following code: 34 | 35 | >>> cfg = HFDatasetConfig( 36 | ... path="mteb/nq", 37 | ... split="test", 38 | ... ) 39 | >>> dataset = HFDataset(cfg) 40 | 41 | You can also load the dataset from a local repository by specifying the path: 42 | 43 | >>> cfg = HFDatasetConfig( 44 | ... path="json", 45 | ... data_files=["path/to/local/my_dataset.json"], 46 | ... ) 47 | >>> dataset = HFDataset(cfg) 48 | 49 | For more information about the parameters, please refer to the HuggingFace ``datasets`` documentation: 50 | https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset 51 | """ 52 | 53 | path: str 54 | name: Optional[str] = None 55 | data_dir: Optional[str] = None 56 | data_files: list[str] = field(default_factory=list) 57 | split: Optional[str] = None 58 | cache_dir: Optional[str] = None 59 | token: Optional[str] = None 60 | trust_remote_code: bool = False 61 | 62 | 63 | class HFDataset(MappingDataset): 64 | """HFDataset is a dataset that wraps the HaggingFace ``datasets`` library.""" 65 | 66 | dataset: _Dataset 67 | 68 | def __init__(self, cfg: HFDatasetConfig) -> None: 69 | super().__init__() 70 | self.dataset = load_dataset( 71 | path=cfg.path, 72 | name=cfg.name, 73 | data_dir=cfg.data_dir, 74 | data_files=cfg.data_files if cfg.data_files else None, 75 | split=cfg.split, 76 | cache_dir=cfg.cache_dir, 77 | token=cfg.token, 78 | trust_remote_code=cfg.trust_remote_code, 79 | ) 80 | if isinstance(self.dataset, _DatasetDict): 81 | raise ValueError( 82 | "Split is missing.\n" 83 | "Please pick one among the following splits: " 84 | f"{list(self.dataset.keys())}" 85 | ) 86 | return 87 | 88 | def __getitem__(self, index: int): 89 | return self.dataset[index] 90 | 91 | def __len__(self): 92 | return len(self.dataset) 93 | -------------------------------------------------------------------------------- /src/flexrag/datasets/retrieval_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | from omegaconf import MISSING 6 | 7 | from flexrag.common_dataclass import Context, IREvalData 8 | 9 | from .dataset import MappingDataset 10 | 11 | 12 | @dataclass 13 | class MTEBDatasetConfig: 14 | """Configuration for loading `MTEB `_ Retrieval Dataset. 15 | The __getitem__ method will return `IREvalData` objects. 16 | 17 | For example, to load the NQ dataset, you can download the test set by running the following command: 18 | 19 | >>> git lfs install 20 | >>> git clone https://huggingface.co/datasets/mteb/nq nq 21 | 22 | Then you can use the following code to load the dataset: 23 | 24 | >>> config = MTEBDatasetConfig( 25 | ... data_path="nq", 26 | ... subset="test", 27 | ... load_corpus=False, 28 | ... ) 29 | >>> dataset = MTEBDataset(config) 30 | 31 | :param data_path: Path to the data directory. Required. 32 | :type data_path: str 33 | :param subset: Subset of the dataset to load. Required. 34 | :type subset: str 35 | :param encoding: Encoding of the data files. Default is 'utf-8'. 36 | :type encoding: str 37 | :param load_corpus: Whether to load the corpus data. Default is False. 38 | :type load_corpus: bool 39 | """ 40 | 41 | data_path: str = MISSING 42 | subset: str = MISSING 43 | encoding: str = "utf-8" 44 | load_corpus: bool = False 45 | 46 | 47 | class MTEBDataset(MappingDataset[IREvalData]): 48 | """Dataset for loading MTEB Retrieval Dataset.""" 49 | 50 | def __init__(self, config: MTEBDatasetConfig) -> None: 51 | qrels: list[dict] = [ 52 | json.loads(line) 53 | for line in open( 54 | os.path.join(config.data_path, "qrels", f"{config.subset}.jsonl"), 55 | "r", 56 | encoding=config.encoding, 57 | ) 58 | ] 59 | queries = [ 60 | json.loads(line) 61 | for line in open( 62 | os.path.join(config.data_path, "queries.jsonl"), 63 | "r", 64 | encoding=config.encoding, 65 | ) 66 | ] 67 | queries = {query["_id"]: query for query in queries} 68 | 69 | if config.load_corpus: 70 | corpus = [ 71 | json.loads(line) 72 | for line in open( 73 | os.path.join(config.data_path, "corpus.jsonl"), 74 | "r", 75 | encoding=config.encoding, 76 | ) 77 | ] 78 | corpus = {doc["_id"]: doc for doc in corpus} 79 | else: 80 | corpus = None 81 | 82 | # merge qrels, queries, and corpus into RetrievalData 83 | dataset_map: dict[str, int] = {} 84 | self.dataset: list[IREvalData] = [] 85 | for qrel in qrels: 86 | # construct the context 87 | context = Context(context_id=qrel["corpus-id"]) 88 | if corpus is not None: 89 | context.data = corpus[qrel["corpus-id"]] 90 | if "score" in qrel: # relevance level of the context 91 | context.meta_data["score"] = int(qrel["score"]) 92 | query = queries[qrel["query-id"]]["text"] 93 | 94 | if qrel["query-id"] not in dataset_map: 95 | dataset_map[qrel["query-id"]] = len(self.dataset) 96 | self.dataset.append( 97 | IREvalData( 98 | question=query, 99 | contexts=[context], 100 | meta_data={"query-id": qrel["query-id"]}, 101 | ) 102 | ) 103 | else: 104 | index = dataset_map[qrel["query-id"]] 105 | self.dataset[index].contexts.append(context) 106 | return 107 | 108 | def __len__(self) -> int: 109 | return len(self.dataset) 110 | 111 | def __getitem__(self, index: int) -> IREvalData: 112 | return self.dataset[index] 113 | -------------------------------------------------------------------------------- /src/flexrag/document_parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .document_parser_base import DocumentParserBase, Document, DOCUMENTPARSERS 2 | from .docling_parser import DoclingParser, DoclingConfig 3 | from .markitdown_parser import MarkItDownParser 4 | 5 | 6 | DocumentParserConfig = DOCUMENTPARSERS.make_config(default="markitdown") 7 | 8 | 9 | __all__ = [ 10 | "DocumentParserBase", 11 | "Document", 12 | "DOCUMENTPARSERS", 13 | "DocumentParserConfig", 14 | "DoclingParser", 15 | "DoclingConfig", 16 | "MarkItDownParser", 17 | ] 18 | -------------------------------------------------------------------------------- /src/flexrag/document_parser/docling_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | 5 | from .document_parser_base import Document, DocumentParserBase, DOCUMENTPARSERS 6 | 7 | 8 | @dataclass 9 | class DoclingConfig: 10 | do_ocr: bool = False 11 | do_table_structure: bool = True 12 | generate_page_images: bool = False 13 | generate_picture_images: bool = False 14 | 15 | 16 | @DOCUMENTPARSERS("docling", config_class=DoclingConfig) 17 | class DoclingParser(DocumentParserBase): 18 | def __init__(self, config: DoclingConfig): 19 | try: 20 | from docling.datamodel.base_models import InputFormat 21 | from docling.datamodel.pipeline_options import PdfPipelineOptions 22 | from docling.document_converter import DocumentConverter, PdfFormatOption 23 | except ImportError: 24 | raise ImportError( 25 | "Docling is not installed. Please install it via `pip install docling`." 26 | ) 27 | 28 | pdf_pipeline_options = PdfPipelineOptions( 29 | do_ocr=config.do_ocr, 30 | do_table_structure=config.do_table_structure, 31 | generate_page_images=config.generate_page_images, 32 | generate_picture_images=config.generate_picture_images, 33 | ) 34 | self.doc_converter = DocumentConverter( 35 | format_options={ 36 | InputFormat.PDF: PdfFormatOption(pipeline_options=pdf_pipeline_options) 37 | } 38 | ) 39 | return 40 | 41 | def parse(self, input_file_path: str) -> Document: 42 | assert os.path.exists(input_file_path) 43 | document_ = self.doc_converter.convert(input_file_path).document 44 | document = Document( 45 | source_file_path=input_file_path, 46 | text=document_.export_to_markdown(), 47 | title=document_.name, 48 | ) 49 | if document.pagaes.image is not None: 50 | document.screenshots = [p.image.pil_image for p in document_.pages] 51 | if document.pictures.image is not None: 52 | document.images = [p.image.pil_image for p in document_.pictures] 53 | return document 54 | -------------------------------------------------------------------------------- /src/flexrag/document_parser/document_parser_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | from PIL.Image import Image 6 | 7 | from flexrag.utils import Register 8 | 9 | 10 | @dataclass 11 | class Document: 12 | """A document parsed by a DocumentParser.""" 13 | 14 | source_file_path: str 15 | title: Optional[str] = None 16 | text: Optional[str] = None 17 | screenshots: list[Image] = field(default_factory=list) 18 | images: list[Image] = field(default_factory=list) 19 | 20 | 21 | class DocumentParserBase(ABC): 22 | @abstractmethod 23 | def parse(self, document_path: str) -> Document: 24 | """Parse the document at the given path. 25 | 26 | :param document_path: The path to the document to parse. 27 | :type document_path: str 28 | :return: The parsed document. 29 | :rtype: Document 30 | """ 31 | return 32 | 33 | 34 | DOCUMENTPARSERS = Register[DocumentParserBase]("document_parser") 35 | -------------------------------------------------------------------------------- /src/flexrag/document_parser/markitdown_parser.py: -------------------------------------------------------------------------------- 1 | from .document_parser_base import DOCUMENTPARSERS, Document, DocumentParserBase 2 | 3 | 4 | @DOCUMENTPARSERS("markitdown") 5 | class MarkItDownParser(DocumentParserBase): 6 | def __init__(self): 7 | try: 8 | from markitdown import MarkItDown 9 | except ImportError: 10 | raise ImportError( 11 | "MarkItDown is not installed. Please install it via `pip install markitdown`." 12 | ) 13 | finally: 14 | self.parser = MarkItDown() 15 | return 16 | 17 | def parse(self, path: str) -> Document: 18 | doc = self.parser.convert(path) 19 | return Document(source_file_path=path, text=doc.text_content, title=doc.title) 20 | -------------------------------------------------------------------------------- /src/flexrag/entrypoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/__init__.py -------------------------------------------------------------------------------- /src/flexrag/entrypoints/assets/flexrag-wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/assets/flexrag-wide.png -------------------------------------------------------------------------------- /src/flexrag/entrypoints/assets/flexrag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/assets/flexrag.png -------------------------------------------------------------------------------- /src/flexrag/entrypoints/assets/robot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/assets/robot.png -------------------------------------------------------------------------------- /src/flexrag/entrypoints/assets/user.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/assets/user.png -------------------------------------------------------------------------------- /src/flexrag/entrypoints/cache.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | 4 | import hydra 5 | from hydra.core.config_store import ConfigStore 6 | from omegaconf import MISSING 7 | 8 | from flexrag.retriever.retriever_base import RETRIEVAL_CACHE 9 | from flexrag.utils import Choices 10 | 11 | 12 | @dataclass 13 | class Config: 14 | export_path: str = MISSING 15 | action: Choices(["clear", "export", "_"]) = "_" # type: ignore 16 | 17 | 18 | cs = ConfigStore.instance() 19 | cs.store(name="default", node=Config) 20 | 21 | 22 | @hydra.main(version_base="1.3", config_path=None, config_name="default") 23 | def main(config: Config): 24 | match config.action: 25 | case "clear": 26 | RETRIEVAL_CACHE.clear() 27 | case "export": 28 | with open(config.export_path, "w", encoding="utf-8") as f: 29 | for data in RETRIEVAL_CACHE: 30 | data["retrieved_contexts"] = RETRIEVAL_CACHE[data] 31 | f.write(json.dumps(data) + "\n") 32 | case _: 33 | raise ValueError("No action specified") 34 | return 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /src/flexrag/entrypoints/combine_outputs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import hydra 6 | from hydra.core.config_store import ConfigStore 7 | from omegaconf import MISSING, OmegaConf 8 | 9 | from flexrag.metrics import Evaluator 10 | from flexrag.utils import LOGGER_MANAGER 11 | 12 | 13 | @dataclass 14 | class Config: 15 | result_paths: list[str] = MISSING 16 | output_path: str = MISSING 17 | 18 | 19 | cs = ConfigStore.instance() 20 | cs.store(name="default", node=Config) 21 | logger = LOGGER_MANAGER.get_logger("combine_outputs") 22 | 23 | 24 | @hydra.main(version_base="1.3", config_path=None, config_name="default") 25 | def main(cfg: Config): 26 | # load the metadata 27 | config_path = os.path.join(cfg.result_paths[0], "config.yaml") 28 | loaded_config = OmegaConf.load(config_path) 29 | evaluator = Evaluator(loaded_config.eval_config) 30 | 31 | # prepare output path 32 | if not os.path.exists(cfg.output_path): 33 | os.makedirs(cfg.output_path) 34 | output_details_path = os.path.join(cfg.output_path, "details.jsonl") 35 | output_eval_score_path = os.path.join(cfg.output_path, "eval_score.json") 36 | output_config_path = os.path.join(cfg.output_path, "config.yaml") 37 | OmegaConf.save(loaded_config, output_config_path) 38 | 39 | # combine the results 40 | logger.info("Combining the results...") 41 | questions = [] 42 | golden_answers = [] 43 | golden_contexts = [] 44 | responses = [] 45 | contexts = [] 46 | with open(output_details_path, "w", encoding="utf-8") as f: 47 | for result_path in cfg.result_paths: 48 | details_path = os.path.join(result_path, "details.jsonl") 49 | for line in open(details_path, "r", encoding="utf-8"): 50 | f.write(line) 51 | data = json.loads(line) 52 | questions.append(data["question"]) 53 | golden_answers.append(data["golden"]) 54 | golden_contexts.append(data["golden_contexts"]) 55 | responses.append(data["response"]) 56 | contexts.append(data["contexts"]) 57 | 58 | # re-evaluate the combined results 59 | logger.info("Re-evaluating the combined results...") 60 | resp_score, resp_score_detail = evaluator.evaluate( 61 | questions=questions, 62 | responses=responses, 63 | golden_responses=golden_answers, 64 | retrieved_contexts=contexts, 65 | golden_contexts=golden_contexts, 66 | log=True, 67 | ) 68 | with open(output_eval_score_path, "w", encoding="utf-8") as f: 69 | json.dump( 70 | { 71 | "eval_scores": resp_score, 72 | "eval_details": resp_score_detail, 73 | }, 74 | f, 75 | indent=4, 76 | ensure_ascii=False, 77 | ) 78 | return 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /src/flexrag/entrypoints/prepare_index.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import hydra 4 | from hydra.core.config_store import ConfigStore 5 | from omegaconf import OmegaConf 6 | 7 | from flexrag.datasets import RAGCorpusDataset, RAGCorpusDatasetConfig 8 | from flexrag.retriever import ( 9 | BM25SRetriever, 10 | BM25SRetrieverConfig, 11 | DenseRetriever, 12 | DenseRetrieverConfig, 13 | ElasticRetriever, 14 | ElasticRetrieverConfig, 15 | TypesenseRetriever, 16 | TypesenseRetrieverConfig, 17 | ) 18 | from flexrag.utils import LOGGER_MANAGER, Choices 19 | 20 | logger = LOGGER_MANAGER.get_logger("flexrag.prepare_index") 21 | 22 | 23 | # fmt: off 24 | @dataclass 25 | class Config(RAGCorpusDatasetConfig): 26 | # retriever configs 27 | retriever_type: Choices(["dense", "elastic", "typesense", "bm25s"]) = "dense" # type: ignore 28 | bm25s_config: BM25SRetrieverConfig = field(default_factory=BM25SRetrieverConfig) 29 | dense_config: DenseRetrieverConfig = field(default_factory=DenseRetrieverConfig) 30 | elastic_config: ElasticRetrieverConfig = field(default_factory=ElasticRetrieverConfig) 31 | typesense_config: TypesenseRetrieverConfig = field(default_factory=TypesenseRetrieverConfig) 32 | reinit: bool = False 33 | # fmt: on 34 | 35 | 36 | cs = ConfigStore.instance() 37 | cs.store(name="default", node=Config) 38 | 39 | 40 | @hydra.main(version_base="1.3", config_path=None, config_name="default") 41 | def main(cfg: Config): 42 | # load retriever 43 | match cfg.retriever_type: 44 | case "bm25s": 45 | retriever = BM25SRetriever(cfg.bm25s_config) 46 | case "dense": 47 | retriever = DenseRetriever(cfg.dense_config) 48 | case "elastic": 49 | retriever = ElasticRetriever(cfg.elastic_config) 50 | case "typesense": 51 | retriever = TypesenseRetriever(cfg.typesense_config) 52 | case _: 53 | raise ValueError(f"Unsupported retriever type: {cfg.retriever_type}") 54 | 55 | # add passages 56 | if cfg.reinit and (len(retriever) > 0): 57 | logger.warning("Reinitializing retriever and removing all passages") 58 | retriever.clean() 59 | 60 | retriever.add_passages(passages=RAGCorpusDataset(cfg)) 61 | return 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /src/flexrag/entrypoints/rebuild_index.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.config_store import ConfigStore 3 | from omegaconf import OmegaConf 4 | 5 | from flexrag.retriever import DenseRetriever, DenseRetrieverConfig 6 | 7 | 8 | cs = ConfigStore.instance() 9 | cs.store(name="default", node=DenseRetrieverConfig) 10 | 11 | 12 | @hydra.main(version_base="1.3", config_path=None, config_name="default") 13 | def main(cfg: DenseRetrieverConfig): 14 | # rebuild index 15 | retriever = DenseRetriever(cfg, no_check=True) 16 | retriever.build_index(rebuild=True) 17 | return 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /src/flexrag/entrypoints/run_assistant.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import hydra 9 | from hydra.core.config_store import ConfigStore 10 | from omegaconf import OmegaConf 11 | 12 | from flexrag.assistant import ASSISTANTS 13 | from flexrag.common_dataclass import RetrievedContext 14 | from flexrag.datasets import RAGEvalDataset, RAGEvalDatasetConfig 15 | from flexrag.metrics import Evaluator, EvaluatorConfig 16 | from flexrag.utils import LOGGER_MANAGER, SimpleProgressLogger, load_user_module 17 | 18 | # load user modules before loading config 19 | for arg in sys.argv: 20 | if arg.startswith("user_module="): 21 | load_user_module(arg.split("=")[1]) 22 | sys.argv.remove(arg) 23 | 24 | 25 | AssistantConfig = ASSISTANTS.make_config() 26 | 27 | 28 | @dataclass 29 | class Config(AssistantConfig, RAGEvalDatasetConfig): 30 | eval_config: EvaluatorConfig = field(default_factory=EvaluatorConfig) # fmt: skip 31 | log_interval: int = 10 32 | output_path: Optional[str] = None 33 | 34 | 35 | cs = ConfigStore.instance() 36 | cs.store(name="default", node=Config) 37 | logger = LOGGER_MANAGER.get_logger("run_assistant") 38 | 39 | 40 | @hydra.main(version_base="1.3", config_path=None, config_name="default") 41 | def main(config: Config): 42 | # load dataset 43 | testset = RAGEvalDataset(config) 44 | 45 | # load assistant 46 | assistant = ASSISTANTS.load(config) 47 | 48 | # prepare output paths 49 | if config.output_path is not None: 50 | if not os.path.exists(config.output_path): 51 | os.makedirs(config.output_path) 52 | details_path = os.path.join(config.output_path, "details.jsonl") 53 | eval_score_path = os.path.join(config.output_path, "eval_score.json") 54 | config_path = os.path.join(config.output_path, "config.yaml") 55 | log_path = os.path.join(config.output_path, "log.txt") 56 | else: 57 | details_path = os.devnull 58 | eval_score_path = os.devnull 59 | config_path = os.devnull 60 | log_path = os.devnull 61 | 62 | # save config and set logger 63 | with open(config_path, "w", encoding="utf-8") as f: 64 | OmegaConf.save(config, f) 65 | handler = logging.FileHandler(log_path) 66 | LOGGER_MANAGER.add_handler(handler) 67 | logger.debug(f"Configs:\n{OmegaConf.to_yaml(config)}") 68 | 69 | # search and generate 70 | p_logger = SimpleProgressLogger(logger, interval=config.log_interval) 71 | questions = [] 72 | golden_answers = [] 73 | golden_contexts = [] 74 | responses = [] 75 | contexts: list[list[RetrievedContext]] = [] 76 | with open(details_path, "w", encoding="utf-8") as f: 77 | for item in testset: 78 | questions.append(item.question) 79 | golden_answers.append(item.golden_answers) 80 | golden_contexts.append(item.golden_contexts) 81 | response, ctxs, metadata = assistant.answer(question=item.question) 82 | responses.append(response) 83 | contexts.append(ctxs) 84 | json.dump( 85 | { 86 | "question": item.question, 87 | "golden": item.golden_answers, 88 | "golden_contexts": item.golden_contexts, 89 | "metadata_test": item.meta_data, 90 | "response": response, 91 | "contexts": ctxs, 92 | "metadata": metadata, 93 | }, 94 | f, 95 | ensure_ascii=False, 96 | ) 97 | f.write("\n") 98 | p_logger.update(desc="Searching") 99 | 100 | # evaluate 101 | evaluator = Evaluator(config.eval_config) 102 | resp_score, resp_score_detail = evaluator.evaluate( 103 | questions=questions, 104 | responses=responses, 105 | golden_responses=golden_answers, 106 | retrieved_contexts=contexts, 107 | golden_contexts=golden_contexts, 108 | log=True, 109 | ) 110 | with open(eval_score_path, "w", encoding="utf-8") as f: 111 | json.dump( 112 | { 113 | "eval_scores": resp_score, 114 | "eval_details": resp_score_detail, 115 | }, 116 | f, 117 | indent=4, 118 | ensure_ascii=False, 119 | ) 120 | return 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /src/flexrag/entrypoints/run_retriever.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import hydra 9 | from hydra.core.config_store import ConfigStore 10 | from omegaconf import OmegaConf 11 | 12 | from flexrag.common_dataclass import Context, RetrievedContext 13 | from flexrag.datasets import MTEBDataset, MTEBDatasetConfig 14 | from flexrag.metrics import Evaluator, EvaluatorConfig 15 | from flexrag.retriever import RETRIEVERS 16 | from flexrag.utils import LOGGER_MANAGER, SimpleProgressLogger, load_user_module 17 | 18 | # load user modules before loading config 19 | for arg in sys.argv: 20 | if arg.startswith("user_module="): 21 | load_user_module(arg.split("=")[1]) 22 | sys.argv.remove(arg) 23 | 24 | RetrieverConfig = RETRIEVERS.make_config(config_name="RetrieverConfig") 25 | 26 | 27 | @dataclass 28 | class Config(RetrieverConfig, MTEBDatasetConfig): 29 | output_path: Optional[str] = None 30 | eval_config: EvaluatorConfig = field(default_factory=EvaluatorConfig) # fmt: skip 31 | log_interval: int = 10 32 | 33 | 34 | cs = ConfigStore.instance() 35 | cs.store(name="default", node=Config) 36 | logger = LOGGER_MANAGER.get_logger("run_retriever") 37 | 38 | 39 | @hydra.main(version_base="1.3", config_path=None, config_name="default") 40 | def main(config: Config): 41 | # load dataset 42 | testset = MTEBDataset(config) 43 | 44 | # load assistant 45 | retriever = RETRIEVERS.load(config) 46 | 47 | # prepare output paths 48 | if config.output_path is not None: 49 | if not os.path.exists(config.output_path): 50 | os.makedirs(config.output_path) 51 | details_path = os.path.join(config.output_path, "details.jsonl") 52 | eval_score_path = os.path.join(config.output_path, "eval_score.json") 53 | config_path = os.path.join(config.output_path, "config.yaml") 54 | log_path = os.path.join(config.output_path, "log.txt") 55 | else: 56 | details_path = os.devnull 57 | eval_score_path = os.devnull 58 | config_path = os.devnull 59 | log_path = os.devnull 60 | 61 | # save config and set logger 62 | with open(config_path, "w", encoding="utf-8") as f: 63 | OmegaConf.save(config, f) 64 | handler = logging.FileHandler(log_path) 65 | LOGGER_MANAGER.add_handler(handler) 66 | logger.debug(f"Configs:\n{OmegaConf.to_yaml(config)}") 67 | 68 | # search and generate 69 | p_logger = SimpleProgressLogger(logger, interval=config.log_interval) 70 | questions = [] 71 | goldens: list[list[Context]] = [] 72 | retrieved: list[list[RetrievedContext]] = [] 73 | with open(details_path, "w", encoding="utf-8") as f: 74 | for item in testset: 75 | questions.append(item.question) 76 | goldens.append(item.contexts) 77 | ctxs = retriever.search(query=item.question)[0] 78 | retrieved.append(ctxs) 79 | json.dump( 80 | { 81 | "question": item.question, 82 | "golden_contexts": item.contexts, 83 | "metadata": item.meta_data, 84 | "contexts": ctxs, 85 | }, 86 | f, 87 | ensure_ascii=False, 88 | ) 89 | f.write("\n") 90 | p_logger.update(desc="Searching") 91 | 92 | # evaluate 93 | evaluator = Evaluator(config.eval_config) 94 | resp_score, resp_score_detail = evaluator.evaluate( 95 | questions=questions, 96 | retrieved_contexts=retrieved, 97 | golden_contexts=goldens, 98 | log=True, 99 | ) 100 | with open(eval_score_path, "w", encoding="utf-8") as f: 101 | json.dump( 102 | { 103 | "eval_scores": resp_score, 104 | "eval_details": resp_score_detail, 105 | }, 106 | f, 107 | indent=4, 108 | ensure_ascii=False, 109 | ) 110 | return 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /src/flexrag/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .generation_metrics import ( 2 | BLEU, 3 | BLEUConfig, 4 | Rouge, 5 | chrF, 6 | chrFConfig, 7 | ) 8 | from .matching_metrics import ( 9 | F1, 10 | Accuracy, 11 | ExactMatch, 12 | MatchingMetrics, 13 | Precision, 14 | Recall, 15 | ) 16 | from .metrics_base import MetricsBase 17 | from .retrieval_metrics import ( 18 | SuccessRate, 19 | SuccessRateConfig, 20 | RetrievalRecall, 21 | RetrievalRecallConfig, 22 | RetrievalPrecision, 23 | RetrievalPrecisionConfig, 24 | RetrievalMAP, 25 | RetrievalMAPConfig, 26 | RetrievalNDCG, 27 | RetrievalNDCGConfig, 28 | ) 29 | 30 | from .evaluator import Evaluator, EvaluatorConfig # isort: skip 31 | 32 | __all__ = [ 33 | "MetricsBase", 34 | "MatchingMetrics", 35 | "Accuracy", 36 | "ExactMatch", 37 | "F1", 38 | "Recall", 39 | "Precision", 40 | "BLEU", 41 | "BLEUConfig", 42 | "Rouge", 43 | "chrF", 44 | "chrFConfig", 45 | "SuccessRate", 46 | "SuccessRateConfig", 47 | "RetrievalRecall", 48 | "RetrievalRecallConfig", 49 | "RetrievalPrecision", 50 | "RetrievalPrecisionConfig", 51 | "RetrievalMAP", 52 | "RetrievalMAPConfig", 53 | "RetrievalNDCG", 54 | "RetrievalNDCGConfig", 55 | "Evaluator", 56 | "EvaluatorConfig", 57 | ] 58 | -------------------------------------------------------------------------------- /src/flexrag/metrics/evaluator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from flexrag.common_dataclass import RetrievedContext 4 | from flexrag.text_process import TextProcessPipeline, TextProcessPipelineConfig 5 | from flexrag.utils import LOGGER_MANAGER 6 | 7 | from .metrics_base import METRICS, MetricsBase 8 | 9 | logger = LOGGER_MANAGER.get_logger("flexrag.metrics") 10 | MetricConfig = METRICS.make_config(allow_multiple=True) 11 | 12 | 13 | @dataclass 14 | class EvaluatorConfig(MetricConfig): 15 | round: int = 2 16 | response_preprocess: TextProcessPipelineConfig = field(default_factory=TextProcessPipelineConfig) # type: ignore 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, cfg: EvaluatorConfig) -> None: 21 | self.metrics: dict[str, MetricsBase] = { 22 | name: metric for name, metric in zip(cfg.metrics_type, METRICS.load(cfg)) 23 | } 24 | self.response_pipeline = TextProcessPipeline(cfg.response_preprocess) 25 | self.round = cfg.round 26 | return 27 | 28 | def evaluate( 29 | self, 30 | questions: list[str] = None, 31 | responses: list[str] = None, 32 | golden_responses: list[list[str]] = None, 33 | retrieved_contexts: list[list[str | RetrievedContext]] = None, 34 | golden_contexts: list[list[str]] = None, 35 | log: bool = True, 36 | ): 37 | """Evaluate the generated responses against the ground truth responses. 38 | 39 | :param questions: A list of questions. Defaults to None. 40 | :param responses: A list of responses. Defaults to None. 41 | :param golden_responses: A list of golden responses. Defaults to None. 42 | :param retrieved_contexts: A list of retrieved contexts. Defaults to None. 43 | :param golden_contexts: A list of golden contexts. Defaults to None. 44 | :param log: Whether to log the evaluation results. Defaults to True. 45 | :type questions: list[str], optional 46 | :type responses: list[str], optional 47 | :type golden_responses: list[list[str]], optional 48 | :type retrieved_contexts: list[list[str | RetrievedContext]], optional 49 | :type golden_contexts: list[list[str]], optional 50 | :type log: bool, optional 51 | :return: The evaluation results and the evaluation details. 52 | :rtype: tuple[dict[str, float], dict[str, Any]] 53 | """ 54 | # check the input arguments 55 | not_none_args = [ 56 | arg 57 | for arg in [ 58 | questions, 59 | responses, 60 | golden_responses, 61 | retrieved_contexts, 62 | golden_contexts, 63 | ] 64 | if arg is not None 65 | ] 66 | assert len(not_none_args) > 1, "At least one argument must be provided." 67 | assert all( 68 | len(i) == len(not_none_args[0]) for i in not_none_args 69 | ), "All arguments must have the same length." 70 | 71 | # evaluate 72 | evaluation_results = {} 73 | evaluation_details = {} 74 | if responses is not None: 75 | responses = [self.response_pipeline(res) for res in responses] 76 | if golden_responses is not None: 77 | golden_responses = [ 78 | [self.response_pipeline(g) for g in golds] for golds in golden_responses 79 | ] 80 | for metric in self.metrics: 81 | metric = str(metric) # make json serializable 82 | r, r_detail = self.metrics[metric]( 83 | questions=questions, 84 | responses=responses, 85 | golden_responses=golden_responses, 86 | retrieved_contexts=retrieved_contexts, 87 | golden_contexts=golden_contexts, 88 | ) 89 | if log: 90 | for name, score in r.items(): 91 | logger.info(f"{name}: {score*100:.{self.round}f}%") 92 | evaluation_results.update(r) 93 | evaluation_details[metric] = r_detail 94 | return evaluation_results, evaluation_details 95 | -------------------------------------------------------------------------------- /src/flexrag/metrics/lib_rel.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace py = pybind11; 7 | 8 | std::vector> get_contain_map(const std::vector& evs, const std::vector& rets) { 9 | std::vector> results; 10 | results.reserve(rets.size()); 11 | 12 | for (const auto& ret : rets) { 13 | std::vector result_row; 14 | result_row.reserve(evs.size()); 15 | for (const auto& ev : evs) { 16 | result_row.push_back(ret.find(ev) != std::string::npos); 17 | } 18 | results.push_back(std::move(result_row)); 19 | } 20 | 21 | return results; 22 | } 23 | 24 | PYBIND11_MODULE(lib_rel, m) { 25 | m.def("get_contain_map", &get_contain_map, "Get contain map."); 26 | } 27 | -------------------------------------------------------------------------------- /src/flexrag/metrics/matching_metrics.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from collections import Counter 3 | 4 | from flexrag.utils import TIME_METER 5 | 6 | from .metrics_base import MetricsBase, METRICS 7 | 8 | 9 | class MatchingMetrics(MetricsBase): 10 | name: str 11 | 12 | @abstractmethod 13 | def compute_item(self, golds: list[str], response: str) -> float: 14 | return 15 | 16 | @TIME_METER("metrics.matching_score") 17 | def compute( 18 | self, responses: list[str], golden_responses: list[list[str]], **kwargs 19 | ) -> tuple[float, dict[str, list[float]]]: 20 | matching_list = [] 21 | for golds, response in zip(golden_responses, responses): 22 | matching_list.append(self.compute_item(golds, response)) 23 | matching_score = sum(matching_list) / len(matching_list) 24 | return {self.name: matching_score}, {"item_score": matching_list} 25 | 26 | 27 | @METRICS("generation_em") 28 | class ExactMatch(MatchingMetrics): 29 | """ExactMatch metric computes if any of the golden responses is exactly the same as the predicted response.""" 30 | 31 | name = "generation_em" 32 | 33 | def compute_item(self, golds: list[str], response: str) -> float: 34 | return float(response in golds) 35 | 36 | 37 | @METRICS("generation_accuracy") 38 | class Accuracy(MatchingMetrics): 39 | """Accuracy metric computes if any of the golden responses is in the predicted response.""" 40 | 41 | name = "generation_accuracy" 42 | 43 | def compute_item(self, golds: list[str], response: str) -> float: 44 | return float(any(gold in response for gold in golds)) 45 | 46 | 47 | def f1_recall_precision(golds: list[str], response: str) -> tuple[float, float, float]: 48 | true_counters = [Counter(gold.split()) for gold in golds] 49 | pred_counter = Counter(response.split()) 50 | precision = 0.0 51 | recall = 0.0 52 | f1 = 0.0 53 | for gold in true_counters: 54 | common = sum((gold & pred_counter).values()) 55 | if common == 0: 56 | continue 57 | p = 1.0 * common / sum(pred_counter.values()) 58 | r = 1.0 * common / sum(gold.values()) 59 | f1_ = (2 * p * r) / (p + r) 60 | precision = max(p, precision) 61 | recall = max(r, recall) 62 | f1 = max(f1, f1_) 63 | return f1, recall, precision 64 | 65 | 66 | @METRICS("generation_f1") 67 | class F1(MatchingMetrics): 68 | """F1 metric computes the F1 score of the predicted response against the golden responses.""" 69 | 70 | name = "generation_f1" 71 | 72 | def compute_item(self, golds: list[str], response: str) -> float: 73 | return f1_recall_precision(golds, response)[0] 74 | 75 | 76 | @METRICS("generation_recall") 77 | class Recall(MatchingMetrics): 78 | """Recall metric computes the recall of the predicted response against the golden responses.""" 79 | 80 | name = "generation_recall" 81 | 82 | def compute_item(self, golds: list[str], response: str) -> float: 83 | return f1_recall_precision(golds, response)[1] 84 | 85 | 86 | @METRICS("generation_precision") 87 | class Precision(MatchingMetrics): 88 | """Precision metric computes the precision of the predicted response against the golden responses.""" 89 | 90 | name = "generation_precision" 91 | 92 | def compute_item(self, golds: list[str], response: str) -> float: 93 | return f1_recall_precision(golds, response)[2] 94 | -------------------------------------------------------------------------------- /src/flexrag/metrics/metrics_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from flexrag.common_dataclass import RetrievedContext 4 | from flexrag.utils import Register 5 | 6 | 7 | class MetricsBase(ABC): 8 | def __call__( 9 | self, 10 | questions: list[str] = None, 11 | responses: list[str] = None, 12 | golden_responses: list[list[str]] = None, 13 | retrieved_contexts: list[list[str | RetrievedContext]] = None, 14 | golden_contexts: list[list[str]] = None, 15 | ) -> tuple[dict[str, float], dict]: 16 | """ 17 | Compute the metric value. 18 | 19 | :param questions: A list of questions. Defaults to None. 20 | :param responses: A list of responses. Defaults to None. 21 | :param golden_responses: A list of golden responses. Defaults to None. 22 | :param retrieved_contexts: A list of retrieved contexts. Defaults to None. 23 | :param golden_contexts: A list of golden contexts. Defaults to None. 24 | :type questions: list[str], optional 25 | :type responses: list[str], optional 26 | :type golden_responses: list[list[str]], optional 27 | :type retrieved_contexts: list[list[str | RetrievedContext]], optional 28 | :type golden_contexts: list[list[str]], optional 29 | :return: The metric scores and the metadata of the metric. 30 | :rtype: tuple[dict[str, float], dict] 31 | """ 32 | return self.compute( 33 | questions=questions, 34 | responses=responses, 35 | golden_responses=golden_responses, 36 | retrieved_contexts=retrieved_contexts, 37 | golden_contexts=golden_contexts, 38 | ) 39 | 40 | @abstractmethod 41 | def compute( 42 | self, 43 | questions: list[str] = None, 44 | responses: list[str] = None, 45 | golden_responses: list[list[str]] = None, 46 | retrieved_contexts: list[list[str | RetrievedContext]] = None, 47 | golden_contexts: list[list[str]] = None, 48 | ) -> tuple[dict[str, float], dict]: 49 | """ 50 | Compute the metric value. 51 | 52 | :param questions: A list of questions. Defaults to None. 53 | :param responses: A list of responses. Defaults to None. 54 | :param golden_responses: A list of golden responses. Defaults to None. 55 | :param retrieved_contexts: A list of retrieved contexts. Defaults to None. 56 | :param golden_contexts: A list of golden contexts. Defaults to None. 57 | :type questions: list[str], optional 58 | :type responses: list[str], optional 59 | :type golden_responses: list[list[str]], optional 60 | :type retrieved_contexts: list[list[str | RetrievedContext]], optional 61 | :type golden_contexts: list[list[str]], optional 62 | :return: The metric scores and the metadata of the metric. 63 | :rtype: tuple[dict[str, float], dict] 64 | """ 65 | return 66 | 67 | 68 | METRICS = Register[MetricsBase]("metrics") 69 | -------------------------------------------------------------------------------- /src/flexrag/metrics/xfinder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from flexrag.utils import TIME_METER, Choices 4 | 5 | from .xfinder_utils import Evaluator 6 | from .metrics_base import METRICS, MetricsBase 7 | 8 | 9 | @dataclass 10 | class xFinderConfig: 11 | model_type: Choices(["qwen", "llama"]) = "qwen" # type: ignore 12 | model_path: str = "IAAR-Shanghai/xFinder-qwen1505" 13 | answer_type: Choices( # type: ignore 14 | ["math", "short_text", "categorical_label", "alphabet_option"] 15 | ) = "short_text" 16 | temperature: float = 0.7 17 | max_tokens: int = 100 18 | device_id: list[int] = field(default_factory=list) 19 | 20 | 21 | @METRICS("generation_xfinder") 22 | class xFinder(MetricsBase): 23 | def __init__(self, config: xFinderConfig): 24 | if config.model_type == "qwen": 25 | model_name = "xFinder-qwen1505" 26 | else: 27 | model_name = "xFinder-llama38it" 28 | self.evaluator = Evaluator( 29 | model_name=model_name, 30 | model_path_or_url=config.model_path, 31 | temperature=config.temperature, 32 | max_tokens=config.max_tokens, 33 | device_id=config.device_id, 34 | ) 35 | self.answer_type = config.answer_type 36 | return 37 | 38 | @TIME_METER("metrics.xfinder_score") 39 | def compute( 40 | self, 41 | questions: list[str], 42 | responses: list[str], 43 | golden_responses: list[list[str]], 44 | choices: list[list[str]], 45 | **kwargs 46 | ) -> tuple[float, dict[str, list[float]]]: 47 | results = [] 48 | for question, response, goldens, choice in zip( 49 | questions, responses, golden_responses, choices 50 | ): 51 | self.evaluator.evaluate_single_item( 52 | question=question, 53 | llm_output=response, 54 | answer_type=self.answer_type, 55 | correct_answer=goldens[0], 56 | answer_range=",".join(choice), 57 | ) 58 | 59 | correct_count = sum(result[-1] for result in results) 60 | accuracy = correct_count / max(1, len(results)) if results else 0 61 | return accuracy 62 | -------------------------------------------------------------------------------- /src/flexrag/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .anthropic_model import AnthropicGenerator, AnthropicGeneratorConfig 2 | from .cohere_model import ( 3 | CohereEncoder, 4 | CohereEncoderConfig, 5 | ) 6 | from .hf_model import ( 7 | HFModelConfig, 8 | HFEncoder, 9 | HFEncoderConfig, 10 | HFGenerator, 11 | HFGeneratorConfig, 12 | HFClipEncoder, 13 | HFClipEncoderConfig, 14 | HFVLMGenerator, 15 | HFVLMGeneratorConfig, 16 | ) 17 | from .jina_model import JinaEncoder, JinaEncoderConfig 18 | from .llamacpp_model import LlamacppGenerator, LlamacppGeneratorConfig 19 | from .model_base import ( 20 | EncoderBase, 21 | GenerationConfig, 22 | GeneratorBase, 23 | VLMGeneratorBase, 24 | GENERATORS, 25 | ENCODERS, 26 | ) 27 | from .ollama_model import ( 28 | OllamaGenerator, 29 | OllamaGeneratorConfig, 30 | OllamaEncoder, 31 | OllamaEncoderConfig, 32 | ) 33 | from .openai_model import ( 34 | OpenAIConfig, 35 | OpenAIEncoder, 36 | OpenAIEncoderConfig, 37 | OpenAIGenerator, 38 | OpenAIGeneratorConfig, 39 | ) 40 | from .vllm_model import VLLMGenerator, VLLMGeneratorConfig 41 | from .sentence_transformers_model import ( 42 | SentenceTransformerEncoder, 43 | SentenceTransformerEncoderConfig, 44 | ) 45 | 46 | 47 | GeneratorConfig = GENERATORS.make_config(config_name="GeneratorConfig") 48 | EncoderConfig = ENCODERS.make_config(config_name="EncoderConfig", default=None) 49 | 50 | 51 | __all__ = [ 52 | "GeneratorBase", 53 | "VLMGeneratorBase", 54 | "GenerationConfig", 55 | "EncoderBase", 56 | "AnthropicGenerator", 57 | "AnthropicGeneratorConfig", 58 | "HFModelConfig", 59 | "HFGenerator", 60 | "HFGeneratorConfig", 61 | "HFEncoder", 62 | "HFEncoderConfig", 63 | "HFClipEncoder", 64 | "HFClipEncoderConfig", 65 | "HFVLMGenerator", 66 | "HFVLMGeneratorConfig", 67 | "OllamaGenerator", 68 | "OllamaGeneratorConfig", 69 | "OllamaEncoder", 70 | "OllamaEncoderConfig", 71 | "OpenAIGenerator", 72 | "OpenAIGeneratorConfig", 73 | "OpenAIConfig", 74 | "OpenAIEncoder", 75 | "OpenAIEncoderConfig", 76 | "VLLMGenerator", 77 | "VLLMGeneratorConfig", 78 | "LlamacppGenerator", 79 | "LlamacppGeneratorConfig", 80 | "JinaEncoder", 81 | "JinaEncoderConfig", 82 | "CohereEncoder", 83 | "CohereEncoderConfig", 84 | "SentenceTransformerEncoder", 85 | "SentenceTransformerEncoderConfig", 86 | "GENERATORS", 87 | "ENCODERS", 88 | ] 89 | -------------------------------------------------------------------------------- /src/flexrag/models/cohere_model.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import httpx 7 | import numpy as np 8 | from numpy import ndarray 9 | from omegaconf import MISSING 10 | 11 | from flexrag.utils import TIME_METER, Choices 12 | 13 | from .model_base import ENCODERS, EncoderBase 14 | 15 | 16 | @dataclass 17 | class CohereEncoderConfig: 18 | """Configuration for CohereEncoder. 19 | 20 | :param model: The model to use. Default is "embed-multilingual-v3.0". 21 | :type model: str 22 | :param input_type: Specifies the type of input passed to the model. Required for embedding models v3 and higher. Default is "search_document". Available options are "search_document", "search_query", "classification", "clustering", "image". 23 | :type input_type: str 24 | :param base_url: The base URL of the API. Default is None. 25 | :type base_url: Optional[str] 26 | :param api_key: The API key to use. Default is os.environ.get("COHERE_API_KEY", MISSING). 27 | :type api_key: str 28 | :param proxy: The proxy to use. Default is None. 29 | :type proxy: Optional[str] 30 | """ 31 | 32 | model: str = "embed-multilingual-v3.0" 33 | input_type: Choices( # type: ignore 34 | [ 35 | "search_document", 36 | "search_query", 37 | "classification", 38 | "clustering", 39 | "image", 40 | ] 41 | ) = "search_document" 42 | base_url: Optional[str] = None 43 | api_key: str = os.environ.get("COHERE_API_KEY", MISSING) 44 | proxy: Optional[str] = None 45 | 46 | 47 | @ENCODERS("cohere", config_class=CohereEncoderConfig) 48 | class CohereEncoder(EncoderBase): 49 | def __init__(self, cfg: CohereEncoderConfig): 50 | from cohere import ClientV2 51 | 52 | if cfg.proxy is not None: 53 | httpx_client = httpx.Client(proxies=cfg.proxy) 54 | else: 55 | httpx_client = None 56 | self.client = ClientV2( 57 | api_key=cfg.api_key, 58 | base_url=cfg.base_url, 59 | httpx_client=httpx_client, 60 | ) 61 | self.model = cfg.model 62 | self.input_type = cfg.input_type 63 | return 64 | 65 | @TIME_METER("cohere_encode") 66 | def _encode(self, texts: list[str]) -> ndarray: 67 | r = self.client.embed( 68 | texts=texts, 69 | model=self.model, 70 | input_type=self.input_type, 71 | embedding_types=["float"], 72 | ) 73 | embeddings = r.embeddings.float 74 | return np.array(embeddings) 75 | 76 | @TIME_METER("cohere_encode") 77 | async def async_encode(self, texts: list[str]): 78 | task = asyncio.create_task( 79 | asyncio.to_thread( 80 | self.client.embed, 81 | texts=texts, 82 | model=self.model, 83 | input_type=self.input_type, 84 | embedding_types=["float"], 85 | ) 86 | ) 87 | embeddings = (await task).embeddings.float 88 | return np.array(embeddings) 89 | 90 | @property 91 | def embedding_size(self) -> int: 92 | return self._data_template["dimension"] 93 | -------------------------------------------------------------------------------- /src/flexrag/models/jina_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import httpx 6 | import numpy as np 7 | from numpy import ndarray 8 | from omegaconf import MISSING 9 | 10 | from flexrag.utils import TIME_METER, Choices 11 | 12 | from .model_base import EncoderBase, ENCODERS 13 | 14 | 15 | @dataclass 16 | class JinaEncoderConfig: 17 | """Configuration for JinaEncoder. 18 | 19 | :param model: The model to use. Default is "jina-embeddings-v3". 20 | :type model: str 21 | :param base_url: The base URL of the Jina embeddings API. Default is "https://api.jina.ai/v1/embeddings". 22 | :type base_url: str 23 | :param api_key: The API key for the Jina embeddings API. 24 | :type api_key: str 25 | :param dimensions: The dimension of the embeddings. Default is 1024. 26 | :type dimensions: int 27 | :param task: The task for the embeddings. Default is None. Available options are "retrieval.query", "retrieval.passage", "separation", "classification", and "text-matching". 28 | :type task: str 29 | :param proxy: The proxy to use. Defaults to None. 30 | :type proxy: Optional[str] 31 | """ 32 | 33 | model: str = "jina-embeddings-v3" 34 | base_url: str = "https://api.jina.ai/v1/embeddings" 35 | api_key: str = os.environ.get("JINA_API_KEY", MISSING) 36 | dimensions: int = 1024 37 | task: Optional[ 38 | Choices( # type: ignore 39 | [ 40 | "retrieval.query", 41 | "retrieval.passage", 42 | "separation", 43 | "classification", 44 | "text-matching", 45 | ] 46 | ) 47 | ] = None 48 | proxy: Optional[str] = None 49 | 50 | 51 | @ENCODERS("jina", config_class=JinaEncoderConfig) 52 | class JinaEncoder(EncoderBase): 53 | def __init__(self, cfg: JinaEncoderConfig): 54 | # prepare client 55 | self.client = httpx.Client( 56 | headers={ 57 | "Content-Type": "application/json", 58 | "Authorization": f"Bearer {cfg.api_key}", 59 | }, 60 | proxy=cfg.proxy, 61 | base_url=cfg.base_url, 62 | follow_redirects=True, 63 | ) 64 | self.async_client = httpx.AsyncClient( 65 | headers={ 66 | "Content-Type": "application/json", 67 | "Authorization": f"Bearer {cfg.api_key}", 68 | }, 69 | proxy=cfg.proxy, 70 | base_url=cfg.base_url, 71 | follow_redirects=True, 72 | ) 73 | # prepare template 74 | self.data_template = { 75 | "model": cfg.model, 76 | "task": cfg.task, 77 | "dimensions": cfg.dimensions, 78 | "late_chunking": False, 79 | "embedding_type": "float", 80 | "input": [], 81 | } 82 | return 83 | 84 | @TIME_METER("jina_encode") 85 | def _encode(self, texts: list[str]) -> ndarray: 86 | data = self.data_template.copy() 87 | data["input"] = texts 88 | response = self.client.post("", json=data) 89 | response.raise_for_status() 90 | embeddings = [i["embedding"] for i in response.json()["data"]] 91 | return np.array(embeddings) 92 | 93 | @TIME_METER("jina_encode") 94 | async def async_encode(self, texts: list[str]) -> ndarray: 95 | data = self.data_template.copy() 96 | data["input"] = texts 97 | response = await self.async_client.post("", json=data) 98 | embeddings = [i["embedding"] for i in response.json()["data"]] 99 | return np.array(embeddings) 100 | 101 | @property 102 | def embedding_size(self) -> int: 103 | return self.data_template["dimension"] 104 | -------------------------------------------------------------------------------- /src/flexrag/models/sentence_transformers_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass, field 3 | from typing import Any, Optional 4 | 5 | import numpy as np 6 | from omegaconf import MISSING 7 | 8 | from flexrag.utils import TIME_METER 9 | 10 | from .model_base import ENCODERS, EncoderBase 11 | 12 | 13 | @dataclass 14 | class SentenceTransformerEncoderConfig: 15 | """Configuration for SentenceTransformerEncoder. 16 | 17 | :param model_path: The path to the model. Required. 18 | :type model_path: str 19 | :param device_id: The device id to use. [] for CPU. Defaults to []. 20 | :type device_id: list[int] 21 | :param trust_remote_code: Whether to trust remote code. Defaults to False. 22 | :type trust_remote_code: bool 23 | :param task: The task to use. Defaults to None. 24 | :type task: Optional[str] 25 | :param prompt_name: The prompt name to use. Defaults to None. 26 | :type prompt_name: Optional[str] 27 | :param prompt: The prompt to use. Defaults to None. 28 | :type prompt: Optional[str] 29 | :param prompt_dict: The prompt dictionary to use. Defaults to None. 30 | :type prompt_dict: Optional[dict] 31 | :param normalize: Whether to normalize embeddings. Defaults to False. 32 | :type normalize: bool 33 | :param model_kwargs: Additional keyword arguments for loading the model. Defaults to {}. 34 | :type model_kwargs: dict[str, Any] 35 | """ 36 | 37 | model_path: str = MISSING 38 | device_id: list[int] = field(default_factory=list) 39 | trust_remote_code: bool = False 40 | task: Optional[str] = None 41 | prompt_name: Optional[str] = None 42 | prompt: Optional[str] = None 43 | prompt_dict: Optional[dict] = None 44 | normalize: bool = False 45 | model_kwargs: dict[str, Any] = field(default_factory=dict) 46 | 47 | 48 | @ENCODERS("sentence_transformer", config_class=SentenceTransformerEncoderConfig) 49 | class SentenceTransformerEncoder(EncoderBase): 50 | def __init__(self, config: SentenceTransformerEncoderConfig) -> None: 51 | super().__init__() 52 | from sentence_transformers import SentenceTransformer 53 | 54 | self.devices = config.device_id 55 | self.model = SentenceTransformer( 56 | model_name_or_path=config.model_path, 57 | device=f"cuda:{config.device_id[0]}" if config.device_id else "cpu", 58 | trust_remote_code=config.trust_remote_code, 59 | backend="torch", 60 | prompts=config.prompt_dict, 61 | model_kwargs=config.model_kwargs, 62 | ) 63 | if len(config.device_id) > 1: 64 | self.pool = self.model.start_multi_process_pool( 65 | target_devices=[f"cuda:{i}" for i in config.device_id] 66 | ) 67 | else: 68 | self.pool = None 69 | 70 | # set args 71 | self.prompt_name = config.prompt_name 72 | self.task = config.task 73 | self.prompt = config.prompt 74 | self.normalize = config.normalize 75 | return 76 | 77 | @TIME_METER("st_encode") 78 | def _encode(self, texts: list[str], **kwargs) -> np.ndarray: 79 | args = { 80 | "sentences": texts, 81 | "batch_size": len(texts), 82 | "show_progress_bar": False, 83 | "convert_to_numpy": True, 84 | "normalize_embeddings": self.normalize, 85 | } 86 | if kwargs.get("task", self.task) is not None: 87 | args["task"] = self.task 88 | if kwargs.get("prompt_name", self.prompt_name) is not None: 89 | args["prompt_name"] = self.prompt_name 90 | if kwargs.get("prompt", self.prompt) is not None: 91 | args["prompt"] = self.prompt 92 | if (len(texts) >= len(self.devices) * 8) and (self.pool is not None): 93 | args["pool"] = self.pool 94 | args["batch_size"] = math.ceil(args["batch_size"] / len(self.devices)) 95 | embeddings = self.model.encode_multi_process(**args) 96 | else: 97 | embeddings = self.model.encode(**args) 98 | return embeddings 99 | 100 | @property 101 | def embedding_size(self) -> int: 102 | return self.model.get_sentence_embedding_dimension() 103 | -------------------------------------------------------------------------------- /src/flexrag/prompt/__init__.py: -------------------------------------------------------------------------------- 1 | from .template import load_template, ChatTemplate, HFTemplate 2 | from .prompt_base import ChatPrompt, ChatTurn, MultiModelChatPrompt, MultiModelChatTurn 3 | 4 | 5 | __all__ = [ 6 | "ChatPrompt", 7 | "ChatTurn", 8 | "load_template", 9 | "ChatTemplate", 10 | "HFTemplate", 11 | "MultiModelChatPrompt", 12 | "MultiModelChatTurn", 13 | ] 14 | -------------------------------------------------------------------------------- /src/flexrag/ranker/__init__.py: -------------------------------------------------------------------------------- 1 | from .cohere_ranker import CohereRanker, CohereRankerConfig 2 | from .gpt_ranker import RankGPTRanker, RankGPTRankerConfig 3 | from .hf_ranker import ( 4 | HFColBertRanker, 5 | HFColBertRankerConfig, 6 | HFCrossEncoderRanker, 7 | HFCrossEncoderRankerConfig, 8 | HFSeq2SeqRanker, 9 | HFSeq2SeqRankerConfig, 10 | ) 11 | from .jina_ranker import JinaRanker, JinaRankerConfig 12 | from .mixedbread_ranker import MixedbreadRanker, MixedbreadRankerConfig 13 | from .voyage_ranker import VoyageRanker, VoyageRankerConfig 14 | 15 | from .ranker import RankerBase, RankerBaseConfig, RANKERS, RankingResult # isort: skip 16 | 17 | 18 | RankerConfig = RANKERS.make_config(config_name="RankerConfig", default=None) 19 | 20 | 21 | __all__ = [ 22 | "RankerBase", 23 | "RankerBaseConfig", 24 | "RANKERS", 25 | "RankerConfig", 26 | "RankingResult", 27 | "HFCrossEncoderRanker", 28 | "HFCrossEncoderRankerConfig", 29 | "HFSeq2SeqRanker", 30 | "HFSeq2SeqRankerConfig", 31 | "HFColBertRanker", 32 | "HFColBertRankerConfig", 33 | "CohereRanker", 34 | "CohereRankerConfig", 35 | "JinaRanker", 36 | "JinaRankerConfig", 37 | "MixedbreadRanker", 38 | "MixedbreadRankerConfig", 39 | "VoyageRanker", 40 | "VoyageRankerConfig", 41 | "RankGPTRanker", 42 | "RankGPTRankerConfig", 43 | ] 44 | -------------------------------------------------------------------------------- /src/flexrag/ranker/cohere_ranker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import httpx 6 | import numpy as np 7 | from omegaconf import MISSING 8 | 9 | from flexrag.utils import TIME_METER 10 | 11 | from .ranker import RankerBase, RankerBaseConfig, RANKERS 12 | 13 | 14 | @dataclass 15 | class CohereRankerConfig(RankerBaseConfig): 16 | """The configuration for the Cohere ranker. 17 | 18 | :param model: the model name of the ranker. Default is "rerank-multilingual-v3.0". 19 | :type model: str 20 | :param base_url: the base URL of the Cohere ranker. Default is None. 21 | :type base_url: Optional[str] 22 | :param api_key: the API key for the Cohere ranker. Required. 23 | :type api_key: str 24 | :param proxy: the proxy for the request. Default is None. 25 | :type proxy: Optional[str] 26 | """ 27 | 28 | model: str = "rerank-multilingual-v3.0" 29 | base_url: Optional[str] = None 30 | api_key: str = MISSING 31 | proxy: Optional[str] = None 32 | 33 | 34 | @RANKERS("cohere", config_class=CohereRankerConfig) 35 | class CohereRanker(RankerBase): 36 | """CohereRanker: The ranker based on the Cohere API.""" 37 | 38 | def __init__(self, cfg: CohereRankerConfig) -> None: 39 | super().__init__(cfg) 40 | from cohere import Client 41 | 42 | if cfg.proxy is not None: 43 | httpx_client = httpx.Client(proxies=cfg.proxy) 44 | else: 45 | httpx_client = None 46 | self.client = Client( 47 | api_key=cfg.api_key, base_url=cfg.base_url, httpx_client=httpx_client 48 | ) 49 | self.model = cfg.model 50 | return 51 | 52 | @TIME_METER("cohere_rank") 53 | def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]: 54 | result = self.client.rerank( 55 | query=query, 56 | documents=candidates, 57 | model=self.model, 58 | top_n=len(candidates), 59 | ) 60 | scores = [i.relevance_score for i in result.results] 61 | return None, scores 62 | 63 | @TIME_METER("cohere_rank") 64 | async def _async_rank(self, query: str, candidates: list[str]): 65 | result = await asyncio.create_task( 66 | asyncio.to_thread( 67 | self.client.rerank, 68 | query=query, 69 | documents=candidates, 70 | model=self.model, 71 | top_n=len(candidates), 72 | ) 73 | ) 74 | scores = [i.relevance_score for i in result.results] 75 | return None, scores 76 | -------------------------------------------------------------------------------- /src/flexrag/ranker/jina_ranker.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import httpx 7 | from omegaconf import MISSING 8 | 9 | from flexrag.utils import TIME_METER 10 | 11 | from .ranker import RankerBase, RankerBaseConfig, RANKERS 12 | 13 | 14 | @dataclass 15 | class JinaRankerConfig(RankerBaseConfig): 16 | """The configuration for the Jina ranker. 17 | 18 | :param model: the model name of the ranker. Default is "jina-reranker-v2-base-multilingual". 19 | :type model: str 20 | :param base_url: the base URL of the Jina ranker. Default is "https://api.jina.ai/v1/rerank". 21 | :type base_url: str 22 | :param api_key: the API key for the Jina ranker. Required. 23 | :type api_key: str 24 | :param proxy: The proxy to use. Defaults to None. 25 | :type proxy: Optional[str] 26 | """ 27 | 28 | model: str = "jina-reranker-v2-base-multilingual" 29 | base_url: str = "https://api.jina.ai/v1/rerank" 30 | api_key: str = os.environ.get("JINA_API_KEY", MISSING) 31 | proxy: Optional[str] = None 32 | 33 | 34 | @RANKERS("jina", config_class=JinaRankerConfig) 35 | class JinaRanker(RankerBase): 36 | """JinaRanker: The ranker based on the Jina API.""" 37 | 38 | def __init__(self, cfg: JinaRankerConfig) -> None: 39 | super().__init__(cfg) 40 | # prepare client 41 | self.client = httpx.Client( 42 | base_url=cfg.base_url, 43 | headers={ 44 | "Content-Type": "application/json", 45 | "Authorization": f"Bearer {cfg.api_key}", 46 | }, 47 | proxy=cfg.proxy, 48 | follow_redirects=True, 49 | ) 50 | self.async_client = httpx.AsyncClient( 51 | base_url=cfg.base_url, 52 | headers={ 53 | "Content-Type": "application/json", 54 | "Authorization": f"Bearer {cfg.api_key}", 55 | }, 56 | proxy=cfg.proxy, 57 | follow_redirects=True, 58 | ) 59 | 60 | # prepare data template 61 | self.data_template = { 62 | "model": cfg.model, 63 | "query": "", 64 | "top_n": 0, 65 | "documents": [], 66 | } 67 | return 68 | 69 | @TIME_METER("jina_rank") 70 | def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]: 71 | data = self.data_template.copy() 72 | data["query"] = query 73 | data["documents"] = candidates 74 | data["top_n"] = len(candidates) 75 | response = self.client.post("", json=data) 76 | response.raise_for_status() 77 | scores = [i["relevance_score"] for i in response.json()["results"]] 78 | return None, scores 79 | 80 | @TIME_METER("jina_rank") 81 | async def _async_rank( 82 | self, query: str, candidates: list[str] 83 | ) -> tuple[np.ndarray, np.ndarray]: 84 | data = self.data_template.copy() 85 | data["query"] = query 86 | data["documents"] = candidates 87 | data["top_n"] = len(candidates) 88 | response = await self.async_client.post("", json=data) 89 | response.raise_for_status() 90 | scores = [i["relevance_score"] for i in response.json()["results"]] 91 | return None, scores 92 | -------------------------------------------------------------------------------- /src/flexrag/ranker/mixedbread_ranker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import httpx 6 | import numpy as np 7 | from omegaconf import MISSING 8 | 9 | from flexrag.utils import TIME_METER 10 | 11 | from .ranker import RankerBase, RankerBaseConfig, RANKERS 12 | 13 | 14 | @dataclass 15 | class MixedbreadRankerConfig(RankerBaseConfig): 16 | """The configuration for the Mixedbread ranker. 17 | 18 | :param model: the model name of the ranker. Default is "mxbai-rerank-large-v1". 19 | :type model: str 20 | :param api_key: the API key for the Mixedbread ranker. Required. 21 | :type api_key: str 22 | :param base_url: the base URL of the Mixedbread ranker. Default is None. 23 | :type base_url: Optional[str] 24 | :param proxy: the proxy for the request. Default is None. 25 | :type proxy: Optional[str] 26 | """ 27 | 28 | model: str = "mxbai-rerank-large-v1" 29 | base_url: Optional[str] = None 30 | api_key: str = MISSING 31 | proxy: Optional[str] = None 32 | 33 | 34 | @RANKERS("mixedbread", config_class=MixedbreadRankerConfig) 35 | class MixedbreadRanker(RankerBase): 36 | """MixedbreadRanker: The ranker based on the Mixedbread API.""" 37 | 38 | def __init__(self, cfg: MixedbreadRankerConfig) -> None: 39 | super().__init__(cfg) 40 | from mixedbread_ai.client import MixedbreadAI 41 | 42 | if cfg.proxy is not None: 43 | httpx_client = httpx.Client(proxies=cfg.proxy) 44 | else: 45 | httpx_client = None 46 | self.client = MixedbreadAI( 47 | api_key=cfg.api_key, base_url=cfg.base_url, httpx_client=httpx_client 48 | ) 49 | self.model = cfg.model 50 | return 51 | 52 | @TIME_METER("mixedbread_rank") 53 | def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]: 54 | result = self.client.reranking( 55 | query=query, 56 | input=candidates, 57 | model=self.model, 58 | top_k=len(candidates), 59 | ) 60 | scores = [i.score for i in result.data] 61 | return None, scores 62 | 63 | @TIME_METER("mixedbread_rank") 64 | async def _async_rank( 65 | self, query: str, candidates: list[str] 66 | ) -> tuple[np.ndarray, np.ndarray]: 67 | result = await asyncio.create_task( 68 | asyncio.to_thread( 69 | self.client.reranking, 70 | query=query, 71 | input=candidates, 72 | model=self.model, 73 | top_k=len(candidates), 74 | ) 75 | ) 76 | scores = [i.score for i in result.data] 77 | return None, scores 78 | -------------------------------------------------------------------------------- /src/flexrag/ranker/ranker_prompts/rankgpt_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": { 3 | "role": "system", 4 | "content": "You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query." 5 | }, 6 | "history": [ 7 | { 8 | "role": "user", 9 | "content": "I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}." 10 | }, 11 | { 12 | "role": "assistant", 13 | "content": "Okay, please provide the passages." 14 | }, 15 | { 16 | "role": "user", 17 | "content": "Search Query: {query}. \nRank the {num} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only response the ranking results, do not say any word or explain." 18 | } 19 | ], 20 | "demonstrations": [] 21 | } -------------------------------------------------------------------------------- /src/flexrag/ranker/voyage_ranker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | from omegaconf import MISSING 6 | 7 | from flexrag.utils import TIME_METER 8 | 9 | from .ranker import RankerBase, RankerBaseConfig, RANKERS 10 | 11 | 12 | @dataclass 13 | class VoyageRankerConfig(RankerBaseConfig): 14 | """The configuration for the Voyage ranker. 15 | 16 | :param model: the model name of the ranker. Default is "rerank-2". 17 | :type model: str 18 | :param api_key: the API key for the Voyage ranker. Required. 19 | :type api_key: str 20 | :param timeout: the timeout for the request. Default is 3.0. 21 | :type timeout: float 22 | :param max_retries: the maximum number of retries. Default is 3. 23 | :type max_retries: int 24 | """ 25 | 26 | model: str = "rerank-2" 27 | api_key: str = MISSING 28 | timeout: float = 3.0 29 | max_retries: int = 3 30 | 31 | 32 | @RANKERS("voyage", config_class=VoyageRankerConfig) 33 | class VoyageRanker(RankerBase): 34 | """VoyageRanker: The ranker based on the Voyage API.""" 35 | 36 | def __init__(self, cfg: VoyageRankerConfig) -> None: 37 | super().__init__(cfg) 38 | from voyageai import Client 39 | 40 | self.client = Client( 41 | api_key=cfg.api_key, max_retries=cfg.max_retries, timeout=cfg.timeout 42 | ) 43 | self.model = cfg.model 44 | return 45 | 46 | @TIME_METER("voyage_rank") 47 | def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]: 48 | result = self.client.rerank( 49 | query=query, 50 | documents=candidates, 51 | model=self.model, 52 | top_k=len(candidates), 53 | ) 54 | scores = [i.relevance_score for i in result.results] 55 | return None, scores 56 | 57 | @TIME_METER("voyage_rank") 58 | async def _async_rank( 59 | self, query: str, candidates: list[str] 60 | ) -> tuple[np.ndarray, np.ndarray]: 61 | result = await asyncio.create_task( 62 | asyncio.to_thread( 63 | self.client.rerank, 64 | query=query, 65 | documents=candidates, 66 | model=self.model, 67 | top_k=len(candidates), 68 | ) 69 | ) 70 | scores = [i.relevance_score for i in result.results] 71 | return None, scores 72 | -------------------------------------------------------------------------------- /src/flexrag/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .bm25s_retriever import BM25SRetriever, BM25SRetrieverConfig 2 | from .dense_retriever import DenseRetriever, DenseRetrieverConfig 3 | from .elastic_retriever import ElasticRetriever, ElasticRetrieverConfig 4 | from .retriever_base import ( 5 | EditableRetriever, 6 | EditableRetrieverConfig, 7 | RetrieverBase, 8 | RetrieverBaseConfig, 9 | LocalRetriever, 10 | LocalRetrieverConfig, 11 | RETRIEVERS, 12 | ) 13 | from .typesense_retriever import TypesenseRetriever, TypesenseRetrieverConfig 14 | from .web_retrievers import ( 15 | SimpleWebRetriever, 16 | SimpleWebRetrieverConfig, 17 | WikipediaRetriever, 18 | WikipediaRetrieverConfig, 19 | ) 20 | from .hyde_retriever import HydeRetriever, HydeRetrieverConfig 21 | 22 | 23 | RetrieverConfig = RETRIEVERS.make_config(config_name="RetrieverConfig", default=None) 24 | 25 | 26 | __all__ = [ 27 | "BM25SRetriever", 28 | "BM25SRetrieverConfig", 29 | "EditableRetriever", 30 | "EditableRetrieverConfig", 31 | "LocalRetriever", 32 | "LocalRetrieverConfig", 33 | "RetrieverBase", 34 | "RetrieverBaseConfig", 35 | "DenseRetriever", 36 | "DenseRetrieverConfig", 37 | "ElasticRetriever", 38 | "ElasticRetrieverConfig", 39 | "TypesenseRetriever", 40 | "TypesenseRetrieverConfig", 41 | "SimpleWebRetriever", 42 | "SimpleWebRetrieverConfig", 43 | "RETRIEVERS", 44 | "RetrieverConfig", 45 | "WikipediaRetriever", 46 | "WikipediaRetrieverConfig", 47 | "HydeRetriever", 48 | "HydeRetrieverConfig", 49 | ] 50 | -------------------------------------------------------------------------------- /src/flexrag/retriever/hyde_retriever.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from flexrag.common_dataclass import RetrievedContext 4 | from flexrag.models import GENERATORS, GeneratorBase, GeneratorConfig 5 | from flexrag.utils import TIME_METER, Choices 6 | 7 | from .dense_retriever import DenseRetriever, DenseRetrieverConfig 8 | from .retriever_base import RETRIEVERS 9 | 10 | 11 | class HydeRewriter: 12 | Prompts = { 13 | "WEB_SEARCH": "Please write a passage to answer the question.\nQuestion: {}\nPassage:", 14 | "SCIFACT": "Please write a scientific paper passage to support/refute the claim.\nClaim: {}\nPassage:", 15 | "ARGUANA": "Please write a counter argument for the passage.\nPassage: {}\nCounter Argument:", 16 | "TREC_COVID": "Please write a scientific paper passage to answer the question.\nQuestion: {}\nPassage:", 17 | "FIQA": "Please write a financial article passage to answer the question.\nQuestion: {}\nPassage:", 18 | "DBPEDIA_ENTITY": "Please write a passage to answer the question.\nQuestion: {}\nPassage:", 19 | "TREC_NEWS": "Please write a news passage about the topic.\nTopic: {}\nPassage:", 20 | "MR_TYDI": "Please write a passage in {} to answer the question in detail.\nQuestion: {}\nPassage:", 21 | } 22 | 23 | def __init__(self, generator: GeneratorBase, task: str, language: str = "en"): 24 | self.task = task 25 | self.language = language 26 | self.generator = generator 27 | return 28 | 29 | def rewrite(self, queries: list[str] | str) -> list[str]: 30 | if isinstance(queries, str): 31 | queries = [queries] 32 | prompts = [self.Prompts[self.task].format(q) for q in queries] 33 | new_queries = [q[0] for q in self.generator.generate(prompts)] 34 | return new_queries 35 | 36 | 37 | @dataclass 38 | class HydeRetrieverConfig(DenseRetrieverConfig, GeneratorConfig): 39 | """Configuration class for HydeRetriever. 40 | 41 | :param task: Task for rewriting the query. Default: "WEB_SEARCH". 42 | Available options: "WEB_SEARCH", "SCIFACT", "ARGUANA", "TREC_COVID", "FIQA", "DBPEDIA_ENTITY", "TREC_NEWS", "MR_TYDI". 43 | :type task: str 44 | :param language: Language for rewriting. Default: "en". 45 | :type language: str 46 | """ 47 | 48 | task: Choices(HydeRewriter.Prompts.keys()) = "WEB_SEARCH" # type: ignore 49 | language: str = "en" 50 | 51 | 52 | @RETRIEVERS("hyde", config_class=HydeRetrieverConfig) 53 | class HydeRetriever(DenseRetriever): 54 | """HydeRetriever is a retriever that rewrites the query before searching. 55 | 56 | The original paper is available at https://aclanthology.org/2023.acl-long.99/. 57 | """ 58 | 59 | def __init__(self, cfg: HydeRetrieverConfig, no_check=False): 60 | super().__init__(cfg, no_check) 61 | generator = GENERATORS.load(cfg) 62 | self.rewriter = HydeRewriter( 63 | generator=generator, task=cfg.task, language=cfg.language 64 | ) 65 | return 66 | 67 | @TIME_METER("hyde_retriever", "search") 68 | def search_batch( 69 | self, 70 | query: list[str], 71 | **search_kwargs, 72 | ) -> list[list[RetrievedContext]]: 73 | new_query = self.rewriter.rewrite(query) 74 | return super().search_batch(new_query, **search_kwargs) 75 | -------------------------------------------------------------------------------- /src/flexrag/retriever/index/__init__.py: -------------------------------------------------------------------------------- 1 | from .annoy_index import AnnoyIndex, AnnoyIndexConfig 2 | from .faiss_index import FaissIndex, FaissIndexConfig 3 | from .index_base import DenseIndexBase, DenseIndexBaseConfig, DENSE_INDEX 4 | from .scann_index import ScaNNIndex, ScaNNIndexConfig 5 | 6 | 7 | DenseIndexConfig = DENSE_INDEX.make_config( 8 | default="faiss", config_name="DenseIndexConfig" 9 | ) 10 | 11 | 12 | __all__ = [ 13 | "AnnoyIndex", 14 | "AnnoyIndexConfig", 15 | "FaissIndex", 16 | "FaissIndexConfig", 17 | "ScaNNIndex", 18 | "ScaNNIndexConfig", 19 | "DenseIndexBase", 20 | "DenseIndexBaseConfig", 21 | "DENSE_INDEX", 22 | "DenseIndexConfig", 23 | ] 24 | -------------------------------------------------------------------------------- /src/flexrag/retriever/web_retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import WebResource 2 | from .web_downloader import ( 3 | WEB_DOWNLOADERS, 4 | PlaywrightWebDownloader, 5 | PlaywrightWebDownloaderConfig, 6 | SimpleWebDownloader, 7 | SimpleWebDownloaderConfig, 8 | WebDownloaderBase, 9 | WebDownloaderBaseConfig, 10 | WebDownloaderConfig, 11 | ) 12 | from .web_reader import ( 13 | WEB_READERS, 14 | JinaReader, 15 | JinaReaderConfig, 16 | JinaReaderLM, 17 | JinaReaderLMConfig, 18 | ScreenshotWebReader, 19 | ScreenshotWebReaderConfig, 20 | SnippetWebReader, 21 | WebReaderBase, 22 | WebReaderConfig, 23 | ) 24 | from .web_retriever import ( 25 | SimpleWebRetriever, 26 | SimpleWebRetrieverConfig, 27 | WebRetrieverBase, 28 | WebRetrieverBaseConfig, 29 | ) 30 | from .web_seeker import ( 31 | SEARCH_ENGINES, 32 | WEB_SEEKERS, 33 | BingEngine, 34 | BingEngineConfig, 35 | DuckDuckGoEngine, 36 | DuckDuckGoEngineConfig, 37 | GoogleEngine, 38 | GoogleEngineConfig, 39 | SerpApi, 40 | SerpApiConfig, 41 | WebSeekerBase, 42 | SearchEngineConfig, 43 | WebSeekerConfig, 44 | ) 45 | from .wikipedia_retriever import WikipediaRetriever, WikipediaRetrieverConfig 46 | 47 | __all__ = [ 48 | "WebResource", 49 | "WEB_DOWNLOADERS", 50 | "PlaywrightWebDownloader", 51 | "PlaywrightWebDownloaderConfig", 52 | "WebDownloaderBase", 53 | "WebDownloaderBaseConfig", 54 | "SimpleWebDownloader", 55 | "SimpleWebDownloaderConfig", 56 | "WebDownloaderConfig", 57 | "WEB_READERS", 58 | "JinaReader", 59 | "JinaReaderConfig", 60 | "JinaReaderLM", 61 | "JinaReaderLMConfig", 62 | "ScreenshotWebReader", 63 | "ScreenshotWebReaderConfig", 64 | "SnippetWebReader", 65 | "WebReaderBase", 66 | "WebReaderConfig", 67 | "SimpleWebRetriever", 68 | "SimpleWebRetrieverConfig", 69 | "WebRetrieverBase", 70 | "WebRetrieverBaseConfig", 71 | "SEARCH_ENGINES", 72 | "WEB_SEEKERS", 73 | "BingEngine", 74 | "BingEngineConfig", 75 | "DuckDuckGoEngine", 76 | "DuckDuckGoEngineConfig", 77 | "GoogleEngine", 78 | "GoogleEngineConfig", 79 | "SerpApi", 80 | "SerpApiConfig", 81 | "WebSeekerBase", 82 | "SearchEngineConfig", 83 | "WebSeekerConfig", 84 | "WikipediaRetriever", 85 | "WikipediaRetrieverConfig", 86 | ] 87 | -------------------------------------------------------------------------------- /src/flexrag/retriever/web_retrievers/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Optional 3 | 4 | 5 | @dataclass 6 | class WebResource: 7 | """The web resource dataclass. 8 | ``WebResource`` is the fundamental component for information transmission in the ``web_retrievers`` module of FlexRAG. 9 | The ``WebSeeker`` retrieves the corresponding ``WebResource`` based on the user's query, 10 | while the ``WebDownloader`` downloads the resource based on the URL in the ``WebResource`` and stores it in the ``data`` field of the ``WebResource``. 11 | The ``WebReader`` then converts the ``data`` field of the ``WebResource`` into a LLM friendly format and returns the ``RetrievedContext``. 12 | 13 | :param url: The URL of the resource. 14 | :type url: str 15 | :param query: The query for the resource. Default is None. 16 | :type query: Optional[str] 17 | :param metadata: The metadata of the resource, offen provided by the WebSeeker. Default is {}. 18 | :type metadata: dict 19 | :param data: The content of the resource, offen filled by the WebDownloader. Default is None. 20 | :type data: Any 21 | """ 22 | 23 | url: str 24 | query: Optional[str] = None 25 | metadata: dict = field(default_factory=dict) 26 | data: Any = None 27 | -------------------------------------------------------------------------------- /src/flexrag/retriever/web_retrievers/wikipedia_retriever.py: -------------------------------------------------------------------------------- 1 | import time 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import httpx 6 | from bs4 import BeautifulSoup 7 | 8 | from flexrag.common_dataclass import RetrievedContext 9 | from flexrag.utils import LOGGER_MANAGER, SimpleProgressLogger 10 | 11 | from ..retriever_base import RETRIEVERS, RetrieverBase, RetrieverBaseConfig 12 | 13 | logger = LOGGER_MANAGER.get_logger("flexrag.retrievers.web_retriever") 14 | 15 | 16 | @dataclass 17 | class WikipediaRetrieverConfig(RetrieverBaseConfig): 18 | """The configuration for the ``WikipediaRetriever``. 19 | 20 | :param search_url: The search URL for Wikipedia. 21 | Default is "https://en.wikipedia.org/w/index.php?search=". 22 | :type search_url: str 23 | :param proxy: The proxy to use. Default is None. 24 | :type proxy: Optional[str] 25 | """ 26 | 27 | search_url: str = "https://en.wikipedia.org/w/index.php?search=" 28 | proxy: Optional[str] = None 29 | 30 | 31 | @RETRIEVERS("wikipedia", config_class=WikipediaRetrieverConfig) 32 | class WikipediaRetriever(RetrieverBase): 33 | """WikipediaRetriever retrieves information from Wikipedia directly. 34 | Adapted from https://github.com/ysymyth/ReAct""" 35 | 36 | name = "wikipedia" 37 | 38 | def __init__(self, cfg: WikipediaRetrieverConfig): 39 | super().__init__(cfg) 40 | # set basic configs 41 | self.search_url = cfg.search_url 42 | self.client = httpx.Client(proxy=cfg.proxy) 43 | return 44 | 45 | def search( 46 | self, 47 | query: list[str] | str, 48 | delay: float = 0.1, 49 | **search_kwargs, 50 | ) -> list[list[RetrievedContext]]: 51 | if isinstance(query, str): 52 | query = [query] 53 | 54 | # search & parse 55 | results = [] 56 | p_logger = SimpleProgressLogger(logger, len(query), self.log_interval) 57 | for q in query: 58 | time.sleep(delay) 59 | p_logger.update(1, "Searching") 60 | results.append([self.search_item(q, **search_kwargs)]) 61 | return results 62 | 63 | def search_item(self, query: str, **kwargs) -> RetrievedContext: 64 | search_url = self.search_url + query.replace(" ", "+") 65 | response_text = self.client.get(search_url).text 66 | 67 | soup = BeautifulSoup(response_text, features="html.parser") 68 | result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) 69 | if result_divs: # mismatch 70 | similar_entities = [ 71 | self._clear_str(div.get_text().strip()) for div in result_divs 72 | ] 73 | page_content = None 74 | summary = None 75 | else: 76 | similar_entities = [] 77 | page = [ 78 | p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul") 79 | ] 80 | if any("may refer to:" in p for p in page): 81 | return self.search_item("[" + query + "]") 82 | else: # concatenate all paragraphs 83 | page_content = "" 84 | for p in page: 85 | if len(p.split(" ")) > 2: 86 | page_content += self._clear_str(p) 87 | if not p.endswith("\n"): 88 | page_content += "\n" 89 | summary = self._get_summary(page_content) 90 | return RetrievedContext( 91 | retriever=self.name, 92 | query=query, 93 | data={ 94 | "raw_page": response_text, 95 | "page_content": page_content, 96 | "summary": summary, 97 | "similar_entities": similar_entities, 98 | }, 99 | source=search_url, 100 | ) 101 | 102 | def _clear_str(self, text: str) -> str: 103 | return text.encode().decode("unicode-escape").encode("latin1").decode("utf-8") 104 | 105 | def _get_summary(self, page: str) -> str: 106 | # find all paragraphs 107 | paragraphs = page.split("\n") 108 | paragraphs = [p.strip() for p in paragraphs if p.strip()] 109 | 110 | # find all sentence 111 | sentences = [] 112 | for p in paragraphs: 113 | sentences += p.split(". ") 114 | sentences = [s.strip() + "." for s in sentences if s.strip()] 115 | summary = " ".join(sentences[:5]) 116 | return summary.replace("\\n", "") 117 | 118 | @property 119 | def fields(self): 120 | return ["raw_page", "page_content", "summary", "similar_entities"] 121 | -------------------------------------------------------------------------------- /src/flexrag/text_process/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_filters import ExactDeduplicate, LengthFilter, LengthFilterConfig 2 | from .basic_processors import ( 3 | AnswerSimplifier, 4 | ChineseSimplifier, 5 | Lowercase, 6 | TokenNormalizer, 7 | TokenNormalizerConfig, 8 | Truncator, 9 | TruncatorConfig, 10 | Unifier, 11 | ) 12 | from .pipeline import TextProcessPipeline, TextProcessPipelineConfig 13 | from .processor import PROCESSORS, Processor, TextUnit 14 | 15 | __all__ = [ 16 | "TextProcessPipeline", 17 | "TextProcessPipelineConfig", 18 | "PROCESSORS", 19 | "Processor", 20 | "TextUnit", 21 | "TokenNormalizerConfig", 22 | "TokenNormalizer", 23 | "ChineseSimplifier", 24 | "Lowercase", 25 | "Unifier", 26 | "TruncatorConfig", 27 | "Truncator", 28 | "AnswerSimplifier", 29 | "ExactDeduplicate", 30 | "LengthFilter", 31 | "LengthFilterConfig", 32 | ] 33 | -------------------------------------------------------------------------------- /src/flexrag/text_process/basic_filters.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | from .processor import PROCESSORS, Processor, TextUnit 5 | from .utils import UnifiedTokenizer, UTokenizerConfig 6 | 7 | 8 | @PROCESSORS("exact_deduplicate") 9 | class ExactDeduplicate(Processor): 10 | def __init__(self) -> None: 11 | self.seen = set() 12 | return 13 | 14 | def process(self, input_text: TextUnit) -> TextUnit: 15 | if input_text.content in self.seen: 16 | input_text.reserved = False 17 | self.seen.add(input_text.content) 18 | return input_text 19 | 20 | 21 | @dataclass 22 | class LengthFilterConfig: 23 | max_tokens: Optional[int] = None 24 | min_tokens: Optional[int] = None 25 | max_chars: Optional[int] = None 26 | min_chars: Optional[int] = None 27 | max_bytes: Optional[int] = None 28 | min_bytes: Optional[int] = None 29 | tokenizer_config: UTokenizerConfig = field(default_factory=UTokenizerConfig) 30 | 31 | 32 | @PROCESSORS("length_filter", config_class=LengthFilterConfig) 33 | class LengthFilter(Processor): 34 | def __init__(self, cfg: LengthFilterConfig) -> None: 35 | super().__init__() 36 | self.max_tokens = cfg.max_tokens 37 | self.min_tokens = cfg.min_tokens 38 | self.max_chars = cfg.max_chars 39 | self.min_chars = cfg.min_chars 40 | self.max_bytes = cfg.max_bytes 41 | self.min_bytes = cfg.min_bytes 42 | if self.max_tokens is not None or self.min_tokens is not None: 43 | self.tokenizer = UnifiedTokenizer(cfg.tokenizer_config) 44 | else: 45 | self.tokenizer = None 46 | return 47 | 48 | def process(self, input_text: TextUnit) -> TextUnit: 49 | if self.tokenizer is not None: 50 | tokens = self.tokenizer.tokenize(input_text.content) 51 | if self.max_tokens is not None and len(tokens) > self.max_tokens: 52 | input_text.reserved = False 53 | if self.min_tokens is not None and len(tokens) < self.min_tokens: 54 | input_text.reserved = False 55 | if self.max_chars is not None and len(input_text.content) > self.max_chars: 56 | input_text.reserved = False 57 | if self.min_chars is not None and len(input_text.content) < self.min_chars: 58 | input_text.reserved = False 59 | if ( 60 | self.max_bytes is not None 61 | and len(input_text.content.encode("utf-8")) > self.max_bytes 62 | ): 63 | input_text.reserved = False 64 | if ( 65 | self.min_bytes is not None 66 | and len(input_text.content.encode("utf-8")) < self.min_bytes 67 | ): 68 | input_text.reserved = False 69 | return input_text 70 | -------------------------------------------------------------------------------- /src/flexrag/text_process/pipeline.py: -------------------------------------------------------------------------------- 1 | from flexrag.utils import TIME_METER 2 | 3 | from .processor import PROCESSORS, Processor, TextUnit 4 | 5 | 6 | TextProcessPipelineConfig = PROCESSORS.make_config( 7 | allow_multiple=True, config_name="TextProcessPipelineConfig" 8 | ) 9 | 10 | 11 | class TextProcessPipeline: 12 | def __init__(self, cfg: TextProcessPipelineConfig) -> None: # type: ignore 13 | # load processors 14 | self.processors: list[Processor] = PROCESSORS.load(cfg) 15 | return 16 | 17 | @TIME_METER("text_process_pipeline") 18 | def __call__(self, text: str, return_detail: bool = False) -> str | TextUnit | None: 19 | unit = TextUnit(content=text) 20 | for processor in self.processors: 21 | unit = processor(unit) 22 | if not unit.reserved: 23 | break 24 | if return_detail: 25 | return unit 26 | return unit.content if unit.reserved else None 27 | 28 | def __contains__(self, processor: Processor | str) -> bool: 29 | if isinstance(processor, str): 30 | return any( 31 | isinstance(p, PROCESSORS[processor]["item"]) for p in self.processors 32 | ) 33 | return processor in self.processors 34 | 35 | def __getitem__(self, processor: str | int) -> Processor: 36 | if isinstance(processor, int): 37 | return self.processors[processor] 38 | assert isinstance(processor, str), "str or int is required" 39 | for p in self.processors: 40 | if isinstance(p, PROCESSORS[processor]["item"]): 41 | return p 42 | raise KeyError(f"Processor {processor} not found in the pipeline") 43 | 44 | def __repr__(self) -> str: 45 | return f"Pipeline({[p.name for p in self.processors]})" 46 | -------------------------------------------------------------------------------- /src/flexrag/text_process/processor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | 4 | from flexrag.utils import Register 5 | 6 | 7 | @dataclass 8 | class TextUnit: 9 | content: str 10 | reserved: bool = True 11 | processed_by: list[str] = field(default_factory=list) 12 | 13 | 14 | class Processor(ABC): 15 | def __call__(self, input_text: TextUnit) -> TextUnit: 16 | """Process the input text. 17 | If the processor has been filtered, the reserved flag of the input TextUnit will be set to False. 18 | 19 | :param input_text: The input text to process. 20 | :type input_text: TextUnit 21 | :return: The processed text. 22 | :rtype: TextUnit 23 | """ 24 | input_text.processed_by.append(self.name) 25 | return self.process(input_text) 26 | 27 | @abstractmethod 28 | def process(self, input_text: TextUnit) -> TextUnit: 29 | return 30 | 31 | @property 32 | def name(self): 33 | return self.__class__.__name__ 34 | 35 | 36 | PROCESSORS = Register[Processor]("processor") 37 | -------------------------------------------------------------------------------- /src/flexrag/text_process/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from flexrag.utils import Choices 5 | 6 | 7 | @dataclass 8 | class UTokenizerConfig: 9 | tokenizer_type: Choices(["hf", "tiktoken", "moses"]) = "moses" # type: ignore 10 | hf_tokenizer_path: Optional[str] = None 11 | tiktok_tokenizer_name: Optional[str] = None 12 | lang: Optional[str] = None 13 | 14 | 15 | class UnifiedTokenizer: 16 | def __init__(self, cfg: UTokenizerConfig) -> None: 17 | self.tokenizer_type = cfg.tokenizer_type 18 | match self.tokenizer_type: 19 | case "hf": 20 | from transformers import AutoTokenizer 21 | 22 | self.tokenizer = AutoTokenizer.from_pretrained(cfg.hf_tokenizer_path) 23 | case "tiktoken": 24 | import tiktoken 25 | 26 | self.tokenizer = tiktoken.get_encoding(cfg.tiktok_tokenizer_name) 27 | case "moses": 28 | from sacremoses import MosesDetokenizer, MosesTokenizer 29 | 30 | self.tokenizer = MosesTokenizer(cfg.lang) 31 | self.detokenizer = MosesDetokenizer(cfg.lang) 32 | case _: 33 | raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}") 34 | return 35 | 36 | def tokenize(self, texts: str) -> list[str | int]: 37 | match self.tokenizer_type: 38 | case "hf": 39 | tokens = self.tokenizer.encode(texts) 40 | case "tiktoken": 41 | tokens = self.tokenizer.encode(texts) 42 | case "moses": 43 | tokens = self.tokenizer.tokenize(texts) 44 | return tokens 45 | 46 | def detokenize(self, tokens: list[str | int]) -> str: 47 | match self.tokenizer_type: 48 | case "hf": 49 | texts = self.tokenizer.decode(tokens) 50 | case "tiktoken": 51 | texts = self.tokenizer.decode(tokens) 52 | case "moses": 53 | texts = self.detokenizer.detokenize(tokens) 54 | return texts 55 | -------------------------------------------------------------------------------- /tested_models.md: -------------------------------------------------------------------------------- 1 | # Tested Huggingface Models 2 | 3 | ## Tested HFEncoders 4 | - jinaai/jina-embeddings-v3 5 | - BAAI/bge-m3 6 | - facebook/contriever 7 | - facebook/contriever-msmarco 8 | - facebook/dragon-plus-query-encoder 9 | - facebook/dragon-plus-context-encoder 10 | - nomic-ai/nomic-embed-text-v1.5 11 | - sentence-transformers/msmarco-MiniLM-L-12-v3 12 | - intfloat/e5-base-v2 13 | - Alibaba-NLP/gte-multilingual-base 14 | - Alibaba-NLP/gte-modernbert-base 15 | 16 | ## Tested HFClipEncoders 17 | - openai/clip-vit-base-patch32 18 | - jinaai/jina-clip-v2 19 | 20 | ## Tested ReRankers 21 | - unicamp-dl/InRanker-base 22 | - colbert-ir/colbertv2.0 23 | - jinaai/Jina-colbert-v2 24 | - jinaai/jina-reranker-v2-base-multilingual 25 | - BAAI/bge-reranker-v2-m3 26 | - intfloat/e5-base-v2 27 | 28 | ## Tested HFGenerators 29 | - Llama-3.2-1B-Instruct 30 | 31 | ## Tested VLMs 32 | - Qwen/Qwen2-VL-7B-Instruct 33 | - meta-llama/Llama-3.2-11B-Vision-Instruct 34 | - google/paligemma2-10b-ft-docci-448 -------------------------------------------------------------------------------- /tests/test_assistant.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | import pytest 5 | from omegaconf import OmegaConf 6 | 7 | from flexrag.assistant import BasicAssistant, BasicAssistantConfig 8 | 9 | 10 | @dataclass 11 | class AssistantTestConfig: 12 | assistant_config: BasicAssistantConfig = field(default_factory=BasicAssistantConfig) 13 | 14 | 15 | class TestAssistant: 16 | cfg: AssistantTestConfig = OmegaConf.merge( 17 | OmegaConf.structured(AssistantTestConfig), 18 | OmegaConf.load( 19 | os.path.join(os.path.dirname(__file__), "configs", "assistant.yaml") 20 | ), 21 | ) 22 | query = "Who is Bruce Wayne?" 23 | # contexts = ["Bruce Wayne is Batman.", "Batman is a superhero."] 24 | 25 | @pytest.mark.asyncio 26 | async def test_answer(self): 27 | assistant = BasicAssistant(self.cfg.assistant_config) 28 | r1, _, _ = assistant.answer(self.query) 29 | return 30 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from flexrag.common_dataclass import RAGEvalData 4 | from flexrag.datasets import RAGEvalDataset, RAGEvalDatasetConfig 5 | from flexrag.utils import LOGGER_MANAGER 6 | 7 | 8 | logger = LOGGER_MANAGER.get_logger("tests.datasets") 9 | 10 | 11 | class TestRAGEvalDataset: 12 | datasets = { 13 | "2wikimultihopqa": ["dev", "train"], 14 | "ambig_qa": ["dev", "train"], 15 | "arc": ["dev", "test", "train"], 16 | "asqa": ["dev", "train"], 17 | "ay2": ["dev", "train"], 18 | "bamboogle": ["test"], 19 | "boolq": ["dev", "train"], 20 | "commonsenseqa": ["dev", "train"], 21 | "curatedtrec": ["test", "train"], 22 | # "domainrag": ["test"], # Error in loading due to dataset schema 23 | "eli5": ["dev", "train"], 24 | "fermi": ["dev", "test", "train"], 25 | "fever": ["dev", "train"], 26 | "hellaswag": ["dev", "train"], 27 | "hotpotqa": ["dev", "train"], 28 | "mmlu": ["5_shot", "dev", "test", "train"], 29 | "msmarco-qa": ["dev", "train"], 30 | "musique": ["dev", "train"], 31 | "narrativeqa": ["dev", "test", "train"], 32 | "nq": ["dev", "test", "train"], 33 | "openbookqa": ["dev", "test", "train"], 34 | "piqa": ["dev", "train"], 35 | "popqa": ["test"], 36 | "quartz": ["dev", "test", "train"], 37 | "siqa": ["dev", "train"], 38 | "squad": ["dev", "train"], 39 | "t-rex": ["dev", "train"], 40 | "triviaqa": ["dev", "test", "train"], 41 | "truthful_qa": ["dev"], 42 | "web_questions": ["test", "train"], 43 | "wikiasp": ["dev", "test", "train"], 44 | "wikiqa": ["dev", "test", "train"], 45 | "wned": ["dev"], 46 | "wow": ["dev", "train"], 47 | "zero-shot_re": ["dev", "train"], 48 | } 49 | 50 | async def run_test(self, name: str, split: str): 51 | # load dataset 52 | logger.info(f"Testing {name} {split}") 53 | dataset = RAGEvalDataset(RAGEvalDatasetConfig(name=name, split=split)) 54 | 55 | # check dataset 56 | assert len(dataset) > 0 57 | for i in dataset: 58 | assert isinstance(i, RAGEvalData) 59 | for i in range(len(dataset)): 60 | assert isinstance(dataset[i], RAGEvalData) 61 | return 62 | 63 | @pytest.mark.asyncio 64 | async def test_rageval_dataset(self): 65 | logger.info("Testing RAGEvalDataset") 66 | for name in self.datasets: 67 | for split in self.datasets[name]: 68 | await self.run_test(name, split) 69 | return 70 | --------------------------------------------------------------------------------