├── .github
└── workflows
│ └── publish.yaml
├── .gitignore
├── .readthedocs.yaml
├── CITATION.cff
├── LICENSE
├── MANIFEST.in
├── README-zh.md
├── README.md
├── assets
├── Framework-Features-en.png
├── Framework-Features-zh.png
├── Framework-FlexRAGv3-en.png
├── Framework-FlexRAGv3-zh.png
├── ModularAssistant.png
├── Retrievers.png
├── WebRetriever.png
├── flexrag-wide.png
├── flexrag.png
├── gui_static.png
└── parse_files.png
├── benchmarks
├── README.md
└── singlehop_qa.md
├── docs
├── Makefile
├── make.bat
├── requirements.docs.txt
└── source
│ ├── conf.py
│ ├── getting_started
│ ├── installation.md
│ ├── quickstart1.md
│ └── quickstart2.md
│ ├── index.rst
│ ├── locales
│ └── zh_CN
│ │ └── LC_MESSAGES
│ │ ├── getting_started
│ │ ├── installation.po
│ │ ├── quickstart1.po
│ │ └── quickstart2.po
│ │ ├── index.po
│ │ ├── reference
│ │ ├── assistant.po
│ │ ├── chunking.po
│ │ ├── common_dataclass.po
│ │ ├── datasets.po
│ │ ├── document_parser.po
│ │ ├── encoders.po
│ │ ├── generators.po
│ │ ├── metrics.po
│ │ ├── prompt.po
│ │ ├── rankers.po
│ │ ├── refiner.po
│ │ ├── retrievers.po
│ │ ├── text_process.po
│ │ ├── tokenizers.po
│ │ └── utils.po
│ │ └── tutorial
│ │ ├── building_assistant.po
│ │ ├── entrypoints.po
│ │ ├── preparing_corpus.po
│ │ ├── preparing_retriever.po
│ │ ├── preparing_web_retriever.po
│ │ └── using_register.po
│ ├── reference
│ ├── assistant.rst
│ ├── chunking.rst
│ ├── common_dataclass.rst
│ ├── datasets.rst
│ ├── document_parser.rst
│ ├── encoders.rst
│ ├── generators.rst
│ ├── metrics.rst
│ ├── prompt.rst
│ ├── rankers.rst
│ ├── refiner.rst
│ ├── retrievers.rst
│ ├── text_process.rst
│ ├── tokenizers.rst
│ └── utils.rst
│ └── tutorial
│ ├── building_assistant.md
│ ├── entrypoints.md
│ ├── preparing_corpus.md
│ ├── preparing_retriever.md
│ ├── preparing_web_retriever.md
│ └── using_register.md
├── pyproject.toml
├── pytest.ini
├── requirements.txt
├── setup.py
├── src
└── flexrag
│ ├── __init__.py
│ ├── assistant
│ ├── __init__.py
│ ├── assistant.py
│ ├── assistant_prompts
│ │ ├── longform_generate_prompt_with_context.json
│ │ ├── longform_generate_prompt_without_context.json
│ │ ├── shortform_generate_prompt_with_context.json
│ │ └── shortform_generate_prompt_without_context.json
│ ├── basic_assistant.py
│ ├── chatqa_assistant.py
│ ├── document_chat_assistant.py
│ ├── modular_rag_assistant.py
│ └── online_assistant.py
│ ├── cache
│ ├── __init__.py
│ ├── backends.py
│ ├── persistent_cache.py
│ └── serializer.py
│ ├── chunking
│ ├── __init__.py
│ ├── basic_chunkers.py
│ ├── chunker_base.py
│ ├── semantic_chunker.py
│ └── sentence_splitter.py
│ ├── common_dataclass.py
│ ├── context_refine
│ ├── __init__.py
│ ├── arranger.py
│ ├── refiner.py
│ └── summarizer.py
│ ├── datasets
│ ├── __init__.py
│ ├── dataset.py
│ ├── hf_dataset.py
│ ├── line_delimited_dataset.py
│ ├── rag_dataset.py
│ └── retrieval_dataset.py
│ ├── document_parser
│ ├── __init__.py
│ ├── docling_parser.py
│ ├── document_parser_base.py
│ └── markitdown_parser.py
│ ├── entrypoints
│ ├── __init__.py
│ ├── assets
│ │ ├── flexrag-wide.png
│ │ ├── flexrag.png
│ │ ├── robot.png
│ │ └── user.png
│ ├── cache.py
│ ├── combine_outputs.py
│ ├── prepare_corpus.py
│ ├── prepare_index.py
│ ├── rebuild_index.py
│ ├── run_assistant.py
│ ├── run_interactive.py
│ └── run_retriever.py
│ ├── metrics
│ ├── __init__.py
│ ├── evaluator.py
│ ├── generation_metrics.py
│ ├── lib_rel.cpp
│ ├── matching_metrics.py
│ ├── metrics_base.py
│ ├── retrieval_metrics.py
│ ├── xfinder.py
│ └── xfinder_utils.py
│ ├── models
│ ├── __init__.py
│ ├── anthropic_model.py
│ ├── cohere_model.py
│ ├── hf_model.py
│ ├── jina_model.py
│ ├── llamacpp_model.py
│ ├── model_base.py
│ ├── ollama_model.py
│ ├── openai_model.py
│ ├── sentence_transformers_model.py
│ ├── tokenizer.py
│ ├── utils.py
│ └── vllm_model.py
│ ├── prompt
│ ├── __init__.py
│ ├── prompt_base.py
│ └── template.py
│ ├── ranker
│ ├── __init__.py
│ ├── cohere_ranker.py
│ ├── gpt_ranker.py
│ ├── hf_ranker.py
│ ├── jina_ranker.py
│ ├── mixedbread_ranker.py
│ ├── ranker.py
│ ├── ranker_prompts
│ │ └── rankgpt_prompt.json
│ └── voyage_ranker.py
│ ├── retriever
│ ├── __init__.py
│ ├── bm25s_retriever.py
│ ├── dense_retriever.py
│ ├── elastic_retriever.py
│ ├── hyde_retriever.py
│ ├── index
│ │ ├── __init__.py
│ │ ├── annoy_index.py
│ │ ├── faiss_index.py
│ │ ├── index_base.py
│ │ └── scann_index.py
│ ├── retriever_base.py
│ ├── typesense_retriever.py
│ └── web_retrievers
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── web_downloader.py
│ │ ├── web_reader.py
│ │ ├── web_retriever.py
│ │ ├── web_seeker.py
│ │ └── wikipedia_retriever.py
│ ├── text_process
│ ├── __init__.py
│ ├── basic_filters.py
│ ├── basic_processors.py
│ ├── normalize_tokens.py
│ ├── pipeline.py
│ ├── processor.py
│ └── utils.py
│ └── utils.py
├── tested_models.md
└── tests
├── test_assistant.py
├── test_cache.py
├── test_chunker.py
├── test_data.py
├── test_model.py
├── test_ranker.py
├── test_retriver.py
└── testcorp
└── testcorp.jsonl
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | name: Publish Python Package
2 |
3 | on:
4 | release:
5 | types:
6 | - published
7 |
8 | jobs:
9 | build_and_publish:
10 | name: Build wheels on ${{ matrix.os }}
11 | runs-on: ${{ matrix.os }}
12 | strategy:
13 | matrix:
14 | os: [ubuntu-latest, windows-latest]
15 |
16 | permissions:
17 | contents: read
18 | id-token: write
19 |
20 | steps:
21 | - name: Check out repository
22 | uses: actions/checkout@v4
23 |
24 | - name: Set up Python
25 | uses: actions/setup-python@v5
26 | with:
27 | python-version: 3.11
28 |
29 | - name: Install dependencies
30 | run: pip install setuptools wheel twine cibuildwheel
31 |
32 | - name: Build wheels
33 | run: python -m cibuildwheel --output-dir wheelhouse
34 |
35 | - name: Publish to PyPI
36 | env:
37 | TWINE_USERNAME: __token__
38 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
39 | run: twine upload wheelhouse/*.whl
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .idea
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | applications/DeepSpeed-Chat/data
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 |
134 | # vscode
135 | .vscode
136 |
137 |
138 | # third party models
139 | yala/model/third_party_models
140 |
141 | # aim
142 | .aim
143 |
144 | # test files
145 | _test*.py
146 | _test*.ipynb
147 |
148 | # experimental configs
149 | experimental_configs/
150 |
151 | # hydra logs
152 | outputs/
153 |
154 | # pytest configs
155 | tests/configs/
156 |
157 | # cibuildwheel
158 | wheelhouse/
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the OS, Python version, and other tools you might need
8 | build:
9 | os: ubuntu-24.04
10 | tools:
11 | python: "3.13"
12 |
13 | # Build documentation in the "docs/" directory with Sphinx
14 | sphinx:
15 | configuration: docs/source/conf.py
16 | # Optionally, but recommended,
17 | # declare the Python requirements required to build your documentation
18 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
19 | python:
20 | install:
21 | - requirements: docs/requirements.docs.txt
22 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Zhang"
5 | given-names: "Zhuocheng"
6 | orcid: "https://orcid.org/0009-0004-4131-7968"
7 | - family-names: "Feng"
8 | given-names: "Yang"
9 | - family-names: "Zhang"
10 | given-names: "Min"
11 | title: "FlexRAG"
12 | doi: 10.5281/zenodo.14306983
13 | url: "https://github.com/ictnlp/FlexRAG"
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 zhangzhuocheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include README.md
3 | include CITATION.cff
4 | recursive-include tests *
5 | recursive-include assets *
--------------------------------------------------------------------------------
/README-zh.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 
6 | [](https://github.com/psf/black)
7 | [](https://pycqa.github.io/isort/)
8 | [](LICENSE)
9 | [](https://flexrag.readthedocs.io/en/latest/)
10 | [](https://flexrag.readthedocs.io/zh-cn/latest/)
11 | [](https://pypi.org/project/flexrag/)
12 | [](https://doi.org/10.5281/zenodo.14306983)
13 |
14 |
15 | |
16 | 介绍视频 |
17 | README (english) |
18 | 文档 |
19 | 检索器 |
20 | 示例
21 | |
22 |
23 |
24 | FlexRAG 是一个创新的开源框架,旨在简化 RAG(检索增强生成)系统的快速复现、开发和评估。它全面支持多种 RAG 场景,包括 **基于文本的、多模态的以及可通过 Web 访问的 RAG** 。借助从数据准备到系统评估的**端到端流水线**,FlexRAG 能够帮助研究人员高效地与社区共享他们的工作,并快速基于自己的算法开发演示原型。
25 |
26 | # 📖 目录
27 | - [📖 目录](#-目录)
28 | - [✨ 框架特色](#-框架特色)
29 | - [📢 最新消息](#-最新消息)
30 | - [🚀 框架入门](#-框架入门)
31 | - [🏗️ FlexRAG 架构](#️-flexrag-架构)
32 | - [📊 基准测试](#-基准测试)
33 | - [🏷️ 许可证](#️-许可证)
34 | - [🖋️ 引用](#️-引用)
35 | - [❤️ 致谢](#️-致谢)
36 |
37 |
38 | # ✨ 框架特色
39 |
40 |
41 |
42 |
43 | # 📢 最新消息
44 | - **2025-03-24**: 中文文档上线啦!请访问 [文档](https://flexrag.readthedocs.io/zh-cn/latest/) 查看。
45 | - **2025-02-25**: FlexRAG 的 LocalRetriever 现在支持从 [HuggingFace Hub](https://huggingface.co/collections/ICTNLP/flexrag-retrievers-67b5373b70123669108a2e59) 上加载啦!
46 | - **2025-01-22**: 新的命令行入口 `run_retriever` 以及大量新的信息检索指标(如 `RetrievalMAP` )现已上线,请阅读[文档](https://flexrag.readthedocs.io/en/latest/)以获取更多信息。
47 | - **2025-01-08**: FlexRAG 现已支持 Windows 系统,您可以直接通过 `pip install flexrag` 来安装。
48 | - **2025-01-08**: FlexRAG 在单跳QA数据集上的基准测试现已公开,详情请参考 [benchmarks](benchmarks/README.md) 页面。
49 | - **2025-01-05**: FlexRAG 的[文档](https://flexrag.readthedocs.io/en/latest/)现已上线。
50 |
51 | # 🚀 框架入门
52 | 从 `pip` 安装 FlexRAG:
53 | ```bash
54 | pip install flexrag
55 | ```
56 |
57 | 访问我们的[文档](https://flexrag.readthedocs.io/zh-cn/latest/)以了解更多信息。
58 | - [安装](https://flexrag.readthedocs.io/zh-cn/latest/getting_started/installation.html)
59 | - [快速入门](https://flexrag.readthedocs.io/zh-cn/latest/getting_started/quickstart1.html)
60 | - [教程](https://flexrag.readthedocs.io/zh-cn/latest/tutorial/preparing_corpus.html)
61 |
62 | # 🏗️ FlexRAG 架构
63 | FlexRAG 采用**模块化**架构设计,让您可以轻松定制和扩展框架以满足您的特定需求。下图说明了 FlexRAG 的架构:
64 |
65 |
66 |
67 |
68 | # 📊 基准测试
69 | 我们利用 FlexRAG 进行了大量的基准测试,详情请参考 [benchmarks](benchmarks/README.md) 页面。
70 |
71 | # 🏷️ 许可证
72 | 本仓库采用 **MIT License** 开源协议. 详情请参考 [LICENSE](LICENSE) 文件。
73 |
74 | # 🖋️ 引用
75 | 如果您在研究中使用了 FlexRAG,请引用我们的项目:
76 | ```bibtex
77 | @software{Zhang_FlexRAG_2025,
78 | author = {Zhang, Zhuocheng and Feng, Yang and Zhang, Min},
79 | doi = {10.5281/zenodo.14593327},
80 | month = jan,
81 | title = {{FlexRAG}},
82 | url = {https://github.com/ictnlp/FlexRAG},
83 | year = {2025}
84 | }
85 | ```
86 |
87 |
88 | # ❤️ 致谢
89 | 下面的开源项目对本项目有所帮助:
90 | - [Faiss](https://github.com/facebookresearch/faiss)
91 | - [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG)
92 | - [LanceDB](https://github.com/lancedb/lancedb)
93 | - [ANN Benchmarks](https://github.com/erikbern/ann-benchmarks)
94 | - [Chonkie](https://github.com/chonkie-ai/chonkie)
95 | - [rerankers](https://github.com/AnswerDotAI/rerankers)
96 |
--------------------------------------------------------------------------------
/assets/Framework-Features-en.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Framework-Features-en.png
--------------------------------------------------------------------------------
/assets/Framework-Features-zh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Framework-Features-zh.png
--------------------------------------------------------------------------------
/assets/Framework-FlexRAGv3-en.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Framework-FlexRAGv3-en.png
--------------------------------------------------------------------------------
/assets/Framework-FlexRAGv3-zh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Framework-FlexRAGv3-zh.png
--------------------------------------------------------------------------------
/assets/ModularAssistant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/ModularAssistant.png
--------------------------------------------------------------------------------
/assets/Retrievers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/Retrievers.png
--------------------------------------------------------------------------------
/assets/WebRetriever.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/WebRetriever.png
--------------------------------------------------------------------------------
/assets/flexrag-wide.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/flexrag-wide.png
--------------------------------------------------------------------------------
/assets/flexrag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/flexrag.png
--------------------------------------------------------------------------------
/assets/gui_static.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/gui_static.png
--------------------------------------------------------------------------------
/assets/parse_files.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/assets/parse_files.png
--------------------------------------------------------------------------------
/benchmarks/README.md:
--------------------------------------------------------------------------------
1 | # Benchmarks
2 | This directory contains benchmarks for the FlexRAG framework. We conduct these experiments for the following reasons:
3 | 1. We hope to help users gain a deeper understanding of the various components in RAG and determine their impact on the overall RAG system.
4 | 2. We aim to test the FlexRAG framework, which will help us reduce potential bugs.
5 |
6 | ## Directory Structure
7 | - [`singlehop_qa.md`](singlehop_qa.md): This file contains the benchmark results for the single-hop QA task.
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.docs.txt:
--------------------------------------------------------------------------------
1 | # Requirements for Sphinx Documentation
2 | sphinx==8.1.3
3 | myst-parser==4.0.0
4 | sphinx-book-theme==1.1.3
5 | sphinx-copybutton==0.5.2
6 | # Requirements for Loading FlexRAG
7 | numpy<2.0.0
8 | tenacity
9 | hydra-core>=1.3
10 | omegaconf>=2.3.0
11 | pillow
12 | accelerate>=0.26.0
13 | rouge
14 | sacrebleu>=2.4.2
15 | openai>=1.30.1
16 | transformers>=4.44.0
17 | lmdb
18 | unidecode
19 | sacremoses
20 | pandas
21 | pylance
22 | bm25s
23 | elasticsearch>=8.14.0
24 | -f https://download.pytorch.org/whl/cpu
25 | torch>=2.3.0
26 | beautifulsoup4
27 | datasets
28 | pytrec_eval
29 | colorama
30 | # gradio>=5.8.0
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 | import re
9 | import pathlib
10 | import sys
11 |
12 | sys.path.insert(0, str(pathlib.Path(__file__).parents[2] / "src"))
13 |
14 |
15 | def get_version() -> str:
16 | version_string_path = pathlib.Path(__file__).parents[2] / "src/flexrag/utils.py"
17 | with open(version_string_path, encoding="utf-8") as f:
18 | version = re.search(r"__VERSION__ = \"(.*?)\"", f.read()).group(1)
19 | return version
20 |
21 |
22 | project = "FlexRAG Documentation"
23 | html_short_title = "FlexRAG Documentation"
24 | copyright = "2025, ZhuochengZhang"
25 | author = "ZhuochengZhang"
26 | release = get_version()
27 |
28 | # -- General configuration ---------------------------------------------------
29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
30 |
31 | extensions = [
32 | "sphinx.ext.napoleon",
33 | "sphinx.ext.autodoc",
34 | "sphinx.ext.viewcode",
35 | "sphinx.ext.autosectionlabel",
36 | "sphinx_copybutton",
37 | "myst_parser",
38 | ]
39 |
40 | templates_path = ["_templates"]
41 | exclude_patterns = []
42 |
43 |
44 | # -- Options for HTML output -------------------------------------------------
45 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
46 |
47 | html_theme = "sphinx_book_theme"
48 | html_static_path = ["_static", "../../assets"]
49 | html_theme_options = {
50 | "path_to_docs": "docs/source",
51 | "repository_url": "https://github.com/ictnlp/flexrag",
52 | "use_repository_button": True,
53 | }
54 |
55 | # -- Options for autodoc -----------------------------------------------------
56 |
57 | autodoc_mock_imports = [
58 | "gradio", # as gradio has a lot of dependencies, we mock it to speed up building the docs.
59 | ]
60 |
61 |
62 | # -- Options for copybutton --------------------------------------------------
63 | copybutton_prompt_text = r">>> |\.\.\. "
64 | copybutton_prompt_is_regexp = True
65 |
66 | # -- Options for autosectionlabel --------------------------------------------
67 | autosectionlabel_prefix_document = False
68 |
69 | # -- Options for multi-language -------------------------------------------------
70 | language = "en"
71 | locale_dirs = ["./locales/"]
72 | gettext_compact = False
73 | gettext_uuid = True
74 |
--------------------------------------------------------------------------------
/docs/source/getting_started/installation.md:
--------------------------------------------------------------------------------
1 | # Installation
2 | FlexRAG is a Python package that can be installed via `pip` or from source.
3 |
4 | ```{eval-rst}
5 | .. important::
6 | FlexRAG requires Python 3.11 or later.
7 | ```
8 |
9 | ## Installation via `pip`
10 | To install FlexRAG via pip, run the following command:
11 |
12 | ```bash
13 | pip install flexrag
14 | ```
15 |
16 | ## Installation from source
17 | Alternatively, to install FlexRAG from the source, follow the steps below:
18 | ```bash
19 | pip install pybind11
20 |
21 | git clone https://github.com/ictnlp/FlexRAG.git
22 | cd flexrag
23 | pip install ./
24 | ```
25 | You can also install the FlexRAG in *editable* mode with the `-e` flag.
26 |
27 | ## Installation flags
28 | FlexRAG can be installed with additional flags to enable specific features. The following flags are available:
29 |
30 | | Flag | pip install command | Description |
31 | | ---------- | ------------------------------- | --------------------------------------------------- |
32 | | scann | pip install flexrag[scann] | Install FlexRAG with the ScaNN index. |
33 | | annoy | pip install flexrag[annoy] | Install FlexRAG with the Annoy index. |
34 | | llamacpp | pip install flexrag[llamacpp] | Install FlexRAG with the LlamaCpp Generator. |
35 | | minference | pip install flexrag[minference] | Install FlexRAG with the Minference. |
36 | | web | pip install flexrag[web] | Install FlexRAG with the Web Retrievers. |
37 | | docs | pip install flexrag[docs] | Install FlexRAG with the Document Parser. |
38 | | all | pip install flexrag[all] | Install FlexRAG with most features. |
39 | | dev | pip install flexrag[dev] | Install FlexRAG with the libraries for development. |
40 |
--------------------------------------------------------------------------------
/docs/source/getting_started/quickstart2.md:
--------------------------------------------------------------------------------
1 | # Quickstart: Building your own RAG application
2 | Besides using RAG assistant, you can also import FlexRAG as a library to develop your own RAG applications. FlexRAG provides a flexible and modular API that allows you to customize your RAG application with ease. For example, you can use the following code to build a simple command line RAG QA system:
3 |
4 | ```python
5 | from flexrag.models import OpenAIGenerator, OpenAIGeneratorConfig
6 | from flexrag.retriever import LocalRetriever
7 |
8 |
9 | def main():
10 | # load the retriever
11 | retriever = LocalRetriever.load_from_hub("FlexRAG/wiki2021_atlas_bm25s")
12 |
13 | # load the generator
14 | generator = OpenAIGenerator(
15 | OpenAIGeneratorConfig(
16 | model_name="Qwen2-7B-Instruct",
17 | base_url="http://10.28.0.148:8000/v1",
18 | )
19 | )
20 |
21 | # build a QA loop
22 | while True:
23 | query = input("Please input your question (type /bye to quit): ")
24 | if query == "/bye":
25 | break
26 | # retrieve the contexts
27 | contexts = retriever.search(query, top_k=3)[0]
28 | # construct the prompt
29 | user_prompt = (
30 | "Please answer the following question based on the given contexts.\n"
31 | f"Question: {query}\n"
32 | )
33 | for i, ctx in enumerate(contexts):
34 | user_prompt += f"Context {i+1}: {ctx.data['text']}\n"
35 | # generate the response
36 | response = generator.chat([{"role": "user", "content": user_prompt}])[0][0]
37 | print(response)
38 |
39 | return
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 | ```
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | |
2 |
3 | .. image:: ../../assets/flexrag-wide.png
4 | :alt: FlexRAG
5 | :align: center
6 |
7 | |
8 |
9 | Welecome to FlexRAG Documentation
10 | =================================
11 |
12 | FlexRAG is a highly reproducible, easy-to-use, and high-performance RAG framework designed for both research and application scenarios. It supports **text**, **multimodal**, and **web-based** RAG, providing a **complete RAG pipeline and evaluation process**. With built-in **asynchronous** processing and **persistent caching**, it ensures efficiency and scalability. Easily load retrievers from Hugging Face and quickly build powerful RAG solutions out of the box.
13 |
14 | .. note::
15 | FlexRAG is under active development and is currently in the **alpha** stage. We welcome contributions from the community and are open to feedback and suggestions.
16 |
17 | .. Check out the :doc:`tutorial` section for further information, including how to
18 | .. :ref:`install ` the project.
19 |
20 |
21 | .. toctree::
22 | :maxdepth: 1
23 | :caption: Getting Started:
24 |
25 | getting_started/installation
26 | getting_started/quickstart1
27 | getting_started/quickstart2
28 |
29 | .. toctree::
30 | :maxdepth: 1
31 | :caption: Tutorial:
32 |
33 | tutorial/preparing_corpus
34 | tutorial/preparing_retriever
35 | tutorial/building_assistant
36 | tutorial/entrypoints
37 | tutorial/using_register
38 | tutorial/preparing_web_retriever
39 |
40 | .. toctree::
41 | :maxdepth: 1
42 | :caption: API Reference Manual:
43 |
44 | reference/assistant
45 | reference/chunking
46 | reference/common_dataclass
47 | reference/refiner
48 | reference/datasets
49 | reference/document_parser
50 | reference/encoders
51 | reference/generators
52 | reference/metrics
53 | reference/prompt
54 | reference/retrievers
55 | reference/rankers
56 | reference/tokenizers
57 | reference/text_process
58 | reference/utils
59 |
--------------------------------------------------------------------------------
/docs/source/locales/zh_CN/LC_MESSAGES/getting_started/quickstart2.po:
--------------------------------------------------------------------------------
1 | # SOME DESCRIPTIVE TITLE.
2 | # Copyright (C) 2025, ZhuochengZhang
3 | # This file is distributed under the same license as the FlexRAG
4 | # Documentation package.
5 | # FIRST AUTHOR , 2025.
6 | #
7 | #, fuzzy
8 | msgid ""
9 | msgstr ""
10 | "Project-Id-Version: FlexRAG Documentation \n"
11 | "Report-Msgid-Bugs-To: \n"
12 | "POT-Creation-Date: 2025-03-21 20:53+0800\n"
13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
14 | "Last-Translator: FULL NAME \n"
15 | "Language: zh_CN\n"
16 | "Language-Team: zh_CN \n"
17 | "Plural-Forms: nplurals=1; plural=0;\n"
18 | "MIME-Version: 1.0\n"
19 | "Content-Type: text/plain; charset=utf-8\n"
20 | "Content-Transfer-Encoding: 8bit\n"
21 | "Generated-By: Babel 2.16.0\n"
22 |
23 | #: ../../source/getting_started/quickstart2.md:1
24 | #: 24f56c005d924c62801a30eefe48f5d7
25 | msgid "Quickstart: Building your own RAG application"
26 | msgstr "快速入门:构建您自己的 RAG 应用"
27 |
28 | #: ../../source/getting_started/quickstart2.md:2
29 | #: e3b83077d47f4da4bcee4584d9354d13
30 | msgid ""
31 | "Besides using RAG assistant, you can also import FlexRAG as a library to "
32 | "develop your own RAG applications. FlexRAG provides a flexible and "
33 | "modular API that allows you to customize your RAG application with ease. "
34 | "For example, you can use the following code to build a simple command "
35 | "line RAG QA system:"
36 | msgstr ""
37 | "除去使用 RAG 助手这一概念,您也可以将 FlexRAG 作为库来开发您的 RAG 应用。"
38 | "FlexRAG 提供了一个灵活且模块化的 API 来帮助您构建 RAG 应用。"
39 | "下面的代码向您展示了如何构建一个简单的命令行检索式 QA 系统:"
40 |
41 |
--------------------------------------------------------------------------------
/docs/source/locales/zh_CN/LC_MESSAGES/index.po:
--------------------------------------------------------------------------------
1 | # SOME DESCRIPTIVE TITLE.
2 | # Copyright (C) 2025, ZhuochengZhang
3 | # This file is distributed under the same license as the FlexRAG
4 | # Documentation package.
5 | # FIRST AUTHOR , 2025.
6 | #
7 | #, fuzzy
8 | msgid ""
9 | msgstr ""
10 | "Project-Id-Version: FlexRAG Documentation \n"
11 | "Report-Msgid-Bugs-To: \n"
12 | "POT-Creation-Date: 2025-03-25 10:28+0800\n"
13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
14 | "Last-Translator: FULL NAME \n"
15 | "Language: zh_CN\n"
16 | "Language-Team: zh_CN \n"
17 | "Plural-Forms: nplurals=1; plural=0;\n"
18 | "MIME-Version: 1.0\n"
19 | "Content-Type: text/plain; charset=utf-8\n"
20 | "Content-Transfer-Encoding: 8bit\n"
21 | "Generated-By: Babel 2.16.0\n"
22 |
23 | #: ../../source/index.rst:21
24 | msgid "Getting Started:"
25 | msgstr "入门:"
26 |
27 | #: ../../source/index.rst:29
28 | msgid "Tutorial:"
29 | msgstr "教程:"
30 |
31 | #: ../../source/index.rst:39
32 | msgid "API Reference Manual:"
33 | msgstr "API参考文档:"
34 |
35 | #: ../../source/index.rst:3 036d9484788041d39f31231b7144b122
36 | msgid "FlexRAG"
37 | msgstr "FlexRAG"
38 |
39 | #: ../../source/index.rst:10 1fd4801699a34995aa1755e9f089e498
40 | msgid "Welecome to FlexRAG Documentation"
41 | msgstr "欢迎访问 FlexRAG 文档"
42 |
43 | #: ../../source/index.rst:12 cd90c2b11e6f411fafdc5d885e302d25
44 | msgid ""
45 | "FlexRAG is a highly reproducible, easy-to-use, and high-performance RAG "
46 | "framework designed for both research and application scenarios. It "
47 | "supports **text**, **multimodal**, and **web-based** RAG, providing a "
48 | "**complete RAG pipeline and evaluation process**. With built-in "
49 | "**asynchronous** processing and **persistent caching**, it ensures "
50 | "efficiency and scalability. Easily load retrievers from Hugging Face and "
51 | "quickly build powerful RAG solutions out of the box."
52 | msgstr ""
53 | "FlexRAG 是一个高可复现、易上手、性能优越的 RAG 框架,专为科研与应用场景设计。它支持 **文本、多模态以及网络 "
54 | "RAG** ,提供 **完整的 RAG 流水线与评估流程** ,开箱即用,同时具备高效的 **异步处理与持久化缓存** 能力。通过 Hugging Face "
55 | "轻松加载 Retriever,助力快速搭建强大的 RAG 解决方案。"
56 |
57 | #: ../../source/index.rst:15 3152a79ea0b34b70b433632a7280d1b1
58 | msgid ""
59 | "FlexRAG is under active development and is currently in the **alpha** "
60 | "stage. We welcome contributions from the community and are open to "
61 | "feedback and suggestions."
62 | msgstr "FlexRAG 目前正处于活跃开发的 **alpha** 阶段。我们非常欢迎来自于社区的贡献、反馈或建议。"
63 |
64 | #~ msgid ""
65 | #~ "FlexRAG is a flexible and high-"
66 | #~ "performance framework designed for "
67 | #~ "Retrieval-Augmented Generation (RAG) tasks, "
68 | #~ "offering support for multimodal data, "
69 | #~ "seamless configuration management, and out-"
70 | #~ "of-the-box performance for both "
71 | #~ "research and prototyping."
72 | #~ msgstr "FlexRAG 是一个灵活的高性能检索增强生成 (RAG) 框架,支持多模态数据以及无缝的配置管理,为研究和原型系统设计提供开箱即用的支持。"
73 |
74 |
--------------------------------------------------------------------------------
/docs/source/locales/zh_CN/LC_MESSAGES/reference/common_dataclass.po:
--------------------------------------------------------------------------------
1 | # SOME DESCRIPTIVE TITLE.
2 | # Copyright (C) 2025, ZhuochengZhang
3 | # This file is distributed under the same license as the FlexRAG
4 | # Documentation package.
5 | # FIRST AUTHOR , 2025.
6 | #
7 | #, fuzzy
8 | msgid ""
9 | msgstr ""
10 | "Project-Id-Version: FlexRAG Documentation \n"
11 | "Report-Msgid-Bugs-To: \n"
12 | "POT-Creation-Date: 2025-03-19 16:54+0800\n"
13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
14 | "Last-Translator: FULL NAME \n"
15 | "Language: zh_CN\n"
16 | "Language-Team: zh_CN \n"
17 | "Plural-Forms: nplurals=1; plural=0;\n"
18 | "MIME-Version: 1.0\n"
19 | "Content-Type: text/plain; charset=utf-8\n"
20 | "Content-Transfer-Encoding: 8bit\n"
21 | "Generated-By: Babel 2.16.0\n"
22 |
23 | #: ../../source/reference/common_dataclass.rst:2
24 | #: a19baf5fcc2c4dd6ab490be232845b4a
25 | msgid "Common Dataclass"
26 | msgstr ""
27 |
28 | #: ../../source/reference/common_dataclass.rst:3
29 | #: 1c2ff1e4446e4df8934f948945708ba4
30 | msgid ""
31 | "This module provides several pre-defined dataclasses that are commonly "
32 | "used in the project."
33 | msgstr ""
34 |
35 | #: 6f7d2329532f44668bc7396446318953 bb1caf9a3761461ca518abb0f868280c
36 | #: flexrag.common_dataclass.Context:1
37 | #: flexrag.common_dataclass.RetrievedContext:1 of
38 | msgid "The dataclass for retrieved context."
39 | msgstr ""
40 |
41 | #: ../../source/reference/common_dataclass.rst 3a6ac042dba747e79cf02646da47ed08
42 | #: 451a02ea6b084d1dbefe5b42ce700e75 a7299aa836b54388a9c3b5bf05a10e7e
43 | #: d06c09ba34ae4c8f97186e56162c3cfd
44 | msgid "Parameters"
45 | msgstr ""
46 |
47 | #: c8ba7a884918444abca911a60c350250 flexrag.common_dataclass.Context:3 of
48 | msgid "The unique identifier of the context. Default: None."
49 | msgstr ""
50 |
51 | #: 0f42874e10a549d1a8c7e53515818e70 flexrag.common_dataclass.Context:5 of
52 | msgid "The context data. Default: {}."
53 | msgstr ""
54 |
55 | #: c75ed32b080545568c648a9a94a76f22 flexrag.common_dataclass.Context:7 of
56 | msgid "The source of the retrieved data. Default: None."
57 | msgstr ""
58 |
59 | #: 3098971d8305433ab2eae75a0481dd28 flexrag.common_dataclass.Context:9 of
60 | msgid "The metadata of the context. Default: {}."
61 | msgstr ""
62 |
63 | #: a25525af6c5547eea661cf145a3d3e7b flexrag.common_dataclass.RetrievedContext:3
64 | #: of
65 | msgid "The name of the retriever. Required."
66 | msgstr ""
67 |
68 | #: e6b598a94438440c8ea864c936636baf flexrag.common_dataclass.RetrievedContext:5
69 | #: of
70 | msgid "The query for retrieval. Required."
71 | msgstr ""
72 |
73 | #: f027bec0202942e895b6849aed67ecd3 flexrag.common_dataclass.RetrievedContext:7
74 | #: of
75 | msgid "The relevance score of the retrieved data. Default: 0.0."
76 | msgstr ""
77 |
78 | #: 6e64ac7b12d64287878456cf58eb4802 flexrag.common_dataclass.RAGEvalData:1 of
79 | msgid "The dataclass for RAG evaluation data."
80 | msgstr ""
81 |
82 | #: 45774b4d2e7747b48630e143844e4ee8 bc41b1dfd24148a9aeafbe7f9323332c
83 | #: flexrag.common_dataclass.IREvalData:3 flexrag.common_dataclass.RAGEvalData:3
84 | #: of
85 | msgid "The question for evaluation. Required."
86 | msgstr ""
87 |
88 | #: 5ed64fae32744097b1a22a1b48b42283 61a97b8a7bb343079faf21e67c2597c2
89 | #: flexrag.common_dataclass.IREvalData:5 flexrag.common_dataclass.RAGEvalData:5
90 | #: of
91 | msgid "The contexts related to the question. Default: None."
92 | msgstr ""
93 |
94 | #: 253237efa5e14e83838d4dd43f4524eb flexrag.common_dataclass.RAGEvalData:7 of
95 | msgid "The golden answers for the question. Default: None."
96 | msgstr ""
97 |
98 | #: 74f3665080794970b5d8e6431f9db399 f96337eee3a84a29ae2511771a095513
99 | #: flexrag.common_dataclass.IREvalData:7 flexrag.common_dataclass.RAGEvalData:9
100 | #: of
101 | msgid "The metadata of the evaluation data. Default: {}."
102 | msgstr ""
103 |
104 | #: 8665df545a604c22a992a078e3a45e94 flexrag.common_dataclass.IREvalData:1 of
105 | msgid "The dataclass for Information Retrieval evaluation data."
106 | msgstr ""
107 |
108 |
--------------------------------------------------------------------------------
/docs/source/locales/zh_CN/LC_MESSAGES/reference/document_parser.po:
--------------------------------------------------------------------------------
1 | # SOME DESCRIPTIVE TITLE.
2 | # Copyright (C) 2025, ZhuochengZhang
3 | # This file is distributed under the same license as the FlexRAG
4 | # Documentation package.
5 | # FIRST AUTHOR , 2025.
6 | #
7 | #, fuzzy
8 | msgid ""
9 | msgstr ""
10 | "Project-Id-Version: FlexRAG Documentation \n"
11 | "Report-Msgid-Bugs-To: \n"
12 | "POT-Creation-Date: 2025-03-19 16:54+0800\n"
13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
14 | "Last-Translator: FULL NAME \n"
15 | "Language: zh_CN\n"
16 | "Language-Team: zh_CN \n"
17 | "Plural-Forms: nplurals=1; plural=0;\n"
18 | "MIME-Version: 1.0\n"
19 | "Content-Type: text/plain; charset=utf-8\n"
20 | "Content-Transfer-Encoding: 8bit\n"
21 | "Generated-By: Babel 2.16.0\n"
22 |
23 | #: ../../source/reference/document_parser.rst:2
24 | #: f5192c1f7fdb435c93337dc462ddfea2
25 | msgid "Document Parser"
26 | msgstr ""
27 |
28 | #: ../../source/reference/document_parser.rst:3
29 | #: a800ad80627a4cc197c8f82929ee88a3
30 | msgid ""
31 | "This module provides a set of classes and functions for parsing a "
32 | "formated document (such as PDF, Word, etc.) into a structured format."
33 | msgstr ""
34 |
35 | #: 9054257c79d94d72b8201d4dd3b6aed3
36 | #: flexrag.document_parser.document_parser_base.Document:1 of
37 | msgid "A document parsed by a DocumentParser."
38 | msgstr ""
39 |
40 | #: 48bb0cf6c9fb48a89550231171ea40bb
41 | #: flexrag.document_parser.document_parser_base.DocumentParserBase.parse:1 of
42 | msgid "Parse the document at the given path."
43 | msgstr ""
44 |
45 | #: ../../source/reference/document_parser.rst 054f7ec99e0c4dbc94b396574d62f3ee
46 | msgid "Parameters"
47 | msgstr ""
48 |
49 | #: 3fb9e3c019b84636b78dbf15fa78adcb
50 | #: flexrag.document_parser.document_parser_base.DocumentParserBase.parse:3 of
51 | msgid "The path to the document to parse."
52 | msgstr ""
53 |
54 | #: ../../source/reference/document_parser.rst 13ada7ac5c144e8e859b78bd0b28273a
55 | msgid "Returns"
56 | msgstr ""
57 |
58 | #: 6b907db4872841a6bfa98da0f6f6d29e
59 | #: flexrag.document_parser.document_parser_base.DocumentParserBase.parse:5 of
60 | msgid "The parsed document."
61 | msgstr ""
62 |
63 | #: ../../source/reference/document_parser.rst 91a749f7d3d44ff08e7cfee65cc270e3
64 | msgid "Return type"
65 | msgstr ""
66 |
67 | #: 88cffbd67e54483e930e18ef9389d0f4 d65e8fd3edb54480a8172de2bc088151
68 | #: flexrag.document_parser.docling_parser.DoclingParser:1
69 | #: flexrag.document_parser.markitdown_parser.MarkItDownParser:1 of
70 | msgid ""
71 | "Bases: "
72 | ":py:class:`~flexrag.document_parser.document_parser_base.DocumentParserBase`"
73 | msgstr ""
74 |
75 |
--------------------------------------------------------------------------------
/docs/source/locales/zh_CN/LC_MESSAGES/reference/prompt.po:
--------------------------------------------------------------------------------
1 | # SOME DESCRIPTIVE TITLE.
2 | # Copyright (C) 2025, ZhuochengZhang
3 | # This file is distributed under the same license as the FlexRAG
4 | # Documentation package.
5 | # FIRST AUTHOR , 2025.
6 | #
7 | #, fuzzy
8 | msgid ""
9 | msgstr ""
10 | "Project-Id-Version: FlexRAG Documentation \n"
11 | "Report-Msgid-Bugs-To: \n"
12 | "POT-Creation-Date: 2025-03-19 16:54+0800\n"
13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
14 | "Last-Translator: FULL NAME \n"
15 | "Language: zh_CN\n"
16 | "Language-Team: zh_CN \n"
17 | "Plural-Forms: nplurals=1; plural=0;\n"
18 | "MIME-Version: 1.0\n"
19 | "Content-Type: text/plain; charset=utf-8\n"
20 | "Content-Transfer-Encoding: 8bit\n"
21 | "Generated-By: Babel 2.16.0\n"
22 |
23 | #: ../../source/reference/prompt.rst:2 66bc65f470714a2ebb8ce73e6ecb1663
24 | msgid "Prompt"
25 | msgstr ""
26 |
27 | #: ../../source/reference/prompt.rst:3 b05bc7889b6c43ec853dab6ac39faaa4
28 | msgid ""
29 | "This module provides two classes namely `ChatPrompt` and `ChatTemplate`. "
30 | "The `ChatPrompt` is used to store the system prompt, chat history, and "
31 | "demonstrations used to interact with the `Generator`. The `ChatTemplate` "
32 | "is used to convert the `ChatPrompt` into a string or a list of tokens "
33 | "that can be used by the model."
34 | msgstr ""
35 |
36 | #: ../../source/reference/prompt.rst:6 d7ec38b189c844d5a8aaa543cc9d5f8d
37 | msgid "Chat Prompt"
38 | msgstr ""
39 |
40 | #: 2962809d7ac14f2c97c440883b21f7e7
41 | #: flexrag.prompt.prompt_base.MultiModelChatPrompt:1 of
42 | msgid ""
43 | "This class shares almost all the methods with ChatPrompt. However, the "
44 | "Generics in Python does not support calling the TypeVar's classmethod. So"
45 | " we have to duplicate the code here."
46 | msgstr ""
47 |
48 | #: ../../source/reference/prompt.rst:26 ffdd017ff3be4a70af01cec9c4c4f828
49 | msgid "Template"
50 | msgstr ""
51 |
52 | #: d8172e9b7fd64f8193b875942ef3e901 flexrag.prompt.template.HFTemplate:1 of
53 | msgid "Bases: :py:class:`~flexrag.prompt.template.ChatTemplate`"
54 | msgstr ""
55 |
56 | #: a83d04b2aa2c49579f0ada3e332eae0a flexrag.prompt.template.load_template:1 of
57 | msgid ""
58 | "Load ChatTemplate for different models. If model_name is not provided, "
59 | "the default template in the Tokenizer will be used."
60 | msgstr ""
61 |
62 | #: ../../source/reference/prompt.rst ce3e83f32e27441680ec2e04227fc204
63 | msgid "Parameters"
64 | msgstr ""
65 |
66 | #: ebcd3a792ddd4b6c8b3d4002fa1b81d0 flexrag.prompt.template.load_template:3 of
67 | msgid "The tokenizer used to encode the prompt."
68 | msgstr ""
69 |
70 | #: 8462d283d9a148fbba18869b2486d3ea flexrag.prompt.template.load_template:4 of
71 | msgid "The name of the model. Default is None."
72 | msgstr ""
73 |
74 | #: ../../source/reference/prompt.rst e763f1c2a17245aea63e6c059e72dcb9
75 | msgid "Returns"
76 | msgstr ""
77 |
78 | #: ebce22c7ffb54b5294101bc34713bc15 flexrag.prompt.template.load_template:7 of
79 | msgid "The loaded ChatTemplate"
80 | msgstr ""
81 |
82 | #: ../../source/reference/prompt.rst b225b7df4078422ca91df70a3c4b0fc6
83 | msgid "Return type"
84 | msgstr ""
85 |
86 |
--------------------------------------------------------------------------------
/docs/source/locales/zh_CN/LC_MESSAGES/reference/text_process.po:
--------------------------------------------------------------------------------
1 | # SOME DESCRIPTIVE TITLE.
2 | # Copyright (C) 2025, ZhuochengZhang
3 | # This file is distributed under the same license as the FlexRAG
4 | # Documentation package.
5 | # FIRST AUTHOR , 2025.
6 | #
7 | #, fuzzy
8 | msgid ""
9 | msgstr ""
10 | "Project-Id-Version: FlexRAG Documentation \n"
11 | "Report-Msgid-Bugs-To: \n"
12 | "POT-Creation-Date: 2025-03-19 16:54+0800\n"
13 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
14 | "Last-Translator: FULL NAME \n"
15 | "Language: zh_CN\n"
16 | "Language-Team: zh_CN \n"
17 | "Plural-Forms: nplurals=1; plural=0;\n"
18 | "MIME-Version: 1.0\n"
19 | "Content-Type: text/plain; charset=utf-8\n"
20 | "Content-Transfer-Encoding: 8bit\n"
21 | "Generated-By: Babel 2.16.0\n"
22 |
23 | #: ../../source/reference/text_process.rst:2 bcafd7856e7f47549f7854330be2a356
24 | msgid "Text Processing"
25 | msgstr ""
26 |
27 | #: ../../source/reference/text_process.rst:3 6e070e6188914335a88aba4c0859e75a
28 | msgid ""
29 | "This module provides a set of classes and functions for preprocessing and"
30 | " filtering texts, including normalization, length filtering, etc."
31 | msgstr ""
32 |
33 | #: 917b95e3e8e944ac92ed02d7b0e4398d
34 | #: flexrag.text_process.processor.Processor.__call__:1 of
35 | msgid ""
36 | "Process the input text. If the processor has been filtered, the reserved "
37 | "flag of the input TextUnit will be set to False."
38 | msgstr ""
39 |
40 | #: ../../source/reference/text_process.rst c66af0aa305d4d899831600d8aae2abf
41 | #: cdc6caf16dfc4972add779030801799b
42 | msgid "Parameters"
43 | msgstr ""
44 |
45 | #: 941f41a569eb4a1990b87bfd5dc43b79
46 | #: flexrag.text_process.processor.Processor.__call__:4 of
47 | msgid "The input text to process."
48 | msgstr ""
49 |
50 | #: ../../source/reference/text_process.rst 89a47ae85377414db31cd68e160a560d
51 | msgid "Returns"
52 | msgstr ""
53 |
54 | #: 3bd6f2ef3122417894a12434795eb3fc
55 | #: flexrag.text_process.processor.Processor.__call__:6 of
56 | msgid "The processed text."
57 | msgstr ""
58 |
59 | #: ../../source/reference/text_process.rst 282fe9bb2d4d43f89d1e2121f3485155
60 | msgid "Return type"
61 | msgstr ""
62 |
63 | #: 9e8d7b2691e24a01bb225cbe77a1f7d1 of types.TextProcessPipelineConfig:1
64 | msgid ""
65 | "Configuration class for processor (name: TextProcessPipelineConfig, "
66 | "default: ???)."
67 | msgstr ""
68 |
69 | #: bb2876a2f865424b924c65545a9c6cc2 of types.TextProcessPipelineConfig:3
70 | msgid "The processor type to use."
71 | msgstr ""
72 |
73 | #: 8cfe929bef594bcdba3e43eea71b21df of types.TextProcessPipelineConfig:5
74 | msgid "The config for LengthFilter."
75 | msgstr ""
76 |
77 | #: 51df3dd1e06b4c9183e6c122db090a16 of types.TextProcessPipelineConfig:7
78 | msgid "The config for TokenNormalizer."
79 | msgstr ""
80 |
81 | #: 59d68c7cf96342119bbff9cb855ba3ed of types.TextProcessPipelineConfig:9
82 | msgid "The config for Truncator."
83 | msgstr ""
84 |
85 | #: 14b63f7efb154a029c233a74a1b7aedf 2c1227cc8ac74dbf8e00e1779f917223
86 | #: 3c75a8b2130941a3a996d9257f580ec3 4b23451799734d3c8259b93dc1fc695f
87 | #: 833fd71b2baa41cd89a4ff8bda7db581 89c0c8b6f0f044f49b519e08069613ba
88 | #: e4a8cc0a00ce498d81029cd80bd400c8
89 | #: flexrag.text_process.basic_filters.ExactDeduplicate:1
90 | #: flexrag.text_process.basic_processors.AnswerSimplifier:1
91 | #: flexrag.text_process.basic_processors.ChineseSimplifier:1
92 | #: flexrag.text_process.basic_processors.Lowercase:1
93 | #: flexrag.text_process.basic_processors.TokenNormalizer:1
94 | #: flexrag.text_process.basic_processors.Truncator:1
95 | #: flexrag.text_process.basic_processors.Unifier:1 of
96 | msgid "Bases: :py:class:`~flexrag.text_process.processor.Processor`"
97 | msgstr ""
98 |
99 |
--------------------------------------------------------------------------------
/docs/source/reference/assistant.rst:
--------------------------------------------------------------------------------
1 | Assistant
2 | =========
3 |
4 | The ``Assistant`` class serves as an abstraction for Retrieval-Augmented Generation (RAG) behavior. It takes the user's query as input and returns an appropriate response. This class provides a flexible interface for defining how the assistant handles queries, including whether a retrieval step is required, how the retrieval should be conducted, and how the assistant generates the response based on the retrieved information.
5 |
6 | The Assistant Interface
7 | -----------------------
8 | ``AssistantBase`` is the base class for all assistants. It provides a simple interface for answering a user query. The answering process is controlled by a configuration object that is passed to the assistant's constructor.
9 |
10 | .. autoclass:: flexrag.assistant.AssistantBase
11 | :members:
12 | :inherited-members:
13 |
14 | FlexRAG Assistants
15 | ------------------
16 | FlexRAG provides several assistant implementations that can be used out of the box. These implementations are designed to be flexible and extensible, allowing users to customize the assistant's behavior by providing their own retrieval and generation components.
17 |
18 | .. autoclass:: flexrag.assistant.BasicAssistantConfig
19 | :members:
20 | :show-inheritance:
21 |
22 | .. autoclass:: flexrag.assistant.BasicAssistant
23 | :members:
24 | :show-inheritance:
25 | :exclude-members: answer
26 |
27 | .. autoclass:: flexrag.assistant.ModularAssistantConfig
28 | :members:
29 | :show-inheritance:
30 |
31 | .. autoclass:: flexrag.assistant.ModularAssistant
32 | :members:
33 | :show-inheritance:
34 | :exclude-members: answer, search, answer_with_contexts
35 |
36 | .. autoclass:: flexrag.assistant.ChatQAAssistant
37 | :members:
38 | :show-inheritance:
39 | :exclude-members: answer, get_formatted_input, answer_with_contexts
--------------------------------------------------------------------------------
/docs/source/reference/chunking.rst:
--------------------------------------------------------------------------------
1 | Chunking
2 | ========
3 | This module provides a set of classes for chunking a long text into smaller chunks.
4 |
5 |
6 | The Chunker Interface
7 | ---------------------
8 | `ChunkerBase` is the base class for all chunkers.
9 | It provides a simple interface for chunking a text into smaller chunks.
10 | The chunking process is controlled by a configuration object that is passed to the chunker's constructor.
11 |
12 | .. autoclass:: flexrag.chunking.ChunkerBase
13 | :members:
14 | :inherited-members:
15 |
16 |
17 | Chunkers
18 | --------
19 |
20 | .. autoclass:: flexrag.chunking.CharChunkerConfig
21 | :members:
22 | :inherited-members:
23 |
24 | .. autoclass:: flexrag.chunking.CharChunker
25 | :members:
26 | :show-inheritance:
27 |
28 | .. autoclass:: flexrag.chunking.TokenChunkerConfig
29 | :members:
30 | :inherited-members:
31 | :show-inheritance:
32 |
33 | .. autoclass:: flexrag.chunking.TokenChunker
34 | :members:
35 | :show-inheritance:
36 |
37 | .. autoclass:: flexrag.chunking.RecursiveChunkerConfig
38 | :members:
39 | :inherited-members:
40 | :show-inheritance:
41 |
42 | .. autoclass:: flexrag.chunking.RecursiveChunker
43 | :members:
44 | :show-inheritance:
45 |
46 | .. autoclass:: flexrag.chunking.SentenceChunkerConfig
47 | :members:
48 | :inherited-members:
49 | :show-inheritance:
50 |
51 | .. autoclass:: flexrag.chunking.SentenceChunker
52 | :members:
53 | :show-inheritance:
54 |
55 | .. autoclass:: flexrag.chunking.SemanticChunkerConfig
56 | :members:
57 | :inherited-members:
58 | :show-inheritance:
59 |
60 | .. autoclass:: flexrag.chunking.SemanticChunker
61 | :members:
62 | :show-inheritance:
63 |
64 |
65 | Sentence Splitters
66 | ------------------
67 | This submodule provides a set of useful tools for splitting a text into sentences.
68 |
69 | .. autoclass:: flexrag.chunking.sentence_splitter.SentenceSplitterBase
70 | :members:
71 | :inherited-members:
72 |
73 | .. autoclass:: flexrag.chunking.sentence_splitter.NLTKSentenceSplitterConfig
74 | :members:
75 | :inherited-members:
76 |
77 | .. autoclass:: flexrag.chunking.sentence_splitter.NLTKSentenceSplitter
78 | :members:
79 | :show-inheritance:
80 |
81 | .. autoclass:: flexrag.chunking.sentence_splitter.RegexSplitterConfig
82 | :members:
83 | :inherited-members:
84 |
85 | .. autoclass:: flexrag.chunking.sentence_splitter.RegexSplitter
86 | :members:
87 | :show-inheritance:
88 |
89 | .. autoclass:: flexrag.chunking.sentence_splitter.SpacySentenceSplitterConfig
90 | :members:
91 | :inherited-members:
92 |
93 | .. autoclass:: flexrag.chunking.sentence_splitter.SpacySentenceSplitter
94 | :members:
95 | :show-inheritance:
96 |
97 | .. autoattribute:: flexrag.chunking.sentence_splitter.PREDEFINED_SPLIT_PATTERNS
98 |
99 | A dictionary of predefined sentence splitting patterns.
100 | The keys are the names of the patterns, and the values are the corresponding regular expressions.
101 | Currently, ``FlexRAG`` provides 2 sets of predefined patterns: "en" for English and "zh" for Chinese.
102 | Please refer to the source code for more details.
103 |
104 | General Configuration
105 | ---------------------
106 | The configuration provides a general interface for loading and configurate the chunker or the sentence splitter.
107 |
108 | .. autoclass:: flexrag.chunking.ChunkerConfig
109 | :members:
110 | :inherited-members:
111 |
112 | .. autoclass:: flexrag.chunking.sentence_splitter.SentenceSplitterConfig
113 | :members:
114 | :inherited-members:
115 |
--------------------------------------------------------------------------------
/docs/source/reference/common_dataclass.rst:
--------------------------------------------------------------------------------
1 | Common Dataclass
2 | ================
3 | This module provides several pre-defined dataclasses that are commonly used in the project.
4 |
5 | .. autoclass:: flexrag.common_dataclass.Context
6 | :members:
7 | :inherited-members:
8 |
9 | .. autoclass:: flexrag.common_dataclass.RetrievedContext
10 | :members:
11 | :inherited-members:
12 |
13 | .. autoclass:: flexrag.common_dataclass.RAGEvalData
14 | :members:
15 | :inherited-members:
16 |
17 | .. autoclass:: flexrag.common_dataclass.IREvalData
18 | :members:
19 | :inherited-members:
20 |
--------------------------------------------------------------------------------
/docs/source/reference/datasets.rst:
--------------------------------------------------------------------------------
1 | Datasets
2 | ========
3 | This module provides a set of classes and functions for loading and processing datasets.
4 |
5 | .. BaseClasses
6 | .. autoclass:: flexrag.datasets.IterableDataset
7 | :members:
8 | :inherited-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: flexrag.datasets.MappingDataset
12 | :members:
13 | :inherited-members:
14 | :show-inheritance:
15 |
16 | .. autoclass:: flexrag.datasets.ChainDataset
17 | :members:
18 | :inherited-members:
19 | :show-inheritance:
20 |
21 | .. autoclass:: flexrag.datasets.ConcatDataset
22 | :members:
23 | :inherited-members:
24 | :show-inheritance:
25 |
26 | .. HuggingFace Datasets
27 | .. autoclass:: flexrag.datasets.HFDatasetConfig
28 | :members:
29 | :inherited-members:
30 | :show-inheritance:
31 |
32 | .. autoclass:: flexrag.datasets.HFDataset
33 | :members:
34 | :inherited-members:
35 | :show-inheritance:
36 |
37 | .. Line Delimited Dataset
38 | .. autoclass:: flexrag.datasets.LineDelimitedDatasetConfig
39 | :members:
40 | :inherited-members:
41 | :show-inheritance:
42 |
43 | .. autoclass:: flexrag.datasets.LineDelimitedDataset
44 | :members:
45 | :inherited-members:
46 | :show-inheritance:
47 |
48 | .. RAG Datasets
49 | .. _RAGEvalDatasetConfig:
50 |
51 | .. autoclass:: flexrag.datasets.RAGEvalDatasetConfig
52 | :members:
53 | :inherited-members:
54 | :show-inheritance:
55 |
56 | .. autoclass:: flexrag.datasets.RAGEvalDataset
57 | :members:
58 | :inherited-members:
59 | :show-inheritance:
60 |
61 | .. autoclass:: flexrag.datasets.RAGCorpusDatasetConfig
62 | :members:
63 | :inherited-members:
64 | :show-inheritance:
65 |
66 | .. autoclass:: flexrag.datasets.RAGCorpusDataset
67 | :members:
68 | :inherited-members:
69 | :show-inheritance:
70 |
71 | .. Retrieval Datasets
72 | .. autoclass:: flexrag.datasets.MTEBDatasetConfig
73 | :members:
74 | :inherited-members:
75 | :show-inheritance:
76 |
77 | .. autoclass:: flexrag.datasets.MTEBDataset
78 | :members:
79 | :inherited-members:
80 | :show-inheritance:
81 |
--------------------------------------------------------------------------------
/docs/source/reference/document_parser.rst:
--------------------------------------------------------------------------------
1 | Document Parser
2 | ===============
3 | This module provides a set of classes and functions for parsing a formated document (such as PDF, Word, etc.) into a structured format.
4 |
5 | .. autoclass:: flexrag.document_parser.Document
6 | :members:
7 | :inherited-members:
8 |
9 | .. autoclass:: flexrag.document_parser.DocumentParserBase
10 | :members:
11 | :inherited-members:
12 |
13 | .. autoclass:: flexrag.document_parser.DoclingConfig
14 | :members:
15 | :inherited-members:
16 |
17 | .. autoclass:: flexrag.document_parser.DoclingParser
18 | :members:
19 | :show-inheritance:
20 | :exclude-members: parse
21 |
22 | .. autoclass:: flexrag.document_parser.MarkItDownParser
23 | :members:
24 | :show-inheritance:
25 | :exclude-members: parse
--------------------------------------------------------------------------------
/docs/source/reference/encoders.rst:
--------------------------------------------------------------------------------
1 | Encoders
2 | ========
3 |
4 | .. autoclass:: flexrag.models.EncoderBase
5 | :members:
6 | :inherited-members:
7 |
8 | .. autoclass:: flexrag.models.EncoderConfig
9 | :members:
10 | :inherited-members:
11 |
12 |
13 | Local Encoders
14 | --------------
15 |
16 | .. Huggingface Encoders
17 | .. autoclass:: flexrag.models.HFEncoderConfig
18 | :members:
19 | :inherited-members:
20 |
21 | .. autoclass:: flexrag.models.HFEncoder
22 | :members:
23 | :show-inheritance:
24 | :exclude-members: async_encode, encode
25 |
26 |
27 | .. Huggingface Clip Encoders
28 | .. autoclass:: flexrag.models.HFClipEncoderConfig
29 | :members:
30 | :inherited-members:
31 |
32 | .. autoclass:: flexrag.models.HFClipEncoder
33 | :members:
34 | :show-inheritance:
35 | :exclude-members: async_encode, encode
36 |
37 |
38 | .. Ollama Encoders
39 | .. autoclass:: flexrag.models.OllamaEncoderConfig
40 | :members:
41 | :inherited-members:
42 |
43 | .. autoclass:: flexrag.models.OllamaEncoder
44 | :members:
45 | :show-inheritance:
46 | :exclude-members: async_encode, encode
47 |
48 |
49 | .. Sentence Transformers Encoders
50 | .. autoclass:: flexrag.models.SentenceTransformerEncoderConfig
51 | :members:
52 | :inherited-members:
53 |
54 | .. autoclass:: flexrag.models.SentenceTransformerEncoder
55 | :members:
56 | :show-inheritance:
57 | :exclude-members: async_encode, encode
58 |
59 |
60 | Oneline Encoders
61 | ----------------
62 |
63 | .. Coherence Encoders
64 | .. autoclass:: flexrag.models.CohereEncoderConfig
65 | :members:
66 | :inherited-members:
67 |
68 | .. autoclass:: flexrag.models.CohereEncoder
69 | :members:
70 | :show-inheritance:
71 | :exclude-members: async_encode, encode
72 |
73 |
74 | .. JinaAI Encoders
75 | .. autoclass:: flexrag.models.JinaEncoderConfig
76 | :members:
77 | :inherited-members:
78 |
79 | .. autoclass:: flexrag.models.JinaEncoder
80 | :members:
81 | :show-inheritance:
82 | :exclude-members: async_encode, encode
83 |
84 |
85 | .. OpenAI Encoders
86 | .. autoclass:: flexrag.models.OpenAIEncoderConfig
87 | :members:
88 | :show-inheritance:
89 | :inherited-members:
90 |
91 | .. autoclass:: flexrag.models.OpenAIEncoder
92 | :members:
93 | :show-inheritance:
94 | :exclude-members: async_encode, encode
--------------------------------------------------------------------------------
/docs/source/reference/generators.rst:
--------------------------------------------------------------------------------
1 | Generators
2 | ==========
3 |
4 | .. autoclass:: flexrag.models.GeneratorBase
5 | :members:
6 | :inherited-members:
7 |
8 |
9 | .. autoclass:: flexrag.models.GenerationConfig
10 | :members:
11 | :inherited-members:
12 |
13 | .. autoclass:: flexrag.models.GeneratorConfig
14 | :members:
15 | :inherited-members:
16 |
17 |
18 | Local Generators
19 | ----------------
20 |
21 | .. Hugging Face Generators
22 | .. autoclass:: flexrag.models.HFModelConfig
23 | :members:
24 | :inherited-members:
25 |
26 |
27 | .. autoclass:: flexrag.models.HFGeneratorConfig
28 | :members:
29 | :show-inheritance:
30 |
31 | .. autoclass:: flexrag.models.HFGenerator
32 | :members:
33 | :show-inheritance:
34 | :exclude-members: async_chat, async_generate, chat, generate
35 |
36 |
37 | .. Llamacpp Generators
38 | .. autoclass:: flexrag.models.LlamacppGeneratorConfig
39 | :members:
40 | :inherited-members:
41 |
42 | .. autoclass:: flexrag.models.LlamacppGenerator
43 | :members:
44 | :show-inheritance:
45 | :exclude-members: async_chat, async_generate, chat, generate
46 |
47 |
48 | .. Ollama Generators
49 | .. autoclass:: flexrag.models.OllamaGeneratorConfig
50 | :members:
51 | :inherited-members:
52 |
53 | .. autoclass:: flexrag.models.OllamaGenerator
54 | :members:
55 | :show-inheritance:
56 | :exclude-members: async_chat, async_generate, chat, generate
57 |
58 | .. VLLM Generators
59 | .. autoclass:: flexrag.models.VLLMGeneratorConfig
60 | :members:
61 | :inherited-members:
62 |
63 | .. autoclass:: flexrag.models.VLLMGenerator
64 | :members:
65 | :show-inheritance:
66 | :exclude-members: async_chat, async_generate, chat, generate
67 |
68 |
69 | Online Generators
70 | -----------------
71 |
72 | .. Anthropic Generators
73 | .. autoclass:: flexrag.models.AnthropicGeneratorConfig
74 | :members:
75 | :inherited-members:
76 |
77 | .. autoclass:: flexrag.models.AnthropicGenerator
78 | :members:
79 | :show-inheritance:
80 | :exclude-members: async_chat, async_generate, chat, generate
81 |
82 | .. OpenAI Generators
83 | .. autoclass:: flexrag.models.OpenAIConfig
84 | :members:
85 | :show-inheritance:
86 | :inherited-members:
87 |
88 | .. autoclass:: flexrag.models.OpenAIGeneratorConfig
89 | :members:
90 | :show-inheritance:
91 |
92 | .. autoclass:: flexrag.models.OpenAIGenerator
93 | :members:
94 | :show-inheritance:
95 | :exclude-members: async_chat, async_generate, chat, generate
96 |
97 |
98 | Visual Language Model Generators
99 | --------------------------------
100 |
101 | .. autoclass:: flexrag.models.VLMGeneratorBase
102 | :members:
103 | :inherited-members:
104 | :show-inheritance:
105 |
106 | .. HF VLM Generators
107 | .. autoclass:: flexrag.models.HFVLMGeneratorConfig
108 | :members:
109 | :show-inheritance:
110 | :inherited-members:
111 |
112 | .. autoclass:: flexrag.models.HFVLMGenerator
113 | :members:
114 | :show-inheritance:
115 | :exclude-members: chat, generate
--------------------------------------------------------------------------------
/docs/source/reference/metrics.rst:
--------------------------------------------------------------------------------
1 | Metrics
2 | =======
3 |
4 | This module contains functions for evaluating the performance of a RAG assistant or a retriever.
5 |
6 | .. autoclass:: flexrag.metrics.MetricsBase
7 | :members:
8 | :inherited-members:
9 |
10 |
11 | Helper Class
12 | ------------
13 | The RAGEvaluator takes a list of metrics and evaluates the performance of a RAG assistant or a retriever.
14 |
15 | .. autoclass:: flexrag.metrics.EvaluatorConfig
16 | :members:
17 | :inherited-members:
18 |
19 | .. autoclass:: flexrag.metrics.Evaluator
20 | :members:
21 | :show-inheritance:
22 |
23 |
24 | RAG Generation Metrics
25 | ----------------------
26 |
27 | .. autoclass:: flexrag.metrics.BLEUConfig
28 | :members:
29 | :inherited-members:
30 |
31 | .. autoclass:: flexrag.metrics.BLEU
32 | :members:
33 | :show-inheritance:
34 | :exclude-members: compute
35 |
36 | .. autoclass:: flexrag.metrics.Rouge
37 | :members:
38 | :show-inheritance:
39 |
40 | .. autoclass:: flexrag.metrics.chrFConfig
41 | :members:
42 | :inherited-members:
43 |
44 | .. autoclass:: flexrag.metrics.chrF
45 | :members:
46 | :show-inheritance:
47 | :exclude-members: compute
48 |
49 | .. autoclass:: flexrag.metrics.F1
50 | :members:
51 | :show-inheritance:
52 |
53 | .. autoclass:: flexrag.metrics.Accuracy
54 | :members:
55 | :show-inheritance:
56 |
57 | .. autoclass:: flexrag.metrics.ExactMatch
58 | :members:
59 | :show-inheritance:
60 |
61 | .. autoclass:: flexrag.metrics.Precision
62 | :members:
63 | :show-inheritance:
64 |
65 | .. autoclass:: flexrag.metrics.Recall
66 | :members:
67 | :show-inheritance:
68 |
69 | Information Retrieval Metrics
70 | -----------------------------
71 |
72 | .. autoclass:: flexrag.metrics.SuccessRateConfig
73 | :members:
74 | :inherited-members:
75 |
76 | .. autoclass:: flexrag.metrics.SuccessRate
77 | :members:
78 | :show-inheritance:
79 | :exclude-members: compute
80 |
81 | .. autoclass:: flexrag.metrics.RetrievalRecallConfig
82 | :members:
83 | :inherited-members:
84 |
85 | .. autoclass:: flexrag.metrics.RetrievalRecall
86 | :members:
87 | :show-inheritance:
88 | :exclude-members: compute
89 |
90 | .. autoclass:: flexrag.metrics.RetrievalPrecisionConfig
91 | :members:
92 | :inherited-members:
93 |
94 | .. autoclass:: flexrag.metrics.RetrievalPrecision
95 | :members:
96 | :show-inheritance:
97 | :exclude-members: compute
98 |
99 | .. autoclass:: flexrag.metrics.RetrievalMAPConfig
100 | :members:
101 | :inherited-members:
102 |
103 | .. autoclass:: flexrag.metrics.RetrievalMAP
104 | :members:
105 | :show-inheritance:
106 | :exclude-members: compute
107 |
108 | .. autoclass:: flexrag.metrics.RetrievalNDCGConfig
109 | :members:
110 | :inherited-members:
111 |
112 | .. autoclass:: flexrag.metrics.RetrievalNDCG
113 | :members:
114 | :show-inheritance:
115 | :exclude-members: compute
116 |
--------------------------------------------------------------------------------
/docs/source/reference/prompt.rst:
--------------------------------------------------------------------------------
1 | Prompt
2 | ======
3 | This module provides two classes namely `ChatPrompt` and `ChatTemplate`. The `ChatPrompt` is used to store the system prompt, chat history, and demonstrations used to interact with the `Generator`. The `ChatTemplate` is used to convert the `ChatPrompt` into a string or a list of tokens that can be used by the model.
4 |
5 | Chat Prompt
6 | -----------
7 |
8 | .. autoclass:: flexrag.prompt.ChatTurn
9 | :members:
10 | :inherited-members:
11 |
12 | .. autoclass:: flexrag.prompt.ChatPrompt
13 | :members:
14 | :inherited-members:
15 |
16 | .. autoclass:: flexrag.prompt.MultiModelChatTurn
17 | :members:
18 | :inherited-members:
19 |
20 | .. autoclass:: flexrag.prompt.MultiModelChatPrompt
21 | :members:
22 | :inherited-members:
23 |
24 |
25 | Template
26 | --------
27 |
28 | .. autoclass:: flexrag.prompt.ChatTemplate
29 | :members:
30 | :inherited-members:
31 |
32 | .. autoclass:: flexrag.prompt.HFTemplate
33 | :members:
34 | :show-inheritance:
35 |
36 | .. autofunction:: flexrag.prompt.load_template
37 |
--------------------------------------------------------------------------------
/docs/source/reference/rankers.rst:
--------------------------------------------------------------------------------
1 | Rankers
2 | =======
3 |
4 | The ranker is the component that determines the order of the results returned by the retriever. FlexRAG provides several rankers that can be used to sort the results based on various criteria.
5 |
6 | .. autoclass:: flexrag.ranker.RankerBaseConfig
7 | :members:
8 | :inherited-members:
9 |
10 | .. autoclass:: flexrag.ranker.RankerBase
11 | :members:
12 | :inherited-members:
13 |
14 | .. autoclass:: flexrag.ranker.RankingResult
15 | :members:
16 | :inherited-members:
17 |
18 | .. autoclass:: flexrag.ranker.RankerConfig
19 | :members:
20 | :inherited-members:
21 |
22 |
23 | Local Ranker
24 | ------------
25 | .. HF Cross Encoder Ranker
26 | .. autoclass:: flexrag.ranker.HFCrossEncoderRankerConfig
27 | :members:
28 | :inherited-members:
29 | :show-inheritance:
30 |
31 | .. autoclass:: flexrag.ranker.HFCrossEncoderRanker
32 | :members:
33 | :show-inheritance:
34 |
35 |
36 | .. HF Cross Seq2Seq Ranker
37 | .. autoclass:: flexrag.ranker.HFSeq2SeqRankerConfig
38 | :members:
39 | :inherited-members:
40 | :show-inheritance:
41 |
42 | .. autoclass:: flexrag.ranker.HFSeq2SeqRanker
43 | :members:
44 | :show-inheritance:
45 |
46 |
47 | .. HF Cross ColBERT Ranker
48 | .. autoclass:: flexrag.ranker.HFColBertRankerConfig
49 | :members:
50 | :inherited-members:
51 | :show-inheritance:
52 |
53 | .. autoclass:: flexrag.ranker.HFColBertRanker
54 | :members:
55 | :show-inheritance:
56 |
57 |
58 | .. RankGPT Ranker
59 | .. autoclass:: flexrag.ranker.RankGPTRankerConfig
60 | :members:
61 | :inherited-members:
62 | :show-inheritance:
63 |
64 | .. autoclass:: flexrag.ranker.RankGPTRanker
65 | :members:
66 | :show-inheritance:
67 |
68 |
69 | Oneline Ranker
70 | --------------
71 | .. Cohere Ranker
72 | .. autoclass:: flexrag.ranker.CohereRankerConfig
73 | :members:
74 | :inherited-members:
75 | :show-inheritance:
76 |
77 | .. autoclass:: flexrag.ranker.CohereRanker
78 | :members:
79 | :show-inheritance:
80 |
81 |
82 | .. Jina Ranker
83 | .. autoclass:: flexrag.ranker.JinaRankerConfig
84 | :members:
85 | :inherited-members:
86 | :show-inheritance:
87 |
88 | .. autoclass:: flexrag.ranker.JinaRanker
89 | :members:
90 | :show-inheritance:
91 |
92 |
93 | .. Mixedbread Ranker
94 | .. autoclass:: flexrag.ranker.MixedbreadRankerConfig
95 | :members:
96 | :inherited-members:
97 | :show-inheritance:
98 |
99 | .. autoclass:: flexrag.ranker.MixedbreadRanker
100 | :members:
101 | :show-inheritance:
102 |
103 |
104 | .. Voyage Ranker
105 | .. autoclass:: flexrag.ranker.VoyageRankerConfig
106 | :members:
107 | :inherited-members:
108 | :show-inheritance:
109 |
110 | .. autoclass:: flexrag.ranker.VoyageRanker
111 | :members:
112 | :show-inheritance:
113 |
--------------------------------------------------------------------------------
/docs/source/reference/refiner.rst:
--------------------------------------------------------------------------------
1 | Context Refiner
2 | ===============
3 | The context refiner is responsible for refining the contexts retrieved by the retriever.
4 | It can be used to rearrange the contexts, summarize them, or extract the most relevant information from them.
5 |
6 | The Context Refiner Interface
7 | -----------------------------
8 | The `RefinerBase` is the base class for all refiners.
9 | It provides the basic interface for refining the contexts retrieved by the retriever.
10 |
11 | .. autoclass:: flexrag.context_refine.RefinerBase
12 | :members:
13 | :inherited-members:
14 |
15 | Refiners
16 | --------
17 | FlexRAG provides several refiners that can be used to refine the contexts retrieved by the retriever.
18 |
19 | .. autoclass:: flexrag.context_refine.ContextArrangerConfig
20 | :members:
21 | :inherited-members:
22 |
23 | .. autoclass:: flexrag.context_refine.ContextArranger
24 | :members:
25 | :show-inheritance:
26 |
27 | .. autoclass:: flexrag.context_refine.AbstractiveSummarizerConfig
28 | :members:
29 | :inherited-members:
30 |
31 | .. autoclass:: flexrag.context_refine.AbstractiveSummarizer
32 | :members:
33 | :show-inheritance:
34 |
35 | .. autoclass:: flexrag.context_refine.RecompExtractiveSummarizerConfig
36 | :members:
37 | :inherited-members:
38 |
39 | .. autoclass:: flexrag.context_refine.RecompExtractiveSummarizer
40 | :members:
41 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/reference/text_process.rst:
--------------------------------------------------------------------------------
1 | Text Processing
2 | ===============
3 | This module provides a set of classes and functions for preprocessing and filtering texts, including normalization, length filtering, etc.
4 |
5 | .. autoclass:: flexrag.text_process.TextUnit
6 | :members:
7 | :inherited-members:
8 |
9 | .. autoclass:: flexrag.text_process.Processor
10 | :members:
11 | :inherited-members:
12 | :special-members: __call__
13 |
14 | .. autoclass:: flexrag.text_process.TextProcessPipelineConfig
15 | :members:
16 | :inherited-members:
17 |
18 | .. autoclass:: flexrag.text_process.TextProcessPipeline
19 | :members:
20 | :inherited-members:
21 |
22 | .. autoclass:: flexrag.text_process.TokenNormalizerConfig
23 | :members:
24 | :inherited-members:
25 |
26 | .. autoclass:: flexrag.text_process.TokenNormalizer
27 | :members:
28 | :show-inheritance:
29 |
30 | .. autoclass:: flexrag.text_process.ChineseSimplifier
31 | :members:
32 | :show-inheritance:
33 |
34 | .. autoclass:: flexrag.text_process.Lowercase
35 | :members:
36 | :show-inheritance:
37 |
38 | .. autoclass:: flexrag.text_process.Unifier
39 | :members:
40 | :show-inheritance:
41 |
42 | .. autoclass:: flexrag.text_process.TruncatorConfig
43 | :members:
44 | :inherited-members:
45 |
46 | .. autoclass:: flexrag.text_process.Truncator
47 | :members:
48 | :show-inheritance:
49 |
50 | .. autoclass:: flexrag.text_process.AnswerSimplifier
51 | :members:
52 | :show-inheritance:
53 |
54 | .. autoclass:: flexrag.text_process.ExactDeduplicate
55 | :members:
56 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/reference/tokenizers.rst:
--------------------------------------------------------------------------------
1 | Tokenizer
2 | =========
3 | This module is a simple wrapper around other tokenizers.
4 | It provides a simple and consistent interface for tokenizing a text into tokens (maybe string or int).
5 |
6 | The Tokenizer Interface
7 | -----------------------
8 | ``TokenizerBase`` is the base class for all tokenizers.
9 |
10 | .. autoclass:: flexrag.models.tokenizer.TokenizerBase
11 | :members:
12 | :inherited-members:
13 |
14 |
15 | Tokenizers
16 | ----------
17 | The wrapped tokenizers.
18 |
19 | .. autoclass:: flexrag.models.tokenizer.HuggingFaceTokenizerConfig
20 | :members:
21 | :inherited-members:
22 |
23 | .. autoclass:: flexrag.models.tokenizer.HuggingFaceTokenizer
24 | :members:
25 | :show-inheritance:
26 |
27 | .. autoclass:: flexrag.models.tokenizer.TikTokenTokenizerConfig
28 | :members:
29 | :inherited-members:
30 |
31 | .. autoclass:: flexrag.models.tokenizer.TikTokenTokenizer
32 | :members:
33 | :show-inheritance:
34 |
35 | .. autoclass:: flexrag.models.tokenizer.MosesTokenizerConfig
36 | :members:
37 | :inherited-members:
38 |
39 | .. autoclass:: flexrag.models.tokenizer.MosesTokenizer
40 | :members:
41 | :show-inheritance:
42 |
43 | .. autoclass:: flexrag.models.tokenizer.NLTKTokenizerConfig
44 | :members:
45 | :inherited-members:
46 |
47 | .. autoclass:: flexrag.models.tokenizer.NLTKTokenizer
48 | :members:
49 | :show-inheritance:
50 |
51 | .. autoclass:: flexrag.models.tokenizer.JiebaTokenizerConfig
52 | :members:
53 | :inherited-members:
54 |
55 | .. autoclass:: flexrag.models.tokenizer.JiebaTokenizer
56 | :members:
57 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/reference/utils.rst:
--------------------------------------------------------------------------------
1 | Utils
2 | =====
3 | This module contains useful functions that are used throughout the codebase.
4 |
5 | Cache
6 | -----
7 | .. automodule:: flexrag.cache
8 | :members:
9 |
10 | Other Utils
11 | -----------
12 | .. autoclass:: flexrag.utils.Register
13 | :members:
--------------------------------------------------------------------------------
/docs/source/tutorial/building_assistant.md:
--------------------------------------------------------------------------------
1 | # Building your own RAG Assistant
2 | FlexRAG provides a flexible and modularized framework for building RAG assistants. You can build your own RAG assistant by defining your own `Assistant` class and registering it with the `ASSISTANTS` decorator.
3 |
4 | ## Define the Assistant Class
5 | To build your RAG assistant, you can create a Python script file and import the necessary FlexRAG modules. Below is an example of how to construct a RAG assistant. In this example, we define a RAG assistant named `SimpleAssistant` by inheriting from the `AssistantBase` class. This assistant includes a dense retriever (`DenseRetriever`) and a generator (`OpenAIGenerator`). Whenever a user asks a question, `SimpleAssistant` uses `DenseRetriever` to retrieve relevant documents from the database, then concatenates these documents into the prompt and utilizes `OpenAIGenerator` to generate the final response.
6 |
7 | ```python
8 | from dataclasses import dataclass
9 |
10 | from flexrag.assistant import ASSISTANTS, AssistantBase
11 | from flexrag.models import OpenAIGenerator, OpenAIGeneratorConfig
12 | from flexrag.prompt import ChatPrompt, ChatTurn
13 | from flexrag.retriever import DenseRetriever, DenseRetrieverConfig
14 |
15 |
16 | @dataclass
17 | class SimpleAssistantConfig(DenseRetrieverConfig, OpenAIGeneratorConfig): ...
18 |
19 |
20 | @ASSISTANTS("simple", config_class=SimpleAssistantConfig)
21 | class SimpleAssistant(AssistantBase):
22 | def __init__(self, config: SimpleAssistantConfig):
23 | self.retriever = DenseRetriever(config)
24 | self.generator = OpenAIGenerator(config)
25 | return
26 |
27 | def answer(self, question: str) -> str:
28 | prompt = ChatPrompt()
29 | context = self.retriever.search(question)[0]
30 | prompt_str = "Please answer the following question based on the given text.\n\n"
31 | prompt_str += f"Question: {question}\n\n"
32 | for n, ctx in enumerate(context):
33 | prompt_str += f"Context {n}: {ctx.data['text']}\n"
34 | prompt.update(ChatTurn(role="user", content=prompt_str))
35 | response = self.generator.chat([prompt])[0][0]
36 | prompt.update(ChatTurn(role="assistant", content=response))
37 | return response
38 | ```
39 |
40 |
41 | ### Running your own RAG Application
42 | After defining the `SimpleAssistant` class and registering it with the `ASSISTANTS` decorator, you can evaluate your assistant using FlexRAG's entrypoints by adding the `user_module=` argument to the command.
43 |
44 | For example, you can evaluate your assistant on the *Natural Questions* dataset using the following command:
45 |
46 | ```bash
47 | DB_PATH=
48 | OPENAI_KEY=
49 | MODULE_PATH=
50 |
51 | python -m flexrag.entrypoints.run_assistant \
52 | user_module=${MODULE_PATH} \
53 | name=nq \
54 | split=test \
55 | assistant_type=simple \
56 | simple_config.model_name='gpt-4o-mini' \
57 | simple_config.api_key=${OPENAI_KEY} \
58 | simple_config.database_path=${DB_PATH} \
59 | simple_config.index_type=faiss \
60 | simple_config.query_encoder_config.encoder_type=hf \
61 | simple_config.query_encoder_config.hf_config.model_path='facebook/contriever-msmarco' \
62 | simple_config.query_encoder_config.hf_config.device_id=[0] \
63 | eval_config.metrics_type=[retrieval_success_rate,generation_f1,generation_em] \
64 | eval_config.retrieval_success_rate_config.eval_field=text \
65 | eval_config.response_preprocess.processor_type=[simplify_answer] \
66 | log_interval=10
67 | ```
68 |
69 | In [FlexRAG_Examples](https://github.com/ictnlp/FlexRAG_Examples) repository, we provide several detailed examples of how to build a RAG assistant.
70 |
--------------------------------------------------------------------------------
/docs/source/tutorial/preparing_corpus.md:
--------------------------------------------------------------------------------
1 | # Preparing the Knowledge Base
2 | In the real world, various types of knowledge are typically stored in documents such as PDFs, Word files, and PPTs. However, this semi-structured data cannot be parsed by large language models (LLMs) and is not suitable for building a knowledge base. Therefore, we need to convert it into structured text data beforehand. In this tutorial, we will use a simple example to demonstrate how to convert a batch of PDF files into structured data.
3 |
4 | ```{tip}
5 | If you already have structured data, you can skip this tutorial.
6 | ```
7 |
8 | ## Parse Files using FlexRAG's Command-Line Tool
9 | FlexRAG provides a command-line tool `prepare_corpus` to help users parse various files into structured data. In this tutorial, we will use a paper from Arxiv as an example to demonstrate how to parse a PDF file using the built-in command-line tool of FlexRAG.
10 |
11 | Run the following command to download a paper from Arxiv:
12 |
13 | ```bash
14 | wget https://arxiv.org/pdf/2502.18139.pdf
15 | ```
16 |
17 | You can then run the following command to parse this paper into structured knowledge base data:
18 |
19 | ```bash
20 | python -m flexrag.entrypoints.prepare_corpus \
21 | document_paths=[2502.18139.pdf] \
22 | output_path=knowledge.jsonl \
23 | document_parser_type=markitdown \
24 | chunker_type=sentence_chunker \
25 | sentence_chunker_config.max_tokens=512 \
26 | sentence_chunker_config.tokenizer_type=tiktoken \
27 | sentence_chunker_config.tiktoken_config.model_name='gpt-4o'
28 | ```
29 |
30 | In this command, we specify the following parameters:
31 | - `document_paths`:a list of file paths to be parsed. Here we only parse one paper;
32 | - `output_path`:the output path of the parsed results. The path should end with `.jsonl`, `.csv`, or `.tsv`;
33 | - `document_parser_type`:the type of document parser. Here we use `markitdown`;
34 | - `chunker_type`:the type of text chunker. Here we use `sentence_chunker`;
35 | - `sentence_chunker_config.max_tokens`:the maximum length of the text chunker. Here we set it to 512;
36 | - `sentence_chunker_config.tokenizer_type`:the type of tokenizer used by the text chunker. Here we use `tiktoken`, which is provided by OpenAI;
37 | - `sentence_chunker_config.tiktoken_config.model_name`:the model name used by the tokenizer. Here we use `gpt-4o`.
38 |
39 | After executing the above command, you will see that the PDF file has been parsed into a JSONL file. As shown in the figure below, FlexRAG executed three steps in this process:
40 | 1. **Parsing**: parsing the file into structured data;
41 | 2. **Chunking**: chunking long text paragraphs in the structured data into short text paragraphs suitable for processing;
42 | 3. **Preprocessing**: preprocessing and filtering the chunked text paragraphs.
43 |
44 | ```{eval-rst}
45 | .. image:: ../../../assets/parse_files.png
46 | :alt: Parse File
47 | :align: center
48 | :width: 80%
49 | ```
50 |
51 | ```{tip}
52 | You can check the [FlexRAG Entrypoints](./entrypoints.md) documentation for more information about the `prepare_corpus` command.
53 | ```
54 |
55 | ```{tip}
56 | You can check the [Preparing the Retriever](./preparing_retriever.md) documentation for how to build a retriever for your knowledge base.
57 | ```
58 |
--------------------------------------------------------------------------------
/docs/source/tutorial/using_register.md:
--------------------------------------------------------------------------------
1 | # Using Registers
2 | The `Register` class is an important component in the FlexRAG that integrates configuration files and loads various RAG components. The registrar can gather multiple components of the same type and generate a unified configuration structure to help you configure and use these components. This tutorial will show you how to use the registrar in FlexRAG.
3 |
4 | ## Using FlexRAG Registers
5 | FlexRAG provides a set of predefined registers for different components. These registers can be used to register and retrieve components of the respective type. The following registers are available in FlexRAG:
6 |
7 | - ASSISTANTS
8 | - REFINERS
9 | - CHUNKERS
10 | - DOCUMENTPARSERS
11 | - PROCESSORS
12 | - METRICS
13 | - GENERATORS
14 | - ENCODERS
15 | - RANKERS
16 | - DENSE_INDEX
17 | - RETRIEVERS
18 | - WEB_DOWNLOADERS
19 | - WEB_READERS
20 |
21 | ```{note}
22 | If you wish to develop your project by modifying the FlexRAG source code, all registrars can be used as decorators to register new components. However, if you use the `run_assistant` or `run_interactive` entrypoints of FlexRAG, **only** the `ASSISTANTS` registrar can be used to register new components.
23 | ```
24 |
25 | ### Registering a New Component
26 | To register a new component, simply decorate the component class with the corresponding register. For example, to register a new `Assistant` component, you can use the `ASSISTANTS` register as shown below:
27 |
28 | ```python
29 | from dataclasses import dataclass
30 | from flexrag.assistant import AssistantBase, ASSISTANTS
31 |
32 | @dataclass
33 | class MyAssistantConfig:
34 | # Define your assistant configuration here
35 | pass
36 |
37 | @ASSISTANTS("my_assistant", config_class=MyAssistantConfig)
38 | class MyAssistant(AssistantBase):
39 | # Define your assistant here
40 | def answer(self, question: str) -> str:
41 | return "MyAssistant: " + question
42 | ```
43 |
44 | The register takes the following arguments, namely `shortnames` and `config_class`.
45 | - The `shortnames` argument is a list of shortnames of the component, which serve as simplified names for the component, making it easier to reference when loading.
46 | - The `config_class` argument is the configuration class for the component. This parameter is optional—if not provided, the component will not use any configuration.
47 |
48 | ### Generating the Configuration
49 | After registering the component, you can generate the configuration `dataclass` for all the registered components using the `make_config` function. For example, to generate the configuration for all the registered `Assistant` components, you can use the `make_config` function as shown below:
50 |
51 | ```python
52 | AssistantConfig = ASSISTANTS.make_config()
53 | ```
54 |
55 | The generated `AssistantConfig` class will have the following structure:
56 |
57 | ```python
58 | from dataclasses import dataclass
59 |
60 | @dataclass
61 | class AssistantConfig:
62 | # The shortname of the assistant
63 | assistant_type: str
64 | # The name of the configuration is the first shortname + "_config"
65 | my_assistant_config: MyAssistantConfig
66 | modular_config: ModularAssistantConfig
67 | # Other registered assistant configurations
68 | ...
69 | ```
70 |
71 | ```{tip}
72 | In the FlexRAG entrypoints, many configurations are generated in this way. This allows us to flexibly modify the components and their configurations in the workflow through configuration files.
73 | ```
74 |
75 | ### Loading the Component
76 | To load the component using the configuration, you can use the `load` function of the register. For example, to load the `MyAssistant` component using the configuration, you can use the `load` function as shown below:
77 |
78 | ```python
79 | AssistantConfig.assistant_type = "my_assistant"
80 | my_assistant = ASSISTANTS.load(AssistantConfig)
81 | ```
82 |
83 | ## Defining a New Register
84 | The `Register` class can be extended to define a new register for a specific component. For example, to define a new register for a `Searcher` component, you can simply create a new instance of the `Register` class as shown below:
85 |
86 | ```python
87 | from flexrag.utils import Register
88 |
89 | SEARCHERS = Register("searcher")
90 | ```
91 |
92 | ### Utilizing Type Hints
93 | As the `Register` class is a generic class, you can utilize type hints to specify the type of the component that the register is managing. For example, to define a register for a `Searcher` component, you can specify the type hint as follows:
94 |
95 | ```python
96 | from abc import ABC
97 | from flexrag.utils import Register
98 |
99 | class Searcher(ABC):
100 | pass
101 |
102 | SEARCHERS = Register[Searcher]("searcher")
103 | ```
104 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61", "wheel", "pybind11>=2.13"]
3 | build-backend = "setuptools.build_meta"
4 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | log_level = DEBUG
3 | asyncio_default_fixture_loop_scope = function
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # basic
2 | numpy<2.0.0
3 | tenacity
4 | hydra-core>=1.3
5 | omegaconf>=2.3.0
6 | pillow
7 | accelerate>=0.26.0
8 | colorama
9 | # metrics
10 | rouge
11 | sacrebleu>=2.4.2
12 | pytrec_eval>=0.5
13 | # datasets
14 | datasets>=3.2.0
15 | # models
16 | openai>=1.30.1
17 | anthropic
18 | cohere
19 | ollama
20 | vllm>=0.6.0
21 | sentence_transformers
22 | transformers>=4.44.0
23 | mixedbread-ai
24 | voyageai
25 | # cache
26 | lmdb
27 | cloudpickle
28 | # processors
29 | unidecode
30 | sacremoses
31 | opencc
32 | # retrievers
33 | pandas
34 | pylance
35 | bm25s
36 | elasticsearch>=8.14.0
37 | torch>=2.3.0
38 | beautifulsoup4
39 | typesense
40 | httpx
41 | scipy
42 | # gui
43 | gradio>=5.8.0
44 | # chunking
45 | regex
46 | nltk
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import logging
4 |
5 | import pybind11
6 | from setuptools import Extension, find_packages, setup
7 |
8 | logging.basicConfig(level=logging.INFO)
9 |
10 |
11 | ext_modules = [
12 | Extension(
13 | "flexrag.metrics.lib_rel",
14 | ["src/flexrag/metrics/lib_rel.cpp"],
15 | include_dirs=[pybind11.get_include()],
16 | language="c++",
17 | extra_compile_args=["-O3"],
18 | ),
19 | ]
20 |
21 |
22 | def get_requirements() -> list[str]:
23 | with open("requirements.txt", encoding="utf-8") as f:
24 | file_content = f.read()
25 | requirements = [
26 | line.strip()
27 | for line in file_content.strip().split("\n")
28 | if not line.startswith("#")
29 | ]
30 | # as faiss may be installed using conda, we need to remove it from the requirements
31 | try:
32 | import faiss
33 |
34 | logging.info(f"Detected installed faiss: faiss {faiss.__version__}")
35 | except ImportError:
36 | requirements.append("faiss-cpu")
37 | return requirements
38 |
39 |
40 | def get_version() -> str:
41 | with open(os.path.join("src", "flexrag", "utils.py"), encoding="utf-8") as f:
42 | file_content = f.read()
43 | pattern = r"{}\W*=\W*\"([^\"]+)\"".format("__VERSION__")
44 | (version,) = re.findall(pattern, file_content)
45 | return version
46 |
47 |
48 | def get_long_description() -> str:
49 | with open("README.md", encoding="utf-8") as f:
50 | return f.read()
51 |
52 |
53 | setup(
54 | name="flexrag",
55 | version=get_version(),
56 | author="Zhuocheng Zhang",
57 | author_email="zhuocheng_zhang@outlook.com",
58 | description="A RAG Framework for Information Retrieval and Generation.",
59 | url="https://github.com/ictnlp/flexrag",
60 | license="MIT License",
61 | long_description=get_long_description(),
62 | long_description_content_type="text/markdown",
63 | packages=find_packages(where="src"),
64 | package_dir={"": "src"},
65 | package_data={
66 | "flexrag": [
67 | "ranker/ranker_prompts/*.json",
68 | "assistant/assistant_prompts/*.json",
69 | "entrypoints/assets/*.png",
70 | ],
71 | },
72 | include_package_data=True,
73 | python_requires=">=3.11",
74 | install_requires=get_requirements(),
75 | extras_require={
76 | "scann": ["scann>=1.3.2"],
77 | "annoy": ["annoy>1.17.0"],
78 | "llamacpp": ["llama_cpp_python>=0.2.84"],
79 | "minference": ["minference>=0.1.5"],
80 | "web": ["duckduckgo_search", "serpapi", "pytest-playwright"],
81 | "docs": ["docling", "markitdown"],
82 | "all": [
83 | "llama_cpp_python>=0.2.84",
84 | "minference>=0.1.5",
85 | "PySocks>=1.7.1",
86 | "duckduckgo_search",
87 | "serpapi",
88 | "docling",
89 | "markitdown",
90 | "annoy>1.17.0",
91 | ],
92 | "dev": [
93 | "black",
94 | "pytest",
95 | "pytest-asyncio",
96 | "sphinx",
97 | "sphinx-autobuild",
98 | "myst-parser",
99 | ],
100 | },
101 | classifiers=[
102 | "Development Status :: 3 - Alpha",
103 | "Intended Audience :: Developers",
104 | "License :: OSI Approved :: MIT License",
105 | "Operating System :: OS Independent",
106 | "Programming Language :: Python :: 3",
107 | "Programming Language :: Python :: 3.11",
108 | "Programming Language :: Python :: 3.12",
109 | "Programming Language :: Python :: 3.13",
110 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
111 | ],
112 | ext_modules=ext_modules,
113 | )
114 |
--------------------------------------------------------------------------------
/src/flexrag/__init__.py:
--------------------------------------------------------------------------------
1 | from .retriever import RETRIEVERS
2 | from .assistant import ASSISTANTS
3 | from .ranker import RANKERS
4 | from .models import GENERATORS, ENCODERS
5 | from .utils import __VERSION__
6 |
7 |
8 | __all__ = [
9 | "RETRIEVERS",
10 | "ASSISTANTS",
11 | "RANKERS",
12 | "GENERATORS",
13 | "ENCODERS",
14 | "__VERSION__",
15 | ]
16 |
--------------------------------------------------------------------------------
/src/flexrag/assistant/__init__.py:
--------------------------------------------------------------------------------
1 | from .assistant import ASSISTANTS, AssistantBase, SearchHistory, PREDEFINED_PROMPTS
2 | from .basic_assistant import BasicAssistant, BasicAssistantConfig
3 | from .modular_rag_assistant import ModularAssistant, ModularAssistantConfig
4 | from .chatqa_assistant import ChatQAAssistant
5 | from .online_assistant import (
6 | JinaDeepSearch,
7 | JinaDeepSearchConfig,
8 | PerplexityAssistant,
9 | PerplexityAssistantConfig,
10 | )
11 |
12 | __all__ = [
13 | "ASSISTANTS",
14 | "AssistantBase",
15 | "SearchHistory",
16 | "PREDEFINED_PROMPTS",
17 | "BasicAssistant",
18 | "BasicAssistantConfig",
19 | "ModularAssistant",
20 | "ModularAssistantConfig",
21 | "ChatQAAssistant",
22 | "JinaDeepSearch",
23 | "JinaDeepSearchConfig",
24 | "PerplexityAssistant",
25 | "PerplexityAssistantConfig",
26 | ]
27 |
--------------------------------------------------------------------------------
/src/flexrag/assistant/assistant.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import ABC, abstractmethod
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | from flexrag.common_dataclass import RetrievedContext
7 | from flexrag.prompt import ChatPrompt
8 | from flexrag.utils import Register
9 |
10 |
11 | class AssistantBase(ABC):
12 | @abstractmethod
13 | def answer(
14 | self, question: str
15 | ) -> tuple[str, Optional[list[RetrievedContext]], Optional[dict]]:
16 | """Answer the given question.
17 |
18 | :param question: The question to answer.
19 | :type question: str
20 | :return: A tuple containing the following elements:
21 | - The response to the question.
22 | - The contexts used to answer the question.
23 | - The metadata of the assistant.
24 | :rtype: tuple[str, Optional[list[RetrievedContext]], Optional[dict]]
25 | """
26 | return
27 |
28 |
29 | @dataclass
30 | class SearchHistory:
31 | query: str
32 | contexts: list[RetrievedContext]
33 |
34 | def to_dict(self) -> dict[str, str | list[dict]]:
35 | return {
36 | "query": self.query,
37 | "contexts": [ctx.to_dict() for ctx in self.contexts],
38 | }
39 |
40 |
41 | ASSISTANTS = Register[AssistantBase]("assistant")
42 |
43 |
44 | PREDEFINED_PROMPTS = {
45 | "shortform_with_context": ChatPrompt.from_json(
46 | os.path.join(
47 | os.path.dirname(__file__),
48 | "assistant_prompts",
49 | "shortform_generate_prompt_with_context.json",
50 | )
51 | ),
52 | "shortform_without_context": ChatPrompt.from_json(
53 | os.path.join(
54 | os.path.dirname(__file__),
55 | "assistant_prompts",
56 | "shortform_generate_prompt_without_context.json",
57 | )
58 | ),
59 | "longform_with_context": ChatPrompt.from_json(
60 | os.path.join(
61 | os.path.dirname(__file__),
62 | "assistant_prompts",
63 | "longform_generate_prompt_with_context.json",
64 | )
65 | ),
66 | "longform_without_context": ChatPrompt.from_json(
67 | os.path.join(
68 | os.path.dirname(__file__),
69 | "assistant_prompts",
70 | "longform_generate_prompt_without_context.json",
71 | )
72 | ),
73 | }
74 |
--------------------------------------------------------------------------------
/src/flexrag/assistant/assistant_prompts/longform_generate_prompt_with_context.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": {
3 | "role": "system",
4 | "content": "Answer the question based on the given contexts. Note that the context might not always contain relevant information to answer the question."
5 | },
6 | "history": [],
7 | "demonstrations": []
8 | }
--------------------------------------------------------------------------------
/src/flexrag/assistant/assistant_prompts/longform_generate_prompt_without_context.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": {
3 | "role": "system",
4 | "content": "Answer the following question."
5 | },
6 | "history": [],
7 | "demonstrations": []
8 | }
--------------------------------------------------------------------------------
/src/flexrag/assistant/assistant_prompts/shortform_generate_prompt_with_context.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": {
3 | "role": "system",
4 | "content": "Answer the question based on the given contexts. Note that the context might not always contain relevant information to answer the question. Only give me the answer and do not output any other words."
5 | },
6 | "history": [],
7 | "demonstrations": []
8 | }
--------------------------------------------------------------------------------
/src/flexrag/assistant/assistant_prompts/shortform_generate_prompt_without_context.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": {
3 | "role": "system",
4 | "content": "Answer the following question. Only give me the answer and do not output any other words."
5 | },
6 | "history": [],
7 | "demonstrations": []
8 | }
--------------------------------------------------------------------------------
/src/flexrag/assistant/basic_assistant.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | from flexrag.models import GENERATORS, GenerationConfig, GeneratorConfig
6 | from flexrag.prompt import ChatPrompt, ChatTurn
7 | from flexrag.utils import LOGGER_MANAGER
8 |
9 | from .assistant import ASSISTANTS, AssistantBase
10 |
11 | logger = LOGGER_MANAGER.get_logger("flexrag.assistant")
12 |
13 |
14 | @dataclass
15 | class BasicAssistantConfig(GeneratorConfig, GenerationConfig):
16 | """The configuration for the basic assistant.
17 |
18 | :param prompt_path: The path to the prompt file. Defaults to None.
19 | :type prompt_path: str, optional
20 | :param use_history: Whether to save the chat history for multi-turn conversation. Defaults to False.
21 | :type use_history: bool, optional
22 | """
23 |
24 | prompt_path: Optional[str] = None
25 | use_history: bool = False
26 |
27 |
28 | @ASSISTANTS("basic", config_class=BasicAssistantConfig)
29 | class BasicAssistant(AssistantBase):
30 | """A basic assistant that generates response without retrieval."""
31 |
32 | def __init__(self, cfg: BasicAssistantConfig):
33 | # set basic args
34 | self.gen_cfg = cfg
35 | if self.gen_cfg.sample_num > 1:
36 | logger.warning("Sample num > 1 is not supported for Assistant")
37 | self.gen_cfg.sample_num = 1
38 |
39 | # load generator
40 | self.generator = GENERATORS.load(cfg)
41 |
42 | # load prompts
43 | if cfg.prompt_path is not None:
44 | self.prompt = ChatPrompt.from_json(cfg.prompt_path)
45 | else:
46 | self.prompt = ChatPrompt()
47 | if cfg.use_history:
48 | self.history_prompt = deepcopy(self.prompt)
49 | else:
50 | self.history_prompt = None
51 | return
52 |
53 | def answer(self, question: str) -> tuple[str, None, dict[str, ChatPrompt]]:
54 | # prepare system prompt
55 | if self.history_prompt is not None:
56 | prompt = deepcopy(self.history_prompt)
57 | else:
58 | prompt = deepcopy(self.prompt)
59 |
60 | prompt.update(ChatTurn(role="user", content=question))
61 |
62 | # generate response
63 | response = self.generator.chat([prompt], generation_config=self.gen_cfg)[0][0]
64 |
65 | # update history prompt
66 | if self.history_prompt is not None:
67 | self.history_prompt.update(ChatTurn(role="user", content=question))
68 | self.history_prompt.update(ChatTurn(role="assistant", content=response))
69 | return response, None, {"prompt": prompt}
70 |
71 | def clear_history(self) -> None:
72 | if self.history_prompt is not None:
73 | self.history_prompt = deepcopy(self.prompt)
74 | return
75 |
--------------------------------------------------------------------------------
/src/flexrag/assistant/chatqa_assistant.py:
--------------------------------------------------------------------------------
1 | from flexrag.common_dataclass import RetrievedContext
2 | from flexrag.utils import LOGGER_MANAGER
3 |
4 | from .assistant import ASSISTANTS
5 | from .modular_rag_assistant import ModularAssistant, ModularAssistantConfig
6 |
7 |
8 | logger = LOGGER_MANAGER.get_logger("flexrag.assistant.chatqa")
9 |
10 |
11 | @ASSISTANTS("chatqa", config_class=ModularAssistantConfig)
12 | class ChatQAAssistant(ModularAssistant):
13 | """The Modular assistant that employs the ChatQA model for response generation."""
14 |
15 | sys_prompt = (
16 | "System: This is a chat between a user and an artificial intelligence assistant. "
17 | "The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. "
18 | "The assistant should also indicate when the answer cannot be found in the context."
19 | )
20 | instruction = "Please give a full and complete answer for the question."
21 | allowed_models = [
22 | "nvidia/Llama3-ChatQA-2-8B",
23 | "nvidia/Llama3-ChatQA-2-70B",
24 | "nvidia/Llama3-ChatQA-1.5-8B",
25 | "nvidia/Llama3-ChatQA-1.5-70B",
26 | ]
27 |
28 | def __init__(self, cfg: ModularAssistantConfig):
29 | super().__init__(cfg)
30 | logger.warning(
31 | f"ChatQA Assistant expects the model to be one of {self.allowed_models}."
32 | )
33 | return
34 |
35 | def answer_with_contexts(
36 | self, question: str, contexts: list[RetrievedContext]
37 | ) -> tuple[str, str]:
38 | prefix = self.get_formatted_input(question, contexts)
39 | response = self.generator.generate([prefix], generation_config=self.gen_cfg)
40 | return response[0][0], prefix
41 |
42 | def get_formatted_input(
43 | self, question: str, contexts: list[RetrievedContext]
44 | ) -> str:
45 | # prepare system prompts
46 | prefix = f"{self.sys_prompt}\n\n"
47 |
48 | # prepare context string
49 | for n, context in enumerate(contexts):
50 | if len(self.used_fields) == 0:
51 | ctx = ""
52 | for field_name, field_value in context.data.items():
53 | ctx += f"{field_name}: {field_value}\n"
54 | elif len(self.used_fields) == 1:
55 | ctx = context.data[self.used_fields[0]]
56 | else:
57 | ctx = ""
58 | for field_name in self.used_fields:
59 | ctx += f"{field_name}: {context.data[field_name]}\n"
60 | prefix += f"Context {n + 1}: {ctx}\n\n"
61 |
62 | # prepare user instruction
63 | prefix += f"User: {self.instruction} {question}\n\nAssistant:"
64 | return prefix
65 |
--------------------------------------------------------------------------------
/src/flexrag/assistant/document_chat_assistant.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any
3 |
4 | from flexrag.common_dataclass import RetrievedContext
5 | from flexrag.chunking import CHUNKERS, ChunkerConfig
6 | from flexrag.document_parser import DOCUMENTPARSERS, DocumentParserConfig
7 | from flexrag.models import GENERATORS, GenerationConfig, GeneratorConfig
8 | from flexrag.prompt import ChatPrompt, ChatTurn
9 | from flexrag.ranker import RANKERS, RankerConfig
10 | from flexrag.retriever import DenseRetriever, DenseRetrieverConfig
11 | from flexrag.utils import LOGGER_MANAGER
12 |
13 | from .assistant import ASSISTANTS, AssistantBase
14 |
15 | logger = LOGGER_MANAGER.get_logger("flexrag.assistant.modular")
16 |
17 |
18 | @dataclass
19 | class DocumentChatAssistantConfig(
20 | GeneratorConfig,
21 | GenerationConfig,
22 | DenseRetrieverConfig,
23 | RankerConfig,
24 | DocumentParserConfig,
25 | ChunkerConfig,
26 | ): ...
27 |
28 |
29 | @ASSISTANTS("document_chat", config_class=DocumentChatAssistantConfig)
30 | class DocumentChatAssistant(AssistantBase):
31 | def __init__(self, cfg: DocumentChatAssistantConfig):
32 | # set basic args
33 | self.gen_cfg = cfg
34 | if self.gen_cfg.sample_num > 1:
35 | logger.warning("Sample num > 1 is not supported for Assistant")
36 | self.gen_cfg.sample_num = 1
37 |
38 | # load generator
39 | self.generator = GENERATORS.load(cfg)
40 |
41 | # load retriever
42 | self.retriever = DenseRetriever(cfg)
43 | assert len(self.retriever) == 0, "Retriever is not empty."
44 |
45 | # load ranker
46 | self.reranker = RANKERS.load(cfg)
47 |
48 | # load parser
49 | self.parser = DOCUMENTPARSERS.load(cfg)
50 |
51 | # load chunker
52 | self.chunker = CHUNKERS.load(cfg)
53 | return
54 |
55 | def attach_document(self, document_path: str = None) -> None:
56 | if document_path is None:
57 | self.retriever.clean()
58 | return
59 | # parse document
60 | self.retriever.clean()
61 | document = self.parser.parse(document_path)
62 | if self.chunker is not None:
63 | chunks = self.chunker.chunk(document.text)
64 | else:
65 | chunks = [document.text]
66 |
67 | # build index
68 | self.retriever.add_passages(chunks)
69 | return
70 |
71 | def answer(
72 | self, question: str
73 | ) -> tuple[str, list[RetrievedContext], dict[str, Any]]:
74 | # answer without contexts
75 | if len(self.retriever) == 0:
76 | prompt = ChatPrompt()
77 | prompt.update(ChatTurn(role="user", content=question))
78 | response = self.generator.chat([prompt], generation_config=self.gen_cfg)
79 | return response[0][0], [], {"prompt": prompt}
80 |
81 | # retrieve
82 | retrieved_contexts = self.retriever.search(question)[0]
83 |
84 | # rerank
85 | if self.reranker is not None:
86 | contexts = self.reranker.rank(question, retrieved_contexts).candidates
87 | else:
88 | contexts = retrieved_contexts
89 |
90 | # prepare prompt
91 | prompt = ChatPrompt(
92 | system="Answer the user question based on the given contexts."
93 | )
94 | usr_prompt = ""
95 | for n, context in enumerate(contexts):
96 | ctx = ""
97 | for field_name, field_value in context.data.items():
98 | ctx += f"{field_name}: {field_value}\n"
99 | usr_prompt += f"Context {n + 1}: {ctx}\n\n"
100 | usr_prompt += f"Question: {question}"
101 | prompt.update(ChatTurn(role="user", content=usr_prompt))
102 |
103 | # generate response
104 | response = self.generator.chat([prompt], generation_config=self.gen_cfg)[0][0]
105 | return response, contexts, {"prompt": prompt}
106 |
--------------------------------------------------------------------------------
/src/flexrag/cache/__init__.py:
--------------------------------------------------------------------------------
1 | from .backends import (
2 | STORAGEBACKENDS,
3 | DictBackend,
4 | LMDBBackend,
5 | LMDBBackendConfig,
6 | StorageBackendBase,
7 | StorageBackendConfig,
8 | )
9 | from .persistent_cache import (
10 | FIFOPersistentCache,
11 | LFUPersistentCache,
12 | LRUPersistentCache,
13 | PersistentCacheBase,
14 | PersistentCacheConfig,
15 | RandomPersistentCache,
16 | )
17 | from .serializer import (
18 | SERIALIZERS,
19 | CloudPickleSerializer,
20 | JsonSerializer,
21 | MsgpackSerializer,
22 | PickleSerializer,
23 | SerializerBase,
24 | SerializerConfig,
25 | )
26 |
27 | __all__ = [
28 | "LMDBBackend",
29 | "LMDBBackendConfig",
30 | "STORAGEBACKENDS",
31 | "StorageBackendConfig",
32 | "StorageBackendBase",
33 | "DictBackend",
34 | "SERIALIZERS",
35 | "SerializerConfig",
36 | "JsonSerializer",
37 | "MsgpackSerializer",
38 | "PickleSerializer",
39 | "CloudPickleSerializer",
40 | "SerializerBase",
41 | "PersistentCacheConfig",
42 | "PersistentCacheBase",
43 | "RandomPersistentCache",
44 | "LRUPersistentCache",
45 | "LFUPersistentCache",
46 | "FIFOPersistentCache",
47 | ]
48 |
--------------------------------------------------------------------------------
/src/flexrag/cache/backends.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import os
3 | from collections import OrderedDict
4 | from dataclasses import dataclass
5 | from typing import MutableMapping
6 |
7 | import lmdb
8 | from omegaconf import MISSING
9 |
10 | from flexrag.utils import Register
11 |
12 |
13 | class StorageBackendBase(MutableMapping[bytes, bytes]):
14 | """The Binary Storage Backend For ``PersistentCache``.
15 | The backend should provide interfaces like ``MutableMapping``.
16 | Thus, The following methods should be implemented:
17 |
18 | >>> def __getitem__(self, key: bytes) -> bytes:
19 | ... pass
20 | >>> def __setitem__(self, key: bytes, value: bytes) -> None:
21 | ... pass
22 | >>> def __delitem__(self, key: bytes) -> None:
23 | ... pass
24 | >>> def __iter__(self) -> Iterable[bytes]:
25 | ... pass
26 | >>> def __len__(self) -> int:
27 | ... pass
28 |
29 | The following methods will be implemented automatically:
30 |
31 | >>> def __contains__(self, key: bytes) -> bool:
32 | ... pass
33 | >>> def keys(self) -> KeysView:
34 | ... pass
35 | >>> def values(self) -> ValuesView:
36 | ... pass
37 | >>> def items(self) -> ItemsView:
38 | ... pass
39 | >>> def get(self, key: bytes, default: Any = None) -> bytes | Any:
40 | ... pass
41 | >>> def __eq__(self, other: StorageBackend) -> bool:
42 | ... pass
43 | >>> def __ne__(self, other: StorageBackend) -> bool:
44 | ... pass
45 | >>> def pop(self, key: bytes, default: Any = None) -> bytes | Any:
46 | ... pass
47 | >>> def popitem(self) -> Tuple:
48 | ... pass
49 | >>> def clear(self) -> None:
50 | ... pass
51 | >>> def update(self, other: MutableMapping) -> None:
52 | ... pass
53 | >>> def setdefault(self, key: bytes, default: Any = None) -> Any:
54 | ... pass
55 | """
56 |
57 | def __repr__(self) -> str:
58 | f"{self.__class__.__name__}(len={len(self)})"
59 |
60 |
61 | STORAGEBACKENDS = Register[StorageBackendBase]("storage_backend")
62 |
63 |
64 | @dataclass
65 | class LMDBBackendConfig:
66 | db_path: str = MISSING
67 | db_size: int = 1 << 30 # 2^30 bytes = 1GB
68 |
69 |
70 | @STORAGEBACKENDS("lmdb", config_class=LMDBBackendConfig)
71 | class LMDBBackend(StorageBackendBase):
72 | def __init__(self, cfg: LMDBBackendConfig) -> None:
73 | self.db_path = cfg.db_path
74 | if not os.path.exists(os.path.dirname(cfg.db_path)):
75 | os.makedirs(os.path.dirname(cfg.db_path), exist_ok=True)
76 | self.database = lmdb.open(cfg.db_path, map_size=cfg.db_size)
77 | atexit.register(self.database.close)
78 | return
79 |
80 | def __getitem__(self, key: bytes) -> bytes:
81 | with self.database.begin() as txn:
82 | data = txn.get(key)
83 | if data is None:
84 | raise KeyError(key)
85 | return data
86 |
87 | def __setitem__(self, key: bytes, value: bytes) -> None:
88 | with self.database.begin(write=True) as txn:
89 | txn.put(key, value)
90 | return
91 |
92 | def __delitem__(self, key: bytes) -> None:
93 | with self.database.begin(write=True) as txn:
94 | txn.delete(key)
95 | return
96 |
97 | def __len__(self) -> int:
98 | with self.database.begin() as txn:
99 | return txn.stat()["entries"]
100 |
101 | def __iter__(self):
102 | with self.database.begin() as txn:
103 | cursor = txn.cursor()
104 | for key, _ in cursor:
105 | yield key
106 | return
107 |
108 | def __repr__(self) -> str:
109 | return f"{self.__class__.__name__}(db_path={self.db_path}, len={len(self)})"
110 |
111 |
112 | @STORAGEBACKENDS("dict")
113 | class DictBackend(OrderedDict, StorageBackendBase): ...
114 |
115 |
116 | StorageBackendConfig = STORAGEBACKENDS.make_config(default="dict")
117 |
--------------------------------------------------------------------------------
/src/flexrag/cache/serializer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | from abc import ABC, abstractmethod
4 | from typing import Any
5 |
6 | from flexrag.utils import Register
7 |
8 |
9 | class SerializerBase(ABC):
10 | """A simple interface for serializing and deserializing python objects."""
11 |
12 | @abstractmethod
13 | def serialize(self, obj: Any) -> bytes:
14 | """Serialize the object into bytes.
15 |
16 | :param obj: The object to serialize.
17 | :type obj: Any
18 | :return: The serialized object.
19 | :rtype: bytes
20 | """
21 | return
22 |
23 | @abstractmethod
24 | def deserialize(self, data: bytes) -> Any:
25 | """Deserialize the bytes into an object.
26 |
27 | :param data: The serialized object.
28 | :type data: bytes
29 | :return: The deserialized object.
30 | :rtype: Any
31 | """
32 | return
33 |
34 |
35 | SERIALIZERS = Register[SerializerBase]("serializer")
36 |
37 |
38 | @SERIALIZERS("pickle")
39 | class PickleSerializer(SerializerBase):
40 | """A serializer that uses the pickle module."""
41 |
42 | def serialize(self, obj: Any) -> bytes:
43 | return pickle.dumps(obj)
44 |
45 | def deserialize(self, data: bytes) -> Any:
46 | return pickle.loads(data)
47 |
48 |
49 | @SERIALIZERS("cloudpickle")
50 | class CloudPickleSerializer(SerializerBase):
51 | """A serializer that uses the cloudpickle module."""
52 |
53 | def __init__(self):
54 | try:
55 | import cloudpickle
56 |
57 | self.pickler = cloudpickle
58 | except:
59 | raise ImportError(
60 | "Please install cloudpickle using `pip install cloudpickle`."
61 | )
62 | return
63 |
64 | def serialize(self, obj: Any) -> bytes:
65 | return self.pickler.dumps(obj)
66 |
67 | def deserialize(self, data: bytes) -> Any:
68 | return self.pickler.loads(data)
69 |
70 |
71 | @SERIALIZERS("json")
72 | class JsonSerializer(SerializerBase):
73 | """A serializer that uses the json module."""
74 |
75 | def serialize(self, obj: Any) -> bytes:
76 | return json.dumps(obj).encode("utf-8")
77 |
78 | def deserialize(self, data: bytes) -> Any:
79 | return json.loads(data.decode("utf-8"))
80 |
81 |
82 | @SERIALIZERS("msgpack")
83 | class MsgpackSerializer(SerializerBase):
84 | """A serializer that uses the msgpack module."""
85 |
86 | def __init__(self) -> None:
87 | try:
88 | import msgpack
89 |
90 | self.msgpack = msgpack
91 | except ImportError:
92 | raise ImportError("Please install msgpack using `pip install msgpack`.")
93 | return
94 |
95 | def serialize(self, obj: Any) -> bytes:
96 | return self.msgpack.packb(obj, use_bin_type=True)
97 |
98 | def deserialize(self, data: bytes) -> Any:
99 | return self.msgpack.unpackb(data, raw=False)
100 |
101 |
102 | SerializerConfig = SERIALIZERS.make_config(default="pickle")
103 |
--------------------------------------------------------------------------------
/src/flexrag/chunking/__init__.py:
--------------------------------------------------------------------------------
1 | from .chunker_base import ChunkerBase, CHUNKERS
2 | from .basic_chunkers import (
3 | CharChunker,
4 | CharChunkerConfig,
5 | TokenChunker,
6 | TokenChunkerConfig,
7 | RecursiveChunker,
8 | RecursiveChunkerConfig,
9 | SentenceChunker,
10 | SentenceChunkerConfig,
11 | )
12 | from .semantic_chunker import SemanticChunker, SemanticChunkerConfig
13 |
14 |
15 | ChunkerConfig = CHUNKERS.make_config(
16 | default="sentence_chunker", config_name="ChunkerConfig"
17 | )
18 |
19 |
20 | __all__ = [
21 | "ChunkerBase",
22 | "CHUNKERS",
23 | "ChunkerConfig",
24 | "CharChunker",
25 | "CharChunkerConfig",
26 | "TokenChunker",
27 | "TokenChunkerConfig",
28 | "RecursiveChunker",
29 | "RecursiveChunkerConfig",
30 | "SentenceChunker",
31 | "SentenceChunkerConfig",
32 | "SemanticChunker",
33 | "SemanticChunkerConfig",
34 | ]
35 |
--------------------------------------------------------------------------------
/src/flexrag/chunking/chunker_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | from flexrag.utils import Register
6 |
7 |
8 | @dataclass
9 | class Chunk:
10 | """The dataclass for a chunk of text.
11 |
12 | :param text: The text of the chunk.
13 | :type text: str
14 | :param start: The start index of the chunk in the original text.
15 | :type start: Optional[int]
16 | :param end: The end index of the chunk in the original text.
17 | :type end: Optional[int]
18 | """
19 |
20 | text: str
21 | start: Optional[int] = None
22 | end: Optional[int] = None
23 |
24 |
25 | class ChunkerBase(ABC):
26 | """Chunker that splits text into chunks of fixed size.
27 | This is an abstract class that defines the interface for all chunkers.
28 | The subclasses should implement the `chunk` method to split the text.
29 | """
30 |
31 | @abstractmethod
32 | def chunk(self, text: str) -> list[Chunk]:
33 | """Chunk the given text into smaller chunks.
34 |
35 | :param text: The text to chunk.
36 | :type text: str
37 | :return: The chunks of the text.
38 | :rtype: list[Chunk]
39 | """
40 | return
41 |
42 |
43 | CHUNKERS = Register[ChunkerBase]("chunker")
44 |
--------------------------------------------------------------------------------
/src/flexrag/common_dataclass.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 | from omegaconf import MISSING
5 |
6 |
7 | @dataclass
8 | class Context:
9 | """The dataclass for retrieved context.
10 |
11 | :param context_id: The unique identifier of the context. Default: None.
12 | :type context_id: Optional[str]
13 | :param data: The context data. Default: {}.
14 | :type data: dict
15 | :param source: The source of the retrieved data. Default: None.
16 | :type source: Optional[str]
17 | :param meta_data: The metadata of the context. Default: {}.
18 | :type meta_data: dict
19 | """
20 |
21 | context_id: Optional[str] = None
22 | data: dict = field(default_factory=dict)
23 | source: Optional[str] = None
24 | meta_data: dict = field(default_factory=dict)
25 |
26 | def to_dict(self):
27 | return {
28 | "context_id": self.context_id,
29 | "source": self.source,
30 | "data": self.data,
31 | "meta_data": self.meta_data,
32 | }
33 |
34 |
35 | @dataclass
36 | class RetrievedContext(Context):
37 | """The dataclass for retrieved context.
38 |
39 | :param retriever: The name of the retriever. Required.
40 | :type retriever: str
41 | :param query: The query for retrieval. Required.
42 | :type query: str
43 | :param score: The relevance score of the retrieved data. Default: 0.0.
44 | :type score: float
45 | """
46 |
47 | retriever: str = MISSING
48 | query: str = MISSING
49 | score: float = 0.0
50 |
51 | def to_dict(self):
52 | return {
53 | **super().to_dict(),
54 | "retriever": self.retriever,
55 | "query": self.query,
56 | "score": self.score,
57 | }
58 |
59 |
60 | @dataclass
61 | class RAGEvalData:
62 | """The dataclass for RAG evaluation data.
63 |
64 | :param question: The question for evaluation. Required.
65 | :type question: str
66 | :param golden_contexts: The contexts related to the question. Default: None.
67 | :type golden_contexts: Optional[list[Context]]
68 | :param golden_answers: The golden answers for the question. Default: None.
69 | :type golden_answers: Optional[list[str]]
70 | :param meta_data: The metadata of the evaluation data. Default: {}.
71 | :type meta_data: dict
72 | """
73 |
74 | question: str = MISSING
75 | golden_contexts: Optional[list[Context]] = None
76 | golden_answers: Optional[list[str]] = None
77 | meta_data: dict = field(default_factory=dict)
78 |
79 |
80 | @dataclass
81 | class IREvalData:
82 | """The dataclass for Information Retrieval evaluation data.
83 |
84 | :param question: The question for evaluation. Required.
85 | :type question: str
86 | :param contexts: The contexts related to the question. Default: None.
87 | :type contexts: Optional[list[Context]]
88 | :param meta_data: The metadata of the evaluation data. Default: {}.
89 | :type meta_data: dict
90 | """
91 |
92 | question: str
93 | contexts: Optional[list[Context]] = None
94 | meta_data: dict = field(default_factory=dict)
95 |
--------------------------------------------------------------------------------
/src/flexrag/context_refine/__init__.py:
--------------------------------------------------------------------------------
1 | from .arranger import ContextArranger, ContextArrangerConfig
2 | from .summarizer import (
3 | RecompExtractiveSummarizer,
4 | RecompExtractiveSummarizerConfig,
5 | AbstractiveSummarizer,
6 | AbstractiveSummarizerConfig,
7 | )
8 | from .refiner import RefinerBase, REFINERS
9 |
10 |
11 | RefinerConfig = REFINERS.make_config(
12 | allow_multiple=True, default=None, config_name="RefinerConfig"
13 | )
14 |
15 |
16 | __all__ = [
17 | "ContextArranger",
18 | "ContextArrangerConfig",
19 | "RecompExtractiveSummarizer",
20 | "RecompExtractiveSummarizerConfig",
21 | "AbstractiveSummarizer",
22 | "AbstractiveSummarizerConfig",
23 | "RefinerBase",
24 | "REFINERS",
25 | "RefinerConfig",
26 | ]
27 |
--------------------------------------------------------------------------------
/src/flexrag/context_refine/arranger.py:
--------------------------------------------------------------------------------
1 | import random as rd
2 | from dataclasses import dataclass
3 |
4 | from flexrag.common_dataclass import RetrievedContext
5 | from flexrag.utils import Choices, TIME_METER
6 |
7 | from .refiner import REFINERS, RefinerBase
8 |
9 |
10 | @dataclass
11 | class ContextArrangerConfig:
12 | """The configuration for the ``ContextArranger``.
13 |
14 | :param order: The order to arrange the contexts. Defaults to "ascending".
15 | available choices: "ascending", "descending", "side", "random".
16 | :type order: str
17 | """
18 |
19 | order: Choices(["ascending", "descending", "side", "random"]) = "ascending" # type: ignore
20 |
21 |
22 | @REFINERS("context_arranger", config_class=ContextArrangerConfig)
23 | class ContextArranger(RefinerBase):
24 | """The ``ContextArranger`` arranges the contexts based on the given order.
25 |
26 | As the `lost-in-the-middle` problem encountered by the LLMs, the order of the contexts may affect the performance.
27 | This refiner helps to arrange the contexts in a specific order.
28 | """
29 |
30 | def __init__(self, config: ContextArrangerConfig):
31 | self.order = config.order
32 | return
33 |
34 | @TIME_METER("repack")
35 | def refine(self, contexts: list[RetrievedContext]) -> list[RetrievedContext]:
36 | match self.order:
37 | case "ascending":
38 | contexts = sorted(contexts, key=lambda x: x.score)
39 | case "descending":
40 | contexts = sorted(contexts, key=lambda x: x.score, reverse=True)
41 | case "random":
42 | indices = list(range(len(contexts)))
43 | rd.shuffle(indices)
44 | contexts = [contexts[i] for i in indices]
45 | case "side":
46 | sort_ctxs = sorted(contexts, key=lambda x: x.score, reverse=True)
47 | contexts_left = []
48 | contexts_right = []
49 | for i in range(0, len(sort_ctxs), 2):
50 | contexts_left.append(sort_ctxs[i])
51 | for i in range(1, len(sort_ctxs), 2):
52 | contexts_right.append(sort_ctxs[i])
53 | contexts = contexts_left + contexts_right[::-1]
54 | case _:
55 | raise ValueError(f"Invalid order: {self.order}")
56 | return contexts
57 |
--------------------------------------------------------------------------------
/src/flexrag/context_refine/refiner.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from flexrag.common_dataclass import RetrievedContext
4 | from flexrag.utils import Register
5 |
6 |
7 | class RefinerBase(ABC):
8 | """The base class for context refiners.
9 | The subclasses should implement the ``refine`` method.
10 | """
11 |
12 | @abstractmethod
13 | def refine(self, contexts: list[RetrievedContext]) -> list[RetrievedContext]:
14 | """Refine the contexts.
15 |
16 | :param contexts: The retrieved contexts to refine.
17 | :type contexts: list[RetrievedContext]
18 | :return: The refined contexts.
19 | :rtype: list[RetrievedContext]
20 | """
21 | return
22 |
23 |
24 | REFINERS = Register[RefinerBase]("refiner")
25 |
--------------------------------------------------------------------------------
/src/flexrag/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # datasets
2 | from .dataset import ChainDataset, ConcatDataset, IterableDataset, MappingDataset
3 | from .hf_dataset import HFDataset, HFDatasetConfig
4 | from .line_delimited_dataset import LineDelimitedDataset, LineDelimitedDatasetConfig
5 | from .rag_dataset import (
6 | RAGCorpusDataset,
7 | RAGCorpusDatasetConfig,
8 | RAGEvalDataset,
9 | RAGEvalDatasetConfig,
10 | )
11 | from .retrieval_dataset import MTEBDataset, MTEBDatasetConfig
12 |
13 | __all__ = [
14 | "ChainDataset",
15 | "IterableDataset",
16 | "MappingDataset",
17 | "ConcatDataset",
18 | "HFDataset",
19 | "HFDatasetConfig",
20 | "LineDelimitedDataset",
21 | "LineDelimitedDatasetConfig",
22 | "RAGEvalDatasetConfig",
23 | "RAGEvalDataset",
24 | "RAGCorpusDatasetConfig",
25 | "RAGCorpusDataset",
26 | "MTEBDataset",
27 | "MTEBDatasetConfig",
28 | ]
29 |
--------------------------------------------------------------------------------
/src/flexrag/datasets/hf_dataset.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 | from datasets import Dataset as _Dataset
5 | from datasets import DatasetDict as _DatasetDict
6 | from datasets import load_dataset
7 |
8 | from .dataset import MappingDataset
9 |
10 |
11 | @dataclass
12 | class HFDatasetConfig:
13 | """The configuration for the ``HFDataset``.
14 | The ``HFDataset`` is a wrapper class that employs the ``load_dataset`` method in HuggingFace ``datasets`` library to load the dataset.
15 |
16 | :param path: Path or name of the dataset.
17 | :type path: str
18 | :param name: Defining the name of the dataset configuration.
19 | :type name: Optional[str]
20 | :param data_dir: Defining the ``data_dir`` of the dataset configuration.
21 | :type data_dir: Optional[str]
22 | :param data_files: Paths to source data files.
23 | :type data_files: list[str]
24 | :param split: Which split of the data to load.
25 | :type split: Optional[str]
26 | :param cache_dir: Directory to read/write data.
27 | :type cache_dir: Optional[str]
28 | :param token: Optional string or boolean to use as Bearer token for remote files on the Datasets Hub.
29 | :type token: Optional[str]
30 | :param trust_remote_code: Whether or not to allow for datasets defined on the Hub using a dataset script.
31 | :type trust_remote_code: bool
32 |
33 | For example, you can load the dataset from the HuggingFace by running the following code:
34 |
35 | >>> cfg = HFDatasetConfig(
36 | ... path="mteb/nq",
37 | ... split="test",
38 | ... )
39 | >>> dataset = HFDataset(cfg)
40 |
41 | You can also load the dataset from a local repository by specifying the path:
42 |
43 | >>> cfg = HFDatasetConfig(
44 | ... path="json",
45 | ... data_files=["path/to/local/my_dataset.json"],
46 | ... )
47 | >>> dataset = HFDataset(cfg)
48 |
49 | For more information about the parameters, please refer to the HuggingFace ``datasets`` documentation:
50 | https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset
51 | """
52 |
53 | path: str
54 | name: Optional[str] = None
55 | data_dir: Optional[str] = None
56 | data_files: list[str] = field(default_factory=list)
57 | split: Optional[str] = None
58 | cache_dir: Optional[str] = None
59 | token: Optional[str] = None
60 | trust_remote_code: bool = False
61 |
62 |
63 | class HFDataset(MappingDataset):
64 | """HFDataset is a dataset that wraps the HaggingFace ``datasets`` library."""
65 |
66 | dataset: _Dataset
67 |
68 | def __init__(self, cfg: HFDatasetConfig) -> None:
69 | super().__init__()
70 | self.dataset = load_dataset(
71 | path=cfg.path,
72 | name=cfg.name,
73 | data_dir=cfg.data_dir,
74 | data_files=cfg.data_files if cfg.data_files else None,
75 | split=cfg.split,
76 | cache_dir=cfg.cache_dir,
77 | token=cfg.token,
78 | trust_remote_code=cfg.trust_remote_code,
79 | )
80 | if isinstance(self.dataset, _DatasetDict):
81 | raise ValueError(
82 | "Split is missing.\n"
83 | "Please pick one among the following splits: "
84 | f"{list(self.dataset.keys())}"
85 | )
86 | return
87 |
88 | def __getitem__(self, index: int):
89 | return self.dataset[index]
90 |
91 | def __len__(self):
92 | return len(self.dataset)
93 |
--------------------------------------------------------------------------------
/src/flexrag/datasets/retrieval_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from dataclasses import dataclass
4 |
5 | from omegaconf import MISSING
6 |
7 | from flexrag.common_dataclass import Context, IREvalData
8 |
9 | from .dataset import MappingDataset
10 |
11 |
12 | @dataclass
13 | class MTEBDatasetConfig:
14 | """Configuration for loading `MTEB `_ Retrieval Dataset.
15 | The __getitem__ method will return `IREvalData` objects.
16 |
17 | For example, to load the NQ dataset, you can download the test set by running the following command:
18 |
19 | >>> git lfs install
20 | >>> git clone https://huggingface.co/datasets/mteb/nq nq
21 |
22 | Then you can use the following code to load the dataset:
23 |
24 | >>> config = MTEBDatasetConfig(
25 | ... data_path="nq",
26 | ... subset="test",
27 | ... load_corpus=False,
28 | ... )
29 | >>> dataset = MTEBDataset(config)
30 |
31 | :param data_path: Path to the data directory. Required.
32 | :type data_path: str
33 | :param subset: Subset of the dataset to load. Required.
34 | :type subset: str
35 | :param encoding: Encoding of the data files. Default is 'utf-8'.
36 | :type encoding: str
37 | :param load_corpus: Whether to load the corpus data. Default is False.
38 | :type load_corpus: bool
39 | """
40 |
41 | data_path: str = MISSING
42 | subset: str = MISSING
43 | encoding: str = "utf-8"
44 | load_corpus: bool = False
45 |
46 |
47 | class MTEBDataset(MappingDataset[IREvalData]):
48 | """Dataset for loading MTEB Retrieval Dataset."""
49 |
50 | def __init__(self, config: MTEBDatasetConfig) -> None:
51 | qrels: list[dict] = [
52 | json.loads(line)
53 | for line in open(
54 | os.path.join(config.data_path, "qrels", f"{config.subset}.jsonl"),
55 | "r",
56 | encoding=config.encoding,
57 | )
58 | ]
59 | queries = [
60 | json.loads(line)
61 | for line in open(
62 | os.path.join(config.data_path, "queries.jsonl"),
63 | "r",
64 | encoding=config.encoding,
65 | )
66 | ]
67 | queries = {query["_id"]: query for query in queries}
68 |
69 | if config.load_corpus:
70 | corpus = [
71 | json.loads(line)
72 | for line in open(
73 | os.path.join(config.data_path, "corpus.jsonl"),
74 | "r",
75 | encoding=config.encoding,
76 | )
77 | ]
78 | corpus = {doc["_id"]: doc for doc in corpus}
79 | else:
80 | corpus = None
81 |
82 | # merge qrels, queries, and corpus into RetrievalData
83 | dataset_map: dict[str, int] = {}
84 | self.dataset: list[IREvalData] = []
85 | for qrel in qrels:
86 | # construct the context
87 | context = Context(context_id=qrel["corpus-id"])
88 | if corpus is not None:
89 | context.data = corpus[qrel["corpus-id"]]
90 | if "score" in qrel: # relevance level of the context
91 | context.meta_data["score"] = int(qrel["score"])
92 | query = queries[qrel["query-id"]]["text"]
93 |
94 | if qrel["query-id"] not in dataset_map:
95 | dataset_map[qrel["query-id"]] = len(self.dataset)
96 | self.dataset.append(
97 | IREvalData(
98 | question=query,
99 | contexts=[context],
100 | meta_data={"query-id": qrel["query-id"]},
101 | )
102 | )
103 | else:
104 | index = dataset_map[qrel["query-id"]]
105 | self.dataset[index].contexts.append(context)
106 | return
107 |
108 | def __len__(self) -> int:
109 | return len(self.dataset)
110 |
111 | def __getitem__(self, index: int) -> IREvalData:
112 | return self.dataset[index]
113 |
--------------------------------------------------------------------------------
/src/flexrag/document_parser/__init__.py:
--------------------------------------------------------------------------------
1 | from .document_parser_base import DocumentParserBase, Document, DOCUMENTPARSERS
2 | from .docling_parser import DoclingParser, DoclingConfig
3 | from .markitdown_parser import MarkItDownParser
4 |
5 |
6 | DocumentParserConfig = DOCUMENTPARSERS.make_config(default="markitdown")
7 |
8 |
9 | __all__ = [
10 | "DocumentParserBase",
11 | "Document",
12 | "DOCUMENTPARSERS",
13 | "DocumentParserConfig",
14 | "DoclingParser",
15 | "DoclingConfig",
16 | "MarkItDownParser",
17 | ]
18 |
--------------------------------------------------------------------------------
/src/flexrag/document_parser/docling_parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 |
4 |
5 | from .document_parser_base import Document, DocumentParserBase, DOCUMENTPARSERS
6 |
7 |
8 | @dataclass
9 | class DoclingConfig:
10 | do_ocr: bool = False
11 | do_table_structure: bool = True
12 | generate_page_images: bool = False
13 | generate_picture_images: bool = False
14 |
15 |
16 | @DOCUMENTPARSERS("docling", config_class=DoclingConfig)
17 | class DoclingParser(DocumentParserBase):
18 | def __init__(self, config: DoclingConfig):
19 | try:
20 | from docling.datamodel.base_models import InputFormat
21 | from docling.datamodel.pipeline_options import PdfPipelineOptions
22 | from docling.document_converter import DocumentConverter, PdfFormatOption
23 | except ImportError:
24 | raise ImportError(
25 | "Docling is not installed. Please install it via `pip install docling`."
26 | )
27 |
28 | pdf_pipeline_options = PdfPipelineOptions(
29 | do_ocr=config.do_ocr,
30 | do_table_structure=config.do_table_structure,
31 | generate_page_images=config.generate_page_images,
32 | generate_picture_images=config.generate_picture_images,
33 | )
34 | self.doc_converter = DocumentConverter(
35 | format_options={
36 | InputFormat.PDF: PdfFormatOption(pipeline_options=pdf_pipeline_options)
37 | }
38 | )
39 | return
40 |
41 | def parse(self, input_file_path: str) -> Document:
42 | assert os.path.exists(input_file_path)
43 | document_ = self.doc_converter.convert(input_file_path).document
44 | document = Document(
45 | source_file_path=input_file_path,
46 | text=document_.export_to_markdown(),
47 | title=document_.name,
48 | )
49 | if document.pagaes.image is not None:
50 | document.screenshots = [p.image.pil_image for p in document_.pages]
51 | if document.pictures.image is not None:
52 | document.images = [p.image.pil_image for p in document_.pictures]
53 | return document
54 |
--------------------------------------------------------------------------------
/src/flexrag/document_parser/document_parser_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass, field
3 | from typing import Optional
4 |
5 | from PIL.Image import Image
6 |
7 | from flexrag.utils import Register
8 |
9 |
10 | @dataclass
11 | class Document:
12 | """A document parsed by a DocumentParser."""
13 |
14 | source_file_path: str
15 | title: Optional[str] = None
16 | text: Optional[str] = None
17 | screenshots: list[Image] = field(default_factory=list)
18 | images: list[Image] = field(default_factory=list)
19 |
20 |
21 | class DocumentParserBase(ABC):
22 | @abstractmethod
23 | def parse(self, document_path: str) -> Document:
24 | """Parse the document at the given path.
25 |
26 | :param document_path: The path to the document to parse.
27 | :type document_path: str
28 | :return: The parsed document.
29 | :rtype: Document
30 | """
31 | return
32 |
33 |
34 | DOCUMENTPARSERS = Register[DocumentParserBase]("document_parser")
35 |
--------------------------------------------------------------------------------
/src/flexrag/document_parser/markitdown_parser.py:
--------------------------------------------------------------------------------
1 | from .document_parser_base import DOCUMENTPARSERS, Document, DocumentParserBase
2 |
3 |
4 | @DOCUMENTPARSERS("markitdown")
5 | class MarkItDownParser(DocumentParserBase):
6 | def __init__(self):
7 | try:
8 | from markitdown import MarkItDown
9 | except ImportError:
10 | raise ImportError(
11 | "MarkItDown is not installed. Please install it via `pip install markitdown`."
12 | )
13 | finally:
14 | self.parser = MarkItDown()
15 | return
16 |
17 | def parse(self, path: str) -> Document:
18 | doc = self.parser.convert(path)
19 | return Document(source_file_path=path, text=doc.text_content, title=doc.title)
20 |
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/__init__.py
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/assets/flexrag-wide.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/assets/flexrag-wide.png
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/assets/flexrag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/assets/flexrag.png
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/assets/robot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/assets/robot.png
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/assets/user.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/FlexRAG/f00cf3a36890a2d56aec508a72b599c4681be3ca/src/flexrag/entrypoints/assets/user.png
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/cache.py:
--------------------------------------------------------------------------------
1 | import json
2 | from dataclasses import dataclass
3 |
4 | import hydra
5 | from hydra.core.config_store import ConfigStore
6 | from omegaconf import MISSING
7 |
8 | from flexrag.retriever.retriever_base import RETRIEVAL_CACHE
9 | from flexrag.utils import Choices
10 |
11 |
12 | @dataclass
13 | class Config:
14 | export_path: str = MISSING
15 | action: Choices(["clear", "export", "_"]) = "_" # type: ignore
16 |
17 |
18 | cs = ConfigStore.instance()
19 | cs.store(name="default", node=Config)
20 |
21 |
22 | @hydra.main(version_base="1.3", config_path=None, config_name="default")
23 | def main(config: Config):
24 | match config.action:
25 | case "clear":
26 | RETRIEVAL_CACHE.clear()
27 | case "export":
28 | with open(config.export_path, "w", encoding="utf-8") as f:
29 | for data in RETRIEVAL_CACHE:
30 | data["retrieved_contexts"] = RETRIEVAL_CACHE[data]
31 | f.write(json.dumps(data) + "\n")
32 | case _:
33 | raise ValueError("No action specified")
34 | return
35 |
36 |
37 | if __name__ == "__main__":
38 | main()
39 |
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/combine_outputs.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from dataclasses import dataclass
4 |
5 | import hydra
6 | from hydra.core.config_store import ConfigStore
7 | from omegaconf import MISSING, OmegaConf
8 |
9 | from flexrag.metrics import Evaluator
10 | from flexrag.utils import LOGGER_MANAGER
11 |
12 |
13 | @dataclass
14 | class Config:
15 | result_paths: list[str] = MISSING
16 | output_path: str = MISSING
17 |
18 |
19 | cs = ConfigStore.instance()
20 | cs.store(name="default", node=Config)
21 | logger = LOGGER_MANAGER.get_logger("combine_outputs")
22 |
23 |
24 | @hydra.main(version_base="1.3", config_path=None, config_name="default")
25 | def main(cfg: Config):
26 | # load the metadata
27 | config_path = os.path.join(cfg.result_paths[0], "config.yaml")
28 | loaded_config = OmegaConf.load(config_path)
29 | evaluator = Evaluator(loaded_config.eval_config)
30 |
31 | # prepare output path
32 | if not os.path.exists(cfg.output_path):
33 | os.makedirs(cfg.output_path)
34 | output_details_path = os.path.join(cfg.output_path, "details.jsonl")
35 | output_eval_score_path = os.path.join(cfg.output_path, "eval_score.json")
36 | output_config_path = os.path.join(cfg.output_path, "config.yaml")
37 | OmegaConf.save(loaded_config, output_config_path)
38 |
39 | # combine the results
40 | logger.info("Combining the results...")
41 | questions = []
42 | golden_answers = []
43 | golden_contexts = []
44 | responses = []
45 | contexts = []
46 | with open(output_details_path, "w", encoding="utf-8") as f:
47 | for result_path in cfg.result_paths:
48 | details_path = os.path.join(result_path, "details.jsonl")
49 | for line in open(details_path, "r", encoding="utf-8"):
50 | f.write(line)
51 | data = json.loads(line)
52 | questions.append(data["question"])
53 | golden_answers.append(data["golden"])
54 | golden_contexts.append(data["golden_contexts"])
55 | responses.append(data["response"])
56 | contexts.append(data["contexts"])
57 |
58 | # re-evaluate the combined results
59 | logger.info("Re-evaluating the combined results...")
60 | resp_score, resp_score_detail = evaluator.evaluate(
61 | questions=questions,
62 | responses=responses,
63 | golden_responses=golden_answers,
64 | retrieved_contexts=contexts,
65 | golden_contexts=golden_contexts,
66 | log=True,
67 | )
68 | with open(output_eval_score_path, "w", encoding="utf-8") as f:
69 | json.dump(
70 | {
71 | "eval_scores": resp_score,
72 | "eval_details": resp_score_detail,
73 | },
74 | f,
75 | indent=4,
76 | ensure_ascii=False,
77 | )
78 | return
79 |
80 |
81 | if __name__ == "__main__":
82 | main()
83 |
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/prepare_index.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | import hydra
4 | from hydra.core.config_store import ConfigStore
5 | from omegaconf import OmegaConf
6 |
7 | from flexrag.datasets import RAGCorpusDataset, RAGCorpusDatasetConfig
8 | from flexrag.retriever import (
9 | BM25SRetriever,
10 | BM25SRetrieverConfig,
11 | DenseRetriever,
12 | DenseRetrieverConfig,
13 | ElasticRetriever,
14 | ElasticRetrieverConfig,
15 | TypesenseRetriever,
16 | TypesenseRetrieverConfig,
17 | )
18 | from flexrag.utils import LOGGER_MANAGER, Choices
19 |
20 | logger = LOGGER_MANAGER.get_logger("flexrag.prepare_index")
21 |
22 |
23 | # fmt: off
24 | @dataclass
25 | class Config(RAGCorpusDatasetConfig):
26 | # retriever configs
27 | retriever_type: Choices(["dense", "elastic", "typesense", "bm25s"]) = "dense" # type: ignore
28 | bm25s_config: BM25SRetrieverConfig = field(default_factory=BM25SRetrieverConfig)
29 | dense_config: DenseRetrieverConfig = field(default_factory=DenseRetrieverConfig)
30 | elastic_config: ElasticRetrieverConfig = field(default_factory=ElasticRetrieverConfig)
31 | typesense_config: TypesenseRetrieverConfig = field(default_factory=TypesenseRetrieverConfig)
32 | reinit: bool = False
33 | # fmt: on
34 |
35 |
36 | cs = ConfigStore.instance()
37 | cs.store(name="default", node=Config)
38 |
39 |
40 | @hydra.main(version_base="1.3", config_path=None, config_name="default")
41 | def main(cfg: Config):
42 | # load retriever
43 | match cfg.retriever_type:
44 | case "bm25s":
45 | retriever = BM25SRetriever(cfg.bm25s_config)
46 | case "dense":
47 | retriever = DenseRetriever(cfg.dense_config)
48 | case "elastic":
49 | retriever = ElasticRetriever(cfg.elastic_config)
50 | case "typesense":
51 | retriever = TypesenseRetriever(cfg.typesense_config)
52 | case _:
53 | raise ValueError(f"Unsupported retriever type: {cfg.retriever_type}")
54 |
55 | # add passages
56 | if cfg.reinit and (len(retriever) > 0):
57 | logger.warning("Reinitializing retriever and removing all passages")
58 | retriever.clean()
59 |
60 | retriever.add_passages(passages=RAGCorpusDataset(cfg))
61 | return
62 |
63 |
64 | if __name__ == "__main__":
65 | main()
66 |
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/rebuild_index.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from hydra.core.config_store import ConfigStore
3 | from omegaconf import OmegaConf
4 |
5 | from flexrag.retriever import DenseRetriever, DenseRetrieverConfig
6 |
7 |
8 | cs = ConfigStore.instance()
9 | cs.store(name="default", node=DenseRetrieverConfig)
10 |
11 |
12 | @hydra.main(version_base="1.3", config_path=None, config_name="default")
13 | def main(cfg: DenseRetrieverConfig):
14 | # rebuild index
15 | retriever = DenseRetriever(cfg, no_check=True)
16 | retriever.build_index(rebuild=True)
17 | return
18 |
19 |
20 | if __name__ == "__main__":
21 | main()
22 |
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/run_assistant.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import sys
5 | from dataclasses import dataclass, field
6 | from typing import Optional
7 |
8 | import hydra
9 | from hydra.core.config_store import ConfigStore
10 | from omegaconf import OmegaConf
11 |
12 | from flexrag.assistant import ASSISTANTS
13 | from flexrag.common_dataclass import RetrievedContext
14 | from flexrag.datasets import RAGEvalDataset, RAGEvalDatasetConfig
15 | from flexrag.metrics import Evaluator, EvaluatorConfig
16 | from flexrag.utils import LOGGER_MANAGER, SimpleProgressLogger, load_user_module
17 |
18 | # load user modules before loading config
19 | for arg in sys.argv:
20 | if arg.startswith("user_module="):
21 | load_user_module(arg.split("=")[1])
22 | sys.argv.remove(arg)
23 |
24 |
25 | AssistantConfig = ASSISTANTS.make_config()
26 |
27 |
28 | @dataclass
29 | class Config(AssistantConfig, RAGEvalDatasetConfig):
30 | eval_config: EvaluatorConfig = field(default_factory=EvaluatorConfig) # fmt: skip
31 | log_interval: int = 10
32 | output_path: Optional[str] = None
33 |
34 |
35 | cs = ConfigStore.instance()
36 | cs.store(name="default", node=Config)
37 | logger = LOGGER_MANAGER.get_logger("run_assistant")
38 |
39 |
40 | @hydra.main(version_base="1.3", config_path=None, config_name="default")
41 | def main(config: Config):
42 | # load dataset
43 | testset = RAGEvalDataset(config)
44 |
45 | # load assistant
46 | assistant = ASSISTANTS.load(config)
47 |
48 | # prepare output paths
49 | if config.output_path is not None:
50 | if not os.path.exists(config.output_path):
51 | os.makedirs(config.output_path)
52 | details_path = os.path.join(config.output_path, "details.jsonl")
53 | eval_score_path = os.path.join(config.output_path, "eval_score.json")
54 | config_path = os.path.join(config.output_path, "config.yaml")
55 | log_path = os.path.join(config.output_path, "log.txt")
56 | else:
57 | details_path = os.devnull
58 | eval_score_path = os.devnull
59 | config_path = os.devnull
60 | log_path = os.devnull
61 |
62 | # save config and set logger
63 | with open(config_path, "w", encoding="utf-8") as f:
64 | OmegaConf.save(config, f)
65 | handler = logging.FileHandler(log_path)
66 | LOGGER_MANAGER.add_handler(handler)
67 | logger.debug(f"Configs:\n{OmegaConf.to_yaml(config)}")
68 |
69 | # search and generate
70 | p_logger = SimpleProgressLogger(logger, interval=config.log_interval)
71 | questions = []
72 | golden_answers = []
73 | golden_contexts = []
74 | responses = []
75 | contexts: list[list[RetrievedContext]] = []
76 | with open(details_path, "w", encoding="utf-8") as f:
77 | for item in testset:
78 | questions.append(item.question)
79 | golden_answers.append(item.golden_answers)
80 | golden_contexts.append(item.golden_contexts)
81 | response, ctxs, metadata = assistant.answer(question=item.question)
82 | responses.append(response)
83 | contexts.append(ctxs)
84 | json.dump(
85 | {
86 | "question": item.question,
87 | "golden": item.golden_answers,
88 | "golden_contexts": item.golden_contexts,
89 | "metadata_test": item.meta_data,
90 | "response": response,
91 | "contexts": ctxs,
92 | "metadata": metadata,
93 | },
94 | f,
95 | ensure_ascii=False,
96 | )
97 | f.write("\n")
98 | p_logger.update(desc="Searching")
99 |
100 | # evaluate
101 | evaluator = Evaluator(config.eval_config)
102 | resp_score, resp_score_detail = evaluator.evaluate(
103 | questions=questions,
104 | responses=responses,
105 | golden_responses=golden_answers,
106 | retrieved_contexts=contexts,
107 | golden_contexts=golden_contexts,
108 | log=True,
109 | )
110 | with open(eval_score_path, "w", encoding="utf-8") as f:
111 | json.dump(
112 | {
113 | "eval_scores": resp_score,
114 | "eval_details": resp_score_detail,
115 | },
116 | f,
117 | indent=4,
118 | ensure_ascii=False,
119 | )
120 | return
121 |
122 |
123 | if __name__ == "__main__":
124 | main()
125 |
--------------------------------------------------------------------------------
/src/flexrag/entrypoints/run_retriever.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import sys
5 | from dataclasses import dataclass, field
6 | from typing import Optional
7 |
8 | import hydra
9 | from hydra.core.config_store import ConfigStore
10 | from omegaconf import OmegaConf
11 |
12 | from flexrag.common_dataclass import Context, RetrievedContext
13 | from flexrag.datasets import MTEBDataset, MTEBDatasetConfig
14 | from flexrag.metrics import Evaluator, EvaluatorConfig
15 | from flexrag.retriever import RETRIEVERS
16 | from flexrag.utils import LOGGER_MANAGER, SimpleProgressLogger, load_user_module
17 |
18 | # load user modules before loading config
19 | for arg in sys.argv:
20 | if arg.startswith("user_module="):
21 | load_user_module(arg.split("=")[1])
22 | sys.argv.remove(arg)
23 |
24 | RetrieverConfig = RETRIEVERS.make_config(config_name="RetrieverConfig")
25 |
26 |
27 | @dataclass
28 | class Config(RetrieverConfig, MTEBDatasetConfig):
29 | output_path: Optional[str] = None
30 | eval_config: EvaluatorConfig = field(default_factory=EvaluatorConfig) # fmt: skip
31 | log_interval: int = 10
32 |
33 |
34 | cs = ConfigStore.instance()
35 | cs.store(name="default", node=Config)
36 | logger = LOGGER_MANAGER.get_logger("run_retriever")
37 |
38 |
39 | @hydra.main(version_base="1.3", config_path=None, config_name="default")
40 | def main(config: Config):
41 | # load dataset
42 | testset = MTEBDataset(config)
43 |
44 | # load assistant
45 | retriever = RETRIEVERS.load(config)
46 |
47 | # prepare output paths
48 | if config.output_path is not None:
49 | if not os.path.exists(config.output_path):
50 | os.makedirs(config.output_path)
51 | details_path = os.path.join(config.output_path, "details.jsonl")
52 | eval_score_path = os.path.join(config.output_path, "eval_score.json")
53 | config_path = os.path.join(config.output_path, "config.yaml")
54 | log_path = os.path.join(config.output_path, "log.txt")
55 | else:
56 | details_path = os.devnull
57 | eval_score_path = os.devnull
58 | config_path = os.devnull
59 | log_path = os.devnull
60 |
61 | # save config and set logger
62 | with open(config_path, "w", encoding="utf-8") as f:
63 | OmegaConf.save(config, f)
64 | handler = logging.FileHandler(log_path)
65 | LOGGER_MANAGER.add_handler(handler)
66 | logger.debug(f"Configs:\n{OmegaConf.to_yaml(config)}")
67 |
68 | # search and generate
69 | p_logger = SimpleProgressLogger(logger, interval=config.log_interval)
70 | questions = []
71 | goldens: list[list[Context]] = []
72 | retrieved: list[list[RetrievedContext]] = []
73 | with open(details_path, "w", encoding="utf-8") as f:
74 | for item in testset:
75 | questions.append(item.question)
76 | goldens.append(item.contexts)
77 | ctxs = retriever.search(query=item.question)[0]
78 | retrieved.append(ctxs)
79 | json.dump(
80 | {
81 | "question": item.question,
82 | "golden_contexts": item.contexts,
83 | "metadata": item.meta_data,
84 | "contexts": ctxs,
85 | },
86 | f,
87 | ensure_ascii=False,
88 | )
89 | f.write("\n")
90 | p_logger.update(desc="Searching")
91 |
92 | # evaluate
93 | evaluator = Evaluator(config.eval_config)
94 | resp_score, resp_score_detail = evaluator.evaluate(
95 | questions=questions,
96 | retrieved_contexts=retrieved,
97 | golden_contexts=goldens,
98 | log=True,
99 | )
100 | with open(eval_score_path, "w", encoding="utf-8") as f:
101 | json.dump(
102 | {
103 | "eval_scores": resp_score,
104 | "eval_details": resp_score_detail,
105 | },
106 | f,
107 | indent=4,
108 | ensure_ascii=False,
109 | )
110 | return
111 |
112 |
113 | if __name__ == "__main__":
114 | main()
115 |
--------------------------------------------------------------------------------
/src/flexrag/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .generation_metrics import (
2 | BLEU,
3 | BLEUConfig,
4 | Rouge,
5 | chrF,
6 | chrFConfig,
7 | )
8 | from .matching_metrics import (
9 | F1,
10 | Accuracy,
11 | ExactMatch,
12 | MatchingMetrics,
13 | Precision,
14 | Recall,
15 | )
16 | from .metrics_base import MetricsBase
17 | from .retrieval_metrics import (
18 | SuccessRate,
19 | SuccessRateConfig,
20 | RetrievalRecall,
21 | RetrievalRecallConfig,
22 | RetrievalPrecision,
23 | RetrievalPrecisionConfig,
24 | RetrievalMAP,
25 | RetrievalMAPConfig,
26 | RetrievalNDCG,
27 | RetrievalNDCGConfig,
28 | )
29 |
30 | from .evaluator import Evaluator, EvaluatorConfig # isort: skip
31 |
32 | __all__ = [
33 | "MetricsBase",
34 | "MatchingMetrics",
35 | "Accuracy",
36 | "ExactMatch",
37 | "F1",
38 | "Recall",
39 | "Precision",
40 | "BLEU",
41 | "BLEUConfig",
42 | "Rouge",
43 | "chrF",
44 | "chrFConfig",
45 | "SuccessRate",
46 | "SuccessRateConfig",
47 | "RetrievalRecall",
48 | "RetrievalRecallConfig",
49 | "RetrievalPrecision",
50 | "RetrievalPrecisionConfig",
51 | "RetrievalMAP",
52 | "RetrievalMAPConfig",
53 | "RetrievalNDCG",
54 | "RetrievalNDCGConfig",
55 | "Evaluator",
56 | "EvaluatorConfig",
57 | ]
58 |
--------------------------------------------------------------------------------
/src/flexrag/metrics/evaluator.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | from flexrag.common_dataclass import RetrievedContext
4 | from flexrag.text_process import TextProcessPipeline, TextProcessPipelineConfig
5 | from flexrag.utils import LOGGER_MANAGER
6 |
7 | from .metrics_base import METRICS, MetricsBase
8 |
9 | logger = LOGGER_MANAGER.get_logger("flexrag.metrics")
10 | MetricConfig = METRICS.make_config(allow_multiple=True)
11 |
12 |
13 | @dataclass
14 | class EvaluatorConfig(MetricConfig):
15 | round: int = 2
16 | response_preprocess: TextProcessPipelineConfig = field(default_factory=TextProcessPipelineConfig) # type: ignore
17 |
18 |
19 | class Evaluator:
20 | def __init__(self, cfg: EvaluatorConfig) -> None:
21 | self.metrics: dict[str, MetricsBase] = {
22 | name: metric for name, metric in zip(cfg.metrics_type, METRICS.load(cfg))
23 | }
24 | self.response_pipeline = TextProcessPipeline(cfg.response_preprocess)
25 | self.round = cfg.round
26 | return
27 |
28 | def evaluate(
29 | self,
30 | questions: list[str] = None,
31 | responses: list[str] = None,
32 | golden_responses: list[list[str]] = None,
33 | retrieved_contexts: list[list[str | RetrievedContext]] = None,
34 | golden_contexts: list[list[str]] = None,
35 | log: bool = True,
36 | ):
37 | """Evaluate the generated responses against the ground truth responses.
38 |
39 | :param questions: A list of questions. Defaults to None.
40 | :param responses: A list of responses. Defaults to None.
41 | :param golden_responses: A list of golden responses. Defaults to None.
42 | :param retrieved_contexts: A list of retrieved contexts. Defaults to None.
43 | :param golden_contexts: A list of golden contexts. Defaults to None.
44 | :param log: Whether to log the evaluation results. Defaults to True.
45 | :type questions: list[str], optional
46 | :type responses: list[str], optional
47 | :type golden_responses: list[list[str]], optional
48 | :type retrieved_contexts: list[list[str | RetrievedContext]], optional
49 | :type golden_contexts: list[list[str]], optional
50 | :type log: bool, optional
51 | :return: The evaluation results and the evaluation details.
52 | :rtype: tuple[dict[str, float], dict[str, Any]]
53 | """
54 | # check the input arguments
55 | not_none_args = [
56 | arg
57 | for arg in [
58 | questions,
59 | responses,
60 | golden_responses,
61 | retrieved_contexts,
62 | golden_contexts,
63 | ]
64 | if arg is not None
65 | ]
66 | assert len(not_none_args) > 1, "At least one argument must be provided."
67 | assert all(
68 | len(i) == len(not_none_args[0]) for i in not_none_args
69 | ), "All arguments must have the same length."
70 |
71 | # evaluate
72 | evaluation_results = {}
73 | evaluation_details = {}
74 | if responses is not None:
75 | responses = [self.response_pipeline(res) for res in responses]
76 | if golden_responses is not None:
77 | golden_responses = [
78 | [self.response_pipeline(g) for g in golds] for golds in golden_responses
79 | ]
80 | for metric in self.metrics:
81 | metric = str(metric) # make json serializable
82 | r, r_detail = self.metrics[metric](
83 | questions=questions,
84 | responses=responses,
85 | golden_responses=golden_responses,
86 | retrieved_contexts=retrieved_contexts,
87 | golden_contexts=golden_contexts,
88 | )
89 | if log:
90 | for name, score in r.items():
91 | logger.info(f"{name}: {score*100:.{self.round}f}%")
92 | evaluation_results.update(r)
93 | evaluation_details[metric] = r_detail
94 | return evaluation_results, evaluation_details
95 |
--------------------------------------------------------------------------------
/src/flexrag/metrics/lib_rel.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | namespace py = pybind11;
7 |
8 | std::vector> get_contain_map(const std::vector& evs, const std::vector& rets) {
9 | std::vector> results;
10 | results.reserve(rets.size());
11 |
12 | for (const auto& ret : rets) {
13 | std::vector result_row;
14 | result_row.reserve(evs.size());
15 | for (const auto& ev : evs) {
16 | result_row.push_back(ret.find(ev) != std::string::npos);
17 | }
18 | results.push_back(std::move(result_row));
19 | }
20 |
21 | return results;
22 | }
23 |
24 | PYBIND11_MODULE(lib_rel, m) {
25 | m.def("get_contain_map", &get_contain_map, "Get contain map.");
26 | }
27 |
--------------------------------------------------------------------------------
/src/flexrag/metrics/matching_metrics.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from collections import Counter
3 |
4 | from flexrag.utils import TIME_METER
5 |
6 | from .metrics_base import MetricsBase, METRICS
7 |
8 |
9 | class MatchingMetrics(MetricsBase):
10 | name: str
11 |
12 | @abstractmethod
13 | def compute_item(self, golds: list[str], response: str) -> float:
14 | return
15 |
16 | @TIME_METER("metrics.matching_score")
17 | def compute(
18 | self, responses: list[str], golden_responses: list[list[str]], **kwargs
19 | ) -> tuple[float, dict[str, list[float]]]:
20 | matching_list = []
21 | for golds, response in zip(golden_responses, responses):
22 | matching_list.append(self.compute_item(golds, response))
23 | matching_score = sum(matching_list) / len(matching_list)
24 | return {self.name: matching_score}, {"item_score": matching_list}
25 |
26 |
27 | @METRICS("generation_em")
28 | class ExactMatch(MatchingMetrics):
29 | """ExactMatch metric computes if any of the golden responses is exactly the same as the predicted response."""
30 |
31 | name = "generation_em"
32 |
33 | def compute_item(self, golds: list[str], response: str) -> float:
34 | return float(response in golds)
35 |
36 |
37 | @METRICS("generation_accuracy")
38 | class Accuracy(MatchingMetrics):
39 | """Accuracy metric computes if any of the golden responses is in the predicted response."""
40 |
41 | name = "generation_accuracy"
42 |
43 | def compute_item(self, golds: list[str], response: str) -> float:
44 | return float(any(gold in response for gold in golds))
45 |
46 |
47 | def f1_recall_precision(golds: list[str], response: str) -> tuple[float, float, float]:
48 | true_counters = [Counter(gold.split()) for gold in golds]
49 | pred_counter = Counter(response.split())
50 | precision = 0.0
51 | recall = 0.0
52 | f1 = 0.0
53 | for gold in true_counters:
54 | common = sum((gold & pred_counter).values())
55 | if common == 0:
56 | continue
57 | p = 1.0 * common / sum(pred_counter.values())
58 | r = 1.0 * common / sum(gold.values())
59 | f1_ = (2 * p * r) / (p + r)
60 | precision = max(p, precision)
61 | recall = max(r, recall)
62 | f1 = max(f1, f1_)
63 | return f1, recall, precision
64 |
65 |
66 | @METRICS("generation_f1")
67 | class F1(MatchingMetrics):
68 | """F1 metric computes the F1 score of the predicted response against the golden responses."""
69 |
70 | name = "generation_f1"
71 |
72 | def compute_item(self, golds: list[str], response: str) -> float:
73 | return f1_recall_precision(golds, response)[0]
74 |
75 |
76 | @METRICS("generation_recall")
77 | class Recall(MatchingMetrics):
78 | """Recall metric computes the recall of the predicted response against the golden responses."""
79 |
80 | name = "generation_recall"
81 |
82 | def compute_item(self, golds: list[str], response: str) -> float:
83 | return f1_recall_precision(golds, response)[1]
84 |
85 |
86 | @METRICS("generation_precision")
87 | class Precision(MatchingMetrics):
88 | """Precision metric computes the precision of the predicted response against the golden responses."""
89 |
90 | name = "generation_precision"
91 |
92 | def compute_item(self, golds: list[str], response: str) -> float:
93 | return f1_recall_precision(golds, response)[2]
94 |
--------------------------------------------------------------------------------
/src/flexrag/metrics/metrics_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from flexrag.common_dataclass import RetrievedContext
4 | from flexrag.utils import Register
5 |
6 |
7 | class MetricsBase(ABC):
8 | def __call__(
9 | self,
10 | questions: list[str] = None,
11 | responses: list[str] = None,
12 | golden_responses: list[list[str]] = None,
13 | retrieved_contexts: list[list[str | RetrievedContext]] = None,
14 | golden_contexts: list[list[str]] = None,
15 | ) -> tuple[dict[str, float], dict]:
16 | """
17 | Compute the metric value.
18 |
19 | :param questions: A list of questions. Defaults to None.
20 | :param responses: A list of responses. Defaults to None.
21 | :param golden_responses: A list of golden responses. Defaults to None.
22 | :param retrieved_contexts: A list of retrieved contexts. Defaults to None.
23 | :param golden_contexts: A list of golden contexts. Defaults to None.
24 | :type questions: list[str], optional
25 | :type responses: list[str], optional
26 | :type golden_responses: list[list[str]], optional
27 | :type retrieved_contexts: list[list[str | RetrievedContext]], optional
28 | :type golden_contexts: list[list[str]], optional
29 | :return: The metric scores and the metadata of the metric.
30 | :rtype: tuple[dict[str, float], dict]
31 | """
32 | return self.compute(
33 | questions=questions,
34 | responses=responses,
35 | golden_responses=golden_responses,
36 | retrieved_contexts=retrieved_contexts,
37 | golden_contexts=golden_contexts,
38 | )
39 |
40 | @abstractmethod
41 | def compute(
42 | self,
43 | questions: list[str] = None,
44 | responses: list[str] = None,
45 | golden_responses: list[list[str]] = None,
46 | retrieved_contexts: list[list[str | RetrievedContext]] = None,
47 | golden_contexts: list[list[str]] = None,
48 | ) -> tuple[dict[str, float], dict]:
49 | """
50 | Compute the metric value.
51 |
52 | :param questions: A list of questions. Defaults to None.
53 | :param responses: A list of responses. Defaults to None.
54 | :param golden_responses: A list of golden responses. Defaults to None.
55 | :param retrieved_contexts: A list of retrieved contexts. Defaults to None.
56 | :param golden_contexts: A list of golden contexts. Defaults to None.
57 | :type questions: list[str], optional
58 | :type responses: list[str], optional
59 | :type golden_responses: list[list[str]], optional
60 | :type retrieved_contexts: list[list[str | RetrievedContext]], optional
61 | :type golden_contexts: list[list[str]], optional
62 | :return: The metric scores and the metadata of the metric.
63 | :rtype: tuple[dict[str, float], dict]
64 | """
65 | return
66 |
67 |
68 | METRICS = Register[MetricsBase]("metrics")
69 |
--------------------------------------------------------------------------------
/src/flexrag/metrics/xfinder.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | from flexrag.utils import TIME_METER, Choices
4 |
5 | from .xfinder_utils import Evaluator
6 | from .metrics_base import METRICS, MetricsBase
7 |
8 |
9 | @dataclass
10 | class xFinderConfig:
11 | model_type: Choices(["qwen", "llama"]) = "qwen" # type: ignore
12 | model_path: str = "IAAR-Shanghai/xFinder-qwen1505"
13 | answer_type: Choices( # type: ignore
14 | ["math", "short_text", "categorical_label", "alphabet_option"]
15 | ) = "short_text"
16 | temperature: float = 0.7
17 | max_tokens: int = 100
18 | device_id: list[int] = field(default_factory=list)
19 |
20 |
21 | @METRICS("generation_xfinder")
22 | class xFinder(MetricsBase):
23 | def __init__(self, config: xFinderConfig):
24 | if config.model_type == "qwen":
25 | model_name = "xFinder-qwen1505"
26 | else:
27 | model_name = "xFinder-llama38it"
28 | self.evaluator = Evaluator(
29 | model_name=model_name,
30 | model_path_or_url=config.model_path,
31 | temperature=config.temperature,
32 | max_tokens=config.max_tokens,
33 | device_id=config.device_id,
34 | )
35 | self.answer_type = config.answer_type
36 | return
37 |
38 | @TIME_METER("metrics.xfinder_score")
39 | def compute(
40 | self,
41 | questions: list[str],
42 | responses: list[str],
43 | golden_responses: list[list[str]],
44 | choices: list[list[str]],
45 | **kwargs
46 | ) -> tuple[float, dict[str, list[float]]]:
47 | results = []
48 | for question, response, goldens, choice in zip(
49 | questions, responses, golden_responses, choices
50 | ):
51 | self.evaluator.evaluate_single_item(
52 | question=question,
53 | llm_output=response,
54 | answer_type=self.answer_type,
55 | correct_answer=goldens[0],
56 | answer_range=",".join(choice),
57 | )
58 |
59 | correct_count = sum(result[-1] for result in results)
60 | accuracy = correct_count / max(1, len(results)) if results else 0
61 | return accuracy
62 |
--------------------------------------------------------------------------------
/src/flexrag/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .anthropic_model import AnthropicGenerator, AnthropicGeneratorConfig
2 | from .cohere_model import (
3 | CohereEncoder,
4 | CohereEncoderConfig,
5 | )
6 | from .hf_model import (
7 | HFModelConfig,
8 | HFEncoder,
9 | HFEncoderConfig,
10 | HFGenerator,
11 | HFGeneratorConfig,
12 | HFClipEncoder,
13 | HFClipEncoderConfig,
14 | HFVLMGenerator,
15 | HFVLMGeneratorConfig,
16 | )
17 | from .jina_model import JinaEncoder, JinaEncoderConfig
18 | from .llamacpp_model import LlamacppGenerator, LlamacppGeneratorConfig
19 | from .model_base import (
20 | EncoderBase,
21 | GenerationConfig,
22 | GeneratorBase,
23 | VLMGeneratorBase,
24 | GENERATORS,
25 | ENCODERS,
26 | )
27 | from .ollama_model import (
28 | OllamaGenerator,
29 | OllamaGeneratorConfig,
30 | OllamaEncoder,
31 | OllamaEncoderConfig,
32 | )
33 | from .openai_model import (
34 | OpenAIConfig,
35 | OpenAIEncoder,
36 | OpenAIEncoderConfig,
37 | OpenAIGenerator,
38 | OpenAIGeneratorConfig,
39 | )
40 | from .vllm_model import VLLMGenerator, VLLMGeneratorConfig
41 | from .sentence_transformers_model import (
42 | SentenceTransformerEncoder,
43 | SentenceTransformerEncoderConfig,
44 | )
45 |
46 |
47 | GeneratorConfig = GENERATORS.make_config(config_name="GeneratorConfig")
48 | EncoderConfig = ENCODERS.make_config(config_name="EncoderConfig", default=None)
49 |
50 |
51 | __all__ = [
52 | "GeneratorBase",
53 | "VLMGeneratorBase",
54 | "GenerationConfig",
55 | "EncoderBase",
56 | "AnthropicGenerator",
57 | "AnthropicGeneratorConfig",
58 | "HFModelConfig",
59 | "HFGenerator",
60 | "HFGeneratorConfig",
61 | "HFEncoder",
62 | "HFEncoderConfig",
63 | "HFClipEncoder",
64 | "HFClipEncoderConfig",
65 | "HFVLMGenerator",
66 | "HFVLMGeneratorConfig",
67 | "OllamaGenerator",
68 | "OllamaGeneratorConfig",
69 | "OllamaEncoder",
70 | "OllamaEncoderConfig",
71 | "OpenAIGenerator",
72 | "OpenAIGeneratorConfig",
73 | "OpenAIConfig",
74 | "OpenAIEncoder",
75 | "OpenAIEncoderConfig",
76 | "VLLMGenerator",
77 | "VLLMGeneratorConfig",
78 | "LlamacppGenerator",
79 | "LlamacppGeneratorConfig",
80 | "JinaEncoder",
81 | "JinaEncoderConfig",
82 | "CohereEncoder",
83 | "CohereEncoderConfig",
84 | "SentenceTransformerEncoder",
85 | "SentenceTransformerEncoderConfig",
86 | "GENERATORS",
87 | "ENCODERS",
88 | ]
89 |
--------------------------------------------------------------------------------
/src/flexrag/models/cohere_model.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import httpx
7 | import numpy as np
8 | from numpy import ndarray
9 | from omegaconf import MISSING
10 |
11 | from flexrag.utils import TIME_METER, Choices
12 |
13 | from .model_base import ENCODERS, EncoderBase
14 |
15 |
16 | @dataclass
17 | class CohereEncoderConfig:
18 | """Configuration for CohereEncoder.
19 |
20 | :param model: The model to use. Default is "embed-multilingual-v3.0".
21 | :type model: str
22 | :param input_type: Specifies the type of input passed to the model. Required for embedding models v3 and higher. Default is "search_document". Available options are "search_document", "search_query", "classification", "clustering", "image".
23 | :type input_type: str
24 | :param base_url: The base URL of the API. Default is None.
25 | :type base_url: Optional[str]
26 | :param api_key: The API key to use. Default is os.environ.get("COHERE_API_KEY", MISSING).
27 | :type api_key: str
28 | :param proxy: The proxy to use. Default is None.
29 | :type proxy: Optional[str]
30 | """
31 |
32 | model: str = "embed-multilingual-v3.0"
33 | input_type: Choices( # type: ignore
34 | [
35 | "search_document",
36 | "search_query",
37 | "classification",
38 | "clustering",
39 | "image",
40 | ]
41 | ) = "search_document"
42 | base_url: Optional[str] = None
43 | api_key: str = os.environ.get("COHERE_API_KEY", MISSING)
44 | proxy: Optional[str] = None
45 |
46 |
47 | @ENCODERS("cohere", config_class=CohereEncoderConfig)
48 | class CohereEncoder(EncoderBase):
49 | def __init__(self, cfg: CohereEncoderConfig):
50 | from cohere import ClientV2
51 |
52 | if cfg.proxy is not None:
53 | httpx_client = httpx.Client(proxies=cfg.proxy)
54 | else:
55 | httpx_client = None
56 | self.client = ClientV2(
57 | api_key=cfg.api_key,
58 | base_url=cfg.base_url,
59 | httpx_client=httpx_client,
60 | )
61 | self.model = cfg.model
62 | self.input_type = cfg.input_type
63 | return
64 |
65 | @TIME_METER("cohere_encode")
66 | def _encode(self, texts: list[str]) -> ndarray:
67 | r = self.client.embed(
68 | texts=texts,
69 | model=self.model,
70 | input_type=self.input_type,
71 | embedding_types=["float"],
72 | )
73 | embeddings = r.embeddings.float
74 | return np.array(embeddings)
75 |
76 | @TIME_METER("cohere_encode")
77 | async def async_encode(self, texts: list[str]):
78 | task = asyncio.create_task(
79 | asyncio.to_thread(
80 | self.client.embed,
81 | texts=texts,
82 | model=self.model,
83 | input_type=self.input_type,
84 | embedding_types=["float"],
85 | )
86 | )
87 | embeddings = (await task).embeddings.float
88 | return np.array(embeddings)
89 |
90 | @property
91 | def embedding_size(self) -> int:
92 | return self._data_template["dimension"]
93 |
--------------------------------------------------------------------------------
/src/flexrag/models/jina_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import httpx
6 | import numpy as np
7 | from numpy import ndarray
8 | from omegaconf import MISSING
9 |
10 | from flexrag.utils import TIME_METER, Choices
11 |
12 | from .model_base import EncoderBase, ENCODERS
13 |
14 |
15 | @dataclass
16 | class JinaEncoderConfig:
17 | """Configuration for JinaEncoder.
18 |
19 | :param model: The model to use. Default is "jina-embeddings-v3".
20 | :type model: str
21 | :param base_url: The base URL of the Jina embeddings API. Default is "https://api.jina.ai/v1/embeddings".
22 | :type base_url: str
23 | :param api_key: The API key for the Jina embeddings API.
24 | :type api_key: str
25 | :param dimensions: The dimension of the embeddings. Default is 1024.
26 | :type dimensions: int
27 | :param task: The task for the embeddings. Default is None. Available options are "retrieval.query", "retrieval.passage", "separation", "classification", and "text-matching".
28 | :type task: str
29 | :param proxy: The proxy to use. Defaults to None.
30 | :type proxy: Optional[str]
31 | """
32 |
33 | model: str = "jina-embeddings-v3"
34 | base_url: str = "https://api.jina.ai/v1/embeddings"
35 | api_key: str = os.environ.get("JINA_API_KEY", MISSING)
36 | dimensions: int = 1024
37 | task: Optional[
38 | Choices( # type: ignore
39 | [
40 | "retrieval.query",
41 | "retrieval.passage",
42 | "separation",
43 | "classification",
44 | "text-matching",
45 | ]
46 | )
47 | ] = None
48 | proxy: Optional[str] = None
49 |
50 |
51 | @ENCODERS("jina", config_class=JinaEncoderConfig)
52 | class JinaEncoder(EncoderBase):
53 | def __init__(self, cfg: JinaEncoderConfig):
54 | # prepare client
55 | self.client = httpx.Client(
56 | headers={
57 | "Content-Type": "application/json",
58 | "Authorization": f"Bearer {cfg.api_key}",
59 | },
60 | proxy=cfg.proxy,
61 | base_url=cfg.base_url,
62 | follow_redirects=True,
63 | )
64 | self.async_client = httpx.AsyncClient(
65 | headers={
66 | "Content-Type": "application/json",
67 | "Authorization": f"Bearer {cfg.api_key}",
68 | },
69 | proxy=cfg.proxy,
70 | base_url=cfg.base_url,
71 | follow_redirects=True,
72 | )
73 | # prepare template
74 | self.data_template = {
75 | "model": cfg.model,
76 | "task": cfg.task,
77 | "dimensions": cfg.dimensions,
78 | "late_chunking": False,
79 | "embedding_type": "float",
80 | "input": [],
81 | }
82 | return
83 |
84 | @TIME_METER("jina_encode")
85 | def _encode(self, texts: list[str]) -> ndarray:
86 | data = self.data_template.copy()
87 | data["input"] = texts
88 | response = self.client.post("", json=data)
89 | response.raise_for_status()
90 | embeddings = [i["embedding"] for i in response.json()["data"]]
91 | return np.array(embeddings)
92 |
93 | @TIME_METER("jina_encode")
94 | async def async_encode(self, texts: list[str]) -> ndarray:
95 | data = self.data_template.copy()
96 | data["input"] = texts
97 | response = await self.async_client.post("", json=data)
98 | embeddings = [i["embedding"] for i in response.json()["data"]]
99 | return np.array(embeddings)
100 |
101 | @property
102 | def embedding_size(self) -> int:
103 | return self.data_template["dimension"]
104 |
--------------------------------------------------------------------------------
/src/flexrag/models/sentence_transformers_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass, field
3 | from typing import Any, Optional
4 |
5 | import numpy as np
6 | from omegaconf import MISSING
7 |
8 | from flexrag.utils import TIME_METER
9 |
10 | from .model_base import ENCODERS, EncoderBase
11 |
12 |
13 | @dataclass
14 | class SentenceTransformerEncoderConfig:
15 | """Configuration for SentenceTransformerEncoder.
16 |
17 | :param model_path: The path to the model. Required.
18 | :type model_path: str
19 | :param device_id: The device id to use. [] for CPU. Defaults to [].
20 | :type device_id: list[int]
21 | :param trust_remote_code: Whether to trust remote code. Defaults to False.
22 | :type trust_remote_code: bool
23 | :param task: The task to use. Defaults to None.
24 | :type task: Optional[str]
25 | :param prompt_name: The prompt name to use. Defaults to None.
26 | :type prompt_name: Optional[str]
27 | :param prompt: The prompt to use. Defaults to None.
28 | :type prompt: Optional[str]
29 | :param prompt_dict: The prompt dictionary to use. Defaults to None.
30 | :type prompt_dict: Optional[dict]
31 | :param normalize: Whether to normalize embeddings. Defaults to False.
32 | :type normalize: bool
33 | :param model_kwargs: Additional keyword arguments for loading the model. Defaults to {}.
34 | :type model_kwargs: dict[str, Any]
35 | """
36 |
37 | model_path: str = MISSING
38 | device_id: list[int] = field(default_factory=list)
39 | trust_remote_code: bool = False
40 | task: Optional[str] = None
41 | prompt_name: Optional[str] = None
42 | prompt: Optional[str] = None
43 | prompt_dict: Optional[dict] = None
44 | normalize: bool = False
45 | model_kwargs: dict[str, Any] = field(default_factory=dict)
46 |
47 |
48 | @ENCODERS("sentence_transformer", config_class=SentenceTransformerEncoderConfig)
49 | class SentenceTransformerEncoder(EncoderBase):
50 | def __init__(self, config: SentenceTransformerEncoderConfig) -> None:
51 | super().__init__()
52 | from sentence_transformers import SentenceTransformer
53 |
54 | self.devices = config.device_id
55 | self.model = SentenceTransformer(
56 | model_name_or_path=config.model_path,
57 | device=f"cuda:{config.device_id[0]}" if config.device_id else "cpu",
58 | trust_remote_code=config.trust_remote_code,
59 | backend="torch",
60 | prompts=config.prompt_dict,
61 | model_kwargs=config.model_kwargs,
62 | )
63 | if len(config.device_id) > 1:
64 | self.pool = self.model.start_multi_process_pool(
65 | target_devices=[f"cuda:{i}" for i in config.device_id]
66 | )
67 | else:
68 | self.pool = None
69 |
70 | # set args
71 | self.prompt_name = config.prompt_name
72 | self.task = config.task
73 | self.prompt = config.prompt
74 | self.normalize = config.normalize
75 | return
76 |
77 | @TIME_METER("st_encode")
78 | def _encode(self, texts: list[str], **kwargs) -> np.ndarray:
79 | args = {
80 | "sentences": texts,
81 | "batch_size": len(texts),
82 | "show_progress_bar": False,
83 | "convert_to_numpy": True,
84 | "normalize_embeddings": self.normalize,
85 | }
86 | if kwargs.get("task", self.task) is not None:
87 | args["task"] = self.task
88 | if kwargs.get("prompt_name", self.prompt_name) is not None:
89 | args["prompt_name"] = self.prompt_name
90 | if kwargs.get("prompt", self.prompt) is not None:
91 | args["prompt"] = self.prompt
92 | if (len(texts) >= len(self.devices) * 8) and (self.pool is not None):
93 | args["pool"] = self.pool
94 | args["batch_size"] = math.ceil(args["batch_size"] / len(self.devices))
95 | embeddings = self.model.encode_multi_process(**args)
96 | else:
97 | embeddings = self.model.encode(**args)
98 | return embeddings
99 |
100 | @property
101 | def embedding_size(self) -> int:
102 | return self.model.get_sentence_embedding_dimension()
103 |
--------------------------------------------------------------------------------
/src/flexrag/prompt/__init__.py:
--------------------------------------------------------------------------------
1 | from .template import load_template, ChatTemplate, HFTemplate
2 | from .prompt_base import ChatPrompt, ChatTurn, MultiModelChatPrompt, MultiModelChatTurn
3 |
4 |
5 | __all__ = [
6 | "ChatPrompt",
7 | "ChatTurn",
8 | "load_template",
9 | "ChatTemplate",
10 | "HFTemplate",
11 | "MultiModelChatPrompt",
12 | "MultiModelChatTurn",
13 | ]
14 |
--------------------------------------------------------------------------------
/src/flexrag/ranker/__init__.py:
--------------------------------------------------------------------------------
1 | from .cohere_ranker import CohereRanker, CohereRankerConfig
2 | from .gpt_ranker import RankGPTRanker, RankGPTRankerConfig
3 | from .hf_ranker import (
4 | HFColBertRanker,
5 | HFColBertRankerConfig,
6 | HFCrossEncoderRanker,
7 | HFCrossEncoderRankerConfig,
8 | HFSeq2SeqRanker,
9 | HFSeq2SeqRankerConfig,
10 | )
11 | from .jina_ranker import JinaRanker, JinaRankerConfig
12 | from .mixedbread_ranker import MixedbreadRanker, MixedbreadRankerConfig
13 | from .voyage_ranker import VoyageRanker, VoyageRankerConfig
14 |
15 | from .ranker import RankerBase, RankerBaseConfig, RANKERS, RankingResult # isort: skip
16 |
17 |
18 | RankerConfig = RANKERS.make_config(config_name="RankerConfig", default=None)
19 |
20 |
21 | __all__ = [
22 | "RankerBase",
23 | "RankerBaseConfig",
24 | "RANKERS",
25 | "RankerConfig",
26 | "RankingResult",
27 | "HFCrossEncoderRanker",
28 | "HFCrossEncoderRankerConfig",
29 | "HFSeq2SeqRanker",
30 | "HFSeq2SeqRankerConfig",
31 | "HFColBertRanker",
32 | "HFColBertRankerConfig",
33 | "CohereRanker",
34 | "CohereRankerConfig",
35 | "JinaRanker",
36 | "JinaRankerConfig",
37 | "MixedbreadRanker",
38 | "MixedbreadRankerConfig",
39 | "VoyageRanker",
40 | "VoyageRankerConfig",
41 | "RankGPTRanker",
42 | "RankGPTRankerConfig",
43 | ]
44 |
--------------------------------------------------------------------------------
/src/flexrag/ranker/cohere_ranker.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import httpx
6 | import numpy as np
7 | from omegaconf import MISSING
8 |
9 | from flexrag.utils import TIME_METER
10 |
11 | from .ranker import RankerBase, RankerBaseConfig, RANKERS
12 |
13 |
14 | @dataclass
15 | class CohereRankerConfig(RankerBaseConfig):
16 | """The configuration for the Cohere ranker.
17 |
18 | :param model: the model name of the ranker. Default is "rerank-multilingual-v3.0".
19 | :type model: str
20 | :param base_url: the base URL of the Cohere ranker. Default is None.
21 | :type base_url: Optional[str]
22 | :param api_key: the API key for the Cohere ranker. Required.
23 | :type api_key: str
24 | :param proxy: the proxy for the request. Default is None.
25 | :type proxy: Optional[str]
26 | """
27 |
28 | model: str = "rerank-multilingual-v3.0"
29 | base_url: Optional[str] = None
30 | api_key: str = MISSING
31 | proxy: Optional[str] = None
32 |
33 |
34 | @RANKERS("cohere", config_class=CohereRankerConfig)
35 | class CohereRanker(RankerBase):
36 | """CohereRanker: The ranker based on the Cohere API."""
37 |
38 | def __init__(self, cfg: CohereRankerConfig) -> None:
39 | super().__init__(cfg)
40 | from cohere import Client
41 |
42 | if cfg.proxy is not None:
43 | httpx_client = httpx.Client(proxies=cfg.proxy)
44 | else:
45 | httpx_client = None
46 | self.client = Client(
47 | api_key=cfg.api_key, base_url=cfg.base_url, httpx_client=httpx_client
48 | )
49 | self.model = cfg.model
50 | return
51 |
52 | @TIME_METER("cohere_rank")
53 | def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]:
54 | result = self.client.rerank(
55 | query=query,
56 | documents=candidates,
57 | model=self.model,
58 | top_n=len(candidates),
59 | )
60 | scores = [i.relevance_score for i in result.results]
61 | return None, scores
62 |
63 | @TIME_METER("cohere_rank")
64 | async def _async_rank(self, query: str, candidates: list[str]):
65 | result = await asyncio.create_task(
66 | asyncio.to_thread(
67 | self.client.rerank,
68 | query=query,
69 | documents=candidates,
70 | model=self.model,
71 | top_n=len(candidates),
72 | )
73 | )
74 | scores = [i.relevance_score for i in result.results]
75 | return None, scores
76 |
--------------------------------------------------------------------------------
/src/flexrag/ranker/jina_ranker.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import numpy as np
6 | import httpx
7 | from omegaconf import MISSING
8 |
9 | from flexrag.utils import TIME_METER
10 |
11 | from .ranker import RankerBase, RankerBaseConfig, RANKERS
12 |
13 |
14 | @dataclass
15 | class JinaRankerConfig(RankerBaseConfig):
16 | """The configuration for the Jina ranker.
17 |
18 | :param model: the model name of the ranker. Default is "jina-reranker-v2-base-multilingual".
19 | :type model: str
20 | :param base_url: the base URL of the Jina ranker. Default is "https://api.jina.ai/v1/rerank".
21 | :type base_url: str
22 | :param api_key: the API key for the Jina ranker. Required.
23 | :type api_key: str
24 | :param proxy: The proxy to use. Defaults to None.
25 | :type proxy: Optional[str]
26 | """
27 |
28 | model: str = "jina-reranker-v2-base-multilingual"
29 | base_url: str = "https://api.jina.ai/v1/rerank"
30 | api_key: str = os.environ.get("JINA_API_KEY", MISSING)
31 | proxy: Optional[str] = None
32 |
33 |
34 | @RANKERS("jina", config_class=JinaRankerConfig)
35 | class JinaRanker(RankerBase):
36 | """JinaRanker: The ranker based on the Jina API."""
37 |
38 | def __init__(self, cfg: JinaRankerConfig) -> None:
39 | super().__init__(cfg)
40 | # prepare client
41 | self.client = httpx.Client(
42 | base_url=cfg.base_url,
43 | headers={
44 | "Content-Type": "application/json",
45 | "Authorization": f"Bearer {cfg.api_key}",
46 | },
47 | proxy=cfg.proxy,
48 | follow_redirects=True,
49 | )
50 | self.async_client = httpx.AsyncClient(
51 | base_url=cfg.base_url,
52 | headers={
53 | "Content-Type": "application/json",
54 | "Authorization": f"Bearer {cfg.api_key}",
55 | },
56 | proxy=cfg.proxy,
57 | follow_redirects=True,
58 | )
59 |
60 | # prepare data template
61 | self.data_template = {
62 | "model": cfg.model,
63 | "query": "",
64 | "top_n": 0,
65 | "documents": [],
66 | }
67 | return
68 |
69 | @TIME_METER("jina_rank")
70 | def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]:
71 | data = self.data_template.copy()
72 | data["query"] = query
73 | data["documents"] = candidates
74 | data["top_n"] = len(candidates)
75 | response = self.client.post("", json=data)
76 | response.raise_for_status()
77 | scores = [i["relevance_score"] for i in response.json()["results"]]
78 | return None, scores
79 |
80 | @TIME_METER("jina_rank")
81 | async def _async_rank(
82 | self, query: str, candidates: list[str]
83 | ) -> tuple[np.ndarray, np.ndarray]:
84 | data = self.data_template.copy()
85 | data["query"] = query
86 | data["documents"] = candidates
87 | data["top_n"] = len(candidates)
88 | response = await self.async_client.post("", json=data)
89 | response.raise_for_status()
90 | scores = [i["relevance_score"] for i in response.json()["results"]]
91 | return None, scores
92 |
--------------------------------------------------------------------------------
/src/flexrag/ranker/mixedbread_ranker.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import httpx
6 | import numpy as np
7 | from omegaconf import MISSING
8 |
9 | from flexrag.utils import TIME_METER
10 |
11 | from .ranker import RankerBase, RankerBaseConfig, RANKERS
12 |
13 |
14 | @dataclass
15 | class MixedbreadRankerConfig(RankerBaseConfig):
16 | """The configuration for the Mixedbread ranker.
17 |
18 | :param model: the model name of the ranker. Default is "mxbai-rerank-large-v1".
19 | :type model: str
20 | :param api_key: the API key for the Mixedbread ranker. Required.
21 | :type api_key: str
22 | :param base_url: the base URL of the Mixedbread ranker. Default is None.
23 | :type base_url: Optional[str]
24 | :param proxy: the proxy for the request. Default is None.
25 | :type proxy: Optional[str]
26 | """
27 |
28 | model: str = "mxbai-rerank-large-v1"
29 | base_url: Optional[str] = None
30 | api_key: str = MISSING
31 | proxy: Optional[str] = None
32 |
33 |
34 | @RANKERS("mixedbread", config_class=MixedbreadRankerConfig)
35 | class MixedbreadRanker(RankerBase):
36 | """MixedbreadRanker: The ranker based on the Mixedbread API."""
37 |
38 | def __init__(self, cfg: MixedbreadRankerConfig) -> None:
39 | super().__init__(cfg)
40 | from mixedbread_ai.client import MixedbreadAI
41 |
42 | if cfg.proxy is not None:
43 | httpx_client = httpx.Client(proxies=cfg.proxy)
44 | else:
45 | httpx_client = None
46 | self.client = MixedbreadAI(
47 | api_key=cfg.api_key, base_url=cfg.base_url, httpx_client=httpx_client
48 | )
49 | self.model = cfg.model
50 | return
51 |
52 | @TIME_METER("mixedbread_rank")
53 | def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]:
54 | result = self.client.reranking(
55 | query=query,
56 | input=candidates,
57 | model=self.model,
58 | top_k=len(candidates),
59 | )
60 | scores = [i.score for i in result.data]
61 | return None, scores
62 |
63 | @TIME_METER("mixedbread_rank")
64 | async def _async_rank(
65 | self, query: str, candidates: list[str]
66 | ) -> tuple[np.ndarray, np.ndarray]:
67 | result = await asyncio.create_task(
68 | asyncio.to_thread(
69 | self.client.reranking,
70 | query=query,
71 | input=candidates,
72 | model=self.model,
73 | top_k=len(candidates),
74 | )
75 | )
76 | scores = [i.score for i in result.data]
77 | return None, scores
78 |
--------------------------------------------------------------------------------
/src/flexrag/ranker/ranker_prompts/rankgpt_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": {
3 | "role": "system",
4 | "content": "You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query."
5 | },
6 | "history": [
7 | {
8 | "role": "user",
9 | "content": "I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}."
10 | },
11 | {
12 | "role": "assistant",
13 | "content": "Okay, please provide the passages."
14 | },
15 | {
16 | "role": "user",
17 | "content": "Search Query: {query}. \nRank the {num} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only response the ranking results, do not say any word or explain."
18 | }
19 | ],
20 | "demonstrations": []
21 | }
--------------------------------------------------------------------------------
/src/flexrag/ranker/voyage_ranker.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from dataclasses import dataclass
3 |
4 | import numpy as np
5 | from omegaconf import MISSING
6 |
7 | from flexrag.utils import TIME_METER
8 |
9 | from .ranker import RankerBase, RankerBaseConfig, RANKERS
10 |
11 |
12 | @dataclass
13 | class VoyageRankerConfig(RankerBaseConfig):
14 | """The configuration for the Voyage ranker.
15 |
16 | :param model: the model name of the ranker. Default is "rerank-2".
17 | :type model: str
18 | :param api_key: the API key for the Voyage ranker. Required.
19 | :type api_key: str
20 | :param timeout: the timeout for the request. Default is 3.0.
21 | :type timeout: float
22 | :param max_retries: the maximum number of retries. Default is 3.
23 | :type max_retries: int
24 | """
25 |
26 | model: str = "rerank-2"
27 | api_key: str = MISSING
28 | timeout: float = 3.0
29 | max_retries: int = 3
30 |
31 |
32 | @RANKERS("voyage", config_class=VoyageRankerConfig)
33 | class VoyageRanker(RankerBase):
34 | """VoyageRanker: The ranker based on the Voyage API."""
35 |
36 | def __init__(self, cfg: VoyageRankerConfig) -> None:
37 | super().__init__(cfg)
38 | from voyageai import Client
39 |
40 | self.client = Client(
41 | api_key=cfg.api_key, max_retries=cfg.max_retries, timeout=cfg.timeout
42 | )
43 | self.model = cfg.model
44 | return
45 |
46 | @TIME_METER("voyage_rank")
47 | def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]:
48 | result = self.client.rerank(
49 | query=query,
50 | documents=candidates,
51 | model=self.model,
52 | top_k=len(candidates),
53 | )
54 | scores = [i.relevance_score for i in result.results]
55 | return None, scores
56 |
57 | @TIME_METER("voyage_rank")
58 | async def _async_rank(
59 | self, query: str, candidates: list[str]
60 | ) -> tuple[np.ndarray, np.ndarray]:
61 | result = await asyncio.create_task(
62 | asyncio.to_thread(
63 | self.client.rerank,
64 | query=query,
65 | documents=candidates,
66 | model=self.model,
67 | top_k=len(candidates),
68 | )
69 | )
70 | scores = [i.relevance_score for i in result.results]
71 | return None, scores
72 |
--------------------------------------------------------------------------------
/src/flexrag/retriever/__init__.py:
--------------------------------------------------------------------------------
1 | from .bm25s_retriever import BM25SRetriever, BM25SRetrieverConfig
2 | from .dense_retriever import DenseRetriever, DenseRetrieverConfig
3 | from .elastic_retriever import ElasticRetriever, ElasticRetrieverConfig
4 | from .retriever_base import (
5 | EditableRetriever,
6 | EditableRetrieverConfig,
7 | RetrieverBase,
8 | RetrieverBaseConfig,
9 | LocalRetriever,
10 | LocalRetrieverConfig,
11 | RETRIEVERS,
12 | )
13 | from .typesense_retriever import TypesenseRetriever, TypesenseRetrieverConfig
14 | from .web_retrievers import (
15 | SimpleWebRetriever,
16 | SimpleWebRetrieverConfig,
17 | WikipediaRetriever,
18 | WikipediaRetrieverConfig,
19 | )
20 | from .hyde_retriever import HydeRetriever, HydeRetrieverConfig
21 |
22 |
23 | RetrieverConfig = RETRIEVERS.make_config(config_name="RetrieverConfig", default=None)
24 |
25 |
26 | __all__ = [
27 | "BM25SRetriever",
28 | "BM25SRetrieverConfig",
29 | "EditableRetriever",
30 | "EditableRetrieverConfig",
31 | "LocalRetriever",
32 | "LocalRetrieverConfig",
33 | "RetrieverBase",
34 | "RetrieverBaseConfig",
35 | "DenseRetriever",
36 | "DenseRetrieverConfig",
37 | "ElasticRetriever",
38 | "ElasticRetrieverConfig",
39 | "TypesenseRetriever",
40 | "TypesenseRetrieverConfig",
41 | "SimpleWebRetriever",
42 | "SimpleWebRetrieverConfig",
43 | "RETRIEVERS",
44 | "RetrieverConfig",
45 | "WikipediaRetriever",
46 | "WikipediaRetrieverConfig",
47 | "HydeRetriever",
48 | "HydeRetrieverConfig",
49 | ]
50 |
--------------------------------------------------------------------------------
/src/flexrag/retriever/hyde_retriever.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from flexrag.common_dataclass import RetrievedContext
4 | from flexrag.models import GENERATORS, GeneratorBase, GeneratorConfig
5 | from flexrag.utils import TIME_METER, Choices
6 |
7 | from .dense_retriever import DenseRetriever, DenseRetrieverConfig
8 | from .retriever_base import RETRIEVERS
9 |
10 |
11 | class HydeRewriter:
12 | Prompts = {
13 | "WEB_SEARCH": "Please write a passage to answer the question.\nQuestion: {}\nPassage:",
14 | "SCIFACT": "Please write a scientific paper passage to support/refute the claim.\nClaim: {}\nPassage:",
15 | "ARGUANA": "Please write a counter argument for the passage.\nPassage: {}\nCounter Argument:",
16 | "TREC_COVID": "Please write a scientific paper passage to answer the question.\nQuestion: {}\nPassage:",
17 | "FIQA": "Please write a financial article passage to answer the question.\nQuestion: {}\nPassage:",
18 | "DBPEDIA_ENTITY": "Please write a passage to answer the question.\nQuestion: {}\nPassage:",
19 | "TREC_NEWS": "Please write a news passage about the topic.\nTopic: {}\nPassage:",
20 | "MR_TYDI": "Please write a passage in {} to answer the question in detail.\nQuestion: {}\nPassage:",
21 | }
22 |
23 | def __init__(self, generator: GeneratorBase, task: str, language: str = "en"):
24 | self.task = task
25 | self.language = language
26 | self.generator = generator
27 | return
28 |
29 | def rewrite(self, queries: list[str] | str) -> list[str]:
30 | if isinstance(queries, str):
31 | queries = [queries]
32 | prompts = [self.Prompts[self.task].format(q) for q in queries]
33 | new_queries = [q[0] for q in self.generator.generate(prompts)]
34 | return new_queries
35 |
36 |
37 | @dataclass
38 | class HydeRetrieverConfig(DenseRetrieverConfig, GeneratorConfig):
39 | """Configuration class for HydeRetriever.
40 |
41 | :param task: Task for rewriting the query. Default: "WEB_SEARCH".
42 | Available options: "WEB_SEARCH", "SCIFACT", "ARGUANA", "TREC_COVID", "FIQA", "DBPEDIA_ENTITY", "TREC_NEWS", "MR_TYDI".
43 | :type task: str
44 | :param language: Language for rewriting. Default: "en".
45 | :type language: str
46 | """
47 |
48 | task: Choices(HydeRewriter.Prompts.keys()) = "WEB_SEARCH" # type: ignore
49 | language: str = "en"
50 |
51 |
52 | @RETRIEVERS("hyde", config_class=HydeRetrieverConfig)
53 | class HydeRetriever(DenseRetriever):
54 | """HydeRetriever is a retriever that rewrites the query before searching.
55 |
56 | The original paper is available at https://aclanthology.org/2023.acl-long.99/.
57 | """
58 |
59 | def __init__(self, cfg: HydeRetrieverConfig, no_check=False):
60 | super().__init__(cfg, no_check)
61 | generator = GENERATORS.load(cfg)
62 | self.rewriter = HydeRewriter(
63 | generator=generator, task=cfg.task, language=cfg.language
64 | )
65 | return
66 |
67 | @TIME_METER("hyde_retriever", "search")
68 | def search_batch(
69 | self,
70 | query: list[str],
71 | **search_kwargs,
72 | ) -> list[list[RetrievedContext]]:
73 | new_query = self.rewriter.rewrite(query)
74 | return super().search_batch(new_query, **search_kwargs)
75 |
--------------------------------------------------------------------------------
/src/flexrag/retriever/index/__init__.py:
--------------------------------------------------------------------------------
1 | from .annoy_index import AnnoyIndex, AnnoyIndexConfig
2 | from .faiss_index import FaissIndex, FaissIndexConfig
3 | from .index_base import DenseIndexBase, DenseIndexBaseConfig, DENSE_INDEX
4 | from .scann_index import ScaNNIndex, ScaNNIndexConfig
5 |
6 |
7 | DenseIndexConfig = DENSE_INDEX.make_config(
8 | default="faiss", config_name="DenseIndexConfig"
9 | )
10 |
11 |
12 | __all__ = [
13 | "AnnoyIndex",
14 | "AnnoyIndexConfig",
15 | "FaissIndex",
16 | "FaissIndexConfig",
17 | "ScaNNIndex",
18 | "ScaNNIndexConfig",
19 | "DenseIndexBase",
20 | "DenseIndexBaseConfig",
21 | "DENSE_INDEX",
22 | "DenseIndexConfig",
23 | ]
24 |
--------------------------------------------------------------------------------
/src/flexrag/retriever/web_retrievers/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import WebResource
2 | from .web_downloader import (
3 | WEB_DOWNLOADERS,
4 | PlaywrightWebDownloader,
5 | PlaywrightWebDownloaderConfig,
6 | SimpleWebDownloader,
7 | SimpleWebDownloaderConfig,
8 | WebDownloaderBase,
9 | WebDownloaderBaseConfig,
10 | WebDownloaderConfig,
11 | )
12 | from .web_reader import (
13 | WEB_READERS,
14 | JinaReader,
15 | JinaReaderConfig,
16 | JinaReaderLM,
17 | JinaReaderLMConfig,
18 | ScreenshotWebReader,
19 | ScreenshotWebReaderConfig,
20 | SnippetWebReader,
21 | WebReaderBase,
22 | WebReaderConfig,
23 | )
24 | from .web_retriever import (
25 | SimpleWebRetriever,
26 | SimpleWebRetrieverConfig,
27 | WebRetrieverBase,
28 | WebRetrieverBaseConfig,
29 | )
30 | from .web_seeker import (
31 | SEARCH_ENGINES,
32 | WEB_SEEKERS,
33 | BingEngine,
34 | BingEngineConfig,
35 | DuckDuckGoEngine,
36 | DuckDuckGoEngineConfig,
37 | GoogleEngine,
38 | GoogleEngineConfig,
39 | SerpApi,
40 | SerpApiConfig,
41 | WebSeekerBase,
42 | SearchEngineConfig,
43 | WebSeekerConfig,
44 | )
45 | from .wikipedia_retriever import WikipediaRetriever, WikipediaRetrieverConfig
46 |
47 | __all__ = [
48 | "WebResource",
49 | "WEB_DOWNLOADERS",
50 | "PlaywrightWebDownloader",
51 | "PlaywrightWebDownloaderConfig",
52 | "WebDownloaderBase",
53 | "WebDownloaderBaseConfig",
54 | "SimpleWebDownloader",
55 | "SimpleWebDownloaderConfig",
56 | "WebDownloaderConfig",
57 | "WEB_READERS",
58 | "JinaReader",
59 | "JinaReaderConfig",
60 | "JinaReaderLM",
61 | "JinaReaderLMConfig",
62 | "ScreenshotWebReader",
63 | "ScreenshotWebReaderConfig",
64 | "SnippetWebReader",
65 | "WebReaderBase",
66 | "WebReaderConfig",
67 | "SimpleWebRetriever",
68 | "SimpleWebRetrieverConfig",
69 | "WebRetrieverBase",
70 | "WebRetrieverBaseConfig",
71 | "SEARCH_ENGINES",
72 | "WEB_SEEKERS",
73 | "BingEngine",
74 | "BingEngineConfig",
75 | "DuckDuckGoEngine",
76 | "DuckDuckGoEngineConfig",
77 | "GoogleEngine",
78 | "GoogleEngineConfig",
79 | "SerpApi",
80 | "SerpApiConfig",
81 | "WebSeekerBase",
82 | "SearchEngineConfig",
83 | "WebSeekerConfig",
84 | "WikipediaRetriever",
85 | "WikipediaRetrieverConfig",
86 | ]
87 |
--------------------------------------------------------------------------------
/src/flexrag/retriever/web_retrievers/utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Any, Optional
3 |
4 |
5 | @dataclass
6 | class WebResource:
7 | """The web resource dataclass.
8 | ``WebResource`` is the fundamental component for information transmission in the ``web_retrievers`` module of FlexRAG.
9 | The ``WebSeeker`` retrieves the corresponding ``WebResource`` based on the user's query,
10 | while the ``WebDownloader`` downloads the resource based on the URL in the ``WebResource`` and stores it in the ``data`` field of the ``WebResource``.
11 | The ``WebReader`` then converts the ``data`` field of the ``WebResource`` into a LLM friendly format and returns the ``RetrievedContext``.
12 |
13 | :param url: The URL of the resource.
14 | :type url: str
15 | :param query: The query for the resource. Default is None.
16 | :type query: Optional[str]
17 | :param metadata: The metadata of the resource, offen provided by the WebSeeker. Default is {}.
18 | :type metadata: dict
19 | :param data: The content of the resource, offen filled by the WebDownloader. Default is None.
20 | :type data: Any
21 | """
22 |
23 | url: str
24 | query: Optional[str] = None
25 | metadata: dict = field(default_factory=dict)
26 | data: Any = None
27 |
--------------------------------------------------------------------------------
/src/flexrag/retriever/web_retrievers/wikipedia_retriever.py:
--------------------------------------------------------------------------------
1 | import time
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import httpx
6 | from bs4 import BeautifulSoup
7 |
8 | from flexrag.common_dataclass import RetrievedContext
9 | from flexrag.utils import LOGGER_MANAGER, SimpleProgressLogger
10 |
11 | from ..retriever_base import RETRIEVERS, RetrieverBase, RetrieverBaseConfig
12 |
13 | logger = LOGGER_MANAGER.get_logger("flexrag.retrievers.web_retriever")
14 |
15 |
16 | @dataclass
17 | class WikipediaRetrieverConfig(RetrieverBaseConfig):
18 | """The configuration for the ``WikipediaRetriever``.
19 |
20 | :param search_url: The search URL for Wikipedia.
21 | Default is "https://en.wikipedia.org/w/index.php?search=".
22 | :type search_url: str
23 | :param proxy: The proxy to use. Default is None.
24 | :type proxy: Optional[str]
25 | """
26 |
27 | search_url: str = "https://en.wikipedia.org/w/index.php?search="
28 | proxy: Optional[str] = None
29 |
30 |
31 | @RETRIEVERS("wikipedia", config_class=WikipediaRetrieverConfig)
32 | class WikipediaRetriever(RetrieverBase):
33 | """WikipediaRetriever retrieves information from Wikipedia directly.
34 | Adapted from https://github.com/ysymyth/ReAct"""
35 |
36 | name = "wikipedia"
37 |
38 | def __init__(self, cfg: WikipediaRetrieverConfig):
39 | super().__init__(cfg)
40 | # set basic configs
41 | self.search_url = cfg.search_url
42 | self.client = httpx.Client(proxy=cfg.proxy)
43 | return
44 |
45 | def search(
46 | self,
47 | query: list[str] | str,
48 | delay: float = 0.1,
49 | **search_kwargs,
50 | ) -> list[list[RetrievedContext]]:
51 | if isinstance(query, str):
52 | query = [query]
53 |
54 | # search & parse
55 | results = []
56 | p_logger = SimpleProgressLogger(logger, len(query), self.log_interval)
57 | for q in query:
58 | time.sleep(delay)
59 | p_logger.update(1, "Searching")
60 | results.append([self.search_item(q, **search_kwargs)])
61 | return results
62 |
63 | def search_item(self, query: str, **kwargs) -> RetrievedContext:
64 | search_url = self.search_url + query.replace(" ", "+")
65 | response_text = self.client.get(search_url).text
66 |
67 | soup = BeautifulSoup(response_text, features="html.parser")
68 | result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
69 | if result_divs: # mismatch
70 | similar_entities = [
71 | self._clear_str(div.get_text().strip()) for div in result_divs
72 | ]
73 | page_content = None
74 | summary = None
75 | else:
76 | similar_entities = []
77 | page = [
78 | p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")
79 | ]
80 | if any("may refer to:" in p for p in page):
81 | return self.search_item("[" + query + "]")
82 | else: # concatenate all paragraphs
83 | page_content = ""
84 | for p in page:
85 | if len(p.split(" ")) > 2:
86 | page_content += self._clear_str(p)
87 | if not p.endswith("\n"):
88 | page_content += "\n"
89 | summary = self._get_summary(page_content)
90 | return RetrievedContext(
91 | retriever=self.name,
92 | query=query,
93 | data={
94 | "raw_page": response_text,
95 | "page_content": page_content,
96 | "summary": summary,
97 | "similar_entities": similar_entities,
98 | },
99 | source=search_url,
100 | )
101 |
102 | def _clear_str(self, text: str) -> str:
103 | return text.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
104 |
105 | def _get_summary(self, page: str) -> str:
106 | # find all paragraphs
107 | paragraphs = page.split("\n")
108 | paragraphs = [p.strip() for p in paragraphs if p.strip()]
109 |
110 | # find all sentence
111 | sentences = []
112 | for p in paragraphs:
113 | sentences += p.split(". ")
114 | sentences = [s.strip() + "." for s in sentences if s.strip()]
115 | summary = " ".join(sentences[:5])
116 | return summary.replace("\\n", "")
117 |
118 | @property
119 | def fields(self):
120 | return ["raw_page", "page_content", "summary", "similar_entities"]
121 |
--------------------------------------------------------------------------------
/src/flexrag/text_process/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic_filters import ExactDeduplicate, LengthFilter, LengthFilterConfig
2 | from .basic_processors import (
3 | AnswerSimplifier,
4 | ChineseSimplifier,
5 | Lowercase,
6 | TokenNormalizer,
7 | TokenNormalizerConfig,
8 | Truncator,
9 | TruncatorConfig,
10 | Unifier,
11 | )
12 | from .pipeline import TextProcessPipeline, TextProcessPipelineConfig
13 | from .processor import PROCESSORS, Processor, TextUnit
14 |
15 | __all__ = [
16 | "TextProcessPipeline",
17 | "TextProcessPipelineConfig",
18 | "PROCESSORS",
19 | "Processor",
20 | "TextUnit",
21 | "TokenNormalizerConfig",
22 | "TokenNormalizer",
23 | "ChineseSimplifier",
24 | "Lowercase",
25 | "Unifier",
26 | "TruncatorConfig",
27 | "Truncator",
28 | "AnswerSimplifier",
29 | "ExactDeduplicate",
30 | "LengthFilter",
31 | "LengthFilterConfig",
32 | ]
33 |
--------------------------------------------------------------------------------
/src/flexrag/text_process/basic_filters.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 | from .processor import PROCESSORS, Processor, TextUnit
5 | from .utils import UnifiedTokenizer, UTokenizerConfig
6 |
7 |
8 | @PROCESSORS("exact_deduplicate")
9 | class ExactDeduplicate(Processor):
10 | def __init__(self) -> None:
11 | self.seen = set()
12 | return
13 |
14 | def process(self, input_text: TextUnit) -> TextUnit:
15 | if input_text.content in self.seen:
16 | input_text.reserved = False
17 | self.seen.add(input_text.content)
18 | return input_text
19 |
20 |
21 | @dataclass
22 | class LengthFilterConfig:
23 | max_tokens: Optional[int] = None
24 | min_tokens: Optional[int] = None
25 | max_chars: Optional[int] = None
26 | min_chars: Optional[int] = None
27 | max_bytes: Optional[int] = None
28 | min_bytes: Optional[int] = None
29 | tokenizer_config: UTokenizerConfig = field(default_factory=UTokenizerConfig)
30 |
31 |
32 | @PROCESSORS("length_filter", config_class=LengthFilterConfig)
33 | class LengthFilter(Processor):
34 | def __init__(self, cfg: LengthFilterConfig) -> None:
35 | super().__init__()
36 | self.max_tokens = cfg.max_tokens
37 | self.min_tokens = cfg.min_tokens
38 | self.max_chars = cfg.max_chars
39 | self.min_chars = cfg.min_chars
40 | self.max_bytes = cfg.max_bytes
41 | self.min_bytes = cfg.min_bytes
42 | if self.max_tokens is not None or self.min_tokens is not None:
43 | self.tokenizer = UnifiedTokenizer(cfg.tokenizer_config)
44 | else:
45 | self.tokenizer = None
46 | return
47 |
48 | def process(self, input_text: TextUnit) -> TextUnit:
49 | if self.tokenizer is not None:
50 | tokens = self.tokenizer.tokenize(input_text.content)
51 | if self.max_tokens is not None and len(tokens) > self.max_tokens:
52 | input_text.reserved = False
53 | if self.min_tokens is not None and len(tokens) < self.min_tokens:
54 | input_text.reserved = False
55 | if self.max_chars is not None and len(input_text.content) > self.max_chars:
56 | input_text.reserved = False
57 | if self.min_chars is not None and len(input_text.content) < self.min_chars:
58 | input_text.reserved = False
59 | if (
60 | self.max_bytes is not None
61 | and len(input_text.content.encode("utf-8")) > self.max_bytes
62 | ):
63 | input_text.reserved = False
64 | if (
65 | self.min_bytes is not None
66 | and len(input_text.content.encode("utf-8")) < self.min_bytes
67 | ):
68 | input_text.reserved = False
69 | return input_text
70 |
--------------------------------------------------------------------------------
/src/flexrag/text_process/pipeline.py:
--------------------------------------------------------------------------------
1 | from flexrag.utils import TIME_METER
2 |
3 | from .processor import PROCESSORS, Processor, TextUnit
4 |
5 |
6 | TextProcessPipelineConfig = PROCESSORS.make_config(
7 | allow_multiple=True, config_name="TextProcessPipelineConfig"
8 | )
9 |
10 |
11 | class TextProcessPipeline:
12 | def __init__(self, cfg: TextProcessPipelineConfig) -> None: # type: ignore
13 | # load processors
14 | self.processors: list[Processor] = PROCESSORS.load(cfg)
15 | return
16 |
17 | @TIME_METER("text_process_pipeline")
18 | def __call__(self, text: str, return_detail: bool = False) -> str | TextUnit | None:
19 | unit = TextUnit(content=text)
20 | for processor in self.processors:
21 | unit = processor(unit)
22 | if not unit.reserved:
23 | break
24 | if return_detail:
25 | return unit
26 | return unit.content if unit.reserved else None
27 |
28 | def __contains__(self, processor: Processor | str) -> bool:
29 | if isinstance(processor, str):
30 | return any(
31 | isinstance(p, PROCESSORS[processor]["item"]) for p in self.processors
32 | )
33 | return processor in self.processors
34 |
35 | def __getitem__(self, processor: str | int) -> Processor:
36 | if isinstance(processor, int):
37 | return self.processors[processor]
38 | assert isinstance(processor, str), "str or int is required"
39 | for p in self.processors:
40 | if isinstance(p, PROCESSORS[processor]["item"]):
41 | return p
42 | raise KeyError(f"Processor {processor} not found in the pipeline")
43 |
44 | def __repr__(self) -> str:
45 | return f"Pipeline({[p.name for p in self.processors]})"
46 |
--------------------------------------------------------------------------------
/src/flexrag/text_process/processor.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass, field
3 |
4 | from flexrag.utils import Register
5 |
6 |
7 | @dataclass
8 | class TextUnit:
9 | content: str
10 | reserved: bool = True
11 | processed_by: list[str] = field(default_factory=list)
12 |
13 |
14 | class Processor(ABC):
15 | def __call__(self, input_text: TextUnit) -> TextUnit:
16 | """Process the input text.
17 | If the processor has been filtered, the reserved flag of the input TextUnit will be set to False.
18 |
19 | :param input_text: The input text to process.
20 | :type input_text: TextUnit
21 | :return: The processed text.
22 | :rtype: TextUnit
23 | """
24 | input_text.processed_by.append(self.name)
25 | return self.process(input_text)
26 |
27 | @abstractmethod
28 | def process(self, input_text: TextUnit) -> TextUnit:
29 | return
30 |
31 | @property
32 | def name(self):
33 | return self.__class__.__name__
34 |
35 |
36 | PROCESSORS = Register[Processor]("processor")
37 |
--------------------------------------------------------------------------------
/src/flexrag/text_process/utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | from flexrag.utils import Choices
5 |
6 |
7 | @dataclass
8 | class UTokenizerConfig:
9 | tokenizer_type: Choices(["hf", "tiktoken", "moses"]) = "moses" # type: ignore
10 | hf_tokenizer_path: Optional[str] = None
11 | tiktok_tokenizer_name: Optional[str] = None
12 | lang: Optional[str] = None
13 |
14 |
15 | class UnifiedTokenizer:
16 | def __init__(self, cfg: UTokenizerConfig) -> None:
17 | self.tokenizer_type = cfg.tokenizer_type
18 | match self.tokenizer_type:
19 | case "hf":
20 | from transformers import AutoTokenizer
21 |
22 | self.tokenizer = AutoTokenizer.from_pretrained(cfg.hf_tokenizer_path)
23 | case "tiktoken":
24 | import tiktoken
25 |
26 | self.tokenizer = tiktoken.get_encoding(cfg.tiktok_tokenizer_name)
27 | case "moses":
28 | from sacremoses import MosesDetokenizer, MosesTokenizer
29 |
30 | self.tokenizer = MosesTokenizer(cfg.lang)
31 | self.detokenizer = MosesDetokenizer(cfg.lang)
32 | case _:
33 | raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
34 | return
35 |
36 | def tokenize(self, texts: str) -> list[str | int]:
37 | match self.tokenizer_type:
38 | case "hf":
39 | tokens = self.tokenizer.encode(texts)
40 | case "tiktoken":
41 | tokens = self.tokenizer.encode(texts)
42 | case "moses":
43 | tokens = self.tokenizer.tokenize(texts)
44 | return tokens
45 |
46 | def detokenize(self, tokens: list[str | int]) -> str:
47 | match self.tokenizer_type:
48 | case "hf":
49 | texts = self.tokenizer.decode(tokens)
50 | case "tiktoken":
51 | texts = self.tokenizer.decode(tokens)
52 | case "moses":
53 | texts = self.detokenizer.detokenize(tokens)
54 | return texts
55 |
--------------------------------------------------------------------------------
/tested_models.md:
--------------------------------------------------------------------------------
1 | # Tested Huggingface Models
2 |
3 | ## Tested HFEncoders
4 | - jinaai/jina-embeddings-v3
5 | - BAAI/bge-m3
6 | - facebook/contriever
7 | - facebook/contriever-msmarco
8 | - facebook/dragon-plus-query-encoder
9 | - facebook/dragon-plus-context-encoder
10 | - nomic-ai/nomic-embed-text-v1.5
11 | - sentence-transformers/msmarco-MiniLM-L-12-v3
12 | - intfloat/e5-base-v2
13 | - Alibaba-NLP/gte-multilingual-base
14 | - Alibaba-NLP/gte-modernbert-base
15 |
16 | ## Tested HFClipEncoders
17 | - openai/clip-vit-base-patch32
18 | - jinaai/jina-clip-v2
19 |
20 | ## Tested ReRankers
21 | - unicamp-dl/InRanker-base
22 | - colbert-ir/colbertv2.0
23 | - jinaai/Jina-colbert-v2
24 | - jinaai/jina-reranker-v2-base-multilingual
25 | - BAAI/bge-reranker-v2-m3
26 | - intfloat/e5-base-v2
27 |
28 | ## Tested HFGenerators
29 | - Llama-3.2-1B-Instruct
30 |
31 | ## Tested VLMs
32 | - Qwen/Qwen2-VL-7B-Instruct
33 | - meta-llama/Llama-3.2-11B-Vision-Instruct
34 | - google/paligemma2-10b-ft-docci-448
--------------------------------------------------------------------------------
/tests/test_assistant.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 |
4 | import pytest
5 | from omegaconf import OmegaConf
6 |
7 | from flexrag.assistant import BasicAssistant, BasicAssistantConfig
8 |
9 |
10 | @dataclass
11 | class AssistantTestConfig:
12 | assistant_config: BasicAssistantConfig = field(default_factory=BasicAssistantConfig)
13 |
14 |
15 | class TestAssistant:
16 | cfg: AssistantTestConfig = OmegaConf.merge(
17 | OmegaConf.structured(AssistantTestConfig),
18 | OmegaConf.load(
19 | os.path.join(os.path.dirname(__file__), "configs", "assistant.yaml")
20 | ),
21 | )
22 | query = "Who is Bruce Wayne?"
23 | # contexts = ["Bruce Wayne is Batman.", "Batman is a superhero."]
24 |
25 | @pytest.mark.asyncio
26 | async def test_answer(self):
27 | assistant = BasicAssistant(self.cfg.assistant_config)
28 | r1, _, _ = assistant.answer(self.query)
29 | return
30 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from flexrag.common_dataclass import RAGEvalData
4 | from flexrag.datasets import RAGEvalDataset, RAGEvalDatasetConfig
5 | from flexrag.utils import LOGGER_MANAGER
6 |
7 |
8 | logger = LOGGER_MANAGER.get_logger("tests.datasets")
9 |
10 |
11 | class TestRAGEvalDataset:
12 | datasets = {
13 | "2wikimultihopqa": ["dev", "train"],
14 | "ambig_qa": ["dev", "train"],
15 | "arc": ["dev", "test", "train"],
16 | "asqa": ["dev", "train"],
17 | "ay2": ["dev", "train"],
18 | "bamboogle": ["test"],
19 | "boolq": ["dev", "train"],
20 | "commonsenseqa": ["dev", "train"],
21 | "curatedtrec": ["test", "train"],
22 | # "domainrag": ["test"], # Error in loading due to dataset schema
23 | "eli5": ["dev", "train"],
24 | "fermi": ["dev", "test", "train"],
25 | "fever": ["dev", "train"],
26 | "hellaswag": ["dev", "train"],
27 | "hotpotqa": ["dev", "train"],
28 | "mmlu": ["5_shot", "dev", "test", "train"],
29 | "msmarco-qa": ["dev", "train"],
30 | "musique": ["dev", "train"],
31 | "narrativeqa": ["dev", "test", "train"],
32 | "nq": ["dev", "test", "train"],
33 | "openbookqa": ["dev", "test", "train"],
34 | "piqa": ["dev", "train"],
35 | "popqa": ["test"],
36 | "quartz": ["dev", "test", "train"],
37 | "siqa": ["dev", "train"],
38 | "squad": ["dev", "train"],
39 | "t-rex": ["dev", "train"],
40 | "triviaqa": ["dev", "test", "train"],
41 | "truthful_qa": ["dev"],
42 | "web_questions": ["test", "train"],
43 | "wikiasp": ["dev", "test", "train"],
44 | "wikiqa": ["dev", "test", "train"],
45 | "wned": ["dev"],
46 | "wow": ["dev", "train"],
47 | "zero-shot_re": ["dev", "train"],
48 | }
49 |
50 | async def run_test(self, name: str, split: str):
51 | # load dataset
52 | logger.info(f"Testing {name} {split}")
53 | dataset = RAGEvalDataset(RAGEvalDatasetConfig(name=name, split=split))
54 |
55 | # check dataset
56 | assert len(dataset) > 0
57 | for i in dataset:
58 | assert isinstance(i, RAGEvalData)
59 | for i in range(len(dataset)):
60 | assert isinstance(dataset[i], RAGEvalData)
61 | return
62 |
63 | @pytest.mark.asyncio
64 | async def test_rageval_dataset(self):
65 | logger.info("Testing RAGEvalDataset")
66 | for name in self.datasets:
67 | for split in self.datasets[name]:
68 | await self.run_test(name, split)
69 | return
70 |
--------------------------------------------------------------------------------