├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── PROMPT_EXAMPLE.txt
├── README.md
├── download.sh
├── evaluation
├── __init__.py
├── example_predictions
│ ├── icl_date
│ │ ├── preds.jsonl
│ │ └── queries.jsonl
│ ├── rag_nq
│ │ ├── preds.jsonl
│ │ └── queries.jsonl
│ ├── rag_quest
│ │ ├── preds.jsonl
│ │ └── queries.jsonl
│ ├── retrieval_nq
│ │ ├── preds.jsonl
│ │ └── queries.jsonl
│ ├── retrieval_quest
│ │ ├── preds.jsonl
│ │ └── queries.jsonl
│ ├── sql_sparc
│ │ ├── preds.jsonl
│ │ └── queries.jsonl
│ └── sql_spider
│ │ ├── preds.jsonl
│ │ └── queries.jsonl
├── icl.py
├── loft_evaluation.py
├── rag.py
├── retrieval.py
├── sql.py
└── utils.py
├── infer_eval.sh
├── inference
└── models.py
├── preprocess.py
├── prompts
├── __init__.py
├── constants
│ ├── __init__.py
│ ├── common.py
│ ├── icl.py
│ ├── mm.py
│ ├── rag.py
│ ├── retrieval.py
│ └── sql.py
├── prompt_registry.py
├── prompts_icl.py
├── prompts_mm.py
├── prompts_rag.py
├── prompts_retrieval.py
├── prompts_sql.py
└── utils.py
├── requirements.txt
├── run_evaluation.py
├── run_inference.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LOFT: A 1 Million+ Token Long-Context Benchmark
2 |
3 | This repository houses the resources for LOFT, the Long Context Frontiers benchmark, introduced in the research paper [Can Long-Context Language Models Subsume Retrieval, RAG, SQL, and More?](https://arxiv.org/abs/2406.13121).
4 | LOFT consists of 6 long-context task categories spanning retrieval, multi-hop
5 | compositional reasoning, and more, totaling 35 datasets and 4 modalities.
6 |
7 | ## Installation
8 | ```bash
9 | $ git clone git@github.com:google-deepmind/loft.git
10 | $ cd loft/
11 | $ pip install -r requirements.txt
12 | ```
13 |
14 | ## Download Datasets and Prompts
15 | The script below downloads all the LOFT datasets under `BASE_DIR`.
16 |
17 | ```bash
18 | $ BASE_DIR=your-choice-of-directory
19 | $ sh download.sh $BASE_DIR
20 | ```
21 |
22 | Each dataset is also available from the links in the [Datasets](#datasets) table.
23 | For a small subset, `download.sh` will additionally run `preprocess.py`, which
24 | infills the missing fields in the queries and corpus files.
25 | Once the download is completed, you will see the file structure as below:
26 |
27 | ```
28 | $BASE_DIR
29 | └── data
30 | ├── retrieval
31 | │ ├── arguana
32 | │ │ ├── 128k
33 | │ │ │ ├── corpus.jsonl
34 | │ │ │ ├── dev_queries.jsonl
35 | │ │ │ ├── few_shot_queries.jsonl
36 | │ │ │ └── test_queries.jsonl
37 | │ │ ├── 1m
38 | │ │ └── 32k
39 | │ ├── fever
40 | │ │ ├── ...
41 | │ ├── ...
42 | ├── rag
43 | ├── sql
44 | ├── icl
45 | └── mm
46 | ```
47 |
48 | We also provide an example prompt in `PROMPT_EXAMPLE.txt` showing how
49 | Corpus-in-Context (CiC) prompting can be done for the text retrieval task.
50 |
51 | ## Inference and Evaluation
52 | We currently support using Gemini (e.g., `gemini-1.5-flash-002`) from VertexAI
53 | for inference.
54 | Please prepare your `PROJECT_ID` from [Google Cloud](https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstarts/quickstart-multimodal#expandable-1).
55 | To run the inference with `gemini-1.5-flash-002` and evaluate predictions:
56 |
57 | ```bash
58 | BASE_DIR=$1
59 | DATASET=$2
60 | LENGTH="128k"
61 | TASK_TYPE="retrieval"
62 | SPLIT="dev"
63 | PROMPT_TYPE="few_shot_with_cot"
64 | PROMPT="${TASK_TYPE}_${DATASET}_${LENGTH}_${SPLIT}:${PROMPT_TYPE}"
65 | echo "Prompt: ${PROMPT}"
66 |
67 | mkdir -p ${BASE_DIR}/outputs/${TASK_TYPE}/${DATASET}/${LENGTH}
68 | answer_file_extension="jsonl"
69 |
70 | python run_inference.py \
71 | --prompt_name ${PROMPT} \
72 | --task_type ${TASK_TYPE} \
73 | --base_dir ${BASE_DIR} \
74 | --data_dir ${TASK_TYPE}/${DATASET}/${LENGTH} \
75 | --split ${SPLIT} \
76 | --context_length ${LENGTH} \
77 | --output_path ${BASE_DIR}/outputs/${TASK_TYPE}/${DATASET}/${LENGTH}/${SPLIT}_predictions.jsonl \
78 | --project_id ${PROJECT_ID} \
79 | --overwrite
80 |
81 | python run_evaluation.py \
82 | --answer_file_path ${BASE_DIR}/data/${TASK_TYPE}/${DATASET}/${LENGTH}/dev_queries.${answer_file_extension} \
83 | --pred_file_path ${BASE_DIR}/outputs/${TASK_TYPE}/${DATASET}/${LENGTH}/${SPLIT}_predictions.jsonl \
84 | --task_type ${TASK_TYPE}
85 | ```
86 |
87 | The same script can be found from `infer_eval.sh`.
88 | We provide example queries and predictions files in [evaluation/example_predictions/](evaluation/example_predictions/).
89 | Each `task_type` outputs many different metric scores.
90 | To understand which `task_type` to use for each dataset and also to see the primary evaluation metric reported in the paper for each dataset, see the [Datasets](#datasets) table.
91 |
92 | ## Datasets
93 |
94 | | Task | Dataset | Description | Task Type | Primary Metric | Infilling Needed? | Download |
95 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
96 | | Text Retrieval | [ArguAna](https://github.com/beir-cellar/beir) | Argument Retrieval | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/arguana.zip) |
97 | | Text Retrieval | [FEVER](https://github.com/beir-cellar/beir) | Fact Checking | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/fever.zip) |
98 | | Text Retrieval | [FIQA](https://github.com/beir-cellar/beir) | Question Answering | `retrieval` | `recall@1` | ✅ | [Link](https://storage.googleapis.com/loft-bench/retrieval/fiqa.zip) |
99 | | Text Retrieval | [MS MARCO](https://github.com/beir-cellar/beir) | Web Search | `retrieval` |`recall@1` | ✅ | [Link](https://storage.googleapis.com/loft-bench/retrieval/msmarco.zip) |
100 | | Text Retrieval | [NQ](https://github.com/beir-cellar/beir) | Question Answering | `retrieval` |`recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/nq.zip) |
101 | | Text Retrieval | [Quora](https://github.com/beir-cellar/beir) | Duplication Detection | `retrieval` |`recall@1` | ✅ | [Link](https://storage.googleapis.com/loft-bench/retrieval/quora.zip) |
102 | | Text Retrieval | [SciFact](https://github.com/beir-cellar/beir) | Citation Prediction | `retrieval` |`recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/scifact.zip) |
103 | | Text Retrieval | [Touché-2020](https://github.com/beir-cellar/beir) | Argument Retrieval | `retrieval` | `recall@1` | ✅ | [Link](https://storage.googleapis.com/loft-bench/retrieval/webis_touche2020.zip) |
104 | | Text Retrieval | [TopiOCQA](https://github.com/McGill-NLP/topiocqa) | Multi-turn QA | `retrieval` |`recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/topiocqa.zip) |
105 | | Text Retrieval | [HotPotQA](https://github.com/beir-cellar/beir) | Multi-hop QA | `retrieval` | `mrecall@2` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/hotpotqa.zip) |
106 | | Text Retrieval | [MuSiQue](https://allenai.org/data/musique) | Multi-hop QA | `retrieval` | `mrecall@5` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/musique.zip) |
107 | | Text Retrieval | [QAMPARI](https://github.com/samsam3232/qampari) | Multi-target QA | `retrieval` | `mrecall@5` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/qampari.zip) |
108 | | Text Retrieval | [QUEST](https://github.com/google-research/language/tree/master/language/quest) | Multi-target QA | `retrieval` | `mrecall@3` | - | [Link](https://storage.googleapis.com/loft-bench/retrieval/quest.zip) |
109 | | Visual Retrieval | [Flickr30k](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset) | Image Retrieval | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/mm/flickr30k.zip) |
110 | | Visual Retrieval | [MS COCO](https://cocodataset.org) | Image Retrieval | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/mm/mscoco.zip) |
111 | | Visual Retrieval | [OVEN](https://github.com/open-vision-language/oven) | Image-text Retrieval | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/mm/oven.zip) |
112 | | Visual Retrieval | [MSR-VTT](https://cove.thecvf.com/datasets/839) | Video Retrieval | `retrieval` | `recall@1`| - | [Link](https://storage.googleapis.com/loft-bench/mm/msrvtt.zip) |
113 | | Audio Retrieval | [FLEURS-en](https://huggingface.co/datasets/google/fleurs) | Audio Retrieval | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/mm/fleurs_en_tts.zip) |
114 | | Audio Retrieval | [FLEURS-es](https://huggingface.co/datasets/google/fleurs) | Audio Retrieval | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/mm/fleurs_es_tts.zip) |
115 | | Audio Retrieval | [FLEURS-fr](https://huggingface.co/datasets/google/fleurs) | Audio Retrieval | `retrieval` | `recall@1`| - | [Link](https://storage.googleapis.com/loft-bench/mm/fleurs_fr_tts.zip) |
116 | | Audio Retrieval | [FLEURS-hi](https://huggingface.co/datasets/google/fleurs) | Audio Retrieval | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/mm/fleurs_hi_tts.zip) |
117 | | Audio Retrieval | [FLEURS-zh](https://huggingface.co/datasets/google/fleurs) | Audio Retrieval | `retrieval` | `recall@1` | - | [Link](https://storage.googleapis.com/loft-bench/mm/fleurs_zh_tts.zip) |
118 | | RAG | [NQ](https://github.com/beir-cellar/beir) | Question Answering | `rag` | `subspan_em` | - | [Link](https://storage.googleapis.com/loft-bench/rag/nq.zip) |
119 | | RAG | [TopiOCQA](https://github.com/McGill-NLP/topiocqa) | Multi-turn QA | `rag` | `subspan_em` | - | [Link](https://storage.googleapis.com/loft-bench/rag/topiocqa.zip) |
120 | | RAG | [HotPotQA](https://github.com/beir-cellar/beir) | Multi-hop QA | `rag` | `subspan_em` | - | [Link](https://storage.googleapis.com/loft-bench/rag/hotpotqa.zip) |
121 | | RAG | [MuSiQue](https://allenai.org/data/musique) | Multi-hop QA | `rag` | `subspan_em` | - | [Link](https://storage.googleapis.com/loft-bench/rag/musique.zip) |
122 | | RAG | [QAMPARI](https://github.com/samsam3232/qampari) | Multi-target QA | `multi_value_rag` | `subspan_em` | - | [Link](https://storage.googleapis.com/loft-bench/rag/qampari.zip) |
123 | | RAG | [QUEST](https://github.com/google-research/language/tree/master/language/quest) | Multi-target QA | `multi_value_rag` | `subspan_em` | - | [Link](https://storage.googleapis.com/loft-bench/rag/quest.zip) |
124 | | SQL | [Spider](https://yale-lily.github.io/spider) | Single-turn SQL | `sql` | `exec_acc` | - | [Link](https://storage.googleapis.com/loft-bench/sql/spider.zip) |
125 | | SQL | [SParC](https://yale-lily.github.io/sparc) | Multi-turn SQL | `sql` | `exec_acc` | - | [Link](https://storage.googleapis.com/loft-bench/sql/sparc.zip) |
126 | | Many-Shot ICL | [BBH-date](https://github.com/suzgunmirac/BIG-Bench-Hard) | Multiple-choice QA | `icl` | `em` | - | [Link](https://storage.googleapis.com/loft-bench/icl/date_understanding.zip) |
127 | | Many-Shot ICL |[BBH-salient](https://github.com/suzgunmirac/BIG-Bench-Hard) | Multiple-choice QA | `icl` | `em` | - | [Link](https://storage.googleapis.com/loft-bench/icl/salient_translation_error_detection.zip) |
128 | | Many-Shot ICL |[BBH-tracking7](https://github.com/suzgunmirac/BIG-Bench-Hard) | Multiple-choice QA | `icl` | `em` | - | [Link](https://storage.googleapis.com/loft-bench/icl/tracking_shuffled_objects_seven_objects.zip) |
129 | | Many-Shot ICL |[BBH-web](https://github.com/suzgunmirac/BIG-Bench-Hard) | Multiple-choice QA | `icl` | `em` | - | [Link](https://storage.googleapis.com/loft-bench/icl/web_of_lies.zip) |
130 | | Many-Shot ICL |[LIB-dialogue](https://github.com/TIGER-AI-Lab/LongICLBench) | Classification | - | - | - | ❌ |
131 |
132 | ## LOFT-Hard Subset
133 | From the experiments in our [paper](https://arxiv.org/abs/2406.13121), we
134 | learned that Gemini 1.5 was already performing well on many LOFT datasets, but
135 | also it showed some headroom on other datasets.
136 | Hence, we recommend iterating on the following three datasets:
137 |
138 | * **MuSiQue, QAMPARI, QUEST**
139 |
140 | Full datasets and inference are supported from the current OSS.
141 |
142 | ## LOFT Multimodal Datasets
143 | For three of the LOFT multimodal datasets, Flickr30k, MS COCO, MSR-VTT, we ask
144 | the user of this repository to download the datasets from their respective
145 | websites:
146 |
147 | * Flickr30k: https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset
148 | * MS COCO: https://cocodataset.org/
149 | * MSR-VTT: https://cove.thecvf.com/datasets/839
150 |
151 | ## Past & Upcoming Releases
152 |
153 | * [x] Remaining multi-modal data and inference.
154 | * [x] Prompt conversion code (data => prompt).
155 | * [x] Inference code and prompts for retrieval (10/25/24).
156 | * [x] Evaluation code for ICL and some ICL and visual retrieval datasets (8/30/24).
157 | * [x] Evaluation code for text tasks and code to regenerate some of the LOFT datasets (6/29/24).
158 | * [x] Initial release with links to download many of the LOFT text datasets (6/20/24).
159 |
160 | ## Citing this work
161 |
162 | ```
163 | @article{Lee2024LongContext,
164 | title={Can Long-Context Language Models Subsume Retrieval, RAG, SQL, and More?},
165 | author={Jinhyuk Lee and Anthony Chen and Zhuyun Dai and Dheeru Dua and Devendra Singh Sachan and Michael Boratko and Yi Luan and Sébastien M. R. Arnold and Vincent Perot and Siddharth Dalmia and Hexiang Hu and Xudong Lin and Panupong Pasupat and Aida Amini and Jeremy R. Cole and Sebastian Riedel and Iftekhar Naim and Ming-Wei Chang and Kelvin Guu},
166 | journal={ArXiv},
167 | year={2024},
168 | volume={abs/2406.13121},
169 | url={https://arxiv.org/abs/2406.13121}
170 | }
171 | ```
172 |
173 | ## License and disclaimer
174 |
175 | Copyright 2024 DeepMind Technologies Limited
176 |
177 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
178 | you may not use this file except in compliance with the Apache 2.0 license.
179 | You may obtain a copy of the Apache 2.0 license at:
180 | https://www.apache.org/licenses/LICENSE-2.0
181 |
182 | All other materials are licensed under the Creative Commons Attribution 4.0
183 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
184 | https://creativecommons.org/licenses/by/4.0/legalcode
185 |
186 | Individual tasks may be subject to copyright and licensing from their respective
187 | owners - please see individual download files for details.
188 |
189 | Unless required by applicable law or agreed to in writing, all software and
190 | materials distributed here under the Apache 2.0 or CC-BY licenses are
191 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
192 | either express or implied. See the licenses for the specific language governing
193 | permissions and limitations under those licenses.
194 |
195 | This is not an official Google product.
196 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | BASE_DIR=$1
18 | ORIGINAL_DIR=$(pwd)
19 | mkdir -p ${BASE_DIR}
20 | cd ${BASE_DIR}
21 | BASE_DIR=$(pwd) # Converts to absolute path once in directory.
22 |
23 | # Text retrieval datasets.
24 | cd ${BASE_DIR}
25 | mkdir -p data/retrieval/
26 | cd data/retrieval
27 | DATASETS=("arguana" "fever" "fiqa" "msmarco" "nq" "quora" "scifact" "webis_touche2020" "topiocqa" "hotpotqa" "musique" "qampari" "quest")
28 | for DATASET in "${DATASETS[@]}"; do
29 | wget https://storage.googleapis.com/loft-bench/retrieval/${DATASET}.zip
30 | unzip ${DATASET}.zip
31 | rm ${DATASET}.zip
32 | done
33 |
34 | # Text RAG datasets.
35 | cd ${BASE_DIR}
36 | mkdir -p data/rag/
37 | cd data/rag
38 | DATASETS=("nq" "hotpotqa" "musique" "qampari" "quest" "topiocqa")
39 | for DATASET in "${DATASETS[@]}"; do
40 | wget https://storage.googleapis.com/loft-bench/rag/${DATASET}.zip
41 | unzip ${DATASET}.zip
42 | rm ${DATASET}.zip
43 | done
44 |
45 | # SQL datasets.
46 | cd ${BASE_DIR}
47 | mkdir -p data/sql/
48 | cd data/sql
49 | DATASETS=("spider" "sparc")
50 | for DATASET in "${DATASETS[@]}"; do
51 | wget https://storage.googleapis.com/loft-bench/sql/${DATASET}.zip
52 | unzip ${DATASET}.zip
53 | rm ${DATASET}.zip
54 | done
55 |
56 | # MM datasets.
57 | cd ${BASE_DIR}
58 | mkdir -p data/mm/
59 | cd data/mm
60 | DATASETS=("fleurs_en_tts" "fleurs_es_tts" "fleurs_fr_tts" "fleurs_hi_tts" "fleurs_zh_tts" "oven")
61 | for DATASET in "${DATASETS[@]}"; do
62 | wget https://storage.googleapis.com/loft-bench/mm/${DATASET}.zip
63 | unzip ${DATASET}.zip
64 | rm ${DATASET}.zip
65 | done
66 |
67 | # ICL datasets.
68 | cd ${BASE_DIR}
69 | mkdir -p data/icl/
70 | cd data/icl
71 | DATASETS=("date_understanding" "salient_translation_error_detection" "tracking_shuffled_objects_seven_objects" "web_of_lies")
72 | for DATASET in "${DATASETS[@]}"; do
73 | wget https://storage.googleapis.com/loft-bench/icl/${DATASET}.zip
74 | unzip ${DATASET}.zip
75 | rm ${DATASET}.zip
76 | done
77 |
78 | # Preprocess and fill in required fields.
79 | cd ${ORIGINAL_DIR}
80 | DATASETS=("fiqa" "msmarco" "quora" "webis_touche2020")
81 | for DATASET in "${DATASETS[@]}"; do
82 | python preprocess.py \
83 | --input_dir ${BASE_DIR}/data/retrieval/${DATASET} \
84 | --dataset ${DATASET}
85 | done
86 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluation mapper."""
17 |
18 | import types
19 |
20 | from evaluation import icl
21 | from evaluation import loft_evaluation
22 | from evaluation import rag
23 | from evaluation import retrieval
24 | from evaluation import sql
25 | from evaluation import utils
26 |
27 |
28 | EVALUATION_TASKS = types.MappingProxyType({
29 | "icl": icl.IclEvaluation(
30 | config=loft_evaluation.EvaluationConfig(
31 | process_model_response_fns=[utils.normalize_answers],
32 | process_gold_answer_fns=[utils.normalize_answer],
33 | )
34 | ),
35 | "rag": rag.RagEvaluation(
36 | config=loft_evaluation.EvaluationConfig(
37 | process_model_response_fns=[
38 | utils.convert_to_str,
39 | utils.normalize_answers,
40 | ],
41 | process_gold_answer_fns=[utils.normalize_answer],
42 | )
43 | ),
44 | "multi_value_rag": rag.MultiValueRagEvaluation(
45 | config=loft_evaluation.EvaluationConfig(
46 | process_model_response_fns=[
47 | utils.convert_to_str,
48 | utils.normalize_answers,
49 | ],
50 | process_gold_answer_fns=[utils.normalize_answer],
51 | )
52 | ),
53 | "retrieval": retrieval.RetrievalEvaluation(
54 | config=loft_evaluation.EvaluationConfig(
55 | process_model_response_fns=[
56 | utils.normalize_passage_ids,
57 | ],
58 | process_gold_answer_fns=[
59 | utils.extract_gold_passage_ids,
60 | utils.normalize_passage_id,
61 | ],
62 | )
63 | ),
64 | "mm": retrieval.RetrievalEvaluation(
65 | config=loft_evaluation.EvaluationConfig(
66 | process_model_response_fns=[
67 | utils.normalize_passage_ids,
68 | ],
69 | process_gold_answer_fns=[
70 | utils.extract_gold_passage_ids,
71 | utils.normalize_passage_id,
72 | ],
73 | )
74 | ),
75 | "sql": sql.SqlEvaluation(
76 | config=loft_evaluation.EvaluationConfig(
77 | process_model_response_fns=[],
78 | process_gold_answer_fns=[],
79 | )
80 | ),
81 | })
82 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/icl_date/preds.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "0", "num_turns": 1, "model_outputs": [["(D)"]]}
2 | {"qid": "1", "num_turns": 1, "model_outputs": [["(B)"]]}
3 | {"qid": "2", "num_turns": 1, "model_outputs": [["(A)"]]}
4 | {"qid": "3", "num_turns": 1, "model_outputs": [["(D)"]]}
5 | {"qid": "4", "num_turns": 1, "model_outputs": [["(D)"]]}
6 | {"qid": "5", "num_turns": 1, "model_outputs": [["(A)"]]}
7 | {"qid": "6", "num_turns": 1, "model_outputs": [["(F)"]]}
8 | {"qid": "7", "num_turns": 1, "model_outputs": [["(D)"]]}
9 | {"qid": "8", "num_turns": 1, "model_outputs": [["(E)"]]}
10 | {"qid": "9", "num_turns": 1, "model_outputs": [["(A)"]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/icl_date/queries.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "0", "query_text": "Jane booked a flight for tomorrow, Jul 29, 2002. What is the date today in MM/DD/YYYY?\nOptions:\n(A) 10/09/2002\n(B) 08/18/2002\n(C) 07/16/2002\n(D) 07/28/2002\n(E) 11/28/2002\n(F) 09/11/2002", "answers": [["(D)"]]}
2 | {"qid": "1", "query_text": "Jane got her job in 2016. Today is her 3-year work anniversary. She still remember that on Dec 2, her second day at work, she spilled coffee on her laptop. What is the date 10 days ago in MM/DD/YYYY?\nOptions:\n(A) 11/21/2019\n(B) 11/20/2019\n(C) 03/21/2020\n(D) 11/21/2080\n(E) 02/21/2020\n(F) 11/22/2019", "answers": [["(A)"]]}
3 | {"qid": "2", "query_text": "Today is Apr 10, 1985. Jane's appointment will be 3 days later. What is the date 24 hours later in MM/DD/YYYY?\nOptions:\n(A) 04/11/1985\n(B) 04/13/1985\n(C) 04/12/1987\n(D) 04/07/1975\n(E) 04/12/1986\n(F) 04/10/1985", "answers": [["(A)"]]}
4 | {"qid": "3", "query_text": "2015 is coming in 36 hours. What is the date 24 hours later in MM/DD/YYYY?\nOptions:\n(A) 01/09/2015\n(B) 12/30/2059\n(C) 12/30/2014\n(D) 01/01/2015\n(E) 01/04/2015\n(F) 12/31/2014", "answers": [["(C)"]]}
5 | {"qid": "4", "query_text": "2015 is coming in 36 hours. What is the date 10 days ago in MM/DD/YYYY?\nOptions:\n(A) 01/02/2015\n(B) 12/23/2014\n(C) 12/19/2014\n(D) 12/18/2014\n(E) 11/21/2014\n(F) 12/20/2014", "answers": [["(C)"]]}
6 | {"qid": "5", "query_text": "Yesterday was 12/31/1929. Today could not be 12/32/1929 because December has only 31 days. What is the date 24 hours later in MM/DD/YYYY?\nOptions:\n(A) 01/02/1930\n(B) 08/02/1930\n(C) 01/02/1884\n(D) 11/06/1929\n(E) 03/30/1930\n(F) 01/02/1877", "answers": [["(A)"]]}
7 | {"qid": "6", "query_text": "Yesterday was 12/31/1929. Today could not be 12/32/1929 because December has only 31 days. What is the date one year ago from today in MM/DD/YYYY?\nOptions:\n(A) 01/01/1898\n(B) 01/01/1994\n(C) 08/01/1929\n(D) 01/08/1929\n(E) 01/01/1891\n(F) 01/01/1929", "answers": [["(F)"]]}
8 | {"qid": "7", "query_text": "Today is 9/7. Jane is watching NFL 2003. What is the date one week ago from today in MM/DD/YYYY?\nOptions:\n(A) 09/05/2003\n(B) 08/30/2003\n(C) 08/31/2074\n(D) 08/31/2003\n(E) 06/30/2004", "answers": [["(D)"]]}
9 | {"qid": "8", "query_text": "Jane is celebrating the last day of Jan 2012. What is the date tomorrow in MM/DD/YYYY?\nOptions:\n(A) 02/02/2012\n(B) 02/15/2012\n(C) 01/25/2012\n(D) 04/22/2012\n(E) 02/01/2012\n(F) 02/11/2012", "answers": [["(E)"]]}
10 | {"qid": "9", "query_text": "May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date one week ago from today in MM/DD/YYYY?\nOptions:\n(A) 04/29/2002\n(B) 04/24/2002\n(C) 04/19/2002\n(D) 04/28/2002\n(E) 02/13/2002\n(F) 05/20/2002", "answers": [["(A)"]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/rag_nq/preds.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "test1018", "num_turns": 1, "model_outputs": [["the following day"]]}
2 | {"qid": "test103", "num_turns": 1, "model_outputs": [["Spain"]]}
3 | {"qid": "test1032", "num_turns": 1, "model_outputs": [["September 13, 2012"]]}
4 | {"qid": "test1041", "num_turns": 1, "model_outputs": [["Haliaeetus"]]}
5 | {"qid": "test1062", "num_turns": 1, "model_outputs": [["After returning to Galactica"]]}
6 | {"qid": "test1065", "num_turns": 1, "model_outputs": [["Nala"]]}
7 | {"qid": "test1087", "num_turns": 1, "model_outputs": [["Christopher Lloyd"]]}
8 | {"qid": "test1136", "num_turns": 1, "model_outputs": [["Charlotte of Mecklenburg-Strelitz"]]}
9 | {"qid": "test1156", "num_turns": 1, "model_outputs": [["Glenn Close"]]}
10 | {"qid": "test1172", "num_turns": 1, "model_outputs": [["the Ramones"]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/rag_nq/queries.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "test1018", "query_text": "when does monday night raw come on hulu", "metadata": {"qrels": [["doc35916", 1]]}, "answers": ["the following day"]}
2 | {"qid": "test103", "query_text": "who did puerto rico belong to before the u.s", "metadata": {"qrels": [["doc3528", 1]]}, "answers": ["Spain", "Ta\u00edno", "indigenous Ta\u00edno people"]}
3 | {"qid": "test1032", "query_text": "when did season 4 of glee come out", "metadata": {"qrels": [["doc36245", 1]]}, "answers": ["September 13, 2012"]}
4 | {"qid": "test1041", "query_text": "what is the genus of a bald eagle", "metadata": {"qrels": [["doc36467", 1]]}, "answers": ["Haliaeetus"]}
5 | {"qid": "test1062", "query_text": "when does boomer find out she a cylon", "metadata": {"qrels": [["doc36987", 1]]}, "answers": ["Kobol's Last Gleaming"]}
6 | {"qid": "test1065", "query_text": "what is the female lion called in lion king", "metadata": {"qrels": [["doc37055", 1]]}, "answers": ["Nala"]}
7 | {"qid": "test1087", "query_text": "who plays the woodsman in over the garden wall", "metadata": {"qrels": [["doc37964", 1]]}, "answers": ["Christopher Lloyd"]}
8 | {"qid": "test1136", "query_text": "who introduced the first chrismas tree to the uk", "metadata": {"qrels": [["doc40236", 1]]}, "answers": ["Charlotte of Mecklenburg-Strelitz"]}
9 | {"qid": "test1156", "query_text": "who played cruella de vil in 101 dalmatians", "metadata": {"qrels": [["doc41039", 1]]}, "answers": ["Glenn Close"]}
10 | {"qid": "test1172", "query_text": "who sang the song i wanna be sedated", "metadata": {"qrels": [["doc41558", 1]]}, "answers": ["the Ramones"]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/rag_quest/preds.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "11425373", "num_turns": 1, "model_outputs": [["Dinner at the Homesick Restaurant"]]}
2 | {"qid": "11879839", "num_turns": 1, "model_outputs": [["Kanden Kadhalai"]]}
3 | {"qid": "12780412", "num_turns": 1, "model_outputs": [["Ludwig II (2012 film)", "It's Only Love (film)", "Operation Crossbow (film)", "The Magic Pipe"]]}
4 | {"qid": "13017425", "num_turns": 1, "model_outputs": [["Care Bears Movie II: A New Generation", "Game Over (2013 film)", "Roberto Carlos em Ritmo de Aventura", "Space Odyssey (TV series)"]]}
5 | {"qid": "15066686", "num_turns": 1, "model_outputs": [["Hour of the Star"]]}
6 | {"qid": "15128548", "num_turns": 1, "model_outputs": [["Look for a Woman"]]}
7 | {"qid": "151565", "num_turns": 1, "model_outputs": [["Boruto: Naruto the Movie", "Children Who Chase Lost Voices", "Moomins and the Comet Chase"]]}
8 | {"qid": "15424189", "num_turns": 1, "model_outputs": [["A Case of Need"]]}
9 | {"qid": "15624559", "num_turns": 1, "model_outputs": [["Crisis of Conscience"]]}
10 | {"qid": "15886701", "num_turns": 1, "model_outputs": [["The Perfect Sap"]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/rag_quest/queries.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "11425373", "query_text": "Novels about families set in Boston and New England.", "metadata": {"qrels": [["quest_153157_0", 1]]}, "answers": ["A Case of Need"]}
2 | {"qid": "11879839", "query_text": "2009 Indian, or by Chetan Bhagat, novels", "metadata": {"qrels": [["quest_154285_0", 1], ["quest_155326_0", 1], ["quest_169419_0", 1]]}, "answers": ["Five Point Someone", "One Night @ the Call Center", "The 3 Mistakes of My Life"]}
3 | {"qid": "12780412", "query_text": "what are some Hungarian Revolution of 1956 or Austrian war films", "metadata": {"qrels": [["quest_118289_0", 1], ["quest_121059_0", 1], ["quest_12755_0", 1]]}, "answers": ["Duel with Death", "Fly Away Home (2016 film)", "Sunshine (1999 film)"]}
4 | {"qid": "13017425", "query_text": "Finnish television and animated films", "metadata": {"qrels": [["quest_120526_0", 1], ["quest_137164_0", 1], ["quest_54562_0", 1]]}, "answers": ["Moomins and the Comet Chase", "Quest for a Heart", "Santa Claus and the Magic Drum"]}
5 | {"qid": "15066686", "query_text": "Hispanic and Latino American novels set anywhere besides North America", "metadata": {"qrels": [["quest_163073_0", 1], ["quest_183935_0", 1], ["quest_189608_0", 1]]}, "answers": ["At Night We Walk in Circles", "La reina de Am\u00e9rica", "Lost City Radio"]}
6 | {"qid": "15128548", "query_text": "Swiss Films about stalking", "metadata": {"qrels": [["quest_71842_0", 1]]}, "answers": ["One Way Trip 3D"]}
7 | {"qid": "151565", "query_text": "2010s supernatural children's animated films", "metadata": {"qrels": [["quest_106496_0", 1], ["quest_117229_0", 1], ["quest_126830_0", 1]]}, "answers": ["Lego Scooby-Doo! Blowout Beach Bash", "Scooby-Doo! and the Curse of the 13th Ghost", "The Magic Snowflake"]}
8 | {"qid": "15424189", "query_text": "1971 British thriller novels", "metadata": {"qrels": [["quest_196055_0", 1], ["quest_196660_0", 1]]}, "answers": ["Firecrest (novel)", "Lament for Leto"]}
9 | {"qid": "15624559", "query_text": "Critical Jehovah's Witness books.", "metadata": {"qrels": [["quest_176606_0", 1]]}, "answers": ["Crisis of Conscience"]}
10 | {"qid": "15886701", "query_text": "what are some Novels set in Jiangsu, 1750s, or Novels by Charlotte Lennox?", "metadata": {"qrels": [["quest_149383_0", 1], ["quest_150293_0", 1], ["quest_152967_0", 1]]}, "answers": ["Demi-Gods and Semi-Devils", "The Adventures of Peregrine Pickle", "The History of Sir Charles Grandison"]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/retrieval_nq/preds.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "test1018", "num_turns": 1, "model_outputs": [["doc35916"]]}
2 | {"qid": "test103", "num_turns": 1, "model_outputs": [["doc3528"]]}
3 | {"qid": "test1032", "num_turns": 1, "model_outputs": [["doc36245"]]}
4 | {"qid": "test1041", "num_turns": 1, "model_outputs": [["doc36467"]]}
5 | {"qid": "test1062", "num_turns": 1, "model_outputs": [["doc36987"]]}
6 | {"qid": "test1065", "num_turns": 1, "model_outputs": [["doc37055"]]}
7 | {"qid": "test1087", "num_turns": 1, "model_outputs": [["doc37964"]]}
8 | {"qid": "test1136", "num_turns": 1, "model_outputs": [["doc40236"]]}
9 | {"qid": "test1156", "num_turns": 1, "model_outputs": [["doc41039"]]}
10 | {"qid": "test1172", "num_turns": 1, "model_outputs": [["doc41558"]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/retrieval_nq/queries.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "test1018", "query_text": "when does monday night raw come on hulu", "metadata": {"qrels": [["doc35916", 1]]}, "answers": [["doc35916", 1]]}
2 | {"qid": "test103", "query_text": "who did puerto rico belong to before the u.s", "metadata": {"qrels": [["doc3528", 1]]}, "answers": [["doc3528", 1]]}
3 | {"qid": "test1032", "query_text": "when did season 4 of glee come out", "metadata": {"qrels": [["doc36245", 1]]}, "answers": [["doc36245", 1]]}
4 | {"qid": "test1041", "query_text": "what is the genus of a bald eagle", "metadata": {"qrels": [["doc36467", 1]]}, "answers": [["doc36467", 1]]}
5 | {"qid": "test1062", "query_text": "when does boomer find out she a cylon", "metadata": {"qrels": [["doc36987", 1]]}, "answers": [["doc36987", 1]]}
6 | {"qid": "test1065", "query_text": "what is the female lion called in lion king", "metadata": {"qrels": [["doc37055", 1]]}, "answers": [["doc37055", 1]]}
7 | {"qid": "test1087", "query_text": "who plays the woodsman in over the garden wall", "metadata": {"qrels": [["doc37964", 1]]}, "answers": [["doc37964", 1]]}
8 | {"qid": "test1136", "query_text": "who introduced the first chrismas tree to the uk", "metadata": {"qrels": [["doc40236", 1]]}, "answers": [["doc40236", 1]]}
9 | {"qid": "test1156", "query_text": "who played cruella de vil in 101 dalmatians", "metadata": {"qrels": [["doc41039", 1]]}, "answers": [["doc41039", 1]]}
10 | {"qid": "test1172", "query_text": "who sang the song i wanna be sedated", "metadata": {"qrels": [["doc41558", 1]]}, "answers": [["doc41558", 1]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/retrieval_quest/preds.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "11425373", "num_turns": 1, "model_outputs": [["quest_147362_0", "quest_159422_2", "quest_147471_3"]]}
2 | {"qid": "11879839", "num_turns": 1, "model_outputs": [["quest_169419_0"]]}
3 | {"qid": "12780412", "num_turns": 1, "model_outputs": [["quest_642_0", "quest_43368_0", "quest_135623_0", "quest_242279_14"]]}
4 | {"qid": "13017425", "num_turns": 1, "model_outputs": [["quest_77806_1", "quest_65207_1"]]}
5 | {"qid": "15066686", "num_turns": 1, "model_outputs": [[]]}
6 | {"qid": "15128548", "num_turns": 1, "model_outputs": [["quest_198987_0"]]}
7 | {"qid": "151565", "num_turns": 1, "model_outputs": [["quest_59551_6"]]}
8 | {"qid": "15424189", "num_turns": 1, "model_outputs": [["quest_172458_0"]]}
9 | {"qid": "15624559", "num_turns": 1, "model_outputs": [["quest_164444_13"]]}
10 | {"qid": "15886701", "num_turns": 1, "model_outputs": [["quest_215428_0", "quest_152967_0"]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/retrieval_quest/queries.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "11425373", "query_text": "Novels about families set in Boston and New England.", "metadata": {"qrels": [["quest_153157_0", 1]]}, "answers": [["quest_153157_0", 1]]}
2 | {"qid": "11879839", "query_text": "2009 Indian, or by Chetan Bhagat, novels", "metadata": {"qrels": [["quest_154285_0", 1], ["quest_155326_0", 1], ["quest_169419_0", 1]]}, "answers": [["quest_154285_0", 1], ["quest_155326_0", 1], ["quest_169419_0", 1]]}
3 | {"qid": "12780412", "query_text": "what are some Hungarian Revolution of 1956 or Austrian war films", "metadata": {"qrels": [["quest_118289_0", 1], ["quest_121059_0", 1], ["quest_12755_0", 1]]}, "answers": [["quest_118289_0", 1], ["quest_121059_0", 1], ["quest_12755_0", 1]]}
4 | {"qid": "13017425", "query_text": "Finnish television and animated films", "metadata": {"qrels": [["quest_120526_0", 1], ["quest_137164_0", 1], ["quest_54562_0", 1]]}, "answers": [["quest_120526_0", 1], ["quest_137164_0", 1], ["quest_54562_0", 1]]}
5 | {"qid": "15066686", "query_text": "Hispanic and Latino American novels set anywhere besides North America", "metadata": {"qrels": [["quest_163073_0", 1], ["quest_183935_0", 1], ["quest_189608_0", 1]]}, "answers": [["quest_163073_0", 1], ["quest_183935_0", 1], ["quest_189608_0", 1]]}
6 | {"qid": "15128548", "query_text": "Swiss Films about stalking", "metadata": {"qrels": [["quest_71842_0", 1]]}, "answers": [["quest_71842_0", 1]]}
7 | {"qid": "151565", "query_text": "2010s supernatural children's animated films", "metadata": {"qrels": [["quest_106496_0", 1], ["quest_117229_0", 1], ["quest_126830_0", 1]]}, "answers": [["quest_106496_0", 1], ["quest_117229_0", 1], ["quest_126830_0", 1]]}
8 | {"qid": "15424189", "query_text": "1971 British thriller novels", "metadata": {"qrels": [["quest_196055_0", 1], ["quest_196660_0", 1]]}, "answers": [["quest_196055_0", 1], ["quest_196660_0", 1]]}
9 | {"qid": "15624559", "query_text": "Critical Jehovah's Witness books.", "metadata": {"qrels": [["quest_176606_0", 1]]}, "answers": [["quest_176606_0", 1]]}
10 | {"qid": "15886701", "query_text": "what are some Novels set in Jiangsu, 1750s, or Novels by Charlotte Lennox?", "metadata": {"qrels": [["quest_149383_0", 1], ["quest_150293_0", 1], ["quest_152967_0", 1]]}, "answers": [["quest_149383_0", 1], ["quest_150293_0", 1], ["quest_152967_0", 1]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/sql_sparc/preds.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "beef7de997c45136d6ec1935b9fbbf6cb843b6d4f9830be42523d71cc65d53e7", "num_turns": 2, "model_outputs": [[["Nassau", 45], ["Painter", 86], ["Alumni", 143], ["Lambeau", 348], ["Garfield", 119]], [["Lamberton", 134], ["Chandler", 375], ["Fairchild", 145], ["Nassau", 45], ["Grace", 40], ["Whitman", 134], ["Lamberton", 143], ["Taylor", 812], ["Saucon", 113], ["Painter", 86], ["Alumni", 547], ["Alumni", 143], ["Drown", 757], ["Saucon", 180], ["Whitman", 434], ["Saucon", 844], ["Bronfman", 700], ["Polya", 808], ["Gates", 707], ["Gates", 314], ["Main", 45], ["Taylor", 183], ["Power", 972], ["Garfield", 119], ["Rathbone", 261], ["Stabler", 105], ["Power", 717], ["Main", 425], ["Lambeau", 348], ["Chandler", 804]]]}
2 | {"qid": "55e5ef86821f02be10f8e448754d9d37d456562342c91cf5fbaae27d1b78e342", "num_turns": 2, "model_outputs": [[["Lamberton", 134], ["Lamberton", 143]], []]}
3 | {"qid": "3faafbbd50daeb96baff9a9166e96b3fa14796d5f2c3fb0326230f51143b6b0a", "num_turns": 3, "model_outputs": [[[612, "Mobile Computing"], [376, "Cost Accounting"], [959, "Bacteriology"], [267, "Hydraulics"], [436, "Stream Processing"], [731, "The Music of Donovan"], [130, "Differential Geometry"], [580, "The Music of Dave Edmunds"], [239, "The Music of the Ramones"]], [], []]}
4 | {"qid": "a4a8550136ba675ef0aafe16ba049ca36f44111d0ff24366e9f6c8497c38b4c5", "num_turns": 3, "model_outputs": [[[210627.58]], [[866831.75]], [[8889884.29]]]}
5 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/sql_sparc/queries.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "beef7de997c45136d6ec1935b9fbbf6cb843b6d4f9830be42523d71cc65d53e7", "query_text": ["Which classrooms have capacity between 50 and 100?", "What are their buildings and room numbers?"], "answers": [[["Nassau", "45", 92], ["Painter", "86", 97], ["Gates", "707", 65], ["Taylor", "183", 71], ["Garfield", "119", 59], ["Rathbone", "261", 60], ["Lambeau", "348", 51]], [["Nassau", "45"], ["Painter", "86"], ["Gates", "707"], ["Taylor", "183"], ["Garfield", "119"], ["Rathbone", "261"], ["Lambeau", "348"]]], "metadata": {"db_id": "college_2", "sql_query": ["SELECT * FROM classroom WHERE capacity BETWEEN 50 AND 100", "SELECT building , room_number FROM classroom WHERE capacity BETWEEN 50 AND 100"], "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"], "num_turns": 2}}
2 | {"qid": "55e5ef86821f02be10f8e448754d9d37d456562342c91cf5fbaae27d1b78e342", "query_text": ["What are all the classrooms in Lamberton?", "How many are there?"], "answers": [[["Lamberton", "134", 10], ["Lamberton", "143", 10]], [[2]]], "metadata": {"db_id": "college_2", "sql_query": ["SELECT * FROM classroom WHERE building = 'Lamberton'", "SELECT count(*) FROM classroom WHERE building = 'Lamberton'"], "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"], "num_turns": 2}}
3 | {"qid": "3faafbbd50daeb96baff9a9166e96b3fa14796d5f2c3fb0326230f51143b6b0a", "query_text": ["What are all the courses in the Physics department?", "What are their course ids?", "How many are there?"], "answers": [[["612", "Mobile Computing", "Physics", 3], ["376", "Cost Accounting", "Physics", 4], ["959", "Bacteriology", "Physics", 4], ["267", "Hydraulics", "Physics", 4], ["436", "Stream Processing", "Physics", 4], ["731", "The Music of Donovan", "Physics", 4], ["130", "Differential Geometry", "Physics", 3], ["239", "The Music of the Ramones", "Physics", 4], ["580", "The Music of Dave Edmunds", "Physics", 4], ["443", "Journalism", "Physics", 4]], [["130"], ["239"], ["267"], ["376"], ["436"], ["443"], ["580"], ["612"], ["731"], ["959"]], [[10]]], "metadata": {"db_id": "college_2", "sql_query": ["SELECT * FROM course WHERE dept_name = 'Physics'", "SELECT DISTINCT course_id FROM course WHERE dept_name = 'Physics'", "SELECT count(DISTINCT course_id) FROM course WHERE dept_name = 'Physics'"], "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"], "num_turns": 3}}
4 | {"qid": "a4a8550136ba675ef0aafe16ba049ca36f44111d0ff24366e9f6c8497c38b4c5", "query_text": ["What are the budgets of the Marketing department?", "How about that for the Finance department?", "What is their total budget?"], "answers": [[[210627.58]], [[866831.75]], [[1077459.33]]], "metadata": {"db_id": "college_2", "sql_query": ["SELECT budget FROM department WHERE dept_name = 'Marketing'", "SELECT budget FROM department WHERE dept_name = 'Finance'", "SELECT sum(budget) FROM department WHERE dept_name = 'Marketing' OR dept_name = 'Finance'"], "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"], "num_turns": 3}}
5 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/sql_spider/preds.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "72424cb88d329e71d075b226fc1f5ad371a3e1895933ff277c403b5bc5ce6da3", "num_turns": 1, "model_outputs": [[["Nassau"], ["Grace"], ["Whitman"], ["Taylor"], ["Saucon"], ["Painter"], ["Alumni"], ["Gates"], ["Main"], ["Lambeau"], ["Stabler"], ["Power"]]]}
2 | {"qid": "632bbee73f015dd096e8225c616761c23448a2895bed549786c94d628525aaaf", "num_turns": 1, "model_outputs": [[[24]]]}
3 | {"qid": "2ae0048fecf71b7576ef972fce93b7a1bc73dc8c57ceeccc84a3c88ba134de15", "num_turns": 1, "model_outputs": [[["Nassau", 45], ["Painter", 86], ["Alumni", 143], ["Painter", 86], ["Lambeau", 348], ["Garfield", 119], ["Rathbone", 261], ["Gates", 707]]]}
4 | {"qid": "04e52913595d4bb18c9f334232586f147f3cca79e12bdc50d16a34fc0c7467f0", "num_turns": 1, "model_outputs": [[["Physics", "Wrigley"]]]}
5 | {"qid": "402e3c337784d85a52e58e850844f54b6a2cd8e9af13f53f3ccb52310a5b84bd", "num_turns": 1, "model_outputs": [[["Cadis"]]]}
6 | {"qid": "aef823a3a4d0d36e3a866097ad48fd5144a2fe94a2c8e1622dc91f5784c8963c", "num_turns": 1, "model_outputs": [[[2]]]}
7 | {"qid": "4aa1df0f9828d97b6bfe3d1214474c87fa6aa8c17f302e22a636b8503f652cb2", "num_turns": 1, "model_outputs": [[[404]]]}
8 | {"qid": "081668a7b3c66e359047181955b39d49b3d7b209c6b15aa75f39409643b4250f", "num_turns": 1, "model_outputs": [[[21]]]}
9 | {"qid": "460f2d616f50c6de77d75056c9c65c8d3fb316bd2890c2f734a01a41ed82fb8c", "num_turns": 1, "model_outputs": [[[6]]]}
10 | {"qid": "e97a31925f9a173e85b6022f69086aa3579a132659f715379e4220c3d1ffd314", "num_turns": 1, "model_outputs": [[[172]]]}
11 |
--------------------------------------------------------------------------------
/evaluation/example_predictions/sql_spider/queries.jsonl:
--------------------------------------------------------------------------------
1 | {"qid": "72424cb88d329e71d075b226fc1f5ad371a3e1895933ff277c403b5bc5ce6da3", "query_text": "Find the buildings which have rooms with capacity more than 50.", "answers": [["Garfield"], ["Gates"], ["Lambeau"], ["Nassau"], ["Painter"], ["Rathbone"], ["Saucon"], ["Stabler"], ["Taylor"], ["Whitman"]], "metadata": {"db_id": "college_2", "sql_query": "SELECT DISTINCT building FROM classroom WHERE capacity > 50", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
2 | {"qid": "632bbee73f015dd096e8225c616761c23448a2895bed549786c94d628525aaaf", "query_text": "Count the number of rooms that are not in the Lamberton building.", "answers": [[28]], "metadata": {"db_id": "college_2", "sql_query": "SELECT count(*) FROM classroom WHERE building != 'Lamberton'", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
3 | {"qid": "2ae0048fecf71b7576ef972fce93b7a1bc73dc8c57ceeccc84a3c88ba134de15", "query_text": "Find the room number of the rooms which can sit 50 to 100 students and their buildings.", "answers": [["Nassau", "45"], ["Painter", "86"], ["Gates", "707"], ["Taylor", "183"], ["Garfield", "119"], ["Rathbone", "261"], ["Lambeau", "348"]], "metadata": {"db_id": "college_2", "sql_query": "SELECT building , room_number FROM classroom WHERE capacity BETWEEN 50 AND 100", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
4 | {"qid": "04e52913595d4bb18c9f334232586f147f3cca79e12bdc50d16a34fc0c7467f0", "query_text": "Find the name and building of the department with the highest budget.", "answers": [["Physics", "Wrigley"]], "metadata": {"db_id": "college_2", "sql_query": "SELECT dept_name , building FROM department ORDER BY budget DESC LIMIT 1", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
5 | {"qid": "402e3c337784d85a52e58e850844f54b6a2cd8e9af13f53f3ccb52310a5b84bd", "query_text": "What is the name of the student who has the highest total credits in the History department.", "answers": [["Cadis"]], "metadata": {"db_id": "college_2", "sql_query": "SELECT name FROM student WHERE dept_name = 'History' ORDER BY tot_cred DESC LIMIT 1", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
6 | {"qid": "aef823a3a4d0d36e3a866097ad48fd5144a2fe94a2c8e1622dc91f5784c8963c", "query_text": "How many rooms does the Lamberton building have?", "answers": [[2]], "metadata": {"db_id": "college_2", "sql_query": "SELECT count(*) FROM classroom WHERE building = 'Lamberton'", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
7 | {"qid": "4aa1df0f9828d97b6bfe3d1214474c87fa6aa8c17f302e22a636b8503f652cb2", "query_text": "How many students have advisors?", "answers": [[2000]], "metadata": {"db_id": "college_2", "sql_query": "SELECT count(DISTINCT s_id) FROM advisor", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
8 | {"qid": "081668a7b3c66e359047181955b39d49b3d7b209c6b15aa75f39409643b4250f", "query_text": "How many departments offer courses?", "answers": [[20]], "metadata": {"db_id": "college_2", "sql_query": "SELECT count(DISTINCT dept_name) FROM course", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
9 | {"qid": "460f2d616f50c6de77d75056c9c65c8d3fb316bd2890c2f734a01a41ed82fb8c", "query_text": "How many different courses offered by Physics department?", "answers": [[10]], "metadata": {"db_id": "college_2", "sql_query": "SELECT count(DISTINCT course_id) FROM course WHERE dept_name = 'Physics'", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
10 | {"qid": "e97a31925f9a173e85b6022f69086aa3579a132659f715379e4220c3d1ffd314", "query_text": "How many courses that do not have prerequisite?", "answers": [[121]], "metadata": {"db_id": "college_2", "sql_query": "SELECT count(*) FROM course WHERE course_id NOT IN (SELECT course_id FROM prereq)", "candidate_pids": ["4e0e5654f90405ba48c1c4f526613330f104c6371625a58bd8f7612d1cea2810", "2e52a28f5648269b8a2a92f728b7faeddf32ab680c539f20328e518191db1369", "47d5c44b4f3b3fbe448206ff31eec50cfcfe16d536365c9744aeb56b46fab817", "59299033aa2e18f88810e7f6f64242b63a581f6d1e85f7840fd29dd86ae72747", "ac7d0506a6cccefa1c3b8b11bf35a7df931db5b24ef7585fbc30c08230036034", "8fcfa1d611817ab25dab2bcce31b6bac3b8c86f8fd3fe7211183e1faea41cb60", "45016232aacfc6ce0f29334f997d794377a7df3c4201d6c389ff1d9360e43658", "664aaad74b94d4d45a9dfd23db73072c96403a4528c7370f407cf3665e5c994c", "da63f93753cb5157316d7bd7be759bc863a45b5bbacef790dec341fe76a9cc6a", "99561ff9221ebfd68153e6fa82a881ef16f4c029764962b60869a91e760400af", "ad116c5457cb220b07f23c6ecd4630c785b62a995cd1ae0b4d2f2455a426e826"]}}
11 |
--------------------------------------------------------------------------------
/evaluation/icl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluation functions for Many-Shot ICL."""
17 |
18 | import collections
19 | from typing import Any
20 |
21 | from evaluation import loft_evaluation as evaluation
22 | from evaluation import utils
23 |
24 |
25 | class IclEvaluation(evaluation.LOFTEvalution):
26 | """Evaluation for ICL model outputs."""
27 |
28 | def __init__(self, config: evaluation.EvaluationConfig):
29 | super().__init__()
30 | self.config = config
31 | self.metrics = collections.defaultdict(list)
32 |
33 | def evaluate(
34 | self, instance: evaluation.EvaluationInstance
35 | ) -> list[dict[str, Any]]:
36 | """Evaluates an ICL prediction."""
37 |
38 | multi_turn_metrics = []
39 | for turn_number in range(instance.num_turns):
40 | gold_answers = self.process_goldens(instance.gold_answers[turn_number])
41 | pred_answers = self.process_prediction(instance.model_output[turn_number])
42 | instance_metrics = {'qid': instance.qid, 'turn_id': str(turn_number)}
43 |
44 | if not pred_answers:
45 | instance_metrics['em'] = 0.0
46 | else:
47 | # Single prediction is allowed and matched against. Ill-formed Model
48 | # outputs may provide multiple answers but are ignored.
49 | if len(pred_answers) > 1:
50 | print(
51 | 'Warning: Multiple answers found in prediction for single value'
52 | f' answers: {pred_answers}.'
53 | )
54 | instance_metrics['em'] = utils.compute_em(gold_answers, pred_answers[0])
55 |
56 | # Make sure to call below to aggregate all the metrics.
57 | self.add_instance_metrics(instance_metrics)
58 | multi_turn_metrics.append(instance_metrics)
59 |
60 | return multi_turn_metrics
61 |
--------------------------------------------------------------------------------
/evaluation/loft_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """LOFT evaluation protocol."""
17 |
18 | import abc
19 | import dataclasses
20 | from typing import Any, Callable, Protocol
21 | import numpy as np
22 |
23 |
24 | @dataclasses.dataclass(frozen=True)
25 | class EvaluationConfig:
26 | """Evaluation configuration."""
27 |
28 | # Sequence of processing functions for model response.
29 | process_model_response_fns: list[Callable[..., Any]]
30 |
31 | # Sequence of processing functions for gold answer.
32 | process_gold_answer_fns: list[Callable[..., Any]]
33 |
34 |
35 | @dataclasses.dataclass(frozen=True)
36 | class EvaluationInstance:
37 | """The protocol for classes that perform evaluation on LOFT EvaluationInstance."""
38 |
39 | # Unique instance identifier.
40 | qid: str
41 | # Multiple gold references in an instance.
42 | gold_answers: list[Any]
43 | # Single model response (greedy) in an instance.
44 | model_output: Any | list[Any]
45 | # Number of converstaional turns in the instance
46 | num_turns: int
47 | # Any additional metadata about the instance.
48 | metadata: dict[str, Any] | None = None
49 |
50 |
51 | class LOFTEvalution(Protocol):
52 | """The protocol for classes that perform evaluation on LOFT EvaluationInstance."""
53 |
54 | config: EvaluationConfig
55 | metrics: dict[str, Any]
56 |
57 | def process_goldens(self, goldens: list[Any]) -> list[Any]:
58 | """Processes goldens by executing functions defined in EvaluationConfig."""
59 | assert self.config is not None
60 | assert self.config.process_gold_answer_fns
61 |
62 | processed_goldens = []
63 | for gold in goldens:
64 | for fn in self.config.process_gold_answer_fns:
65 | gold = fn(gold)
66 | processed_goldens.append(gold)
67 |
68 | return processed_goldens
69 |
70 | def process_prediction(self, prediction: str) -> Any:
71 | """Processes model response by executing functions defined in EvaluationConfig."""
72 | assert self.config is not None
73 | assert self.config.process_model_response_fns
74 |
75 | for fn in self.config.process_model_response_fns:
76 | prediction = fn(prediction)
77 |
78 | return prediction
79 |
80 | def add_instance_metrics(self, instance_metrics: dict[str, Any]):
81 | """Add instance specific metrics to the global metrics field."""
82 | assert self.metrics is not None
83 | for metric_name, value in instance_metrics.items():
84 | self.metrics[metric_name].append(value)
85 |
86 | @abc.abstractmethod
87 | def evaluate(self, instance: EvaluationInstance) -> list[dict[str, Any]]:
88 | """Returns a list of dictionaries containing evaluation metrics."""
89 |
90 | def aggregate_metrics(self) -> dict[str, Any]:
91 | assert self.metrics is not None
92 | aggregated_metrics = {}
93 | for metric_name, metric_values in self.metrics.items():
94 | if any([
95 | not (isinstance(value, float) or isinstance(value, int))
96 | for value in metric_values
97 | ]):
98 | continue
99 | aggregated_metrics[metric_name] = np.mean(metric_values)
100 | return aggregated_metrics
101 |
--------------------------------------------------------------------------------
/evaluation/rag.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluation functions for RAG. EM/F1 are from SQuAD and slightly modified."""
17 |
18 | import collections
19 | from typing import Any
20 |
21 | from evaluation import loft_evaluation as evaluation
22 | from evaluation import utils
23 | import numpy as np
24 | import scipy.optimize
25 |
26 |
27 | def compute_em_multi_value(
28 | gold_answers: list[str], pred_answers: list[str]
29 | ) -> float:
30 | """Calculates exact match score. Taken from SQuAD evaluation."""
31 | return float(set(gold_answers) == set(pred_answers))
32 |
33 |
34 | def compute_coverage(gold_answers: list[str], pred_answers: list[str]) -> float:
35 | """Calculates coverage of gold_answers in pred_answers."""
36 | return len(set(pred_answers).intersection(set(gold_answers))) / float(
37 | len(gold_answers)
38 | )
39 |
40 |
41 | def compute_multi_value_subspan_em(
42 | gold_answers: list[str], pred_answers: list[str]
43 | ) -> float:
44 | """Calculates subspan match score. Adopted from DROP evaluation."""
45 | scores = np.zeros([len(gold_answers), len(pred_answers)])
46 | for gold_index, gold_item in enumerate(gold_answers):
47 | for pred_index, pred_item in enumerate(pred_answers):
48 | if gold_item in pred_item or pred_item in gold_item:
49 | scores[gold_index, pred_index] = 1
50 | row_ind, col_ind = scipy.optimize.linear_sum_assignment(-scores)
51 | aligned_scores = np.zeros(len(gold_answers))
52 | for r, c in zip(row_ind, col_ind):
53 | aligned_scores[r] = scores[r, c]
54 | return float(all(aligned_scores))
55 |
56 |
57 | class MultiValueRagEvaluation(evaluation.LOFTEvalution):
58 | """Evaluation for RAG model outputs for single-turn Set based datasets."""
59 |
60 | def __init__(
61 | self,
62 | config: evaluation.EvaluationConfig,
63 | ):
64 | super().__init__()
65 | self.config = config
66 | self.metrics = collections.defaultdict(list)
67 |
68 | def evaluate(
69 | self, instance: evaluation.EvaluationInstance
70 | ) -> list[dict[str, Any]]:
71 | """Evaluates a RAG prediction."""
72 |
73 | gold_answers = self.process_goldens(instance.gold_answers)
74 | pred_answers = self.process_prediction(instance.model_output[0])
75 |
76 | if not pred_answers:
77 | instance_metrics = {
78 | 'qid': instance.qid,
79 | 'em': 0.0,
80 | 'subspan_em': 0.0,
81 | 'f1': 0.0,
82 | }
83 | self.add_instance_metrics(instance_metrics)
84 | return [instance_metrics]
85 |
86 | instance_metrics = {}
87 | instance_metrics['qid'] = instance.qid
88 | instance_metrics['em'] = compute_em_multi_value(gold_answers, pred_answers)
89 | instance_metrics['coverage'] = compute_coverage(gold_answers, pred_answers)
90 | instance_metrics['subspan_em'] = compute_multi_value_subspan_em(
91 | gold_answers, pred_answers
92 | )
93 | # Make sure to call below to aggregate all the metrics.
94 | self.add_instance_metrics(instance_metrics)
95 |
96 | return [instance_metrics]
97 |
98 |
99 | class RagEvaluation(evaluation.LOFTEvalution):
100 | """Evaluation for multi-turn RAG model outputs."""
101 |
102 | def __init__(self, config: evaluation.EvaluationConfig):
103 | super().__init__()
104 | self.config = config
105 | self.metrics = collections.defaultdict(list)
106 |
107 | def evaluate(
108 | self, instance: evaluation.EvaluationInstance
109 | ) -> list[dict[str, Any]]:
110 | """Evaluates a RAG prediction."""
111 |
112 | multi_turn_metrics = []
113 | for turn_number in range(instance.num_turns):
114 | if instance.num_turns == 1:
115 | gold_answers = self.process_goldens(instance.gold_answers)
116 | else:
117 | gold_answers = self.process_goldens(instance.gold_answers[turn_number])
118 | pred_answers = self.process_prediction(instance.model_output[turn_number])
119 |
120 | if not pred_answers:
121 | instance_metrics = {
122 | 'qid': instance.qid,
123 | 'turn_id': str(turn_number),
124 | 'em': 0.0,
125 | 'subspan_em': 0.0,
126 | 'f1': 0.0,
127 | }
128 |
129 | multi_turn_metrics.append(instance_metrics)
130 | self.add_instance_metrics(instance_metrics)
131 | continue
132 |
133 | instance_metrics = {}
134 | # Single prediction is allowed and matched against. Ill-formed Model
135 | # outputs may provide multiple answers but are ignored.
136 | if len(pred_answers) > 1:
137 | print(
138 | 'Warning: Multiple answers found in prediction for single value'
139 | f' retrieval: {pred_answers}.'
140 | )
141 |
142 | pred_answer = pred_answers[0]
143 | instance_metrics['qid'] = instance.qid
144 | instance_metrics['turn_id'] = str(turn_number)
145 | instance_metrics['em'] = utils.compute_em(gold_answers, pred_answer)
146 | instance_metrics['subspan_em'] = utils.compute_subspan_em(
147 | gold_answers, pred_answer
148 | )
149 | instance_metrics['f1'] = utils.compute_f1(gold_answers, pred_answer)
150 |
151 | # Make sure to call below to aggregate all the metrics.
152 | self.add_instance_metrics(instance_metrics)
153 |
154 | multi_turn_metrics.append(instance_metrics)
155 |
156 | return multi_turn_metrics
157 |
--------------------------------------------------------------------------------
/evaluation/retrieval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluation functions for retrieval."""
17 |
18 | import collections
19 | import json
20 | from typing import Any
21 | from evaluation import loft_evaluation as evaluation
22 |
23 | EvaluationInstance = evaluation.EvaluationInstance
24 |
25 |
26 | def compute_recall_at_k(
27 | gold_ids: list[str],
28 | pred_ids: list[str],
29 | top_k: int,
30 | capped: bool = False,
31 | ) -> float:
32 | """Calculates the recall at k."""
33 | assert top_k > 0
34 | if not pred_ids:
35 | return 0
36 | pred_ids = set(pred_ids[:top_k])
37 | relevant_in_top_k = float(len(pred_ids.intersection(gold_ids)))
38 |
39 | # Capped recall@k is triggered when # of gold docs is > top_k
40 | if capped and len(gold_ids) > top_k:
41 | recall = relevant_in_top_k / top_k
42 | else:
43 | recall = relevant_in_top_k / len(gold_ids)
44 | return recall
45 |
46 |
47 | def compute_mrecall_at_k(
48 | gold_ids: list[str],
49 | pred_ids: list[str],
50 | top_k: int,
51 | ) -> float:
52 | """Calculates the mRecall at k.
53 |
54 | This metric was introduced in Min et al., 2021:
55 | https://aclanthology.org/2021.emnlp-main.560.pdf
56 |
57 | Args:
58 | gold_ids: A list of gold IDs.
59 | pred_ids: A list of prediction IDs.
60 | top_k: The number of predictions to consider.
61 |
62 | Returns:
63 | mRecall@k metric.
64 | """
65 | assert top_k > 0
66 | if not pred_ids:
67 | return 0
68 | pred_ids = set(pred_ids[:top_k])
69 | relevant_in_top_k = float(len(pred_ids.intersection(gold_ids)))
70 |
71 | # This computes the completeness of the answers.
72 | return float(relevant_in_top_k == min(top_k, len(gold_ids)))
73 |
74 |
75 | class RetrievalEvaluation(evaluation.LOFTEvalution):
76 | """Evaluation for Multi-turn retrieval datasets."""
77 |
78 | def __init__(
79 | self,
80 | config: evaluation.EvaluationConfig,
81 | ):
82 | super().__init__()
83 | self.config = config
84 | self.metrics = collections.defaultdict(list)
85 |
86 | def evaluate(self, instance: EvaluationInstance) -> list[dict[str, Any]]:
87 | """Evaluates a retrieval prediction."""
88 |
89 | metrics = []
90 | for turn_number in range(instance.num_turns):
91 | instance_metrics = {}
92 | pid2text = {}
93 | if instance.metadata and instance.metadata.get("candidate_path", None):
94 | def _get_text(candidate):
95 | return (
96 | candidate["title_text"].strip()
97 | + candidate["passage_text"].strip()
98 | )
99 |
100 | with open(instance.metadata["candidate_path"]) as f:
101 | for line in f:
102 | candidate = json.loads(line)
103 | pid2text[candidate["pid"]] = _get_text(candidate)
104 |
105 | if instance.num_turns == 1:
106 | gold_ids = self.process_goldens(instance.gold_answers)
107 | else:
108 | gold_ids = self.process_goldens([instance.gold_answers[turn_number]])
109 | pred_ids = self.process_prediction(instance.model_output[turn_number])
110 | if "candidate_path" in instance.metadata:
111 | gold_ids = [pid2text[pid] for pid in gold_ids]
112 | pred_ids = [pid2text[pid] for pid in pred_ids]
113 |
114 | instance_metrics["qid"] = instance.qid
115 | instance_metrics["turn_id"] = str(turn_number)
116 | instance_metrics["recall@1"] = compute_recall_at_k(
117 | gold_ids, pred_ids, 1, False
118 | )
119 | instance_metrics["recall@2"] = compute_recall_at_k(
120 | gold_ids, pred_ids, 2, False
121 | )
122 | instance_metrics["recall@3"] = compute_recall_at_k(
123 | gold_ids, pred_ids, 3, False
124 | )
125 | instance_metrics["recall@5"] = compute_recall_at_k(
126 | gold_ids, pred_ids, 5, False
127 | )
128 | instance_metrics["mrecall@1"] = compute_mrecall_at_k(
129 | gold_ids, pred_ids, 1
130 | )
131 | instance_metrics["mrecall@2"] = compute_mrecall_at_k(
132 | gold_ids, pred_ids, 2
133 | )
134 | instance_metrics["mrecall@3"] = compute_mrecall_at_k(
135 | gold_ids, pred_ids, 3
136 | )
137 | instance_metrics["mrecall@5"] = compute_mrecall_at_k(
138 | gold_ids, pred_ids, 5
139 | )
140 | instance_metrics["capped_recall@1"] = compute_recall_at_k(
141 | gold_ids, pred_ids, 1, True
142 | )
143 | metrics.append(instance_metrics)
144 |
145 | # Make sure to call below to aggregate all the metrics.
146 | self.add_instance_metrics(instance_metrics)
147 |
148 | return metrics
149 |
--------------------------------------------------------------------------------
/evaluation/sql.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluation functions for SQL.
17 |
18 | The primary metric for SQL is execution accuracy. This compares the gold answer
19 | from executing the gold SQL query to the predicted answer. Gold and predicted
20 | answers must be lists of lists.
21 |
22 | We do not enforce order on the predicted answer. For instance, if the question
23 | is "What are top 3 largest cities in the United States?" and the gold answer is
24 | [["New York"], ["Los Angeles"], ["Chicago"]], we consider the predicted answer
25 | [["Los Angeles"], ["Chicago"], ["New York"]] as correct. There are some
26 | questions that require sorting the answer. We filter questions that require
27 | sorted answer when creating the SQL data for LOFT.
28 | """
29 |
30 | import collections
31 | from typing import Any
32 |
33 | from evaluation import loft_evaluation as evaluation
34 |
35 |
36 | def compute_execution_accuracy(
37 | gold_answer: list[list[str]],
38 | pred_answer: list[list[str]],
39 | ) -> float:
40 | """Calculates the execution accuracy."""
41 | if len(gold_answer) != len(pred_answer):
42 | return 0.0
43 |
44 | # Convert the list of lists into a list of Sets to allow for different
45 | # ordering (very relaxed).
46 | gold_answer = [set(ga) for ga in gold_answer]
47 | pred_answer = [set(pa) for pa in pred_answer]
48 | # Check that the gold answer perfect matches the predicted answer.
49 | for pa in pred_answer:
50 | if pa not in gold_answer:
51 | return 0.0
52 | return 1.0
53 |
54 |
55 | def normalize_sql_answer(answers: list[list[Any]]) -> list[list[str]]:
56 | """Normalizes all answers.
57 |
58 | Takes a list of list of answers and converts all elements in the list
59 | of lists to a string while applying some form of normalization.
60 |
61 | Args:
62 | answers: A list of lists of answers.
63 |
64 | Returns:
65 | normalized_answers: A list of list of answers as strings.
66 | """
67 | normalized_answers = []
68 | for subanswer in answers:
69 | normalized_answers.append([])
70 | for item in subanswer:
71 | # Try to convert all numbers in string form into a number for rounding.
72 | # Round all numbers to 2 decimals to handle various answer precision
73 | try:
74 | item = f"{float(item):.2f}"
75 | except Exception: # pylint: disable=broad-exception-caught
76 | pass
77 | item = str(item).strip().lower()
78 | if item:
79 | normalized_answers[-1].append(item)
80 | return normalized_answers
81 |
82 |
83 | class SqlEvaluation(evaluation.LOFTEvalution):
84 | """Evaluation for SQL model outputs."""
85 |
86 | def __init__(
87 | self,
88 | config: evaluation.EvaluationConfig,
89 | ):
90 | super().__init__()
91 | self.config = config
92 | self.metrics = collections.defaultdict(list)
93 |
94 | def evaluate(
95 | self, instance: evaluation.EvaluationInstance
96 | ) -> list[dict[str, Any]]:
97 | """Evaluates a SQL prediction."""
98 | multi_turn_metrics = []
99 | for turn_number in range(instance.num_turns):
100 | if instance.num_turns > 1:
101 | answers = normalize_sql_answer(instance.gold_answers[turn_number])
102 | else:
103 | answers = normalize_sql_answer(instance.gold_answers)
104 | if not isinstance(answers, list) or not isinstance(answers[0], list):
105 | raise ValueError(
106 | f"Gold answers must be a list of lists but got {answers}"
107 | )
108 |
109 | # We want our output to be nested lists.
110 | pred_answers = instance.model_output[turn_number]
111 | if pred_answers and not isinstance(pred_answers[0], list):
112 | pred_answers = [pred_answers]
113 | pred_answers = normalize_sql_answer(pred_answers)
114 |
115 | instance_metrics = {
116 | "qid": instance.qid,
117 | "exec_acc": compute_execution_accuracy(answers, pred_answers),
118 | "metadata": {"turn_number": turn_number},
119 | }
120 | # Make sure to call below to aggregate all the metrics.
121 | self.add_instance_metrics(instance_metrics)
122 | multi_turn_metrics.append(instance_metrics)
123 |
124 | return multi_turn_metrics
125 |
--------------------------------------------------------------------------------
/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluation utilities."""
17 |
18 | import collections
19 | import re
20 | import string
21 | from typing import Any
22 | import unicodedata
23 |
24 |
25 | def extract_gold_passage_ids(gold_answer: list[str | int]) -> str:
26 | """Extracts passage IDs from the gold answers field.
27 |
28 | The gold answer in the query file for retrieval looks like this ["doc35916",
29 | 1].
30 | We extract the document ID for the gold answer to use for evaluation purposes.
31 |
32 | Args:
33 | gold_answer: The gold answer from the query file.
34 |
35 | Returns:
36 | The document ID for the gold answer.
37 | """
38 | if not isinstance(gold_answer[0], str) or not isinstance(gold_answer[1], int):
39 | raise ValueError(
40 | "Gold answer must be a list consisting of a str and an int."
41 | )
42 | return gold_answer[0]
43 |
44 |
45 | def normalize_passage_id(passage_id: Any) -> str:
46 | return str(passage_id).strip()
47 |
48 |
49 | def normalize_passage_ids(passage_ids: list[Any]) -> list[str]:
50 | return [normalize_passage_id(passage_id) for passage_id in passage_ids]
51 |
52 |
53 | def normalize_answer(s: str) -> str:
54 | """Taken from SQuAD evaluation."""
55 |
56 | s = unicodedata.normalize("NFD", s)
57 |
58 | def remove_articles(text: str) -> str:
59 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
60 | return re.sub(regex, " ", text)
61 |
62 | def white_space_fix(text: str) -> str:
63 | return " ".join(text.split())
64 |
65 | def remove_punc(text: str) -> str:
66 | exclude = set(string.punctuation)
67 | return "".join(ch for ch in text if ch not in exclude)
68 |
69 | def lower(text: str) -> str:
70 | return text.lower()
71 |
72 | return white_space_fix(remove_articles(remove_punc(lower(s))))
73 |
74 |
75 | def normalize_answers(answers: list[str]) -> list[str]:
76 | return [normalize_answer(answer) for answer in answers]
77 |
78 |
79 | def convert_to_str(texts: list[Any]) -> list[str]:
80 | return [str(text) for text in texts]
81 |
82 |
83 | def get_tokens(s: str) -> list[str]:
84 | """Taken from SQuAD evaluation."""
85 | if not s:
86 | return []
87 | return normalize_answer(s).split()
88 |
89 |
90 | def compute_em(gold_answers: list[str], pred_answer: str) -> float:
91 | """Calculates exact match score. Taken from SQuAD evaluation."""
92 | return max([float(ga == pred_answer) for ga in gold_answers])
93 |
94 |
95 | def compute_subspan_em(gold_answers: list[str], pred_answer: str) -> float:
96 | """Calculates subspan match score."""
97 | return max([1.0 if ga in pred_answer else 0.0 for ga in gold_answers])
98 |
99 |
100 | def compute_f1(gold_answers: list[str], pred_answer: str) -> float:
101 | """Calculates F1 score. Taken from SQuAD evaluation."""
102 | pred_toks = get_tokens(pred_answer)
103 |
104 | f1_scores = []
105 | for ga in gold_answers:
106 | gold_toks = get_tokens(ga)
107 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
108 | num_same = sum(common.values())
109 |
110 | if num_same == 0:
111 | f1_scores.append(0.0)
112 | continue
113 |
114 | if not gold_toks or not pred_toks:
115 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
116 | f1 = float(gold_toks == pred_toks)
117 | else:
118 | precision = 1.0 * num_same / len(pred_toks)
119 | recall = 1.0 * num_same / len(gold_toks)
120 | f1 = (2 * precision * recall) / (precision + recall)
121 | f1_scores.append(f1)
122 |
123 | return max(f1_scores)
124 |
--------------------------------------------------------------------------------
/infer_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | BASE_DIR=$1
18 | DATASET=$2
19 | LENGTH="32k"
20 | TASK_TYPE="mm"
21 | SPLIT="dev"
22 | if [[ ${TASK_TYPE} == "icl" ]]; then
23 | PROMPT_TYPE="many_shot" # Use "many_shot" for ICL.
24 | else
25 | PROMPT_TYPE="few_shot_with_cot"
26 | fi
27 | PROMPT="${TASK_TYPE}_${DATASET}_${LENGTH}_${SPLIT}:${PROMPT_TYPE}"
28 | echo "Prompt: ${PROMPT}"
29 |
30 | mkdir -p ${BASE_DIR}/outputs/${TASK_TYPE}/${DATASET}/${LENGTH}
31 |
32 | answer_file_extension="jsonl"
33 | if [[ ${TASK_TYPE} == "icl" ]]; then
34 | answer_file_extension="json"
35 | fi
36 |
37 | # TopiocQA retrieval task has duplicate PIDs.
38 | deduplicate_pids=false
39 | if [[ ${DATASET} == "topiocqa" ]]; then
40 | if [[ ${TASK_TYPE} == "retrieval" ]]; then
41 | deduplicate_pids=true
42 | fi
43 | fi
44 |
45 | python run_inference.py \
46 | --prompt_name ${PROMPT} \
47 | --task_type ${TASK_TYPE} \
48 | --base_dir ${BASE_DIR} \
49 | --data_dir ${TASK_TYPE}/${DATASET}/${LENGTH} \
50 | --split ${SPLIT} \
51 | --context_length ${LENGTH} \
52 | --output_path ${BASE_DIR}/outputs/${TASK_TYPE}/${DATASET}/${LENGTH}/${SPLIT}_predictions.jsonl \
53 | --project_id ${PROJECT_ID} \
54 | --overwrite
55 |
56 | python run_evaluation.py \
57 | --answer_file_path ${BASE_DIR}/data/${TASK_TYPE}/${DATASET}/${LENGTH}/dev_queries.${answer_file_extension} \
58 | --pred_file_path ${BASE_DIR}/outputs/${TASK_TYPE}/${DATASET}/${LENGTH}/${SPLIT}_predictions.jsonl \
59 | --deduplicate_pids=${deduplicate_pids} \
60 | --task_type ${TASK_TYPE}
61 |
--------------------------------------------------------------------------------
/inference/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Models used for inference."""
17 |
18 | import abc
19 | import traceback
20 | from typing import Any, List
21 |
22 | from absl import logging
23 | import utils
24 | import vertexai
25 | from vertexai.generative_models import GenerationConfig
26 | from vertexai.generative_models import GenerativeModel
27 | from vertexai.generative_models import HarmCategory
28 | from vertexai.generative_models import Part
29 | from vertexai.generative_models import SafetySetting
30 |
31 |
32 | ContentChunk = utils.ContentChunk
33 | MimeType = utils.MimeType
34 | LOCATION = 'us-central1'
35 | TEMPERATURE = 0.0
36 |
37 |
38 | class Model(metaclass=abc.ABCMeta):
39 | """Base class for models."""
40 |
41 | def index(
42 | self,
43 | content_chunks: List[ContentChunk],
44 | document_indices: List[tuple[int, int]],
45 | **kwargs: Any,
46 | ) -> str:
47 | """Indexes the example containing the corpus.
48 |
49 | Arguments:
50 | content_chunks: list of content chunks to send to the model.
51 | document_indices: list of (start, end) indices marking the documents
52 | boundaries within content_chunks.
53 | **kwargs: additional arguments to pass.
54 |
55 | Returns:
56 | Indexing result.
57 | """
58 | del content_chunks, document_indices, kwargs # Unused.
59 | return 'Indexing skipped since not supported by model.'
60 |
61 | @abc.abstractmethod
62 | def infer(
63 | self,
64 | content_chunks: List[ContentChunk],
65 | document_indices: List[tuple[int, int]],
66 | **kwargs: Any,
67 | ) -> str:
68 | """Runs inference on model and returns text response.
69 |
70 | Arguments:
71 | content_chunks: list of content chunks to send to the model.
72 | document_indices: list of (start, end) indices marking the documents
73 | boundaries within content_chunks.
74 | **kwargs: additional arguments to pass to the model.
75 |
76 | Returns:
77 | Inference result.
78 | """
79 | raise NotImplementedError
80 |
81 |
82 | class VertexAIModel(Model):
83 | """GCP VertexAI wrapper for general Gemini models."""
84 |
85 | def __init__(
86 | self,
87 | project_id: str,
88 | model_name: str,
89 | pid_mapper: dict[str, str],
90 | answer_prefix: str = 'final answer',
91 | ):
92 | self.project_id = project_id
93 | self.model_name = model_name
94 | self.pid_mapper = pid_mapper
95 | vertexai.init(project=project_id, location=LOCATION)
96 | self.model = GenerativeModel(self.model_name)
97 | self.answer_prefix = answer_prefix
98 |
99 | def _process_content_chunk(self, content_chunk: ContentChunk) -> Part:
100 | if content_chunk.mime_type in [
101 | MimeType.TEXT,
102 | MimeType.IMAGE_JPEG,
103 | MimeType.AUDIO_WAV,
104 | ]:
105 | return Part.from_data(
106 | content_chunk.data, mime_type=content_chunk.mime_type
107 | )
108 | else:
109 | raise ValueError(f'Unsupported MimeType: {content_chunk.mime_type}')
110 |
111 | def _get_safety_settings(
112 | self, content_chunks: List[ContentChunk]
113 | ) -> List[SafetySetting]:
114 | """Returns safety settings for the given content chunks."""
115 | # Audio prompts cannot use BLOCK_NONE.
116 | if any(
117 | content_chunk.mime_type == MimeType.AUDIO_WAV
118 | for content_chunk in content_chunks
119 | ):
120 | threshold = SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH
121 | else:
122 | threshold = SafetySetting.HarmBlockThreshold.BLOCK_NONE
123 | return [
124 | SafetySetting(
125 | category=category,
126 | threshold=threshold,
127 | )
128 | for category in HarmCategory
129 | ]
130 |
131 | def _postprocess_response(self, response: Any) -> List[str]:
132 | """Postprocesses the response from the model."""
133 | try:
134 | output_text = getattr(response, 'candidates')[0].content.parts[0].text
135 | final_answers = utils.extract_prediction(output_text, self.answer_prefix)
136 | if self.pid_mapper is not None:
137 | final_answers = [
138 | self.pid_mapper[str(answer)] for answer in final_answers
139 | ]
140 | except Exception as e: # pylint:disable=broad-exception-caught
141 | logging.error('Bad response %s with error: %s', response, str(e))
142 | traceback.print_exc()
143 | raise ValueError(f'Unexpected response: {response}') from e
144 |
145 | return final_answers
146 |
147 | def infer(
148 | self,
149 | content_chunks: List[ContentChunk],
150 | **kwargs: Any,
151 | ) -> List[str]:
152 | response = self.model.generate_content(
153 | [
154 | self._process_content_chunk(content_chunk)
155 | for content_chunk in content_chunks
156 | ],
157 | generation_config=GenerationConfig(temperature=TEMPERATURE, top_p=1.0),
158 | safety_settings=self._get_safety_settings(content_chunks),
159 | )
160 |
161 | return self._postprocess_response(response)
162 |
163 |
164 | def get_model(
165 | model_url_or_name: str,
166 | project_id: str | None,
167 | pid_mapper: dict[str, str],
168 | answer_prefix: str = 'final answer',
169 | ) -> Model:
170 | """Returns the model to use."""
171 |
172 | if model_url_or_name.startswith('gemini-'):
173 | if project_id is None:
174 | raise ValueError(
175 | 'Project ID and service account are required for VertexAIModel.'
176 | )
177 | model = VertexAIModel(
178 | project_id=project_id,
179 | model_name=model_url_or_name,
180 | pid_mapper=pid_mapper,
181 | answer_prefix=answer_prefix,
182 | )
183 | else:
184 | raise ValueError(f'Unsupported model: {model_url_or_name}')
185 | return model
186 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""Preprocess the LOFT data by filling in the missing fields.
17 |
18 | Example usage:
19 | python preprocess.py \
20 | --input_dir=data/loft/retrieval/fiqa \
21 | --dataset=fiqa
22 | """
23 |
24 | from collections.abc import Sequence
25 | import glob
26 | import json
27 | import os
28 | import zipfile
29 |
30 | from absl import app
31 | from absl import flags
32 | import cv2
33 | import numpy as np
34 | import tqdm
35 | import wget
36 |
37 |
38 | # pylint: disable=line-too-long
39 | DATASET_DOWNLOAD_LINKS = {
40 | "fiqa": "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip",
41 | "msmarco": "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/msmarco.zip",
42 | "quora": "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/quora.zip",
43 | "webis_touche2020": "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/webis-touche2020.zip",
44 | "msrvtt": (
45 | "https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip"
46 | ),
47 | }
48 | # pylint: enable=line-too-long
49 |
50 | _INPUT_DIR = flags.DEFINE_string(
51 | "input_dir",
52 | default=None,
53 | help="The input directory to extract the LOFT data from.",
54 | required=True,
55 | )
56 | _DATASET = flags.DEFINE_enum(
57 | "dataset",
58 | default=None,
59 | enum_values=list(DATASET_DOWNLOAD_LINKS),
60 | help="Dataset to download and preprocess.",
61 | required=True,
62 | )
63 | _COMPRESSION_TYPE = flags.DEFINE_enum(
64 | "compression_type",
65 | default="zip",
66 | enum_values=["zip"],
67 | help="Compression type of the dataset.",
68 | )
69 |
70 | VIDEO_FILEPATTERN = "msrvtt/videos/all/{}.mp4"
71 | DATASET_LENGTHS = ["32k", "128k", "1m"]
72 | QUERY_FILES = [
73 | "dev_queries.jsonl",
74 | "few_shot_queries.jsonl",
75 | "test_queries.jsonl",
76 | ]
77 |
78 |
79 | def extract_frames_from_video(video_path, output_pattern, num_frames=3):
80 | """Extract video frames from a input video at a given frame rate."""
81 | # Open the video file
82 | video_capture = cv2.VideoCapture(video_path)
83 |
84 | # Get the total number of frames in the video
85 | total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
86 |
87 | # Generate the frame indices to sample uniformly
88 | frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
89 |
90 | frame_names = []
91 | for frame_index in frame_indices:
92 | # Set the video capture to the specific frame index
93 | video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
94 |
95 | ret, frame = video_capture.read()
96 |
97 | if ret:
98 | # Save the sampled frame to the output folder
99 | frame_name = f"{output_pattern}_{frame_index:08d}.jpg"
100 | cv2.imwrite(frame_name, frame)
101 | frame_names.append(os.path.basename(frame_name))
102 |
103 | video_capture.release()
104 | return frame_names
105 |
106 |
107 | def extract_video_resource(
108 | download_dir: str, resource_dir: str
109 | ) -> dict[str, str]:
110 | """Extract video into image frames."""
111 | if not os.path.exists(resource_dir):
112 | os.makedirs(resource_dir)
113 |
114 | video2frames = dict()
115 | video_filepaths = glob.glob(
116 | os.path.join(download_dir, VIDEO_FILEPATTERN.format("*"))
117 | )
118 | for video_filepath in tqdm.tqdm(video_filepaths):
119 | video_id = os.path.basename(video_filepath).split(".")[0]
120 | video_frame_pattern = os.path.join(resource_dir, video_id)
121 |
122 | video2frames[video_id] = extract_frames_from_video(
123 | video_filepath, video_frame_pattern
124 | )
125 | return video2frames
126 |
127 |
128 | def extract_dataset(
129 | dataset: str, input_dir: str, compression_type: str
130 | ) -> None:
131 | """Extracts the dataset from the compressed file."""
132 | if compression_type == "zip":
133 | with zipfile.ZipFile(
134 | os.path.join(input_dir, dataset + ".zip"), "r"
135 | ) as zip_ref:
136 | extracted_dir = zip_ref.namelist()[0]
137 | zip_ref.extractall(input_dir)
138 | # Rename the extracted directory to the dataset name. Needed for datasets
139 | # like msrvtt and webis_touche2020 where the extracted directory name is
140 | # different from the dataset name.
141 | os.rename(
142 | os.path.join(input_dir, extracted_dir),
143 | os.path.join(input_dir, dataset),
144 | )
145 | else:
146 | raise ValueError(f"Unsupported compression type: {compression_type}")
147 |
148 |
149 | def download_dataset(dataset: str, download_dir: str) -> None:
150 | """Downloads the dataset from the dataset download link."""
151 |
152 | os.makedirs(download_dir, exist_ok=True)
153 | zipped_filepath = os.path.join(download_dir, dataset + ".zip")
154 | if not os.path.exists(zipped_filepath):
155 | wget.download(DATASET_DOWNLOAD_LINKS[dataset], out=zipped_filepath)
156 | else:
157 | print("Skipping downloading as the zip file already exists.")
158 |
159 | if not os.path.exists(os.path.join(download_dir, dataset)):
160 | extract_dataset(dataset, download_dir, _COMPRESSION_TYPE.value)
161 | else:
162 | print("Skipping extracting as the dataset already exists.")
163 |
164 |
165 | def load_dataset(
166 | dataset: str, input_dir: str
167 | ) -> tuple[dict[str, str], dict[str, dict[str, str]]]:
168 | """Load the downloaded source dataset."""
169 | qid2text = {}
170 | pid2text = {}
171 | # Other datasets like Flickr will be added later.
172 | if dataset in ["fiqa", "msmarco", "quora", "webis_touche2020"]:
173 | # Fill in the missing fields in the query and corpus files.
174 | source_dir = os.path.join(input_dir, "source", dataset)
175 | with open(os.path.join(source_dir, "queries.jsonl"), "r") as f:
176 | for line in f:
177 | query = json.loads(line)
178 | qid2text[query["_id"]] = query["text"]
179 | with open(os.path.join(source_dir, "corpus.jsonl"), "r") as f:
180 | for line in f:
181 | passage = json.loads(line)
182 | pid2text[passage["_id"]] = {
183 | "title": passage["title"],
184 | "text": passage["text"],
185 | }
186 | else:
187 | raise ValueError(f"Dataset {dataset} not available.")
188 |
189 | return qid2text, pid2text
190 |
191 |
192 | def update_loft_dataset(
193 | qid2text: dict[str, str],
194 | pid2text: dict[str, dict[str, str]],
195 | input_dir: str,
196 | ) -> None:
197 | """Update the LOFT dataset with the missing fields."""
198 | for length in DATASET_LENGTHS:
199 | for query_file in QUERY_FILES:
200 | target_query_file = os.path.join(input_dir, length, query_file)
201 | if not os.path.exists(target_query_file):
202 | print(f"Skipping {target_query_file} as it does not exist.")
203 | continue
204 | queries = []
205 | with open(target_query_file, "r") as f:
206 | for line in f:
207 | query = json.loads(line)
208 | if query["qid"] not in qid2text:
209 | raise ValueError(f"Query {query['qid']} not found in the queries.")
210 | query["query_text"] = qid2text[query["qid"]]
211 | queries.append(query)
212 | with open(target_query_file, "w") as f:
213 | for query in queries:
214 | json.dump(query, f)
215 | f.write("\n")
216 | print(f"Wrote to {target_query_file}.")
217 |
218 | target_corpus_file = os.path.join(input_dir, length, "corpus.jsonl")
219 | passages = []
220 | with open(target_corpus_file, "r") as f:
221 | for line in f:
222 | passage = json.loads(line)
223 | if passage["pid"] not in pid2text:
224 | raise ValueError(f"Passage {passage['pid']} not found in the corpus.")
225 | passage["title_text"] = pid2text[passage["pid"]]["title"]
226 | passage["passage_text"] = pid2text[passage["pid"]]["text"]
227 | passages.append(passage)
228 | with open(target_corpus_file, "w") as f:
229 | for passage in passages:
230 | json.dump(passage, f)
231 | f.write("\n")
232 | print(f"Wrote to {target_corpus_file}.")
233 |
234 |
235 | def update_mm_loft_dataset(
236 | input_dir: str,
237 | resource_mapping: dict[str, str],
238 | ) -> None:
239 | """Update the LOFT dataset with the missing fields."""
240 | for length in DATASET_LENGTHS:
241 | # Loading the corpus file.
242 | target_corpus_file = os.path.join(input_dir, length, "corpus.jsonl")
243 | passages = []
244 | with open(target_corpus_file, "r") as f:
245 | for line in f:
246 | passage = json.loads(line)
247 | resource_id = passage["pid"]
248 | passage["metadata"]["img_paths"] = resource_mapping[resource_id]
249 | passages.append(passage)
250 |
251 | # Writing the corpus file.
252 | with open(target_corpus_file, "w") as f:
253 | for passage in passages:
254 | json.dump(passage, f)
255 | f.write("\n")
256 | print(f"Wrote to {target_corpus_file}.")
257 |
258 |
259 | def main(argv: Sequence[str]) -> None:
260 | if len(argv) > 1:
261 | raise app.UsageError("Too many command-line arguments.")
262 |
263 | download_dataset(_DATASET.value, os.path.join(_INPUT_DIR.value, "source"))
264 | if _DATASET.value in ["msrvtt"]:
265 | resource_mapping = extract_video_resource(
266 | os.path.join(_INPUT_DIR.value, "source"),
267 | os.path.join(_INPUT_DIR.value, "resource"),
268 | )
269 | update_mm_loft_dataset(_INPUT_DIR.value, resource_mapping)
270 | elif _DATASET.value in ["fiqa", "msmarco", "quora", "webis_touche2020"]:
271 | qid2text, pid2text = load_dataset(_DATASET.value, _INPUT_DIR.value)
272 | update_loft_dataset(qid2text, pid2text, _INPUT_DIR.value)
273 | else:
274 | raise ValueError(
275 | f"Preprocessor for dataset {_DATASET.value} not available."
276 | )
277 |
278 |
279 | if __name__ == "__main__":
280 | app.run(main)
281 |
--------------------------------------------------------------------------------
/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Init for LOFT Prompts."""
17 |
18 | from prompts import prompts_icl
19 | from prompts import prompts_mm
20 | from prompts import prompts_rag
21 | from prompts import prompts_retrieval
22 | from prompts import prompts_sql
23 |
--------------------------------------------------------------------------------
/prompts/constants/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Init for LOFT Prompt Constants."""
17 |
--------------------------------------------------------------------------------
/prompts/constants/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Constants for prompt formatting."""
17 |
18 | CORPUS_INSTRUCTION = """
19 | You will be given a list of candidates such as documents, images, videos, audios, etc. You need to check them carefully and understand all of them. Then you will be given a query, and your goal is to find all candidates from the list that can help answer the query. Print out the ID of each candidate.
20 | """.strip()
21 |
22 | CORPUS_FORMAT = "ID: {pid} | TITLE: {title} | CONTENT: {passage}"
23 | CORPUS_FORMAT_NOTITLE = "ID: {pid} | CONTENT: {passage}"
24 | CORPUS_FORMAT_ECHO = (
25 | "ID: {pid} | TITLE: {title} | CONTENT: {passage} | END ID: {pid}"
26 | )
27 | CORPUS_FORMAT_ECHO_NOTITLE = "ID: {pid} | CONTENT: {passage} | END ID: {pid}"
28 | CORPUS_FORMAT_ID_CONCAT = 'TITLE: "{title} {pid}" | CONTENT: {passage}'
29 | CORPUS_FORMAT_ID_CONCAT_NOTITLE = 'CONTENT: "{passage} {pid}"'
30 | CORPUS_FORMAT_RAG = (
31 | "[{pid}] ({title}) CONTENT: {passage}"
32 | )
33 |
34 | # Few-shot example answer formats
35 | FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE = "ID: {pid} | TITLE: {title}"
36 | FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_NOTITLE = "ID: {pid} | CONTENT: {passage}"
37 | FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE = "TITLE: {title} | ID: {pid}"
38 | FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE_NOTITLE = (
39 | "CONTENT: {passage} | ID: {pid}"
40 | )
41 | FEW_SHOT_EXAMPLE_ANSWER_FORMAT_ID_ONLY = "ID: {pid}"
42 | FEW_SHOT_EXAMPLE_ANSWER_FORMAT_ID_CONCAT = 'TITLE: "{title} {pid}"'
43 | FEW_SHOT_EXAMPLE_ANSWER_FORMAT_ID_CONCAT_NOTITLE = 'CONTENT: "{passage} {pid}"'
44 | FEW_SHOT_EXAMPLE_ANSWER_RAG = "[{pid}] ({title})"
45 |
46 | # Final answer used for evaluation.
47 | FINAL_ANSWER_FORMAT = "Final Answer: {final_answer}"
48 |
49 | FEW_SHOT_SEPARATOR = "====== Example {example_id} ======\n"
50 | TEST_QUERY_SEPARATOR = "====== Now let's start! ======\n"
51 |
--------------------------------------------------------------------------------
/prompts/constants/icl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Prompt constants for ICL."""
17 |
18 | ############################ Corpus Instruction ################################
19 | LONG_ICL_BENCH_DIALOGUE_RE_LABELS = [
20 | "per:alternate_names",
21 | "per:alumni",
22 | "per:place_of_residence",
23 | "per:employee_or_member_of",
24 | "per:girl/boyfriend",
25 | "per:title",
26 | "per:positive_impression",
27 | "gpe:residents_of_place",
28 | "org:employees_or_members",
29 | "per:children",
30 | "per:parents",
31 | "per:siblings",
32 | "per:spouse",
33 | "per:friends",
34 | "per:negative_impression",
35 | "per:client",
36 | "per:pet",
37 | "per:place_of_work",
38 | "per:boss",
39 | "per:subordinate",
40 | "per:acquaintance",
41 | "per:roommate",
42 | "per:dates",
43 | "per:other_family",
44 | "per:age",
45 | "per:visited_place",
46 | "gpe:visitors_of_place",
47 | "per:origin",
48 | "per:neighbor",
49 | "per:works",
50 | "per:schools_attended",
51 | "org:students",
52 | "per:major",
53 | "per:date_of_birth",
54 | ]
55 |
56 | CORPUS_INSTRUCTION = {
57 | "bbh": """
58 | Please answer the following questions and ensure you follow a consistent format. In particular, ensure your final answer always looks like `Output: ['your_answer_here']`.
59 | """.strip(),
60 | "long_icl_bench_dialogue_re": (
61 | """
62 | Given the dialogue, please find the name pair entities in the dialogue and their corresponding relation types in the strict format of the given examples. Please only strictly choose from the following relation types (note that the number of entities has to strictly have the same value as the number of respective relation):
63 | """.strip()
64 | + "\n"
65 | + "\n".join(LONG_ICL_BENCH_DIALOGUE_RE_LABELS)
66 | + "\n\n"
67 | + """Note that expected output is a series of relations from the provided relation list each in a new line.\nDo NOT include any information other than the relations separated by a new line.\nDo not print and markers.\nNote that the number of the relations should strictly match the number of entity pairs provided.\nYou can use the examples below to learn how to find the correct relations:
68 | """.strip()
69 | ),
70 | }
71 |
72 | CORPUS_FORMAT = {
73 | "bbh": "{input}\nOutput: {target}\n",
74 | "long_icl_bench_dialogue_re": """
75 | {target_proposition}\n{dialogue}\n\nThe list of {pair_length} entity pairs are:\n{pair_list}\nThe {pair_length} respective relations between each entity pairs are:{relation_list}{ending_notation}""".rstrip(),
76 | }
77 |
78 | TARGET_PROPOSITION = (
79 | "Now look at the dialogue below and mark the relations for the given"
80 | " entity pairs:\n\n"
81 | )
82 | ENDING_NOTATION = "\n\n\n"
83 |
--------------------------------------------------------------------------------
/prompts/constants/mm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Prompt constants for multimodal retrieval."""
17 | from prompts.constants import common as common_constants
18 |
19 | # Fall back to another similar prompt if it is not found from the constants.
20 | PROMPT_MAPPER = {
21 | "fleurs_en_tts": "fleurs_tts",
22 | "fleurs_es_tts": "fleurs_tts",
23 | "fleurs_hi_tts": "fleurs_tts",
24 | "fleurs_zh_tts": "fleurs_tts",
25 | "fleurs_fr_tts": "fleurs_tts",
26 | "fleurs_en_stt": "fleurs_stt",
27 | "fleurs_es_stt": "fleurs_stt",
28 | "fleurs_hi_stt": "fleurs_stt",
29 | "fleurs_zh_stt": "fleurs_stt",
30 | "fleurs_fr_stt": "fleurs_stt",
31 | }
32 |
33 | ############################ Corpus Instruction ################################
34 | CLOSED_BOOK_CORPUS_INSTRUCTION = {
35 | "oven": """
36 | You will be given a input image and a question related to the image, and your goal is to find most relevant Wikipedia entry that can be used to best answer the question. Output the Wikipedia title."
37 | """.strip(),
38 | }
39 |
40 | CORPUS_INSTRUCTION = {
41 | "oven": """
42 | You will be given a list of Wikipedia entries which contains Wikipedia ID, Title and Description image. You need to watch carefully and memorize all of them. Then you will be given a input image and a question related to the image, and your goal is to find most relevant Wikipedia entry from the list that can be used to best answer the question. First output the Wikipedia title then output the Wikipedia ID."
43 | """.strip(),
44 | "msrvtt": """
45 | You will be given a list of videos which contains the video ID and video content (present as sequence of images, with timestamp in text), You need to wath carefully and memorize all of them. Then you will be given a text query, and your goal is to find most relevant video from the list that can best answer the question. Output the corresponding video ID.
46 | """.strip(),
47 | "flickr30k": """
48 | You will be given a list of images. You need to watch carefully and memorize all of them. Then you will be given a new sentence, and your goal is to find most relevant image from the list for the given sentence. Print out the image index, which is presented before image in the corpus.
49 | """.strip(),
50 | "mscoco": """
51 | You will be given a list of images. You need to watch carefully and memorize all of them. Then you will be given a new sentence, and your goal is to find most relevant image from the list for the given sentence. Print out the image index, which is presented before image in the corpus.
52 | """.strip(),
53 | "fleurs_tts": """
54 | You will be given a list of audio which contains Audio ID and audio. You need to listen carefully and memorize all of them. Then you will be given a transcript, and your goal is to find most relevant audio from the list that matches the given transcript. Print out the Audio ID of the audio presented in the list.
55 |
56 | AUDIO CORPUS
57 | """.strip(),
58 | "fleurs_stt": """
59 | You will be given a list of transcripts which contains Transcript ID and transcript. You need to read carefully and memorize all of them. Then you will be given an audio, and your goal is to find most relevant transcript from the list that matches the given audio. Print out the Transcript ID of the transcript presented in the list.
60 |
61 | TRANSCRIPT CORPUS
62 | """.strip(),
63 | }
64 | CORPUS_FORMAT = {
65 | "oven": "ID: {pid} | TITLE: {title}",
66 | "msrvtt": "Video ID: {pid}",
67 | "flickr30k": "{pid}",
68 | "mscoco": "{pid}",
69 | "fleurs_tts": "Audio ID: {pid}",
70 | "fleurs_stt": "Transcript ID: {pid} | Transcript: {passage}",
71 | }
72 | CORPUS_FORMAT_RAG = {
73 | "oven": "ID: {pid} | TITLE: {title} | DESCRIPTION: {passage} | IMAGE:",
74 | }
75 |
76 | ############################# Query Formats ####################################
77 | CLOSED_BOOK_QUERY_FORMAT_PREFIX = {
78 | "oven": """
79 | ====== Now let's start! ======
80 | Given a input Image and a Question, find the most relevant Wikipedia entry for the given question. First output the ID then output the TITLE. Then format the Wikipedia IDs into a list in Final Answer.
81 | """.strip(),
82 | }
83 |
84 | QUERY_FORMAT_PREFIX = {
85 | "oven": """
86 | ====== Now let's start! ======
87 | Given a input Image and a Question, find the most relevant Wikipedia entry from the above list for the given question. First output the ID then output the TITLE. Then format the Wikipedia IDs into a list in Final Answer.
88 | """,
89 | "msrvtt": """
90 | ====== Now let's start! ======
91 | Given the text query, find the most relevant video entry from the above list, and output the corresponding Video ID into a list in Final Answer.
92 | """,
93 | "flickr30k": """
94 | ====== Now let's start! ======
95 | Given a sentence, find most relevant image from the list for the given sentence. Print out the image index, which is presented before image in the corpus.
96 | """,
97 | "mscoco": """
98 | ====== Now let's start! ======
99 | Given a sentence, find most relevant image from the list for the given sentence. Print out the image index, which is presented before image in the corpus.
100 | """,
101 | "fleurs_tts": """
102 | ====== Now let's start! ======
103 | Given a transcript, find most relevant audio from the above list of audio in the corpus. Print out the Audio ID of the audio and then format the Audio ID into a list in Final Answer.
104 | """,
105 | "fleurs_stt": """
106 | ====== Now let's start! ======
107 | Given an audio, find most relevant transcript from the above list of transcript in the corpus. Print out the Transcript ID of the audio and then format the Transcript ID into a list in Final Answer.
108 | Audio:
109 | """,
110 | }
111 |
112 | CLOSED_BOOK_QUERY_FORMAT_SUFFIX = {
113 | "oven": """
114 | Question: {query}
115 | The following wikipedia entry can answer the question:
116 | """.strip(),
117 | }
118 |
119 | QUERY_FORMAT_SUFFIX = {
120 | "oven": """
121 | Question: {query}
122 | The following wikipedia entry can answer the question:
123 | """,
124 | "msrvtt": """
125 | Query: {query}
126 | """,
127 | "flickr30k": """
128 | Sentence: {query}
129 | The most relevant image is:
130 | """.strip(),
131 | "mscoco": """
132 | Sentence: {query}
133 | The most relevant image is:
134 | """.strip(),
135 | "fleurs_tts": "Transcript: {query}",
136 | "fleurs_stt": "",
137 | }
138 |
139 | ################################ Few-shot Formats ##############################
140 | FEW_SHOT_QUERY_FORMAT_PREFIX = {
141 | "oven": """
142 | ====== Example {example_id} ======
143 | Given a input Image and a Question, find the most relevant Wikipedia entry from the above list for the given question. First output the ID then output the TITLE. Then format the Wikipedia IDs into a list in Final Answer.
144 | """.strip(),
145 | "msrvtt": """
146 | ====== Example {example_id} ======
147 | Given the text query, find the most relevant video from the above list, output the Video ID.
148 | """.strip(),
149 | "flickr30k": """
150 | ====== Example {example_id} ======
151 | Given a sentence, find most relevant image from the list for the given sentence. Print out the image index, which is presented before image in the corpus.
152 | """.strip(),
153 | "mscoco": """
154 | ====== Example {example_id} ======
155 | Given a sentence, find most relevant image from the list for the given sentence. Print out the image index, which is presented before image in the corpus.
156 | """.strip(),
157 | "fleurs_tts": """
158 | ====== Example {example_id} ======
159 | Given a transcript, find most relevant audio from the above list of audio in the corpus. Print out the Audio ID of the audio and then format the Audio ID into a list in Final Answer.
160 | """.strip(),
161 | "fleurs_stt": """
162 | ====== Example {example_id} ======
163 | Given an audio, find most relevant transcript from the above list of transcript in the corpus. Print out the Transcript ID of the audio and then format the Transcript ID into a list in Final Answer.
164 | Audio:
165 | """.strip(),
166 | }
167 |
168 | FEW_SHOT_QUERY_FORMAT_SUFFIX = {
169 | "oven": """Question: {query}
170 | The following wikipedia entry can answer the question:
171 | """.strip(),
172 | "msrvtt": """
173 | Query: {query}
174 | """.strip(),
175 | "flickr30k": """
176 | Sentence: {query}
177 | The most relevant image is:
178 | """.strip(),
179 | "mscoco": """
180 | Sentence: {query}
181 | The most relevant image is:
182 | """.strip(),
183 | "fleurs_tts": "Transcript: {query}",
184 | "fleurs_stt": "",
185 | }
186 |
187 | FEW_SHOT_EXAMPLE_ANSWER_FORMAT = {
188 | "oven": common_constants.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE,
189 | "msrvtt": common_constants.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_ID_ONLY,
190 | "flickr30k": common_constants.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_ID_ONLY,
191 | "mscoco": common_constants.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_ID_ONLY,
192 | "fleurs_tts": "Audio ID: {pid}",
193 | "fleurs_stt": "Transcript ID: {pid}"
194 | }
195 |
--------------------------------------------------------------------------------
/prompts/constants/rag.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Prompt constants for RAG."""
17 |
18 | from prompts.constants import common
19 |
20 | # Fall back to another similar prompt if it is not found from the constants.
21 | PROMPT_MAPPER = {
22 | "hotpotqa": "nq",
23 | "musique": "nq",
24 | "qampari": "nq",
25 | "quest": "nq",
26 | "topiocqa": "nq",
27 | "romqa": "nq",
28 | }
29 |
30 |
31 | ############################# Corpus Formats ###################################
32 |
33 | # All datasets have title.
34 | CORPUS_FORMAT_SIMPLE = {
35 | "nq": common.CORPUS_FORMAT,
36 | }
37 |
38 | CORPUS_FORMAT_ECHO = {
39 | "nq": common.CORPUS_FORMAT_ECHO,
40 | }
41 |
42 | CORPUS_FORMAT_REMOVED = {
43 | "nq": "",
44 | }
45 |
46 |
47 | CORPUS_FORMAT_RAG = {
48 | "nq": common.CORPUS_FORMAT_RAG,
49 | }
50 |
51 | ############################ Corpus Instruction ################################
52 |
53 | FORMATTING_INSTRUCTION = """
54 | Your final answer should be in a list, in the following format:
55 | Final Answer: ['answer1', 'answer2', ...]
56 | If there is only one answer, it should be in the format:
57 | Final Answer: ['answer']
58 | """
59 |
60 | CLOSED_BOOK_CORPUS_INSTRUCTION = {
61 | "nq": """
62 | You will be given a query, and your goal is to answer the query.
63 | """.strip(),
64 | }
65 |
66 | CORPUS_INSTRUCTION = {
67 | "nq": """
68 | You will be given a list of documents. You need to read carefully and understand all of them. Then you will be given a query, and your goal is to answer the query based on the documents you have read.
69 | """.strip(),
70 | }
71 |
72 | BASELINE_CORPUS_INSTRUCTION = {
73 | "nq": """
74 | You will be given a query and a list of documents. Your goal is to answer the query based on the documents you have read.
75 | """.strip(),
76 | }
77 |
78 | ############################# Query Formats ####################################
79 | CLOSED_BOOK_QUERY_FORMAT = {
80 | "nq": """
81 | ====== Now let's start! ======
82 | Can you answer the query? Format the answers into a list.
83 | query: {query}
84 | """.strip(),
85 | }
86 |
87 |
88 | QUERY_FORMAT_NO_COT = {
89 | "nq": """
90 | Based on the documents above, can you answer the following query? Format the answer into a list.
91 | query: {query}
92 | """.strip(),
93 | }
94 |
95 |
96 | QUERY_FORMAT_SIMPLE = {
97 | "nq": """
98 | Based on the documents above, can you answer the following query? Print out the ID and TITLE of the documents you use to answer. Then format the answers into a list.
99 | query: {query}
100 | """.strip(),
101 | }
102 |
103 | QUERY_FORMAT_RAG = {
104 | "nq": """
105 | Based on the documents above, can you answer the following query? Print out the passage number and TITLE of the documents you use to answer. Then format the answers into a list.
106 | query: {query}
107 | """.strip(),
108 | }
109 |
110 |
111 | QUERY_FORMAT_SIMPLE_REVERSE = {
112 | k: v.replace("ID and TITLE", "TITLE and ID").replace(
113 | "ID and CONTENT", "CONTENT and ID"
114 | )
115 | for k, v in QUERY_FORMAT_SIMPLE.items()
116 | }
117 |
118 | QUERY_FORMAT_WITH_COT = {
119 | "nq": """
120 | ====== Now let's start! ======
121 | Which document is most relevant to the query and can answer the query? Think step-by-step and then format the answers into a list.
122 | query: {query}
123 | """.strip(),
124 | }
125 |
126 | QUERY_FORMAT_NO_COT = {
127 | "nq": """
128 | Based on the documents above, answer the following query? Format the answers into a list.
129 | query: {query}
130 | """.strip(),
131 | }
132 |
133 | QUERY_FORMAT_NO_COT_BASELINE = """
134 | ====== Now let's start! ======
135 | Based on the documents below, answer the following query? Format the answers into a list.
136 | query: {query}
137 | documents:
138 | """.strip()
139 |
140 | ################################ Few-shot Formats ##############################
141 | CLOSED_BOOK_FEW_SHOT_QUERY_FORMAT = {
142 | "nq": """
143 | ====== Example {example_id} ======
144 | Can you answer the query? Format the answers into a list.
145 | query: {query}
146 | """.strip(),
147 | }
148 |
149 | FEW_SHOT_QUERY_FORMAT_0 = {
150 | "nq": """
151 | ====== Example {example_id} ======
152 | Based on the documents above, can you answer the following query? Format the answers into a list.
153 | query: {query}
154 | """.strip(),
155 | }
156 |
157 | FEW_SHOT_QUERY_FORMAT_WITH_COT = {
158 | "nq": """
159 | ====== Example {example_id} ======
160 | Which document is most relevant to the query and can answer the query? Think step-by-step and then format the answers into a list.
161 | query: {query}
162 | """.strip(),
163 | }
164 |
165 | FEW_SHOT_QUERY_FORMAT_NO_COT_BASELINE = """
166 | ====== Example {example_id} ======
167 | Based on the documents below, answer the following query? Format the answers into a list.
168 | query: {query}
169 | documents:
170 | """.strip()
171 |
172 |
173 | ################## Few-Shot Example Reasoning Formats #######################
174 | FEW_SHOT_EXAMPLE_COT_FORMAT_SIMPLE = {
175 | "nq": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE,
176 | }
177 | FEW_SHOT_EXAMPLE_COT_FORMAT_SIMPLE_REVERSE = {
178 | "nq": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE,
179 | }
180 | FEW_SHOT_EXAMPLE_COT_FORMAT_RAG = {
181 | "nq": common.FEW_SHOT_EXAMPLE_ANSWER_RAG,
182 | }
183 |
--------------------------------------------------------------------------------
/prompts/constants/retrieval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Prompt constants for retrieval."""
17 |
18 | from prompts.constants import common
19 |
20 | # Fall back to another similar prompt if it is not found from the constants.
21 | PROMPT_MAPPER = {
22 | "fever": "scifact",
23 | "msmarco": "fiqa",
24 | "musique": "hotpotqa",
25 | "qampari": "hotpotqa",
26 | "quest": "hotpotqa",
27 | "topiocqa": "nq",
28 | }
29 |
30 | ############################ Corpus Instruction ################################
31 |
32 | FORMATTING_INSTRUCTION = """
33 | Your final answer should be a list of IDs, in the following format:
34 | Final Answer: [id1, id2, ...]
35 | If there is only one ID, it should be in the format:
36 | Final Answer: [id1]
37 |
38 | If there is no perfect answer output the closest one. Do not give an empty final answer.
39 | """
40 |
41 | CLOSED_BOOK_CORPUS_INSTRUCTION = {
42 | "hotpotqa": """
43 | You will be given a query, and your goal is to find the titles of Wikipedia articles that can help answer the query.
44 | """.strip(),
45 | "nq": """
46 | You will be given a query, and your goal is to find the titles of Wikipedia articles that can help answer the query.
47 | """.strip(),
48 | }
49 |
50 | CORPUS_INSTRUCTION = {
51 | "arguana": """
52 | You will be given a list of statements. You need to read carefully and understand all of them. Then you will be given a claim, and your goal is to find all statements from the list that can counterargue the claim.
53 | """.strip(),
54 | "fiqa": """
55 | You will be given a list of documents. You need to read carefully and understand all of them. Then you will be given a query, and your goal is to find all documents from the list that can help answer the query. Print out the ID and CONTENT of each document.
56 | """.strip(),
57 | "hotpotqa": """
58 | You will be given a list of documents. You need to read carefully and understand all of them. Then you will be given a query that may require you to use 1 or more documents to find the answer. Your goal is to find all documents from the list that can help answer the query. Print out the ID and TITLE of each document.
59 | """.strip(),
60 | "nq": """
61 | You will be given a list of documents. You need to read carefully and understand all of them. Then you will be given a query, and your goal is to find all documents from the list that can help answer the query. Print out the ID and TITLE of each document.
62 | """.strip(),
63 | "quora": """
64 | You will be given a list of questions. You need to read carefully and understand all of them. Then you will be given a new question, and your goal is to find all questions from the list that are near duplicates of the new question. Print out the ID and CONTENT of each question.
65 | """.strip(),
66 | "scifact": """
67 | You will be given a list of passages. You need to read carefully and understand all of them. Then you will be given a claim, and your goal is to find all passages from the list that can help verify the claim as true of false. Print out the ID and TITLE of each passage.
68 | """.strip(),
69 | "webis_touche2020": """
70 | You will be given a list of arguments. You need to read carefully and understand all of them. Then you will be given a controversial debating topic, and your goal is to find arguments from the list that's relevant to the topic. Print out the ID and TITLE of each argument.
71 | """.strip(),
72 | }
73 |
74 | ############################# Corpus Formats ###################################
75 | CORPUS_FORMAT_SIMPLE = {
76 | "arguana": common.CORPUS_FORMAT_NOTITLE,
77 | "fiqa": common.CORPUS_FORMAT_NOTITLE,
78 | "hotpotqa": common.CORPUS_FORMAT,
79 | "nq": common.CORPUS_FORMAT,
80 | "quora": common.CORPUS_FORMAT_NOTITLE,
81 | "scifact": common.CORPUS_FORMAT,
82 | "webis_touche2020": common.CORPUS_FORMAT,
83 | }
84 |
85 | CORPUS_FORMAT_ECHO = {
86 | "arguana": common.CORPUS_FORMAT_ECHO_NOTITLE,
87 | "fiqa": common.CORPUS_FORMAT_ECHO_NOTITLE,
88 | "hotpotqa": common.CORPUS_FORMAT_ECHO,
89 | "nq": common.CORPUS_FORMAT_ECHO,
90 | "quora": common.CORPUS_FORMAT_ECHO_NOTITLE,
91 | "scifact": common.CORPUS_FORMAT_ECHO,
92 | "webis_touche2020": common.CORPUS_FORMAT_ECHO,
93 | }
94 | ################## Few-Shot Example Reasoning Formats #######################
95 | FEW_SHOT_EXAMPLE_COT_FORMAT_SIMPLE = {
96 | "arguana": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_NOTITLE,
97 | "fiqa": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_NOTITLE,
98 | "hotpotqa": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE,
99 | "nq": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE,
100 | "quora": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_NOTITLE,
101 | "scifact": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE,
102 | "webis_touche2020": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE,
103 | }
104 | FEW_SHOT_EXAMPLE_COT_FORMAT_SIMPLE_REVERSE = {
105 | "arguana": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE_NOTITLE,
106 | "fiqa": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE_NOTITLE,
107 | "hotpotqa": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE,
108 | "nq": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE,
109 | "quora": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE_NOTITLE,
110 | "scifact": common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE,
111 | "webis_touche2020": (
112 | common.FEW_SHOT_EXAMPLE_ANSWER_FORMAT_SIMPLE_REVERSE_NOTITLE
113 | ),
114 | }
115 |
116 | ################ Query Formats for few-shots and test queries ################
117 | CLOSED_BOOK_QUERY_FORMAT = {
118 | "hotpotqa": """
119 | Which Wikipedia article is most relevant to the query and can answer the query? You will need two articles to answer the query. Format the titles into a list.
120 | query: {query}
121 | The following articles can help answer the query:
122 | """.strip(),
123 | "nq": """
124 | Which Wikipedia article is most relevant to the query and can answer the query? Format the titles into a list.
125 | query: {query}
126 | The following articles can help answer the query:
127 | """.strip(),
128 | }
129 |
130 | QUERY_FORMAT_SIMPLE = {
131 | "arguana": """
132 | Given a claim, which statements provide a counterargument? Print out the ID and CONTENT of each statement. Then format the IDs into a list.
133 | If there is no perfect answer output the closest one. Do not give an empty final answer.
134 | claim: {query}
135 | The following statements can counterargue the claim:
136 | """.strip(),
137 | "fiqa": """
138 | Which document is most relevant to answering the query? Print out the ID and CONTENT of the document. Then format the IDs into a list.
139 | If there is no perfect answer output the closest one. Do not give an empty final answer.
140 | query: {query}
141 | The following documents can answer the query:
142 | """.strip(),
143 | "hotpotqa": """
144 | Which documents can help answer the query? Print out the ID and TITLE of each document. Then format the IDs into a list.
145 | If there is no perfect answer output the closest one. Do not give an empty final answer.
146 | query: {query}
147 | The following documents can help answer the query:
148 | """.strip(),
149 | "nq": """
150 | Which document is most relevant to answer the query? Print out the ID and TITLE of the document. Then format the IDs into a list.
151 | If there is no perfect answer output the closest one. Do not give an empty final answer.
152 | query: {query}
153 | The following documents can help answer the query:
154 | """.strip(),
155 | "quora": """
156 | Given the following query, which existing question is most similar to it? Print out the ID and CONTENT of the question. Then format the IDs into a list.
157 | If there is no perfect answer output the closest one. Do not give an empty final answer.
158 | query: {query}
159 | The following existing questions are most similar to the given query:
160 | """.strip(),
161 | "scifact": """
162 | Which passage is most relevant to the claim, and can help verify the claim as true or false? Print out the ID and TITLE of the document. Then format the IDs into a list.
163 | If there is no perfect answer output the closest one. Do not give an empty final answer.
164 | claim: {query}
165 | The following passages can help verify this sentence:
166 | """.strip(),
167 | "webis_touche2020": """
168 | Which argument is most relevant to the query? Print out the ID and TITLE of the argument. Then format the IDs into a list.
169 | If there is no perfect answer output the closest one. Do not give an empty final answer.
170 | query: {query}
171 | The following argument is most relevant to the query:
172 | """.strip(),
173 | }
174 |
175 | QUERY_FORMAT_SIMPLE_REVERSE = {
176 | k: v.replace("ID and TITLE", "TITLE and ID").replace(
177 | "ID and CONTENT", "CONTENT and ID"
178 | )
179 | for k, v in QUERY_FORMAT_SIMPLE.items()
180 | }
181 |
182 | QUERY_FORMAT_NO_COT = {
183 | k: v.replace("ID and TITLE", "ID").replace("ID and CONTENT", "ID")
184 | for k, v in QUERY_FORMAT_SIMPLE.items()
185 | }
186 |
187 |
188 | QUERY_FORMAT_WITH_COT = {
189 | "hotpotqa": """
190 | Which documents are relevant to answering the query? Let's think step-by-step to find two relevant documents. Then format the IDs into a list.
191 | If there is no perfect answer output the closest one. Do not give an empty final answer.
192 | query: {query}
193 | """.strip(),
194 | "nq": """
195 | Which document is most relevant to answering the query? Think step-by-step and then format the document IDs into a list.
196 | If there is no perfect answer output the closest one. Do not give an empty final answer.
197 | query: {query}
198 | The following documents can help answer the query:
199 | """.strip(),
200 | }
201 |
--------------------------------------------------------------------------------
/prompts/constants/sql.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Prompt constants for SQL."""
17 |
18 | # Fall back to another similar prompt if it is not found from the constants.
19 | PROMPT_MAPPER = {
20 | "sparc": "spider",
21 | }
22 |
23 | ############################ Corpus Instruction ################################
24 | CORPUS_INSTRUCTION = {
25 | "spider": """
26 | You will be given a list of tables. You need to memorize all of the rows of each table. Then you will be given a query, and your goal is to get the answer from the tables. Then format the answer into a list of lists. When formatting the answer into a list of lists, make sure you use the exact fields that are provided in the tables.
27 | """.strip(),
28 | }
29 |
30 | CORPUS_INSTRUCTION_TEXT2SQL_BASELINE = {"spider": """
31 | You will be given a list of SQL tables. Then you will be given a query. Your goal is to write the SQL query to get the answer from the tables. Do not include any explanation.
32 | """.strip()}
33 |
34 | CORPUS_FORMAT = {"spider": "Table: {title}\n{passage}"}
35 |
36 | CORPUS_FORMAT_TEXT2SQL_BASELINE = {"spider": "{passage}"}
37 |
38 | ############################# Query Formats ####################################
39 | QUERY_FORMAT_0 = {
40 | "spider": """
41 | ====== Now let's start! ======
42 | Given a query, find from the following tables all the information relevant to the query. Then answer the query. Then format the answer into a list of lists.
43 | TABLES
44 | {corpus}
45 |
46 | Query: {query}
47 | Answer: Here's a step-by-step approach using the provided tables:
48 | """.strip(),
49 | }
50 |
51 | FOLLOW_UP_QUERY_FORMAT_0 = {
52 | "spider": """
53 | ====== Now let's start! ======
54 | Given a query, find from the following tables all the information relevant to the query. Then answer the query. Then format the answer into a list of lists.
55 |
56 | Query: {query}
57 |
58 | Answer:
59 | """.strip(),
60 | }
61 |
62 |
63 | QUERY_NO_COT_FORMAT_0 = {
64 | "spider": """
65 | ====== Now let's start! ======
66 | Given a query, find from the following tables all the information relevant to the query. Then answer the query. Then format the answer into a list of lists.
67 | TABLES
68 | {corpus}
69 |
70 | Query: {query}
71 | Answer: Here is the final answer.
72 | """.strip(),
73 | }
74 |
75 | FOLLOW_UP_QUERY_NO_COT_FORMAT_0 = {
76 | "spider": """
77 | ====== Now let's start! ======
78 | Given a query, find from the following tables all the information relevant to the query. Then answer the query. Then format the answer into a list of lists.
79 |
80 | Query: {query}
81 | Answer: Here is the final answer.
82 | """.strip(),
83 | }
84 |
85 | QUERY_FORMAT_TEXT2SQL_BASELINE = {
86 | "spider": """
87 | ====== Now let's start! ======
88 | Given the following database schema, answer the query. Do not include any explanation.
89 |
90 | {corpus}
91 |
92 | Query: {query}
93 |
94 | Answer:
95 | """.strip(),
96 | }
97 |
98 | FOLLOW_UP_QUERY_FORMAT_TEXT2SQL_BASELINE = {
99 | "spider": """
100 | ====== Now let's start! ======
101 | Given the following database schema, answer the query. Do not include any explanation.
102 |
103 | Query: {query}
104 |
105 | Answer:
106 | """.strip(),
107 | }
108 |
109 | ################################ Few-shot Formats ##############################
110 | FEW_SHOT_EXAMPLES_V0 = {"spider": ["""
111 | ====== Example 1 ======
112 | Given a query, find from the following tables all the information relevant to the query. Then answer the query. Then format the answer into a list of lists.
113 | TABLES
114 | Table: Concert
115 | concert_ID,concert_Name,Theme,Stadium_ID,Year
116 | 1,Auditions,Free choice,1,2014
117 | 2,Super bootcamp,Free choice 2,2,2014
118 | 3,Home Visits,Bleeding Love,2,2015
119 | 4,Week 1,Wide Awake,10,2014
120 | 5,Week 1,Happy Tonight,9,2015
121 | 6,Week 2,Party All Night,7,2015
122 |
123 | Table: Singer
124 | Singer_ID,Name,Country,Song_Name,Song_release_year,Age,Is_male
125 | 1,Joe Sharp,Netherlands,You,1992,52,F
126 | 2,Timbaland,United States,Dangerous,2008,32,T
127 | 3,Justin Brown,France,Hey Oh,2013,29,T
128 | 4,Rose White,France,Sun,2003,41,F
129 | 5,John Nizinik,France,Gentleman,2014,43,T
130 | 6,Tribal King,France,Love,2016,25,T
131 |
132 | Table: Singer_in_Concert
133 | concert_ID,Singer_ID
134 | 1,2
135 | 1,3
136 | 1,5
137 | 2,3
138 | 2,6
139 | 3,5
140 | 4,4
141 | 5,6
142 | 5,3
143 | 6,2
144 |
145 | Table: Stadium
146 | Stadium_ID,Location,Name,Capacity,Highest,Lowest,Average
147 | 1,Raith Rovers,Stark's Park,10104,4812,1294,2106
148 | 2,Ayr United,Somerset Park,11998,2363,1057,1477
149 | 3,East Fife,Bayview Stadium,2000,1980,533,864
150 | 4,Queen's Park,Hampden Park,52500,1763,466,730
151 | 5,Stirling Albion,Forthbank Stadium,3808,1125,404,642
152 | 6,Arbroath,Gayfield Park,4125,921,411,638
153 | 7,Alloa Athletic,Recreation Park,3100,1057,331,637
154 | 9,Peterhead,Balmoor,4000,837,400,615
155 | 10,Brechin City,Glebe Park,3960,780,315,552
156 |
157 | Query: Show the stadium name and the number of concerts in each stadium.
158 |
159 | Answer: Here's a step-by-step approach using the provided tables:
160 |
161 | **1. Access relevant data:**
162 | We need information from two tables:
163 | * **Concert:** This table contains details about each concert, including the stadium ID where it was held.
164 | * **Stadium:** This table provides information about each stadium, including its name.
165 |
166 | **2. Combine data based on stadium ID:**
167 | We need to link the concert data with the corresponding stadium information. This can be done by joining the "Concert" and "Stadium" tables based on the common column "Stadium_ID".
168 |
169 | **3. Get concerts per stadium:**
170 | After joining the tables, we can group the data by stadium name and get the concerts associated with each stadium.
171 | Here are the concert_ID associated with each stadium name:
172 | * Stark's Park: 1
173 | * Somerset Park: 2, 3
174 | * Glebe Park: 4
175 | * Recreation Park: 6
176 | * Balmoor: 5
177 |
178 | **Note:** Stadiums with no associated concerts are not included.
179 |
180 | **3. Count concerts per stadium:**
181 | After we have gotten the concerts associated with each stadium we can count how many concerts were in each stadium.
182 | * Stark's Park: 1
183 | * Somerset Park: 2
184 | * Glebe Park: 1
185 | * Recreation Park: 1
186 | * Balmoor: 1
187 |
188 | **4. Present the results:**
189 | The final output will be a table showing each stadium name and the corresponding number of concerts held there.
190 |
191 | **Based on the data provided, here's the breakdown of concerts per stadium:**
192 |
193 | | Stadium Name | Number of Concerts |
194 | |---|---|
195 | | Stark's Park | 1 |
196 | | Somerset Park | 2 |
197 | | Glebe Park | 1 |
198 | | Recreation Park | 1 |
199 | | Balmoor | 1 |
200 |
201 | Final Answer: [["Stark's Park", 1], ["Somerset Park", 2], ["Glebe Park", 1], ["Recreation Park", 1], ["Balmoor", 1]]
202 | """.strip()]}
203 |
204 | FEW_SHOT_EXAMPLES_V1 = {"spider": ["""====== Example 1 ======
205 | Given a query, find from the following tables all the information relevant to the query. Then answer the query. Then format the answer into a list of lists.
206 | TABLES
207 | Table: Concert
208 | concert_ID,concert_Name,Theme,Stadium_ID,Year
209 | 1,Auditions,Free choice,1,2014
210 | 2,Super bootcamp,Free choice 2,2,2014
211 | 3,Home Visits,Bleeding Love,2,2015
212 | 4,Week 1,Wide Awake,10,2014
213 | 5,Week 1,Happy Tonight,9,2015
214 | 6,Week 2,Party All Night,7,2015
215 |
216 | Table: Singer
217 | Singer_ID,Name,Country,Song_Name,Song_release_year,Age,Is_male
218 | 1,Joe Sharp,Netherlands,You,1992,52,F
219 | 2,Timbaland,United States,Dangerous,2008,32,T
220 | 3,Justin Brown,France,Hey Oh,2013,29,T
221 | 4,Rose White,France,Sun,2003,41,F
222 | 5,John Nizinik,France,Gentleman,2014,43,T
223 | 6,Tribal King,France,Love,2016,25,T
224 |
225 | Table: Singer_in_Concert
226 | concert_ID,Singer_ID
227 | 1,2
228 | 1,3
229 | 1,5
230 | 2,3
231 | 2,6
232 | 3,5
233 | 4,4
234 | 5,6
235 | 5,3
236 | 6,2
237 |
238 | Table: Stadium
239 | Stadium_ID,Location,Name,Capacity,Highest,Lowest,Average
240 | 1,Raith Rovers,Stark's Park,10104,4812,1294,2106
241 | 2,Ayr United,Somerset Park,11998,2363,1057,1477
242 | 3,East Fife,Bayview Stadium,2000,1980,533,864
243 | 4,Queen's Park,Hampden Park,52500,1763,466,730
244 | 5,Stirling Albion,Forthbank Stadium,3808,1125,404,642
245 | 6,Arbroath,Gayfield Park,4125,921,411,638
246 | 7,Alloa Athletic,Recreation Park,3100,1057,331,637
247 | 9,Peterhead,Balmoor,4000,837,400,615
248 | 10,Brechin City,Glebe Park,3960,780,315,552
249 |
250 | Query: What is the total number of singers?
251 | Answer: Here's a step-by-step approach using the provided tables:
252 |
253 | **1. Access relevant data:**
254 | We need information from the "Singer" table, which stores details about each singer.
255 |
256 | **2. Get the singers:**
257 | We can directly count the number of rows in the "Singer" table. Each row represents a unique singer.
258 |
259 | **Based on the data provided, the "Singer" table has the singers Joe Sharp, Timbaland, Justin Brown, Rose White, John Nizinik, and Tribal King.**
260 |
261 | **3. Count the number of singers:**
262 | There are 6 singers in the table.
263 |
264 | Final Answer: [[6]]
265 |
266 | Query: What is the name of the singer who has a song with 'Hey' in its name?
267 | Answer: Here's a step-by-step approach using the provided tables:
268 |
269 | **1. Identify relevant data:**
270 | We need to look at the "Song_Name" column in the "Singer" table to find songs with "Hey" in their names.
271 |
272 | **2. Search for matching songs:**
273 | Scan through the "Song_Name" column and identify the song that contains "Hey."
274 |
275 | **Based on the provided data, the song "Hey Oh" is the only one with "Hey" in its name.**
276 |
277 | **3. Find the corresponding singer:**
278 | Once you've identified the song, look at the "Name" column in the same row to find the singer's name.
279 |
280 | **The singer associated with "Hey Oh" is Justin Brown.**
281 |
282 | Final Answer: [["Justin Brown"]]
283 |
284 | Query: Show the stadium name and the number of concerts in each stadium.
285 | Answer: Here's a step-by-step approach using the provided tables:
286 |
287 | **1. Access relevant data:**
288 | We need information from two tables:
289 | * **Concert:** This table contains details about each concert, including the stadium ID where it was held.
290 | * **Stadium:** This table provides information about each stadium, including its name.
291 |
292 | **2. Combine data based on stadium ID:**
293 | We need to link the concert data with the corresponding stadium information. This can be done by joining the "Concert" and "Stadium" tables based on the common column "Stadium_ID".
294 |
295 | **3. Get concerts per stadium:**
296 | After joining the tables, we can group the data by stadium name and get the concerts associated with each stadium.
297 | Here are the concert_ID associated with each stadium name:
298 | * Stark's Park: 1
299 | * Somerset Park: 2, 3
300 | * Glebe Park: 4
301 | * Recreation Park: 6
302 | * Balmoor: 5
303 |
304 | **Note:** Stadiums with no associated concerts are not included.
305 |
306 | **3. Count concerts per stadium:**
307 | After we have gotten the concerts associated with each stadium we can count how many concerts were in each stadium.
308 | * Stark's Park: 1
309 | * Somerset Park: 2
310 | * Glebe Park: 1
311 | * Recreation Park: 1
312 | * Balmoor: 1
313 |
314 | **4. Present the results:**
315 | The final output will be a table showing each stadium name and the corresponding number of concerts held there.
316 |
317 | **Based on the data provided, here's the breakdown of concerts per stadium:**
318 |
319 | | Stadium Name | Number of Concerts |
320 | |---|---|
321 | | Stark's Park | 1 |
322 | | Somerset Park | 2 |
323 | | Glebe Park | 1 |
324 | | Recreation Park | 1 |
325 | | Balmoor | 1 |
326 |
327 | Final Answer: [["Stark's Park", 1], ["Somerset Park", 2], ["Glebe Park", 1], ["Recreation Park", 1], ["Balmoor", 1]]
328 | """.strip()]}
329 |
330 | FEW_SHOT_EXAMPLES_TEXT2SQL_BASELINE = {"spider": ["""
331 | ====== Example 1 ======
332 | Given the following database schema, answer the query. Do not include any explanation.
333 |
334 | CREATE TABLE IF NOT EXISTS "stadium" (
335 | "Stadium_ID" int,
336 | "Location" text,
337 | "Name" text,
338 | "Capacity" int,
339 | "Highest" int,
340 | "Lowest" int,
341 | "Average" int,
342 | PRIMARY KEY ("Stadium_ID")
343 | );
344 |
345 | CREATE TABLE IF NOT EXISTS "singer" (
346 | "Singer_ID" int,
347 | "Name" text,
348 | "Country" text,
349 | "Song_Name" text,
350 | "Song_release_year" text,
351 | "Age" int,
352 | "Is_male" bool,
353 | PRIMARY KEY ("Singer_ID")
354 | );
355 |
356 | CREATE TABLE IF NOT EXISTS "concert" (
357 | "concert_ID" int,
358 | "concert_Name" text,
359 | "Theme" text,
360 | "Stadium_ID" text,
361 | "Year" text,
362 | PRIMARY KEY ("concert_ID"),
363 | FOREIGN KEY ("Stadium_ID") REFERENCES "stadium"("Stadium_ID")
364 | );
365 |
366 | CREATE TABLE IF NOT EXISTS "singer_in_concert" (
367 | "concert_ID" int,
368 | "Singer_ID" text,
369 | PRIMARY KEY ("concert_ID","Singer_ID"),
370 | FOREIGN KEY ("concert_ID") REFERENCES "concert"("concert_ID"),
371 | FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
372 | );
373 |
374 | Query: What is the total number of singers?
375 |
376 | Answer: SELECT count(*) FROM singer
377 |
378 | Query: What is the name of the singer who has a song with 'Hey' in its name?
379 |
380 | Answer: SELECT Name FROM singer WHERE Song_Name LIKE '%Hey%'
381 |
382 | Query: Show the stadium name and the number of concerts in each stadium.
383 |
384 | Answer: SELECT stadium.name, count(*) FROM concert JOIN stadium ON concert.stadium_id = stadium.stadium_id GROUP BY concert.stadium_id
385 | """.strip()]}
386 |
387 | FEW_SHOT_NO_COT_EXAMPLES_V0 = {"spider": ["""====== Example 1 ======
388 | Given a query, find from the following tables all the information relevant to the query. Then answer the query by formatting the answer into a list of lists.
389 | TABLES
390 | Table: Concert
391 | concert_ID,concert_Name,Theme,Stadium_ID,Year
392 | 1,Auditions,Free choice,1,2014
393 | 2,Super bootcamp,Free choice 2,2,2014
394 | 3,Home Visits,Bleeding Love,2,2015
395 | 4,Week 1,Wide Awake,10,2014
396 | 5,Week 1,Happy Tonight,9,2015
397 | 6,Week 2,Party All Night,7,2015
398 |
399 | Table: Singer
400 | Singer_ID,Name,Country,Song_Name,Song_release_year,Age,Is_male
401 | 1,Joe Sharp,Netherlands,You,1992,52,F
402 | 2,Timbaland,United States,Dangerous,2008,32,T
403 | 3,Justin Brown,France,Hey Oh,2013,29,T
404 | 4,Rose White,France,Sun,2003,41,F
405 | 5,John Nizinik,France,Gentleman,2014,43,T
406 | 6,Tribal King,France,Love,2016,25,T
407 |
408 | Table: Singer_in_Concert
409 | concert_ID,Singer_ID
410 | 1,2
411 | 1,3
412 | 1,5
413 | 2,3
414 | 2,6
415 | 3,5
416 | 4,4
417 | 5,6
418 | 5,3
419 | 6,2
420 |
421 | Table: Stadium
422 | Stadium_ID,Location,Name,Capacity,Highest,Lowest,Average
423 | 1,Raith Rovers,Stark's Park,10104,4812,1294,2106
424 | 2,Ayr United,Somerset Park,11998,2363,1057,1477
425 | 3,East Fife,Bayview Stadium,2000,1980,533,864
426 | 4,Queen's Park,Hampden Park,52500,1763,466,730
427 | 5,Stirling Albion,Forthbank Stadium,3808,1125,404,642
428 | 6,Arbroath,Gayfield Park,4125,921,411,638
429 | 7,Alloa Athletic,Recreation Park,3100,1057,331,637
430 | 9,Peterhead,Balmoor,4000,837,400,615
431 | 10,Brechin City,Glebe Park,3960,780,315,552
432 |
433 | Query: What is the total number of singers?
434 | Answer: Here is the final answer.
435 | Final Answer: [[6]]
436 |
437 | Query: What is the name of the singer who has a song with 'Hey' in its name?
438 | Answer: Here is the final answer.
439 | Final Answer: [["Justin Brown"]]
440 |
441 | Query: Show the stadium name and the number of concerts in each stadium.
442 | Answer: Here is the final answer.
443 | Final Answer: [["Stark's Park", 1], ["Somerset Park", 2], ["Glebe Park", 1], ["Recreation Park", 1], ["Balmoor", 1]]
444 | """.strip()]}
445 |
--------------------------------------------------------------------------------
/prompts/prompt_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Define PromptRegistry class to store prompts.
17 |
18 | PromptRegistry is a class that stores prompts based on the LOFT data.
19 | """
20 |
21 | from collections.abc import Callable, Sequence
22 | import copy
23 | import dataclasses
24 | import random
25 | from typing import Any, Optional
26 | import utils
27 |
28 | ContentChunk = utils.ContentChunk
29 | QueryTurn = utils.QueryTurn
30 | Example = utils.Example
31 |
32 |
33 | @dataclasses.dataclass(frozen=True)
34 | class LOFTPrompt:
35 | """Prompt for LOFT."""
36 |
37 | data_dir: str
38 | split: str
39 | data_loader: Callable[..., Any] = utils.load_data_from_file
40 | context_processors: Optional[Sequence[Callable[..., Any]]] = None
41 | query_turn_processors: Optional[Sequence[Callable[..., Any]]] = None
42 | gold_answer_processors: Optional[Sequence[Callable[..., Any]]] = None
43 | share_context: bool = True
44 | use_context_first: bool = True
45 | append_gold_answers_to_query_turns: bool = False
46 | is_multi_turn: bool = False
47 | # Whether the corpus is cacheable for that particular prompt.
48 | # Defaults to False since safer (always correct).
49 | cacheable_corpus: bool = False
50 |
51 |
52 | class PromptRegistry:
53 | """Class to store prompts based on the LOFT data."""
54 |
55 | prompts = {}
56 | base_dir = ""
57 |
58 | @classmethod
59 | def add(
60 | cls,
61 | name: str,
62 | data_dir: str,
63 | split: str,
64 | data_loader: Callable[..., Any] = utils.load_data_from_file,
65 | context_processors: Optional[Sequence[Callable[..., Any]]] = None,
66 | query_turn_processors: Optional[Sequence[Callable[..., Any]]] = None,
67 | gold_answer_processors: Optional[Sequence[Callable[..., Any]]] = None,
68 | share_context: bool = True,
69 | use_context_first: bool = True,
70 | append_gold_answers_to_query_turns: bool = False,
71 | is_multi_turn: bool = False,
72 | cacheable_corpus: bool = False,
73 | ):
74 | """Adds a prompt to the registry."""
75 | if name in cls.prompts:
76 | raise ValueError(f"Prompt {name} already exists in the prompt registry.")
77 |
78 | cls.prompts[name] = LOFTPrompt(
79 | data_dir=data_dir,
80 | split=split,
81 | data_loader=data_loader,
82 | context_processors=context_processors,
83 | query_turn_processors=query_turn_processors,
84 | gold_answer_processors=gold_answer_processors,
85 | share_context=share_context,
86 | use_context_first=use_context_first,
87 | append_gold_answers_to_query_turns=append_gold_answers_to_query_turns,
88 | is_multi_turn=is_multi_turn,
89 | cacheable_corpus=cacheable_corpus,
90 | )
91 |
92 | @classmethod
93 | def get_examples(
94 | cls,
95 | name: str,
96 | base_dir: Optional[str] = None,
97 | max_examples: Optional[int] = None,
98 | loft_data: Optional[utils.LOFTData] = None,
99 | **kwargs,
100 | ) -> list[Example]:
101 | """Returns the examples for a given prompt name."""
102 | shuffle_queries = kwargs.get("shuffle_queries", False)
103 |
104 | if name not in cls.prompts:
105 | task_name = name.split("_")[0]
106 | registry_str = "\n".join(
107 | filter(
108 | lambda x: x.startswith(task_name),
109 | list(cls.prompts.keys()),
110 | )
111 | )
112 | raise ValueError(
113 | f"Prompt {name} not found in registry.\nAvailable"
114 | f" prompts:\n{registry_str}"
115 | )
116 | loft_prompt = cls.prompts[name]
117 | examples = []
118 |
119 | if loft_data is None:
120 | if not base_dir:
121 | base_dir = cls.base_dir
122 | else:
123 | cls.base_dir = base_dir
124 | # 1. Load the LOFT data if not provided.
125 | loft_data = loft_prompt.data_loader(
126 | data_dir=loft_prompt.data_dir,
127 | base_dir=base_dir,
128 | split=loft_prompt.split,
129 | )
130 |
131 | # 2. For each query, create an example.
132 | # NOTE: Each chunk (query or context) is added in-place to the list.
133 | context_chunks: list[ContentChunk] = []
134 | corpus_document_boundaries: list[tuple[int, int]] = []
135 | # Locates where a particular passage with a given pid is in the context.
136 | # Will be updated if fixed_pid_mapper is given for context processors.
137 | # NOTE: fixed_pid_mapper given to add_corpus_chunks can update pid_mapper.
138 | queries = list(loft_data.queries.keys())
139 |
140 | if shuffle_queries:
141 | random.shuffle(queries)
142 |
143 | pid_mapper: dict[str, str] = {
144 | pid: str(p_idx) for p_idx, pid in enumerate(loft_data.corpus)
145 | }
146 | loft_data.metadata["pid_mapper"] = pid_mapper
147 |
148 | for example_idx, qid in enumerate(queries):
149 | # Each query turn can be a list of query chunks (e.g. image + text).
150 | query_turns: list[list[ContentChunk]] = []
151 | gold_pids: list[Any] = []
152 | gold_answers: list[Any] = copy.deepcopy(loft_data.answers[qid])
153 | if not loft_prompt.is_multi_turn:
154 | if loft_data.metadata.values() and "qrels" in next(
155 | iter(loft_data.metadata.values())
156 | ):
157 | gold_pids = copy.deepcopy(
158 | [pid for pid, _ in loft_data.metadata[qid]["qrels"]]
159 | )
160 | else:
161 | if (
162 | isinstance(loft_data.answers[qid], list)
163 | and all([len(ans) == 2 for ans in loft_data.answers[qid]])
164 | ):
165 | gold_pids = copy.deepcopy(
166 | [pid for pid, _ in loft_data.answers[qid]]
167 | )
168 | # Make these as a single turn.
169 | gold_pids = [gold_pids]
170 | gold_answers = [gold_answers]
171 | else:
172 | if loft_data.metadata.values() and "qrels" in next(
173 | iter(loft_data.metadata.values())
174 | ):
175 | gold_pids = copy.deepcopy([
176 | [pid for pid, _ in qrels]
177 | for qrels in loft_data.metadata[qid]["qrels"]
178 | ])
179 | gold_answers = copy.deepcopy(
180 | [[gold_answer] for gold_answer in gold_answers]
181 | )
182 |
183 | # 2-1. Process the context chunks.
184 | if loft_prompt.context_processors:
185 | if (
186 | not loft_prompt.cacheable_corpus
187 | or not context_chunks
188 | or not loft_prompt.share_context
189 | ):
190 | context_chunks: list[ContentChunk] = [] # Reset context chunks.
191 | corpus_document_boundaries: list[tuple[int, int]] = (
192 | []
193 | ) # Reset corpus document boundaries.
194 | for context_processor in cls.prompts[name].context_processors:
195 | context_processor(
196 | chunks=context_chunks,
197 | pid_mapper=pid_mapper,
198 | gold_pids=gold_pids,
199 | loft_data=loft_data,
200 | qid=qid,
201 | corpus_document_boundaries=corpus_document_boundaries,
202 | **kwargs,
203 | )
204 |
205 | # 2-2. Process the gold answers.
206 | # NOTE: We need to process the gold answers first for the cases where we
207 | # need to append gold answers to query turns.
208 | if loft_prompt.gold_answer_processors:
209 | for gold_answer_processor in cls.prompts[name].gold_answer_processors:
210 | gold_answer_processor(
211 | query_turns=query_turns,
212 | gold_answers=gold_answers,
213 | loft_data=loft_data,
214 | qid=qid,
215 | pid_mapper=pid_mapper,
216 | )
217 |
218 | # 2-3. Process the query chunks.
219 | if loft_prompt.query_turn_processors:
220 | for query_turn_processor in cls.prompts[name].query_turn_processors:
221 | query_turn_processor(
222 | query_turns=query_turns,
223 | # Needed for the reasoning chain.
224 | gold_pids=gold_pids,
225 | gold_answers=gold_answers,
226 | loft_data=loft_data,
227 | qid=qid,
228 | example_id=str(example_idx + 1),
229 | pid_mapper=pid_mapper,
230 | )
231 |
232 | # 2-4. Create the example.
233 | examples.append(
234 | Example(
235 | qid=qid,
236 | num_turns=len(query_turns),
237 | context_chunks=context_chunks,
238 | # Lazily convert to QueryTurn objects.
239 | query_turns=[QueryTurn(chunks=chunks) for chunks in query_turns],
240 | gold_answers=gold_answers,
241 | gold_pids=gold_pids,
242 | corpus_document_boundaries=corpus_document_boundaries,
243 | )
244 | )
245 | if max_examples and len(examples) >= max_examples:
246 | break
247 |
248 | return examples
249 |
--------------------------------------------------------------------------------
/prompts/prompts_icl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Register prompts for Many-shot ICL."""
17 |
18 | import functools
19 |
20 | from prompts import prompt_registry
21 | from prompts import utils as prompt_utils
22 | from prompts.constants import icl as task_constants
23 | import utils
24 |
25 |
26 | PromptRegistry = prompt_registry.PromptRegistry
27 | TASK = 'icl'
28 |
29 | BBH_TASK_LENGTH = {
30 | 'date_understanding': 23578,
31 | 'salient_translation_error_detection': 28299,
32 | 'tracking_shuffled_objects_seven_objects': 28360,
33 | 'web_of_lies': 10380,
34 | }
35 |
36 | LENGTHS = ('32k', '128k', '1m')
37 | SPLITS = ('dev', 'test')
38 | LENGTH_TO_MAX_TOKENS = {
39 | '2k': 2_000,
40 | '4k': 4_000,
41 | '8k': 8_000,
42 | '16k': 16_000,
43 | '32k': 32_000,
44 | '128k': 128_000,
45 | '200k': 200_000,
46 | '1m': 1_000_000,
47 | }
48 |
49 | dataset = 'bbh'
50 | for length in ['32k', '16k', '8k', '4k', '2k']:
51 | for subtask_name, subtask_length in BBH_TASK_LENGTH.items():
52 | data_dir = f'{TASK}/{subtask_name}/{length}'
53 | max_tokens = LENGTH_TO_MAX_TOKENS[length]
54 | dataset_name = f'{subtask_name}'
55 | for split in SPLITS:
56 | PromptRegistry.add(
57 | name=f'{TASK}_{dataset_name}_{length}_{split}:many_shot',
58 | data_dir=data_dir,
59 | split=split,
60 | cacheable_corpus=True,
61 | data_loader=utils.load_bbh_data_from_file,
62 | context_processors=[
63 | functools.partial(
64 | prompt_utils.add_text_chunks,
65 | texts=[
66 | task_constants.CORPUS_INSTRUCTION[dataset],
67 | ],
68 | ),
69 | functools.partial(
70 | prompt_utils.add_many_shot_chunks,
71 | chunk_format_fn=prompt_utils.get_bbh_example_chunk,
72 | corpus_format=task_constants.CORPUS_FORMAT[dataset],
73 | ),
74 | ],
75 | query_turn_processors=[
76 | functools.partial(
77 | prompt_utils.add_query_turns_for_many_shot,
78 | chunk_format_fn=prompt_utils.get_bbh_example_chunk,
79 | corpus_format=task_constants.CORPUS_FORMAT[dataset],
80 | ),
81 | ],
82 | gold_answer_processors=[],
83 | )
84 |
85 | dataset = 'long_icl_bench_dialogue_re'
86 | for length in LENGTHS:
87 | data_dir = f'{TASK}/{dataset}/{length}'
88 | dataset_name = dataset
89 | for split in SPLITS:
90 | PromptRegistry.add(
91 | name=f'{TASK}_{dataset_name}_{length}_{split}:many_shot',
92 | data_dir=data_dir,
93 | split=split,
94 | cacheable_corpus=True,
95 | data_loader=functools.partial(
96 | utils.load_long_icl_bench_dialogue_re_data_from_file,
97 | ),
98 | context_processors=[
99 | functools.partial(
100 | prompt_utils.add_text_chunks,
101 | texts=[
102 | task_constants.CORPUS_INSTRUCTION[dataset],
103 | ],
104 | ),
105 | functools.partial(
106 | prompt_utils.add_many_shot_chunks,
107 | chunk_format_fn=prompt_utils.get_long_icl_bench_dialogue_re_example_chunk,
108 | corpus_format=task_constants.CORPUS_FORMAT[dataset],
109 | target_proposition=task_constants.TARGET_PROPOSITION,
110 | ending_notation=task_constants.ENDING_NOTATION,
111 | ),
112 | ],
113 | query_turn_processors=[
114 | functools.partial(
115 | prompt_utils.add_query_turns_for_many_shot,
116 | chunk_format_fn=prompt_utils.get_long_icl_bench_dialogue_re_example_chunk,
117 | corpus_format=task_constants.CORPUS_FORMAT[dataset],
118 | target_proposition=task_constants.TARGET_PROPOSITION,
119 | ending_notation=task_constants.ENDING_NOTATION,
120 | ),
121 | ],
122 | gold_answer_processors=[],
123 | )
124 |
--------------------------------------------------------------------------------
/prompts/prompts_mm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Register prompts for multimodal retrieval."""
17 |
18 | import functools
19 |
20 | from prompts import prompt_registry
21 | from prompts import utils as prompt_utils
22 | from prompts.constants import common as common_constants
23 | from prompts.constants import mm as task_constants
24 | import utils
25 |
26 |
27 | PromptRegistry = prompt_registry.PromptRegistry
28 | TASK = 'mm'
29 |
30 | MULTIMODAL_RETRIEVAL_DATASETS = (
31 | 'oven',
32 | 'msrvtt',
33 | 'flickr30k',
34 | 'mscoco',
35 | 'fleurs_en_tts',
36 | 'fleurs_es_tts',
37 | 'fleurs_hi_tts',
38 | 'fleurs_zh_tts',
39 | 'fleurs_fr_tts',
40 | 'fleurs_en_stt',
41 | 'fleurs_es_stt',
42 | 'fleurs_hi_stt',
43 | 'fleurs_zh_stt',
44 | 'fleurs_fr_stt',
45 | )
46 | RESOURCE_DIR = 'resources/'
47 | LENGTHS = ('32k', '128k', '1m')
48 | SPLITS = ('dev', 'test')
49 |
50 | # Register few-shot prompts first.
51 | # NOTE: This prompt will be used inside the other prompts, and not designed for
52 | # direct use.
53 | for length in LENGTHS:
54 | for dataset in MULTIMODAL_RETRIEVAL_DATASETS:
55 | if (dataset.endswith('stt') and length != '32k') or (
56 | dataset.endswith('tts') and length == '1m'
57 | ):
58 | continue
59 | # If there exists a mapper for the prompt, use it.
60 | prompt_name = task_constants.PROMPT_MAPPER.get(dataset, dataset)
61 | PromptRegistry.add(
62 | name=f'{TASK}_{dataset}_{length}:few_shot_examples',
63 | data_dir=f'{TASK}/{dataset}/{length}',
64 | cacheable_corpus=True,
65 | split='few_shot',
66 | data_loader=utils.load_data_from_file, # resource_dir is deprecated.
67 | context_processors=[], # No shared context is used.
68 | query_turn_processors=[
69 | functools.partial(
70 | prompt_utils.add_multimodal_query_turns,
71 | query_prefix_format=task_constants.FEW_SHOT_QUERY_FORMAT_PREFIX[
72 | prompt_name
73 | ],
74 | query_suffix_format=task_constants.FEW_SHOT_QUERY_FORMAT_SUFFIX[
75 | prompt_name
76 | ],
77 | use_example_id=True,
78 | ),
79 | functools.partial(
80 | prompt_utils.append_reasoning_to_query_turns,
81 | reasoning_format=task_constants.FEW_SHOT_EXAMPLE_ANSWER_FORMAT[
82 | prompt_name
83 | ],
84 | qid2reasoning=None,
85 | ),
86 | functools.partial(
87 | prompt_utils.append_gold_answers_to_query_turns,
88 | answer_format=common_constants.FINAL_ANSWER_FORMAT,
89 | ),
90 | ],
91 | gold_answer_processors=[
92 | prompt_utils.convert_pids_into_gold_answers,
93 | ],
94 | )
95 |
96 | for length in LENGTHS:
97 | for dataset in MULTIMODAL_RETRIEVAL_DATASETS:
98 | for split in SPLITS:
99 | prompt_name = task_constants.PROMPT_MAPPER.get(dataset, dataset)
100 |
101 | name = f'{TASK}_{dataset}_{length}_{split}:few_shot_with_cot'
102 | corpus_instruction = task_constants.CORPUS_INSTRUCTION[prompt_name]
103 | PromptRegistry.add(
104 | name=name,
105 | data_dir=f'{TASK}/{dataset}/{length}',
106 | split=split,
107 | cacheable_corpus=False,
108 | data_loader=utils.load_data_from_file,
109 | context_processors=[
110 | functools.partial(
111 | prompt_utils.add_text_chunks,
112 | texts=[
113 | corpus_instruction,
114 | ],
115 | ),
116 | # Adds both corpus chunks and few-shot.
117 | functools.partial(
118 | prompt_utils.add_corpus_chunks_and_query_turns_from_few_shot_examples,
119 | corpus_format=task_constants.CORPUS_FORMAT[prompt_name],
120 | shuffle_seed=None,
121 | add_image_chunk=True,
122 | few_shot_prompt_name=(
123 | f'{TASK}_{dataset}_{length}:few_shot_examples'
124 | ),
125 | ),
126 | ],
127 | query_turn_processors=[
128 | functools.partial(
129 | prompt_utils.add_multimodal_query_turns,
130 | use_example_id=False,
131 | query_prefix_format=task_constants.QUERY_FORMAT_PREFIX[
132 | prompt_name
133 | ],
134 | query_suffix_format=task_constants.QUERY_FORMAT_SUFFIX[
135 | prompt_name
136 | ],
137 | ),
138 | ],
139 | gold_answer_processors=[
140 | prompt_utils.convert_pids_into_gold_answers,
141 | ],
142 | )
143 |
--------------------------------------------------------------------------------
/prompts/prompts_rag.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Register prompts for RAG."""
17 |
18 | import functools
19 | from prompts import prompt_registry
20 | from prompts import utils as prompt_utils
21 | from prompts.constants import common as common_constants
22 | from prompts.constants import rag as task_constants
23 |
24 |
25 | PromptRegistry = prompt_registry.PromptRegistry
26 | TASK = 'rag'
27 |
28 | RAG_DATASETS = (
29 | 'nq',
30 | 'hotpotqa',
31 | 'musique',
32 | 'qampari',
33 | 'quest',
34 | 'topiocqa',
35 | 'romqa', # cannot be used for training
36 | )
37 | LENGTHS = ('32k', '128k', '1m')
38 | SPLITS = ('dev', 'test')
39 |
40 |
41 | # get_few_shot_reasoning_format among rag and retrieval.
42 | def get_query_format(dataset_name: str, few_shot_type_name: str) -> str:
43 | """Get query format str for the given dataset and few-shot type."""
44 |
45 | p = task_constants.PROMPT_MAPPER.get(dataset_name, dataset_name)
46 | if few_shot_type_name == 'few_shot_with_cot':
47 | return task_constants.QUERY_FORMAT_SIMPLE_REVERSE[p]
48 | else:
49 | raise ValueError(f'Unsupported few-shot type: {few_shot_type_name}')
50 |
51 |
52 | def get_few_shot_reasoning_format(
53 | dataset_name: str, few_shot_type_name: str
54 | ) -> str | None:
55 | """Get few-shot reasoning format str for the given dataset and few-shot type."""
56 | p = task_constants.PROMPT_MAPPER.get(dataset_name, dataset_name)
57 | if few_shot_type_name == 'few_shot_with_cot':
58 | return task_constants.FEW_SHOT_EXAMPLE_COT_FORMAT_SIMPLE_REVERSE[p]
59 | else:
60 | raise ValueError(f'Unsupported few-shot type: {few_shot_type_name}')
61 |
62 |
63 | # Register few-shot prompts first.
64 | # NOTE: This prompt will be used inside the other prompts, and not designed for
65 | # direct use.
66 | for length in LENGTHS:
67 | for dataset in RAG_DATASETS:
68 | # Processors for the few-shot examples.
69 | query_turn_processors = [
70 | # Format few-shot example's query.
71 | functools.partial(
72 | prompt_utils.add_query_turns,
73 | query_format=common_constants.FEW_SHOT_SEPARATOR
74 | + get_query_format(
75 | dataset_name=dataset, few_shot_type_name='few_shot_with_cot'
76 | ),
77 | use_example_id=True,
78 | ),
79 | ]
80 | query_turn_processors.append(
81 | functools.partial(
82 | prompt_utils.append_reasoning_to_query_turns,
83 | reasoning_format=get_few_shot_reasoning_format(
84 | dataset_name=dataset, few_shot_type_name='few_shot_with_cot'
85 | ),
86 | qid2reasoning=None,
87 | ),
88 | )
89 | query_turn_processors.append(
90 | # Format few-shot example's answer.
91 | functools.partial(
92 | prompt_utils.append_gold_answers_to_query_turns,
93 | answer_format=common_constants.FINAL_ANSWER_FORMAT,
94 | ),
95 | )
96 |
97 | PromptRegistry.add(
98 | name=f'{TASK}_{dataset}_{length}:few_shot_examples',
99 | data_dir=f'{TASK}/{dataset}/{length}',
100 | cacheable_corpus=True,
101 | split='few_shot',
102 | is_multi_turn=dataset == 'topiocqa',
103 | context_processors=[], # No shared context is used.
104 | query_turn_processors=query_turn_processors,
105 | gold_answer_processors=[],
106 | )
107 |
108 | # Adding few-shot prompts.
109 | for length in LENGTHS:
110 | for dataset in RAG_DATASETS:
111 | for split in SPLITS:
112 | prompt_name = task_constants.PROMPT_MAPPER.get(dataset, dataset)
113 | corpus_instruction = task_constants.CORPUS_INSTRUCTION[
114 | prompt_name
115 | ]
116 | corpus_format = task_constants.CORPUS_FORMAT_ECHO[prompt_name]
117 | name = f'{TASK}_{dataset}_{length}_{split}:few_shot_with_cot'
118 |
119 | test_query_processors = [
120 | functools.partial(
121 | prompt_utils.add_query_turns,
122 | query_format=common_constants.TEST_QUERY_SEPARATOR
123 | + get_query_format(
124 | dataset_name=dataset,
125 | few_shot_type_name='few_shot_with_cot',
126 | ),
127 | )
128 | ]
129 | PromptRegistry.add(
130 | name=name,
131 | data_dir=f'{TASK}/{dataset}/{length}',
132 | split=split,
133 | cacheable_corpus=False,
134 | is_multi_turn=dataset == 'topiocqa',
135 | context_processors=[
136 | functools.partial(
137 | prompt_utils.add_text_chunks,
138 | texts=[
139 | corpus_instruction,
140 | task_constants.FORMATTING_INSTRUCTION,
141 | ],
142 | ),
143 | # Adds both corpus chunks and few-shot.
144 | functools.partial(
145 | prompt_utils.add_corpus_chunks_and_query_turns_from_few_shot_examples,
146 | corpus_format=corpus_format,
147 | shuffle_seed=None,
148 | few_shot_prompt_name=(
149 | f'{TASK}_{dataset}_{length}:few_shot_examples'
150 | ),
151 | ),
152 | ],
153 | query_turn_processors=test_query_processors,
154 | gold_answer_processors=[],
155 | )
156 |
--------------------------------------------------------------------------------
/prompts/prompts_retrieval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Register prompts for retrieval."""
17 |
18 | import functools
19 | from prompts import prompt_registry
20 | from prompts import utils as prompt_utils
21 | from prompts.constants import common as common_constants
22 | from prompts.constants import retrieval as task_constants
23 |
24 |
25 | PromptRegistry = prompt_registry.PromptRegistry
26 | TASK = 'retrieval'
27 |
28 | RETRIEVAL_DATASETS = (
29 | 'arguana',
30 | 'fever',
31 | 'fiqa',
32 | 'msmarco',
33 | 'nq',
34 | 'quora',
35 | 'scifact',
36 | 'webis_touche2020',
37 | 'topiocqa',
38 | )
39 |
40 | SET_RETRIEVAL_DATASETS = (
41 | 'hotpotqa',
42 | 'musique',
43 | 'qampari',
44 | 'quest',
45 | )
46 |
47 | LENGTHS = ('32k', '128k', '1m')
48 | SPLITS = ('dev', 'test')
49 | FEW_SHOT_TYPES = (
50 | 'few_shot_with_cot', # Content as the CoT
51 | )
52 |
53 |
54 | def get_query_format(dataset_name: str, few_shot_type_name: str) -> str:
55 | """Get query format str for the given dataset and few-shot type."""
56 |
57 | p = task_constants.PROMPT_MAPPER.get(dataset_name, dataset_name)
58 | if few_shot_type_name == 'few_shot_with_cot':
59 | return task_constants.QUERY_FORMAT_SIMPLE_REVERSE[p]
60 | else:
61 | raise ValueError(f'Unsupported few-shot type: {few_shot_type_name}')
62 |
63 |
64 | def get_few_shot_reasoning_format(
65 | dataset_name: str, few_shot_type_name: str
66 | ) -> str | None:
67 | """Get few-shot reasoning format str for the given dataset and few-shot type."""
68 | p = task_constants.PROMPT_MAPPER.get(dataset_name, dataset_name)
69 | if few_shot_type_name == 'few_shot_with_cot':
70 | return task_constants.FEW_SHOT_EXAMPLE_COT_FORMAT_SIMPLE_REVERSE[p]
71 | else:
72 | raise ValueError(f'Unsupported few-shot type: {few_shot_type_name}')
73 |
74 |
75 | # Register few-shot prompts first.
76 | # NOTE: This prompt will be used inside the other prompts, and not designed for
77 | # direct use.
78 | for length in LENGTHS:
79 | for dataset in RETRIEVAL_DATASETS + SET_RETRIEVAL_DATASETS:
80 | for few_shot_type in FEW_SHOT_TYPES:
81 |
82 | # Processors for the few-shot examples.
83 | query_turn_processors = [
84 | # Format few-shot example's query.
85 | functools.partial(
86 | prompt_utils.add_query_turns,
87 | query_format=common_constants.FEW_SHOT_SEPARATOR
88 | + get_query_format(
89 | dataset_name=dataset, few_shot_type_name=few_shot_type
90 | ),
91 | use_example_id=True,
92 | ),
93 | ]
94 | # Format few-shot example's reasoning.
95 | query_turn_processors.append(
96 | functools.partial(
97 | prompt_utils.append_reasoning_to_query_turns,
98 | reasoning_format=get_few_shot_reasoning_format(
99 | dataset_name=dataset, few_shot_type_name=few_shot_type
100 | ),
101 | ),
102 | )
103 | # Format few-shot example's answer.
104 | query_turn_processors.append(
105 | functools.partial(
106 | prompt_utils.append_gold_answers_to_query_turns,
107 | answer_format=common_constants.FINAL_ANSWER_FORMAT,
108 | ),
109 | )
110 |
111 | PromptRegistry.add(
112 | name=f'{TASK}_{dataset}_{length}:few_shot_examples',
113 | data_dir=f'{TASK}/{dataset}/{length}',
114 | split='few_shot',
115 | cacheable_corpus=True,
116 | is_multi_turn=dataset == 'topiocqa',
117 | context_processors=[], # No shared context is used.
118 | query_turn_processors=query_turn_processors,
119 | gold_answer_processors=[
120 | prompt_utils.convert_pids_into_gold_answers,
121 | ],
122 | )
123 |
124 | # Adding few-shot prompts
125 | for length in LENGTHS:
126 | for dataset in RETRIEVAL_DATASETS + SET_RETRIEVAL_DATASETS:
127 | for split in SPLITS:
128 | for few_shot_type in FEW_SHOT_TYPES:
129 | prompt_name = task_constants.PROMPT_MAPPER.get(
130 | dataset, dataset
131 | )
132 | corpus_format = task_constants.CORPUS_FORMAT_ECHO[
133 | prompt_name
134 | ]
135 | name = f'{TASK}_{dataset}_{length}_{split}:{few_shot_type}'
136 | corpus_instruction = task_constants.CORPUS_INSTRUCTION[
137 | prompt_name
138 | ]
139 | PromptRegistry.add(
140 | name=name,
141 | data_dir=f'{TASK}/{dataset}/{length}',
142 | split=split,
143 | cacheable_corpus=False,
144 | is_multi_turn=dataset == 'topiocqa',
145 | context_processors=[
146 | functools.partial(
147 | prompt_utils.add_text_chunks,
148 | texts=[
149 | corpus_instruction,
150 | task_constants.FORMATTING_INSTRUCTION,
151 | ],
152 | ),
153 | # Adds both corpus chunks and few-shot.
154 | functools.partial(
155 | prompt_utils.add_corpus_chunks_and_query_turns_from_few_shot_examples,
156 | corpus_format=corpus_format,
157 | shuffle_seed=None,
158 | few_shot_prompt_name=(
159 | f'{TASK}_{dataset}_{length}:few_shot_examples'
160 | ),
161 | ),
162 | ],
163 | query_turn_processors=[
164 | functools.partial(
165 | prompt_utils.add_query_turns,
166 | query_format=common_constants.TEST_QUERY_SEPARATOR
167 | + get_query_format(
168 | dataset_name=dataset,
169 | few_shot_type_name=few_shot_type,
170 | ),
171 | ),
172 | ],
173 | gold_answer_processors=[
174 | prompt_utils.convert_pids_into_gold_answers,
175 | ],
176 | )
177 |
--------------------------------------------------------------------------------
/prompts/prompts_sql.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Register prompts for SQL."""
17 |
18 | import functools
19 | from prompts import prompt_registry
20 | from prompts import utils as prompt_utils
21 | from prompts.constants import sql as task_constants
22 |
23 |
24 | PromptRegistry = prompt_registry.PromptRegistry
25 | TASK = 'sql'
26 |
27 | SQL_DATASETS = (
28 | 'spider',
29 | 'sparc',
30 | )
31 | LENGTHS = ('32k', '128k', '1m')
32 | SPLITS = ('dev', 'test')
33 |
34 | # Few-shot examples are directly provided as a list of strings in SQL.
35 | for length in LENGTHS:
36 | for dataset in SQL_DATASETS:
37 | for split in SPLITS:
38 | prompt_name = task_constants.PROMPT_MAPPER.get(dataset, dataset)
39 | PromptRegistry.add(
40 | name=f'{TASK}_{dataset}_{length}_{split}:few_shot_with_cot',
41 | data_dir=f'{TASK}/{dataset}/{length}',
42 | split=split,
43 | cacheable_corpus=True,
44 | context_processors=[
45 | functools.partial(
46 | prompt_utils.add_text_chunks,
47 | texts=[
48 | task_constants.CORPUS_INSTRUCTION[prompt_name],
49 | ],
50 | ),
51 | # Add few-shot examples.
52 | functools.partial(
53 | prompt_utils.add_text_chunks,
54 | texts=task_constants.FEW_SHOT_EXAMPLES_V1[prompt_name],
55 | ),
56 | ],
57 | query_turn_processors=[
58 | functools.partial(
59 | prompt_utils.add_query_turns_with_corpus,
60 | query_format=task_constants.QUERY_FORMAT_0[prompt_name],
61 | corpus_format=task_constants.CORPUS_FORMAT[prompt_name],
62 | follow_up_query_format=task_constants.FOLLOW_UP_QUERY_FORMAT_0[
63 | prompt_name
64 | ],
65 | ),
66 | ],
67 | gold_answer_processors=[],
68 | )
69 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | numpy==2.0.1
3 | scipy==1.14.0
4 | wget==3.2
5 | opencv-python==4.10.0.84
6 | tqdm==4.66.4
7 | attrs==24.2.0
8 | pillow==10.4.0
9 |
--------------------------------------------------------------------------------
/run_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""Run evaluation on a set of predictions.
17 |
18 | We provide a script to run evaluation on a set of predictions. The predictions
19 | are expected to be in jsonl format, where each line is a json
20 | dictionary containing the following fields:
21 | * qid: The id of the query.
22 | * model_outputs: The model predictions extracted from the model response.
23 | * num_turns: The number of turns in the conversation.
24 |
25 | We provide example predictions for each task under
26 | evaluation/example_predictions. To run evaluation on the example predictions:
27 | ```
28 | python run_evaluation.py \
29 | --answer_file_path evaluation/example_predictions/icl_date/queries.jsonl \
30 | --pred_file_path evaluation/example_predictions/icl_date/preds.jsonl \
31 | --task_type icl
32 |
33 | python run_evaluation.py \
34 | --answer_file_path evaluation/example_predictions/rag_nq/queries.jsonl \
35 | --pred_file_path evaluation/example_predictions/rag_nq/preds.jsonl \
36 | --task_type rag
37 |
38 | python run_evaluation.py \
39 | --answer_file_path evaluation/example_predictions/rag_quest/queries.jsonl \
40 | --pred_file_path evaluation/example_predictions/rag_quest/preds.jsonl \
41 | --task_type multi_value_rag
42 |
43 | python run_evaluation.py \
44 | --answer_file_path evaluation/example_predictions/retrieval_nq/queries.jsonl \
45 | --pred_file_path evaluation/example_predictions/retrieval_nq/preds.jsonl \
46 | --task_type retrieval
47 |
48 | python run_evaluation.py \
49 | --answer_file_path evaluation/example_predictions/sql_spider/queries.jsonl \
50 | --pred_file_path evaluation/example_predictions/sql_spider/preds.jsonl \
51 | --task_type sql
52 | ```
53 | where is one of the keys in evaluation.EVALUATION_TASKS.
54 |
55 | To understand which to use for a given dataset, see the table under
56 | the README.
57 | """
58 |
59 | from collections.abc import Sequence
60 | import json
61 | import os
62 | from typing import Any
63 |
64 | from absl import app
65 | from absl import flags
66 | import evaluation
67 | from evaluation import loft_evaluation
68 |
69 |
70 | _ANSWER_FILE_PATH = flags.DEFINE_string(
71 | "answer_file_path",
72 | None,
73 | help="Path to gold answers",
74 | required=True,
75 | )
76 | _PRED_FILE_PATH = flags.DEFINE_string(
77 | "pred_file_path",
78 | None,
79 | help="Path to predictions to run evaluation on.",
80 | required=True,
81 | )
82 | _TASK_TYPE = flags.DEFINE_enum(
83 | "task_type",
84 | None,
85 | enum_values=evaluation.EVALUATION_TASKS.keys(),
86 | help="Task name to run evaluation on.",
87 | required=True,
88 | )
89 | _DEDUPLICATE_PIDS = flags.DEFINE_bool(
90 | "deduplicate_pids",
91 | False,
92 | help="Whether to deduplicate pids for gold answers and predictions.",
93 | )
94 |
95 |
96 | def _load_predictions_from_jsonl(path: str) -> dict[str, Any]:
97 | """Loads predictions from a jsonl file."""
98 | predictions = {}
99 | for line in open(path):
100 | line = json.loads(line)
101 | predictions[line["qid"]] = line
102 | return predictions
103 |
104 |
105 | def run_evaluation(
106 | answer_file_path: str,
107 | pred_file_path: str,
108 | task_type: str,
109 | ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
110 | """Evaluates predictions and returns metrics.
111 |
112 | Args:
113 | answer_file_path: Path to gold answers.
114 | pred_file_path: Path to predictions to run evaluation on.
115 | task_type: Task name to run evaluation on.
116 |
117 | Returns:
118 | metrics_per_line: List of metrics dictionaries per prediction.
119 | final_metrics: Metrics averaged over all predictions.
120 | """
121 | eval_task = evaluation.EVALUATION_TASKS[task_type]
122 | predictions = _load_predictions_from_jsonl(pred_file_path)
123 |
124 | metrics_per_line = []
125 | num_unanswered_queries = 0
126 | if task_type == "icl":
127 | if answer_file_path.endswith(".json"):
128 | with open(answer_file_path) as f:
129 | json_objects = json.load(f)
130 | else:
131 | raise ValueError(f"Unsupported answer file extension: {answer_file_path}")
132 | else:
133 | if answer_file_path.endswith(".jsonl"):
134 | with open(answer_file_path) as f:
135 | lines = f.readlines()
136 | json_objects = []
137 | for line in lines:
138 | json_objects.append(json.loads(line))
139 | else:
140 | raise ValueError(f"Unsupported answer file extension: {answer_file_path}")
141 | for data_idx, data_blob in enumerate(json_objects):
142 | qid = data_blob.get("qid", str(data_idx))
143 | prediction = predictions.get(qid, None)
144 | if not prediction:
145 | num_unanswered_queries += 1
146 | # An EvaluationInstance is an object that contains the information
147 | # necessary to compute the metrics for a single prediction. Here we
148 | # create a dummy instance with an empty model output, which will cause
149 | # the evaluation to return 0 for all metrics because we do not have an
150 | # answer to compare to.
151 | eval_instance = loft_evaluation.EvaluationInstance(
152 | qid=qid,
153 | gold_answers=data_blob.get(
154 | "answers", [data_blob.get("target", "").split(" ")]
155 | ),
156 | model_output=[""],
157 | num_turns=1,
158 | )
159 | print(f"[Warning] Query {qid} was unanswered, marking it as incorrect.")
160 | else:
161 | if _DEDUPLICATE_PIDS.value:
162 | candidate_path = os.path.join(
163 | os.path.dirname(_ANSWER_FILE_PATH.value), "corpus.jsonl"
164 | )
165 | prediction["metadata"].update({"candidate_path": candidate_path})
166 | eval_instance = loft_evaluation.EvaluationInstance(
167 | qid=qid,
168 | gold_answers=data_blob.get(
169 | "answers", [data_blob.get("target", "").split(" ")]
170 | ),
171 | model_output=prediction["model_outputs"],
172 | num_turns=prediction["num_turns"],
173 | metadata=prediction.get("metadata", None),
174 | )
175 |
176 | # Compute list of metrics dictionaries per instance
177 | instance_metric = eval_task.evaluate(eval_instance)
178 | metrics_per_line.extend(instance_metric)
179 |
180 | # Average all metrics over all predictions
181 | quality_metrics = eval_task.aggregate_metrics()
182 | final_metrics = {
183 | "quality": quality_metrics,
184 | "num_unanswered_queries": num_unanswered_queries,
185 | }
186 |
187 | return metrics_per_line, final_metrics
188 |
189 |
190 | def main(argv: Sequence[str]) -> None:
191 | if len(argv) > 1:
192 | raise app.UsageError("Too many command-line arguments.")
193 |
194 | metrics_per_line, final_metrics = run_evaluation(
195 | answer_file_path=_ANSWER_FILE_PATH.value,
196 | pred_file_path=_PRED_FILE_PATH.value,
197 | task_type=_TASK_TYPE.value,
198 | )
199 |
200 | # Write two metrics file. E.g., if the preds file is /path/to/nq.jsonl, write:
201 | # * /path/to/nq_metrics_per_line.jsonl: Metrics per prediction.
202 | # * /path/to/nq_metrics.json: Metrics averaged over all predictions.
203 | output_dir = os.path.dirname(_PRED_FILE_PATH.value)
204 | file_basename = os.path.splitext(os.path.basename(_PRED_FILE_PATH.value))[0]
205 |
206 | mplpath = os.path.join(output_dir, f"{file_basename}_metrics_per_line.jsonl")
207 | with open(mplpath, "w") as f:
208 | for l in metrics_per_line:
209 | f.write(json.dumps(l) + "\n")
210 |
211 | metrics_path = os.path.join(output_dir, f"{file_basename}_metrics.json")
212 | with open(metrics_path, "w") as f:
213 | f.write(json.dumps(final_metrics, indent=4))
214 |
215 | print(json.dumps(final_metrics, indent=4))
216 | print(f"""Two files written to directory {output_dir}:
217 | * {os.path.basename(mplpath)}: Metrics for each line in the prediction file.
218 | * {os.path.basename(metrics_path)}: Metrics for all predictions.
219 | """.strip())
220 |
221 |
222 | if __name__ == "__main__":
223 | app.run(main)
224 |
--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""Run LCLM inference on LOFT.
17 |
18 | We run LCLMs on LOFT by having LCLMs ingest the entire corpus along with each
19 | query and output the answer in natural language. We use a few-shot prompt to
20 | show what the CoT reasoning should look like.
21 |
22 | Example run command:
23 | # Retrieval
24 | BASE_DIR=./data
25 | LENGTH="32k"
26 | TASK_TYPE="retrieval"
27 | SPLIT="dev"
28 | PROMPT_TYPE="few_shot_cot_simple_reverse:corpus_echo:None_max_shots"
29 | PROMPT="${TASK_TYPE}_${DATASET}_${LENGTH}_${SPLIT}:${PROMPT_TYPE}"
30 |
31 | mkdir -p ${BASE_DIR}/outputs/${TASK_TYPE}/${DATASET}/${LENGTH}
32 |
33 | python run_inference.py \
34 | --prompt_name ${PROMPT} \
35 | --base_dir ${BASE_DIR} \
36 | --data_dir ${TASK_TYPE}/${DATASET}/${LENGTH} \
37 | --split ${SPLIT} \
38 | --context_length ${LENGTH} \
39 | --output_path ${BASE_DIR}/outputs/${TASK_TYPE}/${DATASET}/${LENGTH}/${SPLIT}_predictions.jsonl \
40 | --project_id ${PROJECT_ID} \
41 | --overwrite
42 | """
43 |
44 | import concurrent.futures
45 | import functools
46 | import json
47 | import os
48 | from typing import Any, Dict, Sequence
49 |
50 | from absl import app
51 | from absl import flags
52 | from inference import models
53 | import prompts # pylint: disable=unused-import
54 | from prompts import prompt_registry
55 | import tqdm
56 | import utils
57 |
58 |
59 | CONTEXT_LENGTH_TO_NUM_TOKENS = {
60 | "32k": 32000,
61 | "128k": 128000,
62 | "1m": 1000000,
63 | }
64 |
65 | _PROMPT_NAME = flags.DEFINE_string(
66 | "prompt_name",
67 | None,
68 | "Name of the prompt to use.",
69 | required=True,
70 | )
71 | _TASK_TYPE = flags.DEFINE_string(
72 | "task_type",
73 | None,
74 | "Task type of the prompt to use.",
75 | required=True,
76 | )
77 | _BASE_DIR = flags.DEFINE_string(
78 | "base_dir",
79 | None,
80 | "Path to the base directory.",
81 | required=True,
82 | )
83 | _DATA_DIR = flags.DEFINE_string(
84 | "data_dir",
85 | None,
86 | "Relative path to the data directory given the base directory.",
87 | required=True,
88 | )
89 | _SPLIT = flags.DEFINE_string(
90 | "split",
91 | "dev",
92 | "Split of the data to use.",
93 | )
94 | _OUTPUT_PATH = flags.DEFINE_string(
95 | "output_path",
96 | None,
97 | "Path to write prediction outputs as a JSONL file.",
98 | required=True,
99 | )
100 | _MODEL_URL_OR_NAME = flags.DEFINE_string(
101 | "model_url_or_name",
102 | "gemini-1.5-pro",
103 | "Evergreen model URL or API-based model name.",
104 | )
105 | _PROJECT_ID = flags.DEFINE_string(
106 | "project_id",
107 | None,
108 | "Project ID of Google Cloud Project.",
109 | required=True,
110 | )
111 | _CONTEXT_LENGTH = flags.DEFINE_enum(
112 | "context_length",
113 | "32k",
114 | CONTEXT_LENGTH_TO_NUM_TOKENS.keys(),
115 | "Context length of the prompt. Four pre-defined lengths are available.",
116 | )
117 | _OVERWRITE = flags.DEFINE_bool(
118 | "overwrite",
119 | False,
120 | "If True, regenerate the outputs. If False, reuse results from output file"
121 | "if it already exists.",
122 | )
123 | _CACHE_FIRST_INPUT = flags.DEFINE_bool(
124 | "cache_first_input",
125 | False,
126 | "If True, run and cache the first input to the model.",
127 | )
128 | _MAX_WORKERS = flags.DEFINE_integer(
129 | "max_workers",
130 | 1,
131 | "Maximum number of workers to use for multi-thread inference. This should"
132 | "be 1x to 2x the number of model replicas available.",
133 | )
134 | _LOG_FAILING_PROMPTS = flags.DEFINE_bool(
135 | "log_failing_prompts",
136 | True,
137 | "If True, log the failing prompts. This is useful for debugging VertexAI.",
138 | )
139 |
140 | MimeType = utils.MimeType
141 | ContentChunk = utils.ContentChunk
142 | PromptRegistry = prompt_registry.PromptRegistry
143 |
144 |
145 | def get_num_tokens(text_input: str) -> int:
146 | # Simple tokenization for the estimated number of tokens.
147 | return len(text_input.strip().split(" "))
148 |
149 |
150 | def _run_one_example(
151 | example: utils.Example,
152 | model: models.Model,
153 | finished_lines: Dict[str, Any],
154 | ) -> Dict[str, Any] | None:
155 | """Runs one example and returns the output."""
156 | try:
157 | return utils.run_one_example(example, model, finished_lines)
158 | except Exception as exception: # pylint: disable=broad-exception-caught
159 | print(exception)
160 | output_path = f"{_OUTPUT_PATH.value}.failed_prompt.{example.qid}"
161 | print(f"Logging failing prompt to {output_path}")
162 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
163 | if _LOG_FAILING_PROMPTS.value:
164 | with open(output_path, "wb") as f:
165 | for chunk in example.all_chunks:
166 | f.write(chunk.data)
167 | f.flush()
168 | return None
169 |
170 |
171 | def main(argv: Sequence[str]) -> None:
172 | del argv
173 | if _PROMPT_NAME.value not in PromptRegistry.prompts:
174 | task_name = _PROMPT_NAME.value.split("_")[0]
175 | print(PromptRegistry.prompts.keys())
176 | registry_str = "\n".join(
177 | filter(
178 | lambda x: x.startswith(task_name),
179 | list(PromptRegistry.prompts.keys()),
180 | )
181 | )
182 | raise ValueError(
183 | f"Prompt {_PROMPT_NAME.value} not found in registry.\nAvailable"
184 | f" prompts:\n{registry_str}"
185 | )
186 |
187 | pid_mapper = None
188 | if _TASK_TYPE.value in ["retrieval", "mm"]:
189 | pid_mapper = {
190 | str(idx): pid
191 | for idx, pid in enumerate(
192 | utils.load_data_from_file(
193 | data_dir=_DATA_DIR.value,
194 | base_dir=_BASE_DIR.value,
195 | split=_SPLIT.value,
196 | ).corpus
197 | )
198 | }
199 | answer_prefix = "final answer"
200 | if _TASK_TYPE.value == "icl":
201 | answer_prefix = "output"
202 |
203 | model = models.get_model(
204 | model_url_or_name=_MODEL_URL_OR_NAME.value,
205 | project_id=_PROJECT_ID.value,
206 | pid_mapper=pid_mapper,
207 | answer_prefix=answer_prefix,
208 | )
209 |
210 | finished_lines = {}
211 | if os.path.exists(_OUTPUT_PATH.value) and not _OVERWRITE.value:
212 | with open(_OUTPUT_PATH.value) as f:
213 | for l in f:
214 | l = json.loads(l)
215 | finished_lines[l["qid"]] = l
216 | else:
217 | os.makedirs(os.path.dirname(_OUTPUT_PATH.value), exist_ok=True)
218 | print(f"Found {len(finished_lines)} finished lines.")
219 |
220 | # Log the configuration that was used. This is nice for knowing what exact
221 | # command was to run when you have to look at the results months after an
222 | # experiment is run (e.g. during rebuttal).
223 | utils.save_run_metadata(
224 | flags.FLAGS.flags_by_module_dict(), output_path_prefix=_OUTPUT_PATH.value
225 | )
226 |
227 | # Load the lines for inference and the one-shot prompt, then runs inference.
228 | examples = PromptRegistry.get_examples(
229 | name=_PROMPT_NAME.value, base_dir=_BASE_DIR.value
230 | )
231 | qid2example = {ex.qid: ex for ex in examples}
232 |
233 | for ex in qid2example.values():
234 | if not all(chunk.mime_type == MimeType.TEXT for chunk in ex.context_chunks):
235 | continue
236 | num_tokens = get_num_tokens(
237 | "\n".join(chunk.data.decode("utf-8") for chunk in ex.all_chunks)
238 | )
239 | if num_tokens > CONTEXT_LENGTH_TO_NUM_TOKENS[_CONTEXT_LENGTH.value]:
240 | raise ValueError(
241 | f"qid={ex.qid} has {num_tokens} tokens in its prompt, which is more"
242 | f" than the context length of {_CONTEXT_LENGTH.value}"
243 | )
244 |
245 | # Caching and saving one prompt to disk.
246 | print("Starting saving one prompt to disk...")
247 | indexing_example = list(qid2example.values())[0]
248 | prompt_path = f"{_OUTPUT_PATH.value}.prompt_first_query_example"
249 | utils.save_content_chunks(indexing_example.all_chunks, prompt_path)
250 | print(f"Finished saving one prompt to disk in {prompt_path}.txt")
251 |
252 | if _CACHE_FIRST_INPUT.value:
253 | try:
254 | print("Starting caching.")
255 | # Do prefix cache by running the inference once.
256 | model_output = model.infer(
257 | list(indexing_example.all_chunks),
258 | )
259 | print("Finished caching. Model output:", model_output)
260 | except Exception as exception: # pylint: disable=broad-exception-caught
261 | print(exception)
262 | print("Failed to cache; continuing inference without caching...")
263 |
264 | with open(_OUTPUT_PATH.value, "w", encoding="utf-8") as f:
265 | with concurrent.futures.ThreadPoolExecutor(
266 | max_workers=_MAX_WORKERS.value
267 | ) as executor:
268 | eval_futures = executor.map(
269 | functools.partial(
270 | _run_one_example, model=model, finished_lines=finished_lines
271 | ),
272 | qid2example.values(),
273 | )
274 | for output in tqdm.tqdm(eval_futures, total=len(qid2example)):
275 | if output:
276 | f.write(json.dumps(output, ensure_ascii=False) + "\n")
277 | f.flush()
278 |
279 | print(f"Wrote results to {_OUTPUT_PATH.value}")
280 |
281 |
282 | if __name__ == "__main__":
283 | app.run(main)
284 |
--------------------------------------------------------------------------------