├── LICENSE
├── README.md
├── evaluation
├── qdu-tasks
│ ├── cqa.sh
│ ├── eval_rerank.py
│ ├── postprocess_cqa.py
│ ├── run_eval.sh
│ └── src
│ │ ├── __init__.py
│ │ ├── args.py
│ │ ├── metrics.py
│ │ ├── modeling.py
│ │ └── utils.py
├── qu-du-tasks
│ ├── eval_sampling.py
│ ├── inference_dataset.py
│ ├── inference_qu_du.py
│ └── inference_tasks
│ │ ├── conversational_qa.py
│ │ ├── fact_verification.py
│ │ ├── query_clarification.py
│ │ ├── query_description.py
│ │ ├── query_expansion.py
│ │ ├── query_intent_classification.py
│ │ ├── query_matching.py
│ │ ├── query_reformulation.py
│ │ ├── query_subtopic_generation.py
│ │ ├── query_suggestion.py
│ │ ├── reading_comprehension.py
│ │ └── summarization.py
└── readme.md
├── img
├── dataset.png
├── in-domain-google.png
├── intro.jpg
├── logo1.jpg
└── process.jpg
└── instruct_templates.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yutao ZHU
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
6 | ## INTERS: Unlocking the Power of Large Language Models in Search with Instruction Tuning
7 |
8 |
9 |
10 |
11 |
12 |
13 | **Authors**: Yutao Zhu, Peitian Zhang, Chenghao Zhang, Yifei Chen, Binyu Xie, Zhicheng Dou, Zheng Liu, and Ji-Rong Wen
14 |
15 |
16 | 📃 Paper
17 | •
18 | 📚 Dataset
19 |
20 |
21 | 🤗 HuggingFace Model List
22 |
23 |
24 | | Model | Backbone Model |
25 | |:---------------------------------------------------------------------------------|:------------------------------------------------------------------------|
26 | | [INTERS-LLaMA-7b-Chat](https://huggingface.co/yutaozhu94/INTERS-LLaMA-7b-chat) | [LLaMA-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
27 | | [INTERS-LLaMA-7b-Base](https://huggingface.co/yutaozhu94/INTERS-LLaMA-7b-base) | [LLaMA-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) |
28 | | [INTERS-Mistral-7b](https://huggingface.co/yutaozhu94/INTERS-Mistral-7b) | [Mistral-7b](https://huggingface.co/mistralai/Mistral-7B-v0.1) |
29 | | [INTERS-Minima-3b](https://huggingface.co/yutaozhu94/INTERS-Minima-3b) | [Minima-2-3b](https://huggingface.co/GeneZC/MiniMA-2-3B) |
30 | | [INTERS-Falcon-1b](https://huggingface.co/yutaozhu94/INTERS-Falcon-1b) | [Falcon-rw-1b](https://huggingface.co/tiiuae/falcon-rw-1b) |
31 |
32 | ## News
33 | - May, 2024: We are happy that INTERS has been accepted by ACL 2024 main conference!
34 | - Feb, 2024: We have released the dataset, instruction templates, fine-tuned models, and evaluation scripts.
35 |
36 | ## Introduction
37 |
38 |

39 |
40 |
41 | Large language models (LLMs) have demonstrated impressive capabilities in various natural language processing tasks. Despite this, their application to information retrieval (IR) tasks is still challenging due to the infrequent occurrence of many IR-specific concepts in natural language. While prompt-based methods can provide task descriptions to LLMs, they often fall short in facilitating a comprehensive understanding and execution of IR tasks, thereby limiting LLMs' applicability. To address this gap, in this work, we explore the potential of instruction tuning to enhance LLMs' proficiency in IR tasks. We introduce a novel instruction tuning dataset, INTERS, encompassing 20 tasks across three fundamental IR categories: query understanding, document understanding, and query-document relationship understanding. The data are derived from 43 distinct datasets with manually written templates. Our empirical results reveal that INTERS significantly boosts the performance of various publicly available LLMs, such as LLaMA, Mistral, and Phi, in IR tasks. Furthermore, we conduct extensive experiments to analyze the effects of instruction design, template diversity, few-shot demonstrations, and the volume of instructions on performance.
42 |
43 | ## Tasks & Datasets
44 | We consider tasks under the categories of query understanding, document understanding, and query-document understanding. Our dataset consists of 20 tasks derived from 43 datasets. All tasks and datasets we used are shown in the figure below.
45 |
46 |

47 |
48 |
49 | ## Dataset Construction
50 |
51 |

52 |
53 |
54 | ## General Performance
55 |
56 |

57 |
58 |
59 | ## Zero-shot Evaluation
60 | The evaluation script is under the ``evaluation`` directory.
61 |
62 | ### Required packages
63 | ```
64 | torch 2.0.0
65 | transformers 4.36.2
66 | numpy 1.26.3
67 | tqdm 4.66.1
68 | scikit-learn 1.4.0
69 | rouge_score 0.1.2
70 | nltk 3.8.1
71 | accelerate 0.26.1
72 | ```
73 |
74 | ### For query understanding tasks and document understanding tasks (qu-du-tasks)
75 | This evaluation script use pytorch DDP for text generation.
76 |
77 | 1. Download [test data](https://huggingface.co/datasets/yutaozhu94/INTERS/tree/main/test-qu-du-zero-shot) and save it to ``data/in-domain/zero_shot/``. The directory structure is like below:
78 | ```
79 | qu-du-tasks
80 | ├── eval_sampling.py
81 | ├── inference_dataset.py
82 | ├── inference_qu_du.py
83 | ├── inference_tasks
84 | │ ├── conversational_qa.py
85 | │ ├── fact_verification.py
86 | │ └── ...
87 | └── data
88 | └── in-domain
89 | └── zero-shot
90 | ├── conversational_qa_coqa.zero_shot.test.jsonl
91 | ├── conversational_qa_quac.zero_shot.test.jsonl
92 | ├── fact_verification_climate_fever.zero_shot.test.jsonl
93 | ├── fact_verification_fever.zero_shot.test.jsonl
94 | ├── fact_verification_scifact.zero_shot.test.jsonl
95 | └── ...
96 | ```
97 | 2. If you choose to place the test files in other directories, you can modify the path in each task file under ``inference_tasks`` directory (in ``get_path()`` function).
98 |
99 | 3. Run evaluation as
100 | ```
101 | TOKENIZERS_PARALLELISM=True python3 inference_qu_du.py \
102 | --model_name_or_path your/model/path \
103 | --tokenizer_name your/tokenizer/path \
104 | --setting in-domain \
105 | --n_shots zero_shot
106 | ```
107 |
108 | ### For query-document relationship understanding tasks (qdu-tasks)
109 | 1. Download [test data](https://huggingface.co/datasets/yutaozhu94/INTERS/tree/main/test-qdu) and save it to ``data/``. The directory structure is like below:
110 | ```
111 | qdu-tasks
112 | ├── cqa.sh
113 | ├── eval_rank.py
114 | ├── postprocess_cqa.py
115 | ├── run_eval.sh
116 | └── data
117 | ├── cqadupstack
118 | │ ├── android
119 | │ │ └── test.pt.key.do-not-overwrite.json
120 | │ ├── english
121 | │ │ └── test.pt.key.do-not-overwrite.json
122 | │ └── ...
123 | ├── arguana.bm25.100.jsonl
124 | ├── climate_fever.bm25.100.jsonl
125 | └── ...
126 | ```
127 | 1. For datasets other than cqadupstack, modify the paths in ``run_eval.sh``, then run the script
128 | ```
129 | MODEL_PATH="your/model/path"
130 | TOKENIZER_PATH="your/tokenizer/path"
131 | RESULT_PATH="your/result/path"
132 | EVAL_DATA_PATH="data"
133 |
134 | -----------------------
135 | bash run_eval.sh
136 | ```
137 | 2. For cqadupstack dataset, modify the paths in ``cqa.sh``, then run the script
138 | ```
139 | MODEL_PATH="your/model/path"
140 | TOKENIZER_PATH="your/tokenizer/path"
141 | RESULT_PATH="your/result/path"
142 |
143 | -----------------------
144 | bash cqa.sh
145 | ```
146 | 3. This script supports testing pointwise/pairwise/listwise methods for reranking. Modify the parameter of ``eval_rerank.py`` in ``run_eval.sh`` or ``cqa.sh``
147 | ```
148 | # pointwise: (default)
149 | --rerank_method pointwise
150 |
151 | # pairwise:
152 | --rerank_method pairwise
153 |
154 | # listwise:
155 | --rerank_method listwise \
156 | --listwise_window 5 \
157 | --listwise_stride 5
158 | ```
159 |
160 | ## Citation
161 | Please kindly cite our paper if it helps your research:
162 | ```BibTex
163 | @inproceedings{INTERS,
164 | author = {Yutao Zhu and
165 | Peitian Zhang and
166 | Chenghao Zhang and
167 | Yifei Chen and
168 | Binyu Xie and
169 | Zheng Liu and
170 | Ji{-}Rong Wen and
171 | Zhicheng Dou},
172 | editor = {Lun{-}Wei Ku and
173 | Andre Martins and
174 | Vivek Srikumar},
175 | title = {{INTERS:} Unlocking the Power of Large Language Models in Search with
176 | Instruction Tuning},
177 | booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational
178 | Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand,
179 | August 11-16, 2024},
180 | pages = {2782--2809},
181 | publisher = {Association for Computational Linguistics},
182 | year = {2024},
183 | url = {https://doi.org/10.18653/v1/2024.acl-long.154},
184 | doi = {10.18653/V1/2024.ACL-LONG.154},
185 | }
186 | ```
187 |
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/cqa.sh:
--------------------------------------------------------------------------------
1 | MODEL_PATH="your/model/path"
2 | TOKENIZER_PATH="your/tokenizer/path"
3 | RESULT_PATH="your/result/path"
4 |
5 | RESULT_PATH=${RESULT_PATH}/cqadupstack
6 | EVAL_DATA_PATH="data"
7 |
8 | mkdir ${RESULT_PATH}
9 |
10 | # to gather metrics from all sub-datasets
11 | TMP_PATH=${RESULT_PATH}/tmp.log
12 |
13 | ############# MODIFY PATH HERE #############
14 | # containing all sub-dataset folders
15 | CQA_ROOT=${EVAL_DATA_PATH}/cqadupstack/
16 | ################################################
17 |
18 | COUNTER=0
19 | for dataset in $CQA_ROOT/*
20 | do
21 |
22 | # get fewshot data files
23 | fewshot_data=($dataset/*train.pt.neg.do-not-overwrite.fewshot*)
24 | fewshot_data=${fewshot_data[0]}
25 |
26 | eval_data="$dataset/test.pt.key.do-not-overwrite.json"
27 |
28 | ############# MODIFY COMMANDS HERE #############
29 | outputString=`torchrun --nproc_per_node 8 eval_rerank.py \
30 | --eval_data $eval_data \
31 | --output_dir $RESULT_PATH \
32 | --model_name_or_path $MODEL_PATH \
33 | --tokenizer_name_or_path $TOKENIZER_PATH \
34 | --hits 10 \
35 | --rerank_method pointwise \
36 | --dataset_name cqadupstack \
37 | --batch_size 4`
38 |
39 | # to add 1-shot
40 | # --fewshot_data $fewshot_data \
41 | # --shots 1
42 | ################################################
43 |
44 | if [[ $COUNTER == 0 ]]
45 | then
46 | echo $outputString > $TMP_PATH
47 | else
48 | echo $outputString >> $TMP_PATH
49 | fi
50 |
51 | COUNTER=$[$COUNTER +1]
52 | done
53 |
54 | python postprocess_cqa.py -t $TMP_PATH
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/eval_rerank.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datasets
3 | import numpy as np
4 | from typing import List
5 | from dataclasses import dataclass, field, asdict
6 | from torch.utils.data import DataLoader
7 | from collections import defaultdict
8 | from accelerate import Accelerator
9 | from transformers import HfArgumentParser
10 | from transformers.utils import logging
11 |
12 | from src import ModelArgs, Metric, DatasetProcessFn, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs
13 |
14 |
15 | logger = logging.get_logger(__name__)
16 |
17 |
18 | @dataclass
19 | class Args(ModelArgs):
20 | eval_data: str = field(
21 | default=None,
22 | metadata={'help': 'The evaluation json data path.'}
23 | )
24 | fewshot_data: str = field(
25 | default=None,
26 | metadata={'help': 'The fewshot json data path.'}
27 | )
28 | output_dir: str = field(
29 | default="data/results/rerank",
30 | metadata={'help': 'Output directory for results and logs.'}
31 | )
32 | batch_size: int = field(
33 | default=16,
34 | metadata={'help': 'Evaluation batch size.'}
35 | )
36 |
37 | query_max_length: int = field(
38 | default=64,
39 | metadata={'help': 'How many tokens at maximum?'}
40 | )
41 | doc_max_length: int = field(
42 | default=512,
43 | metadata={'help': 'How many tokens at maximum?'}
44 | )
45 |
46 | with_description: bool = field(
47 | default=True,
48 | metadata={'help': "Whether to add task description"}
49 | )
50 | rerank_method: str = field(
51 | default="pointwise",
52 | metadata={'help': 'How to evaluate reranking? {pointwise, pairwise, listwise, no}'}
53 | )
54 | dataset_name: str = field(
55 | default="msmarco",
56 | metadata={'help': 'Select one from [msmarco | trec_covid | nfcorpus | nq | fiqa | hotpot_qa | arguana | touche | cqadupstack | quora | dbpdia | scidocs | climate_fever | fever | scifact]'},
57 | )
58 | hits: int = field(
59 | default=100,
60 | metadata={'help': 'How many candidates to rerank?'}
61 | )
62 | shots: int = field(
63 | default=None,
64 | metadata={'help': 'How many shots to use for fewshot testing?'}
65 | )
66 |
67 | metrics: List[str] = field(
68 | default_factory=lambda: ["mrr", "recall", "ndcg"],
69 | metadata={'help': 'List of metrics. {rouge, acc}'}
70 | )
71 | cutoffs: List[int] = field(
72 | default_factory=lambda: [1, 5, 10],
73 | metadata={'help': 'Cutoffs to evaluate retrieval metrics.'}
74 | )
75 | listwise_window: int = field(
76 | default=10,
77 | metadata={'help': 'How long is the window in listwise?'}
78 | )
79 | listwise_stride: int = field(
80 | default=5,
81 | metadata={'help': 'How long is the step in listwise?'}
82 | )
83 |
84 |
85 | TASK_DESCRIPTION = "In the reranking task, search engines must understand the relationship between the user's query, which may be keywords or a sentence, and the potential documents. The goal is to ensure that the most relevant documents, those that best cover the user's information needs, are ranked highest. This requires a nuanced understanding of both the query's intent and the content of the documents."
86 |
87 | PROMPT_TEMPLATE = {
88 | "pointwise": {
89 | "msmarco": "Assess the relevance between the provided document:\n{text}\nand the query: \"{query}\". Respond with 'Yes' if the document is relevant to the query or 'No' if not.",
90 | "trec_covid": "Document: {text}\n\nQuery: {query}\n\nAssess the relevance of the provided document in the context of the COVID-19-related query. Answer 'Yes' if the document explicitly addresses or pertains to the query, or 'No' if it is unrelated.",
91 | "nfcorpus": "Assess the relevance of the medical document:\n{text}\nin relation to the search query \"{query}\". Determine if the document is relevant by responding with 'Yes' for relevance or 'No' for irrelevance.",
92 | "nq": "Review the content of the document:\n{text}\nand ascertain its relevance to the topic: \"{query}\". Provide your determination by responding with either 'Yes' for relevance or 'No' for irrelevance.",
93 | "fiqa": "Evaluate the financial document:\n{text}\nin the context of the query: \"{query}\". Determine if the document is relevant to the query and provide your judgment with 'Yes' for relevance or 'No' for irrelevance.",
94 | "hotpot_qa": "Analyze the relationship between the document:\n{text}\nand the query: \"{query}\". Offer a definitive response by indicating whether they are relevant, and reply with either 'Yes' for relevance or 'No' for irrelevance.",
95 | "arguana": "Document:\n\n{text}\n\nQuery:\n\n{query}\n\nDetermine their relevance and return the judgment about the document's relevance to the query by responding with either 'Yes' for relevance or 'No' for irrelevance.",
96 | "touche": "Evaluate the relevance between the given document on controversial subjects:\n{text}\nand the query: \"{query}\", and provide a clear 'Yes' for relevance or 'No' for irrelevance judgment regarding their relevance.",
97 | "cqadupstack": "Evaluate the relevance between the given document on controversial subjects:\n{text}\nand the query: \"{query}\", and provide a clear 'Yes' for relevance or 'No' for irrelevance judgment regarding their relevance.",
98 | "quora": "Determine the relevance between the document:\n{text}\nand the query: {query}. Judge whether they are related by responding with 'Yes' for relevance or 'No' for irrelevance.",
99 | "dbpedia": "Determine the relevance of the document:\n{text}\nto the query: \"{query}\". Conclude with 'Yes' for relevance or 'No' for irrelevance.",
100 | "scidocs": "Determine the correlation between the provided literature-related document:\n{text}\nand the query: \"{query}\". Conclude if they are closely connected with a 'Yes' for relevance or 'No' for irrelevance.",
101 | "climate_fever": "Analyze the correlation between this document on the topic of climate change:\n{text}\nand the following query: {query}. Determine if the document is relevant to the query. Respond with 'Yes' for relevance or 'No' for irrelevance.",
102 | "fever": "Assess the relationship between the given document:\n{text}\nand the query: \"{query}\" to determine if the document is relevant. Answer with 'Yes' for relevance or 'No' for irrelevance.",
103 | "scifact": "Assess the relevance between the scientific document:\n{text}\nand query: \"{query}\", and state 'Yes' or 'No' to reflect the relevance judgment. Answer 'Yes' for relevance and 'No' for irrelevance."
104 | },
105 | "pairwise": {
106 | "msmarco": "Consider a query \"{query}\" alongside two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nDecide which document is more relevant to the given query by providing the corresponding document identifier.",
107 | "trec_covid": "Given a query: \"{query}\" and two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nEach marked with a unique identifier and related to COVID-19, assess and identify which document is more closely related to the query. Respond with the identifier of the more relevant document.",
108 | "nfcorpus": "Assess which of the two medical field documents is more relevant to the provided query \"{query}\". Each document is presented with a unique identifier.\nDocuments:\n\n[1] {doc1}\n\n[2] {doc2}\n\nIdentify and return the identifier of the document that best aligns with the query.",
109 | "nq": "Evaluate the relevance of the provided query \"{query}\" to a pair of documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach identified separately. Determine which document is more relevant to the query and return the identifier of the more relevant document.",
110 | "fiqa": "Evaluate the relevance of the query: \"{query}\" and the pair of financial documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach assigned a unique identifier. Determine which document is more relevant to the provided query and specify its document identifier.",
111 | "hotpot_qa": "Compare the relevance of two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nto the provided query: \"{query}\". Identify the document with the higher relevance by specifying its identifier.",
112 | "arguana": "Evaluate the relevance of two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nto the provided query \"{query}\". Express the identifier of the document that is more relevant to the query.",
113 | "touche": "Given a query: \"{query}\" and two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach with its unique identifier, determine which document is more relevant to the given query by providing the document identifier.",
114 | "cqadupstack": "Compare the relevance of two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nto the query: \"{query}\", and specify the document identifier that has higher relevance.",
115 | "quora": "With a given query: \"{query}\" and two unique documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nIdentify the document that aligns more closely with the query by stating its identifier.",
116 | "dbpedia": "Given a query: \"{query}\" and two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach identified by a distinct number. Determine the document identifier for the one more relevant to the provided query.",
117 | "scidocs": "Given a query:\"{query}\" and two literature-related documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach with a unique identifier, identify which document is more relevant to the query by specifying its identifier.",
118 | "climate_fever": "Given this query:\"{query}\" and two climate change documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach with a unique identifier, identify which document aligns more closely with the query by indicating the document's identifier.",
119 | "fever": "Evaluate two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nagainst a given query: \"{query}\" and identify the document that is more relevant by its identifier.",
120 | "scifact": "Analyze the relevance of two scientific documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nto the query: \"{query}\" and identify the more relevant document by its identifier."
121 | },
122 | "listwise": {
123 | "msmarco": "Here are {num} documents:\n{docs}\nand a query \"{query}\", and each document is indicated by a number identifier. Please sort the documents in an order based on their relevance to the above query by returning an identifier list. Be careful to sort documents in order of their relevance to the query from highest to lowest. Result: ",
124 | "trec_covid": "Given {num} documents:\n{docs}\neach pertaining to COVID-19, and a specific query: \"{query}\", organize the documents in order of relevance to the query. List the identifiers of these documents starting from the most relevant to the least relevant. Result: ",
125 | "nfcorpus": "Arrange the given {num} medical field documents:\n{docs}\nin order of relevance to the specified query \"{query}\", with the most relevant document at the top and the least relevant at the bottom. Use their unique number identifiers to indicate the sequence of relevance. Result: ",
126 | "nq": "Arrange the provided {num} documents:\n{docs}\nin accordance with their relevance to the specified query \"{query}\". Utilize number identifiers to denote the sequence of relevance, with the most relevant document at the top and the least relevant document at the end. Result: ",
127 | "fiqa": "Assess the relevance of the query: \"{query}\" to a set of {num} financial documents:\n{docs}\nGenerate a list of document identifiers, arranging them from the most relevant to the least relevant in relation to the query, and return the identifiers list. Result: ",
128 | "hotpot_qa": "Rank the following {num} documents in descending order of relevance to the query: \"{query}\"\n{docs}\nProvide the list of identifiers. Result: ",
129 | "arguana": "Rank the {num} documents:\n{docs}\n in descending order of relevance to the provided query: \"{query}\". Return the list of identifiers associated with each document in order. Result:",
130 | "touche": "Rerank the provided {num} documents on a controversial topic:\n{docs}\nbased on their relevance to the query \"{query}\". Return a list of identifiers in the order of descending relevance. Result: ",
131 | "cqadupstack": "Rerank the provided {num} documents on a controversial topic:\n{docs}\nbased on their relevance to the query \"{query}\". Return a list of identifiers in the order of descending relevance. Result: ",
132 | "quora": "Rank {num} documents:\n{docs}\nby their relevance to the query: \"{query}\" in descending order. List their identifiers. Result: ",
133 | "dbpedia": "Rank {num} documents:\n{docs}\nbased on their relevance to the specified query \"{query}\". Return the list of identifiers in descending order. Result: ",
134 | "scidocs": "Rank {num} documents:\n{docs}\nin order of their relevance to the query: {query}. List the identifiers starting with the most relevant document. Result: ",
135 | "climate_fever": "Rerank these {num} climate change documents:\n{docs}\nby their relevance to the query: \"{query}\" and list the identifiers. Rank from most to least relevant. Result: ",
136 | "fever": "Rank a set of {num} documents:\n{docs}\nby their relevance to a query: \"{query}\". List their identifiers in order of decreasing relevance. Result: ",
137 | "scifact": "Order the provided {num} scientific documents:\n{docs}\n based on their relevance to the query: \"{query}\", from most to least relevant. Only output a list of identifiers. Result: "
138 | # "Here is a query:\n\n{query}\n\nHere are {num} documents, and each document has its own identifier:\n\n{docs}\n\nRank the {num} documents above based on their relevance to the given query. The documents should be listed in descending order using identifiers. The most relevant documents should be listed first. The output format should be [] > [], e.g., [1] > [2]. Each two identifiers are separated by ' > '. Only response the ranking results, do not say any word or explain.",
139 | },
140 | "no": defaultdict(lambda: ""),
141 | }
142 |
143 |
144 | def truncate(text, tokenizer, max_length):
145 | if tokenizer is not None:
146 | return tokenizer.decode(tokenizer.encode(text, add_special_tokens=False, max_length=max_length, truncation=True))
147 | else:
148 | return text
149 |
150 |
151 | def process_rerank(tokenizer, rerank_method, prompt_template, query_max_length=64, doc_max_length=512, hits=10, fewshot_data=None, shots=None, cache_dir=None):
152 | if fewshot_data is not None:
153 | fewshot_data = datasets.load_dataset("json", data_files=fewshot_data, split="train", cache_dir=cache_dir)
154 | rng = np.random.default_rng(42)
155 | indices = rng.choice(range(len(fewshot_data)), size=shots, replace=False).tolist()
156 |
157 | fewshot_prompt = []
158 |
159 | for index in indices:
160 | item = fewshot_data[index]
161 | # NOTE: do not use query, pos, neg here, because they are the arguments for the _process function
162 | fs_query = item["query"]
163 | fs_pos = item["pos"]
164 | fs_neg = item["neg"]
165 |
166 | fs_query = truncate(fs_query, tokenizer, query_max_length)
167 |
168 | if rerank_method == "pointwise":
169 | # sample 1 candidate from the union of pos and neg
170 | candidate = fs_pos + fs_neg
171 | candidate_range = range(len(candidate))
172 | candidate_index = rng.choice(candidate_range).tolist()
173 | candidate = truncate(candidate[candidate_index], tokenizer, doc_max_length)
174 | prompt = prompt_template.format(query=fs_query, text=candidate)
175 |
176 | if candidate_index < len(fs_pos):
177 | # sampled from pos
178 | prompt += " Yes."
179 | else:
180 | prompt += " No."
181 |
182 | elif rerank_method == "pairwise":
183 | # sample 1 positive and 1 negative for comparison
184 | fs_pos = rng.choice(fs_pos).tolist()
185 | fs_neg = rng.choice(fs_neg).tolist()
186 |
187 | reverse = rng.choice([True, False])
188 | if reverse:
189 | prompt = prompt_template.format(query=fs_query, doc1=fs_neg, doc2=fs_pos)
190 | prompt += "[2]"
191 | else:
192 | prompt = prompt_template.format(query=fs_query, doc1=fs_pos, doc2=fs_neg)
193 | prompt += "[1]"
194 |
195 | elif rerank_method == "listwise":
196 | len_pos = len(fs_pos)
197 | len_neg = len(fs_neg)
198 | len_all = len_pos + len_neg
199 | all_documents = fs_pos + fs_neg
200 | result = list(range(1, len_all + 1))
201 | rng.shuffle(result)
202 | idx_document = {idx: document for idx, document in zip(result, all_documents)}
203 | idx_document = dict(sorted(idx_document.items()))
204 | result = [f"[{num}]" for num in result]
205 | result = " > ".join(result)
206 | docs = ""
207 | rank = 0
208 | for value in idx_document.values():
209 | rank + 1
210 | docs += f"[{rank}] " + value + "\n"
211 | prompt = prompt_template.format(query=fs_query, num=rank, docs=docs)
212 | prompt += " " + result
213 |
214 | else:
215 | raise NotImplementedError(f"Rerank method {rerank_method} not implemented for fewshot!")
216 |
217 | fewshot_prompt.append(prompt)
218 |
219 | fewshot_prompt = "\n\n".join(fewshot_prompt) + "\n\n"
220 |
221 | else:
222 | fewshot_prompt = None
223 |
224 |
225 | @DatasetProcessFn(augment=True)
226 | def _process_pointwise(query, pos=None, neg=None, pos_index=None, neg_index=None, pos_score=None, key=None, key_index=None, query_id=None, _index=None, **kwds):
227 | outputs = defaultdict(list)
228 | if pos_score is None:
229 | pos_score = [1 for _ in pos]
230 |
231 | # rerank positive and negative documents when there are no pre-defined candidates
232 | if neg is not None:
233 | key = pos + neg
234 | key_index = pos_index + neg_index
235 |
236 | if query_id is None:
237 | assert _index is not None, f"Make sure to set with_indices=True when there is no given query_id!"
238 | query_id = _index
239 |
240 | # truncate query
241 | query = truncate(query, tokenizer, query_max_length)
242 | # only rerank the top-hits candidates
243 | key = key[:hits]
244 | key_index = key_index[:hits]
245 |
246 | for doc_id, doc_text in zip(key_index, key):
247 | # truncate document
248 | doc_text = truncate(doc_text, tokenizer, doc_max_length)
249 |
250 | prompt = prompt_template.format(query=query, text=doc_text)
251 | # NOTE: prepend fewshot prompt
252 | if fewshot_prompt is not None:
253 | prompt = fewshot_prompt + prompt
254 |
255 | output = tokenizer(prompt)
256 |
257 | output["query_id"] = query_id
258 | output["doc_id"] = doc_id
259 |
260 | for k, v in output.items():
261 | outputs[k].append(v)
262 | return outputs
263 |
264 |
265 | @DatasetProcessFn()
266 | def _process_other(query, pos=None, neg=None, pos_index=None, neg_index=None, pos_score=None, key=None, key_index=None, query_id=None, _index=None, **kwds):
267 | if pos_score is None:
268 | pos_score = [1 for _ in pos]
269 |
270 | # rerank positive and negative documents when there are no pre-defined candidates
271 | if neg is not None:
272 | key = pos + neg
273 | key_index = pos_index + neg_index
274 |
275 | if query_id is None:
276 | assert _index is not None, f"Make sure to set with_indices=True when there is no given query_id!"
277 | query_id = _index
278 |
279 | # truncate query
280 | query = truncate(query, tokenizer, query_max_length)
281 |
282 | if len(key) < hits:
283 | return None
284 |
285 | # only rerank the top-hits candidates
286 | key = key[:hits]
287 | key_index = key_index[:hits]
288 | for i, k in enumerate(key):
289 | key[i] = truncate(k, tokenizer, doc_max_length)
290 |
291 | outputs = {
292 | "query": query,
293 | "query_id": query_id,
294 | "docs": key,
295 | "doc_ids": key_index
296 | }
297 |
298 | # always add fewshot prompt
299 | outputs["fewshot_prompt"] = fewshot_prompt
300 | return outputs
301 |
302 | if rerank_method == "pointwise":
303 | return _process_pointwise
304 | else:
305 | return _process_other
306 |
307 |
308 | def main():
309 | parser = HfArgumentParser([Args])
310 | args: Args = parser.parse_args_into_dataclasses()[0]
311 |
312 | # NOTE: if skip, we just use CPU to load model
313 | # import json
314 | # with open("/share/peitian/Data/Datasets/searchgpt/qrels.train.pt.neg.do-not-overwrite.fewshot.jsonl", 'r', encoding='utf-8') as fr:
315 | # lines = [json.loads(line.strip()) for line in fr.readlines()[:100]]
316 | # with open("/share/yutao/yifei/data/test_fewshot_100.jsonl", 'w', encoding='utf-8') as fw:
317 | # for line in lines:
318 | # json.dump(line, fw)
319 | # fw.write('\n')
320 | # with open("/share/peitian/Data/Datasets/searchgpt/qrels.dev.pt.key.do-not-overwrite.jsonl", 'r', encoding='utf-8') as fr:
321 | # lines = [json.loads(line.strip()) for line in fr.readlines()[:100]]
322 | # with open("/share/yutao/yifei/data/test_100.jsonl", 'w', encoding='utf-8') as fw:
323 | # for line in lines:
324 | # json.dump(line, fw)
325 | # fw.write('\n')
326 | # return
327 |
328 | accelerator = Accelerator(cpu=args.cpu or args.rerank_method == "no")
329 |
330 | if args.rerank_method == "no":
331 | model = None
332 | tokenizer = None
333 | logger.info(f"directly evaluating {args.eval_data}...")
334 | else:
335 | model, tokenizer = get_model_and_tokenizer(args, accelerator=accelerator)
336 |
337 | with accelerator.main_process_first():
338 | prompt_template = PROMPT_TEMPLATE[args.rerank_method][args.dataset_name]
339 | if args.with_description:
340 | prompt_template = TASK_DESCRIPTION + " " + prompt_template
341 | process_fn = process_rerank(
342 | tokenizer=tokenizer,
343 | rerank_method=args.rerank_method,
344 | prompt_template=prompt_template,
345 | query_max_length=args.query_max_length,
346 | doc_max_length=args.doc_max_length,
347 | hits=args.hits,
348 | fewshot_data=args.fewshot_data,
349 | shots=args.shots,
350 | cache_dir=args.dataset_cache_dir,
351 | )
352 | dataset = datasets.load_dataset("json", data_files=args.eval_data, cache_dir=args.dataset_cache_dir, split="train")
353 | dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, with_indices=True)
354 |
355 | if args.rerank_method == "no":
356 | # directly compose rerank results from retrieval results
357 | from src import RerankResult
358 | results = {}
359 | for x in dataset:
360 | query_id = x["query_id"]
361 | results[query_id] = [RerankResult(doc_id, 0) for doc_id in x["doc_ids"]]
362 |
363 | else:
364 | data_collator = DefaultDataCollator(tokenizer=tokenizer)
365 | dataloader = DataLoader(
366 | dataset,
367 | batch_size=args.batch_size,
368 | collate_fn=data_collator,
369 | pin_memory=model.device.index is not None,
370 | )
371 |
372 | if args.rerank_method == "pointwise":
373 | results = model.rerank_pointwise(dataloader, accelerator=accelerator)
374 | elif args.rerank_method == "pairwise":
375 | results = model.rerank_pairwise(dataloader, prompt_template=prompt_template, accelerator=accelerator)
376 | elif args.rerank_method == "listwise":
377 | results = model.rerank_listwise(dataloader, prompt_template=prompt_template, accelerator=accelerator, window=args.listwise_window, stride=args.listwise_stride)
378 | else:
379 | raise NotImplementedError(f"Rerank method {args.rerank_method} not implemented!")
380 |
381 | if accelerator.process_index == 0:
382 | result_path = Metric._get_save_path(args.eval_data, args.output_dir)
383 | Metric._save_rerank_result(results, result_path, eval_data=args.eval_data)
384 | metrics = Metric.get_metric_fn(metric_names=args.metrics, eval_data=args.eval_data, cutoffs=args.cutoffs)(results)
385 |
386 | file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log")))
387 | file_logger.log(metrics, Args=asdict(args))
388 |
389 |
390 | if __name__ == "__main__":
391 | main()
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/postprocess_cqa.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import argparse
4 | from collections import defaultdict
5 |
6 | def main(args):
7 | all_results = defaultdict(list)
8 | with open(args.tmp_path, encoding="utf-8") as f:
9 | for line in f:
10 | result = re.search("Metrics : (\{.*\})", line).group(1)
11 | result = json.loads(result)
12 | for k, v in result.items():
13 | all_results[k].append(v)
14 | for k, v in all_results.items():
15 | all_results[k] = sum(v) / len(v)
16 | print(dict(all_results))
17 |
18 |
19 | if __name__ == "__main__":
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("-t", "--tmp_path", default="data/results/cqadupstack/tmp.log")
22 | args = parser.parse_args()
23 |
24 | main(args)
25 |
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/run_eval.sh:
--------------------------------------------------------------------------------
1 | GPU_NUM=8
2 | MODEL_PATH="your/model/path"
3 | TOKENIZER_PATH="your/tokenizer/path"
4 | RESULT_PATH="your/result/path"
5 | EVAL_DATA_PATH="data"
6 |
7 | torchrun --nproc_per_node 8 eval_rerank.py \
8 | --eval_data ${EVAL_DATA_PATH}/msmarco.bm25.100.jsonl \
9 | --output_dir ${RESULT_PATH} \
10 | --model_name_or_path ${MODEL_PATH} \
11 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
12 | --dataset_cache_dir hf_cache/dataset/ \
13 | --use_flash_attention_2 False \
14 | --max_length 2048 \
15 | --batch_size 4 \
16 | --with_description True \
17 | --dataset_name msmarco \
18 |
19 | torchrun --nproc_per_node 8 eval_rerank.py \
20 | --eval_data ${EVAL_DATA_PATH}/touche.bm25.100.jsonl \
21 | --output_dir ${RESULT_PATH} \
22 | --model_name_or_path ${MODEL_PATH} \
23 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
24 | --dataset_cache_dir hf_cache/dataset/ \
25 | --use_flash_attention_2 False \
26 | --max_length 2048 \
27 | --batch_size 4 \
28 | --with_description True \
29 | --dataset_name touche
30 |
31 | torchrun --nproc_per_node 8 eval_rerank.py \
32 | --eval_data ${EVAL_DATA_PATH}/arguana.bm25.100.jsonl \
33 | --output_dir ${RESULT_PATH} \
34 | --model_name_or_path ${MODEL_PATH} \
35 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
36 | --dataset_cache_dir hf_cache/dataset/ \
37 | --use_flash_attention_2 False \
38 | --max_length 2048 \
39 | --batch_size 4 \
40 | --with_description True \
41 | --dataset_name arguana
42 |
43 | torchrun --nproc_per_node 8 eval_rerank.py \
44 | --eval_data ${EVAL_DATA_PATH}/trec_covid.bm25.100.jsonl \
45 | --output_dir ${RESULT_PATH} \
46 | --model_name_or_path ${MODEL_PATH} \
47 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
48 | --dataset_cache_dir hf_cache/dataset/ \
49 | --use_flash_attention_2 False \
50 | --max_length 2048 \
51 | --batch_size 4 \
52 | --with_description True \
53 | --dataset_name trec_covid
54 |
55 | torchrun --nproc_per_node 8 eval_rerank.py \
56 | --eval_data ${EVAL_DATA_PATH}/nfcorpus.bm25.100.jsonl \
57 | --output_dir ${RESULT_PATH} \
58 | --model_name_or_path ${MODEL_PATH} \
59 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
60 | --dataset_cache_dir hf_cache/dataset/ \
61 | --use_flash_attention_2 False \
62 | --max_length 2048 \
63 | --batch_size 4 \
64 | --with_description True \
65 | --dataset_name nfcorpus
66 |
67 | torchrun --nproc_per_node 8 eval_rerank.py \
68 | --eval_data ${EVAL_DATA_PATH}/scidocs.bm25.100.jsonl \
69 | --output_dir ${RESULT_PATH} \
70 | --model_name_or_path ${MODEL_PATH} \
71 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
72 | --dataset_cache_dir hf_cache/dataset/ \
73 | --use_flash_attention_2 False \
74 | --max_length 2048 \
75 | --batch_size 4 \
76 | --with_description True \
77 | --dataset_name scidocs
78 |
79 | torchrun --nproc_per_node 8 eval_rerank.py \
80 | --eval_data ${EVAL_DATA_PATH}/quora.bm25.100.jsonl \
81 | --output_dir ${RESULT_PATH} \
82 | --model_name_or_path ${MODEL_PATH} \
83 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
84 | --dataset_cache_dir hf_cache/dataset/ \
85 | --use_flash_attention_2 False \
86 | --max_length 2048 \
87 | --batch_size 4 \
88 | --with_description True \
89 | --dataset_name quora
90 |
91 | torchrun --nproc_per_node 8 eval_rerank.py \
92 | --eval_data ${EVAL_DATA_PATH}/dbpedia.bm25.100.jsonl \
93 | --output_dir ${RESULT_PATH} \
94 | --model_name_or_path ${MODEL_PATH} \
95 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
96 | --dataset_cache_dir hf_cache/dataset/ \
97 | --use_flash_attention_2 False \
98 | --max_length 2048 \
99 | --batch_size 4 \
100 | --with_description True\
101 | --rerank_method listwise \
102 | --listwise_window 5 \
103 | --listwise_stride 5 \
104 | --dataset_name dbpedia
105 |
106 | torchrun --nproc_per_node 8 eval_rerank.py \
107 | --eval_data ${EVAL_DATA_PATH}/fever.bm25.100.jsonl \
108 | --output_dir ${RESULT_PATH} \
109 | --model_name_or_path ${MODEL_PATH} \
110 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
111 | --dataset_cache_dir hf_cache/dataset/ \
112 | --use_flash_attention_2 False \
113 | --max_length 2048 \
114 | --batch_size 4 \
115 | --with_description True \
116 | --dataset_name fever
117 |
118 | torchrun --nproc_per_node 8 eval_rerank.py \
119 | --eval_data ${EVAL_DATA_PATH}/climate_fever.bm25.100.jsonl \
120 | --output_dir ${RESULT_PATH} \
121 | --model_name_or_path ${MODEL_PATH} \
122 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
123 | --dataset_cache_dir hf_cache/dataset/ \
124 | --use_flash_attention_2 False \
125 | --max_length 2048 \
126 | --batch_size 4 \
127 | --with_description True \
128 | --dataset_name climate_fever
129 |
130 | torchrun --nproc_per_node 8 eval_rerank.py \
131 | --eval_data ${EVAL_DATA_PATH}/scifact.bm25.100.jsonl \
132 | --output_dir ${RESULT_PATH} \
133 | --model_name_or_path ${MODEL_PATH} \
134 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
135 | --dataset_cache_dir hf_cache/dataset/ \
136 | --use_flash_attention_2 False \
137 | --max_length 2048 \
138 | --batch_size 4 \
139 | --with_description True \
140 | --dataset_name scifact
141 |
142 | torchrun --nproc_per_node 8 eval_rerank.py \
143 | --eval_data ${EVAL_DATA_PATH}/nq.bm25.100.jsonl \
144 | --output_dir ${RESULT_PATH} \
145 | --model_name_or_path ${MODEL_PATH} \
146 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
147 | --dataset_cache_dir hf_cache/dataset/ \
148 | --use_flash_attention_2 False \
149 | --max_length 2048 \
150 | --batch_size 4 \
151 | --with_description True \
152 | --dataset_name nq
153 |
154 | torchrun --nproc_per_node 8 eval_rerank.py \
155 | --eval_data ${EVAL_DATA_PATH}/fiqa.bm25.100.jsonl \
156 | --output_dir ${RESULT_PATH} \
157 | --model_name_or_path ${MODEL_PATH} \
158 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
159 | --dataset_cache_dir hf_cache/dataset/ \
160 | --use_flash_attention_2 False \
161 | --max_length 2048 \
162 | --batch_size 4 \
163 | --with_description True \
164 | --dataset_name fiqa
165 |
166 | torchrun --nproc_per_node 8 eval_rerank.py \
167 | --eval_data ${EVAL_DATA_PATH}/hotpot_qa.bm25.100.jsonl \
168 | --output_dir ${RESULT_PATH} \
169 | --model_name_or_path ${MODEL_PATH} \
170 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
171 | --dataset_cache_dir hf_cache/dataset/ \
172 | --use_flash_attention_2 False \
173 | --max_length 2048 \
174 | --batch_size 4 \
175 | --with_description True \
176 | --dataset_name hotpot_qa
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling import get_model_and_tokenizer, RerankResult
2 | from .args import ModelArgs
3 | from .utils import FileLogger, DatasetProcessFn, DefaultDataCollator, makedirs, split_file_dir_name_ext, mask_nested_lists
4 | from .metrics import Metric
5 |
6 | import logging
7 | logging.basicConfig(
8 | level=logging.INFO,
9 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
10 | datefmt="%m/%d/%Y %H:%M:%S",
11 | )
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/src/args.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from typing import Optional, List, Union
4 |
5 |
6 | @dataclass
7 | class ModelArgs:
8 | model_cache_dir: Optional[str] = field(
9 | # default=None,
10 | default="/share/LMs",
11 | metadata={'help': 'Default path to save language models.'}
12 | )
13 | dataset_cache_dir: Optional[str] = field(
14 | # default=None,
15 | default="/share/peitian/Data/Datasets/huggingface",
16 | metadata={'help': 'Default path to save huggingface datasets.'}
17 | )
18 | eval_data: Optional[str] = field(
19 | default=None,
20 | metadata={'help': 'Evaluation json file.'},
21 | )
22 |
23 | model_name_or_path: str = field(
24 | default='meta-llama/Llama-2-7b-chat-hf',
25 | metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
26 | )
27 | tokenizer_name_or_path: str = field(
28 | default='meta-llama/Llama-2-7b-chat-hf',
29 | metadata={'help': 'Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models'}
30 | )
31 | padding_side: str = field(
32 | default="left",
33 | metadata={'help': 'Tokenizer padding side.'}
34 | )
35 | access_token: Optional[str] = field(
36 | default=None,
37 | metadata={'help': 'Huggingface access token.'}
38 | )
39 | max_length: int = field(
40 | default=2048,
41 | metadata={'help': 'How many tokens at maximum for each input?'},
42 | )
43 |
44 | lora: Optional[str] = field(
45 | default=None,
46 | metadata={'help': 'LoRA ID.'},
47 | )
48 |
49 | dtype: str = field(
50 | default="bf16",
51 | metadata={'help': 'Data type for embeddings.'}
52 | )
53 | device_map: Optional[str] = field(
54 | default=None,
55 | metadata={'help': 'Device map for loading the model. Set to auto to load across devices.'}
56 | )
57 | use_flash_attention_2: bool = field(
58 | default=True,
59 | metadata={'help': 'Use flash attention?'}
60 | )
61 | cpu: bool = field(
62 | default=False,
63 | metadata={'help': 'Use cpu?'}
64 | )
65 |
66 | metrics: List[str] = field(
67 | default_factory=lambda: [],
68 | metadata={'help': 'List of metrics. {rouge, acc}'}
69 | )
70 |
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/src/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | import inspect
5 | import numpy as np
6 | from tqdm import tqdm
7 | from typing import List, Dict, Tuple
8 | from .modeling import RerankResult
9 | from .utils import makedirs, split_file_dir_name_ext, normalize_text
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class Metric:
15 | """Class for computing metrics and some post-processings."""
16 | @classmethod
17 | def get_metric_fn(cls, metric_names, **kwds):
18 | assert isinstance(metric_names, list) or isinstance(metric_names, tuple), "You must pass metric_names in a list or tuple!"
19 | all_metrics = {}
20 | # get all methods
21 | all_implemented_fns = [x[0] for x in inspect.getmembers(cls, predicate=inspect.isfunction) if not x[0].startswith("_")]
22 |
23 | def compute_metrics(*args, **kwargs):
24 | for metric_name in metric_names:
25 | # call corresponding method
26 | if metric_name in all_implemented_fns:
27 | metric_fn = getattr(cls, metric_name)
28 | metric = metric_fn(**kwds)(*args, **kwargs)
29 | # NOTE: some metric_fn are only used for post-processing and saving results, which return None by default
30 | if metric is not None:
31 | all_metrics.update(metric)
32 | else:
33 | raise NotImplementedError(f"Metric {metric_name} not implemented!")
34 | return all_metrics
35 | return compute_metrics
36 |
37 | @staticmethod
38 | def _get_save_path(eval_data, output_dir=None, field="result", save_name=None):
39 | """
40 | if output_dir is None:
41 | -> {eval_data_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext}
42 | else:
43 | -> {output_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext}
44 | """
45 | eval_data_dir, eval_data_name, eval_data_ext = split_file_dir_name_ext(eval_data)
46 | if output_dir is None:
47 | output_dir = eval_data_dir
48 | fields = [eval_data_name, field]
49 | if save_name is not None:
50 | fields.append(save_name)
51 | save_path = os.path.join(output_dir, ".".join(fields) + eval_data_ext)
52 | makedirs(save_path)
53 | return save_path
54 |
55 | @staticmethod
56 | def _save_generation_result(result, path, eval_data=None):
57 | if eval_data is not None:
58 | items = {}
59 | with open(eval_data, encoding="utf-8") as f:
60 | for i, line in enumerate(f):
61 | item = json.loads(line)
62 | if "query_id" in eval_data:
63 | index = item["query_id"]
64 | else:
65 | index = i
66 | items[index] = item
67 | with open(path, "w") as f:
68 | for index, pred in result.items():
69 | res = {"query_id": index}
70 | if eval_data is not None:
71 | item = items[index]
72 | res.update(item)
73 | res["pred"] = pred
74 | f.write(json.dumps(res, ensure_ascii=False) + "\n")
75 |
76 | @staticmethod
77 | def _save_rerank_result(results: Dict[int, List[RerankResult]], path, eval_data=None):
78 | if eval_data is not None:
79 | items = {}
80 | with open(eval_data, encoding="utf-8") as f:
81 | for i, line in enumerate(f):
82 | full_item = json.loads(line)
83 | # we only save the index and score of positive samples
84 | item = {
85 | "pos_index": full_item["pos_index"],
86 | "pos_score": full_item["pos_score"],
87 | }
88 | if "query_id" in eval_data:
89 | index = full_item["query_id"]
90 | else:
91 | index = i
92 | items[index] = item
93 |
94 | with open(path, "w") as f:
95 | for index, preds in results.items():
96 | item = {
97 | "query_id": index,
98 | "doc_index": [x.doc_id for x in preds],
99 | "doc_score": [x.doc_score for x in preds],
100 | }
101 | if eval_data is not None:
102 | data_item = items[index]
103 | item.update(data_item)
104 | f.write(json.dumps(item, ensure_ascii=False) + "\n")
105 |
106 | @staticmethod
107 | def _prepare_label_for_generation(eval_data):
108 | labels = {}
109 | with open(eval_data) as f:
110 | for i, line in enumerate(f):
111 | item = json.loads(line)
112 | if "query_id" in item:
113 | index = item["query_id"]
114 | else:
115 | index = i
116 | label = item["completion"]
117 | labels[index] = label
118 | return labels
119 |
120 | @staticmethod
121 | def _prepare_label_for_retrieval(eval_data):
122 | labels = {}
123 | with open(eval_data) as f:
124 | for i, line in enumerate(f):
125 | item = json.loads(line)
126 | if "query_id" in item:
127 | query_id = item["query_id"]
128 | else:
129 | query_id = i
130 | # save positive indices and their scores (for computing ndcg)
131 | labels[query_id] = (item["pos_index"], item["pos_score"])
132 | return labels
133 |
134 | @staticmethod
135 | def _prepare_pred_for_retrieval(preds:List[RerankResult]) -> List[RerankResult]:
136 | valid_preds = [x for x in preds if x.doc_id > -1]
137 | return valid_preds
138 |
139 | @staticmethod
140 | def mrr(eval_data=None, cutoffs=[10], **kwds):
141 | if eval_data is not None:
142 | data_labels = Metric._prepare_label_for_retrieval(eval_data)
143 |
144 | def compute_metric(results:Dict[int, List[RerankResult]], labels:[Dict[int, Tuple[List[int], List[int]]]]=None, **kwargs):
145 | if labels is None:
146 | labels = data_labels
147 |
148 | mrrs = np.zeros(len(cutoffs))
149 | counts = 0
150 |
151 | for query_id, preds in results.items():
152 | pos_indices, pos_scores = labels[query_id]
153 | # remove irrelevant documents
154 | pos_indices = [pos_index for i, pos_index in enumerate(pos_indices) if pos_scores[i] > 0]
155 | if len(pos_indices) == 0:
156 | continue
157 |
158 | preds = Metric._prepare_pred_for_retrieval(preds)
159 | jump = False
160 | counts += 1
161 |
162 | for i, pred in enumerate(preds, 1):
163 | if pred.doc_id in pos_indices:
164 | for k, cutoff in enumerate(cutoffs):
165 | if i <= cutoff:
166 | mrrs[k] += 1 / i
167 | jump = True
168 | if jump:
169 | break
170 |
171 | mrrs /= counts
172 |
173 | metric = {}
174 | for i, cutoff in enumerate(cutoffs):
175 | mrr = mrrs[i]
176 | metric[f"mrr@{cutoff}"] = mrr
177 |
178 | return metric
179 | return compute_metric
180 |
181 | @staticmethod
182 | def recall(eval_data=None, cutoffs=[10], **kwds):
183 | if eval_data is not None:
184 | data_labels = Metric._prepare_label_for_retrieval(eval_data)
185 |
186 | def compute_metric(results:Dict[int, List[RerankResult]], labels:[Dict[int, Tuple[List[int], List[int]]]]=None, **kwargs):
187 | if labels is None:
188 | labels = data_labels
189 |
190 | recalls = np.zeros(len(cutoffs))
191 | counts = 0
192 |
193 | for query_id, preds in results.items():
194 | pos_indices, pos_scores = labels[query_id]
195 | # remove irrelevant documents
196 | pos_indices = [pos_index for i, pos_index in enumerate(pos_indices) if pos_scores[i] > 0]
197 | if len(pos_indices) == 0:
198 | continue
199 |
200 | preds = Metric._prepare_pred_for_retrieval(preds)
201 | preds_indices = [x.doc_id for x in preds]
202 | counts += 1
203 |
204 | for k, cutoff in enumerate(cutoffs):
205 | recall = np.intersect1d(pos_indices, preds_indices[:cutoff])
206 | recalls[k] += len(recall) / len(pos_indices)
207 |
208 | recalls /= counts
209 |
210 | metric = {}
211 | for i, cutoff in enumerate(cutoffs):
212 | recall = recalls[i]
213 | metric[f"recall@{cutoff}"] = recall
214 |
215 | return metric
216 | return compute_metric
217 |
218 | @staticmethod
219 | def ndcg(eval_data=None, cutoffs=[10], **kwds):
220 | if eval_data is not None:
221 | data_labels = Metric._prepare_label_for_retrieval(eval_data)
222 |
223 | def compute_metric(results:Dict[int, List[RerankResult]], labels:[Dict[int, Tuple[List[int], List[int]]]]=None, **kwargs):
224 | if labels is None:
225 | labels = data_labels
226 |
227 | ndcgs = np.zeros(len(cutoffs))
228 | counts = 0
229 |
230 | for query_id, preds in results.items():
231 | pos_indices, pos_scores = labels[query_id]
232 | preds = Metric._prepare_pred_for_retrieval(preds)
233 |
234 | pos_indices_to_scores = {k: v for k, v in zip(pos_indices, pos_scores)}
235 | if len(pos_indices_to_scores) == 0:
236 | continue
237 |
238 | dcg = np.zeros(len(cutoffs))
239 | idcg = np.zeros(len(cutoffs))
240 | counts += 1
241 |
242 | for i, pred in enumerate(preds, 1):
243 | if pred.doc_id in pos_indices:
244 | for k, cutoff in enumerate(cutoffs):
245 | if i <= cutoff:
246 | # get the relevance score of the pred
247 | dcg[k] += (2 ** pos_indices_to_scores[pred.doc_id] - 1) / np.log2(i + 1)
248 |
249 | # descendingly sort positives to acquire the ideal ranking
250 | ideal_ranking = sorted(pos_scores, reverse=True)
251 | for j, y in enumerate(ideal_ranking, 1):
252 | for k, cutoff in enumerate(cutoffs):
253 | if j <= cutoff:
254 | idcg[k] += (2 ** y - 1) / np.log2(j + 1)
255 |
256 | ndcgs += dcg / idcg
257 |
258 | ndcgs /= counts
259 |
260 | metric = {}
261 | for i, cutoff in enumerate(cutoffs):
262 | ndcg = ndcgs[i]
263 | metric[f"ndcg@{cutoff}"] = ndcg
264 | return metric
265 | return compute_metric
266 |
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/src/modeling.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import copy
4 | import time
5 | import random
6 | import numpy as np
7 | import torch.nn as nn
8 | from transformers import (
9 | AutoTokenizer,
10 | AutoConfig,
11 | AutoModelForCausalLM,
12 | AutoModelForSeq2SeqLM,
13 | logging
14 | )
15 | from torch.utils.data import DataLoader
16 | from accelerate import Accelerator
17 | from typing import Mapping, Tuple, List, Optional
18 | from tqdm import tqdm
19 | from collections import defaultdict
20 | from dataclasses import dataclass, asdict
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 |
25 | @dataclass
26 | class RerankResult:
27 | doc_id: int
28 | doc_score: int
29 |
30 |
31 | class LM(nn.Module):
32 | def __init__(self, model_name_or_path, tokenizer_name_or_path, padding_side="left", dtype="bf16", device_map=None, use_flash_attention_2=False, access_token=None, cache_dir="/share/LMs", accelerator: Accelerator=None) -> None:
33 | super().__init__()
34 |
35 | logger.info(f"loading tokenizer from {tokenizer_name_or_path}...")
36 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, cache_dir=cache_dir, padding_side=padding_side, trust_remote_code=True)
37 | if tokenizer.pad_token is None:
38 | if tokenizer.eos_token is None:
39 | pad_token = "<|endoftext|>"
40 | else:
41 | pad_token = tokenizer.eos_token
42 | tokenizer.pad_token = pad_token
43 |
44 | if dtype == "bf16":
45 | dtype = torch.bfloat16
46 | elif dtype == "fp16":
47 | dtype = torch.float16
48 | else:
49 | dtype = torch.float32
50 |
51 | if device_map is None:
52 | if accelerator is not None:
53 | device_map = {"": accelerator.device}
54 | else:
55 | device_map = {"": "cpu"}
56 |
57 | logger.info(f"loading model from {model_name_or_path}...")
58 | config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
59 | if config.is_encoder_decoder:
60 | model = AutoModelForSeq2SeqLM.from_pretrained(
61 | model_name_or_path,
62 | cache_dir=cache_dir,
63 | torch_dtype=dtype,
64 | trust_remote_code=True,
65 | device_map=device_map,
66 | token=access_token,
67 | )
68 | else:
69 | model = AutoModelForCausalLM.from_pretrained(
70 | model_name_or_path,
71 | cache_dir=cache_dir,
72 | torch_dtype=dtype,
73 | trust_remote_code=True,
74 | device_map=device_map,
75 | use_flash_attention_2=use_flash_attention_2,
76 | token=access_token,
77 | )
78 |
79 | self.config = model.config
80 | self.tokenizer = tokenizer
81 |
82 | if accelerator is not None:
83 | self.model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True)
84 | else:
85 | self.model = model
86 |
87 | self.rng = np.random.default_rng(42)
88 | self.eval()
89 |
90 | @property
91 | def device(self):
92 | return self.model.device
93 |
94 | def prepare_inputs_for_generation(self, *args, **kwargs):
95 | return self.model.prepare_inputs_for_generation(*args, **kwargs)
96 |
97 | def _reorder_cache(self, *args, **kwargs):
98 | return self.model._reorder_cache(*args, **kwargs)
99 |
100 | def _move_to_device(self, data):
101 | """
102 | Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
103 | """
104 | if isinstance(data, Mapping):
105 | return type(data)({k: self._move_to_device(v) for k, v in data.items()})
106 | elif isinstance(data, (tuple, list)):
107 | return type(data)(self._move_to_device(v) for v in data)
108 | elif isinstance(data, torch.Tensor):
109 | kwargs = {"device": self.device}
110 | return data.to(**kwargs)
111 | else:
112 | return data
113 |
114 | def forward(self, *args, **kwargs):
115 | return self.model(*args, **kwargs)
116 |
117 | @torch.no_grad()
118 | def generate(self, return_new_tokens_only=True, decode=True, accelerator:Optional[Accelerator]=None, **inputs):
119 | outputs = self.model.generate(**inputs)
120 |
121 | if return_new_tokens_only:
122 | if self.model.config.is_encoder_decoder:
123 | if "decoder_input_ids" in inputs:
124 | start_idx = inputs["decoder_input_ids"].shape[1] + 1
125 | else:
126 | start_idx = 1
127 | else:
128 | start_idx = inputs["input_ids"].shape[1]
129 | outputs = outputs[:, start_idx:]
130 |
131 | if accelerator is not None:
132 | # must be contiguous
133 | outputs = outputs.contiguous()
134 | outputs = accelerator.pad_across_processes(outputs, pad_index=self.tokenizer.pad_token_id, dim=1)
135 | outputs = accelerator.gather_for_metrics(outputs)
136 |
137 | outputs = outputs.tolist()
138 | if decode:
139 | outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
140 | return outputs
141 |
142 | @torch.no_grad()
143 | def rerank_pointwise(self, dataloader, accelerator):
144 | """
145 | Args:
146 | dataloader: return a batch of input_ids and attention_mask, each of which is a query-doc pair.
147 | """
148 | if accelerator is not None and type(dataloader) is DataLoader:
149 | dataloader = accelerator.prepare(dataloader)
150 |
151 | is_encoder_decoder = self.model.config.is_encoder_decoder
152 | yes_id = self.tokenizer.encode("Yes", add_special_tokens=False)[0]
153 | no_id = self.tokenizer.encode("No", add_special_tokens=False)[0]
154 |
155 | rerank_results = defaultdict(list)
156 | for i, x in enumerate(tqdm(dataloader, desc="Pointwise Reranking", ncols=120)):
157 | query_ids = x.pop("query_id")
158 | doc_ids = x.pop("doc_id")
159 |
160 | if is_encoder_decoder:
161 | raise NotImplementedError
162 |
163 | else:
164 | logits = self.model(**x).logits[:, -1] # batch_size, vocab_size
165 |
166 | yes_and_no_logits = logits[:, [yes_id, no_id]] # batch_size, 2
167 | # NOTE: normalize so that different documents are comparable
168 | yes_and_no_logits = torch.softmax(yes_and_no_logits, dim=1)
169 | doc_scores = yes_and_no_logits[:, 0] # batch_size
170 |
171 | # gather outputs across devices
172 | if accelerator is not None:
173 | query_ids = accelerator.gather_for_metrics(query_ids)
174 | doc_ids = accelerator.gather_for_metrics(doc_ids)
175 | doc_scores = accelerator.gather_for_metrics(doc_scores)
176 |
177 | for query_id, doc_id, doc_score in zip(query_ids.tolist(), doc_ids.tolist(), doc_scores.tolist()):
178 | rerank_result = RerankResult(doc_id=doc_id, doc_score=doc_score)
179 | rerank_results[query_id].append(rerank_result)
180 |
181 | # sort candidates of each query
182 | for query_id, res in rerank_results.items():
183 | sorted_res = sorted(res, key=lambda x: x.doc_score, reverse=True)
184 | rerank_results[query_id] = sorted_res
185 |
186 | return dict(rerank_results)
187 |
188 | def compare(self, query: str, docs: List, prompt_template: str, fewshot_prompt: Optional[str]=None):
189 | doc1, doc2 = docs[0], docs[1]
190 | input_texts = [prompt_template.format(query=query, doc1=doc1, doc2=doc2) + "[1]", prompt_template.format(query=query, doc1=doc1, doc2=doc2) + "[2]",
191 | prompt_template.format(query=query, doc1=doc2, doc2=doc1) + "[1]", prompt_template.format(query=query, doc1=doc2, doc2=doc1) + "[2]"]
192 |
193 | # NOTE: add fewshot prompt
194 | if fewshot_prompt is not None:
195 | input_texts = [fewshot_prompt + x for x in input_texts]
196 |
197 | inputs = self.tokenizer(input_texts, return_tensors="pt")
198 | #print("input: ", inputs.device())
199 |
200 | target_texts = ["[1]", "[2]", "[1]", "[2]"]
201 | targets = self.tokenizer(target_texts, add_special_tokens=False)["input_ids"]
202 | targets_length = [len(x) for x in targets]
203 |
204 | labels = inputs["input_ids"].clone()
205 | for i, label in enumerate(labels):
206 | labels[i, :-targets_length[i]] = -100
207 | # inputs["labels"] = labels
208 | inputs = inputs.to(self.device)
209 |
210 | outputs = self.model(**inputs)
211 | logits = outputs["logits"].detach()
212 | logits = logits[:, :-1, :].contiguous()
213 | labels = labels[:, 1:].contiguous()
214 | labels = labels.to(logits.device)
215 | batch_size = logits.shape[0]
216 | token_loss = torch.nn.functional.cross_entropy(
217 | logits.flatten(0, 1),
218 | labels.reshape(-1),
219 | reduction="none"
220 | ).reshape(batch_size, -1)
221 |
222 | valid_token_num = (labels != -100).sum(-1)
223 |
224 | token_prob = - token_loss.sum(-1) / valid_token_num
225 |
226 | if token_prob[0] > token_prob[1] and token_prob[2] < token_prob[3]:
227 | return f'change'
228 | return f'not change'
229 |
230 | @torch.no_grad()
231 | def rerank_pairwise(self, dataloader, prompt_template, accelerator):
232 | rerank_results = defaultdict(list)
233 | if accelerator is not None and type(dataloader) is DataLoader:
234 | dataloader = accelerator.prepare(dataloader)
235 |
236 | for _, ranking_result in enumerate(tqdm(dataloader, desc="Pairwise Reranking", ncols=120)):
237 | k = ranking_result["doc_ids"].size(dim=1)
238 | last_end = k - 1
239 | query = ranking_result["query"]
240 | batch_size = ranking_result["doc_ids"].size(dim=0)
241 |
242 | ranking_result["doc_ids"] = ranking_result["doc_ids"].tolist()
243 |
244 | fewshot_prompts = ranking_result["fewshot_prompt"]
245 |
246 | for i in range(batch_size):
247 | # NOTE: convert doc_ids to list
248 | pairs = list(zip(ranking_result["docs"][i], ranking_result["doc_ids"][i]))
249 | self.rng.shuffle(pairs)
250 |
251 | shuffled_docs, shuffled_doc_ids = zip(*pairs)
252 | shuffled_docs = list(shuffled_docs)
253 | shuffled_doc_ids = list(shuffled_doc_ids)
254 | ranking_result["docs"][i] = shuffled_docs
255 | ranking_result["doc_ids"][i] = shuffled_doc_ids
256 |
257 | if fewshot_prompts is not None:
258 | fewshot_prompt = fewshot_prompts[i]
259 | else:
260 | fewshot_prompt = None
261 |
262 | for j in range(k):
263 | current_ind = last_end
264 | is_change = False
265 | while True:
266 | if current_ind <= j:
267 | break
268 | doc1 = ranking_result["docs"][i][current_ind]
269 | doc2 = ranking_result["docs"][i][current_ind - 1]
270 | output = self.compare(query[i], [doc1, doc2], prompt_template=prompt_template, fewshot_prompt=fewshot_prompt)
271 | if output == 'change':
272 | ranking_result["docs"][i][current_ind - 1], ranking_result["docs"][i][current_ind] = ranking_result["docs"][i][current_ind], ranking_result["docs"][i][current_ind - 1]
273 | ranking_result["doc_ids"][i][current_ind - 1], ranking_result["doc_ids"][i][current_ind] = ranking_result["doc_ids"][i][current_ind], ranking_result["doc_ids"][i][current_ind - 1]
274 | if not is_change:
275 | is_change = True
276 | if last_end != k - 1: # skip unchanged pairs at the bottom
277 | last_end += 1
278 | if not is_change:
279 | last_end -= 1
280 | current_ind -= 1
281 | query_ids = ranking_result.pop("query_id")
282 | doc_ids = torch.tensor(ranking_result.pop("doc_ids"), device=self.device)
283 | if accelerator is not None:
284 | query_ids = accelerator.gather_for_metrics(query_ids)
285 | doc_ids = accelerator.gather_for_metrics(doc_ids)
286 | for query_id, doc_id in zip(query_ids.tolist(), doc_ids.tolist()):
287 | for doc_id_i in doc_id:
288 | ranking_result = RerankResult(doc_id=doc_id_i, doc_score=0)
289 | rerank_results[query_id].append(ranking_result)
290 | return dict(rerank_results)
291 |
292 | def permutation_pipeline(self, item=None, rank_start=0, rank_end=100, prompt_template=None, window_size=5):
293 | f_query = item["query"]
294 | num = len(item['hits'][rank_start: rank_end])
295 | fewshot_prompt = item["fewshot_prompt"]
296 |
297 | rank = 0
298 | docs = ""
299 | for hit in item['hits'][rank_start: rank_end]:
300 | rank += 1
301 | if isinstance(hit['document'], str):
302 | document = hit['document'].strip()
303 | else:
304 | print(item['hits'])
305 | raise ValueError("document should be a string")
306 | docs += f"[{rank}] " + document + "\n"
307 | messages = prompt_template.format(query=f_query, num=num, docs=docs)
308 | # messages = prompt_template.format(query=query, num=rank, docs=docs)
309 | if fewshot_prompt is not None:
310 | messages = fewshot_prompt + messages
311 |
312 | input_ids = self.tokenizer(messages, return_tensors="pt", padding='longest', truncation=False).input_ids.to(self.model.device)
313 | output_ids = self.model.generate(input_ids,
314 | do_sample=False,
315 | temperature=0.0,
316 | top_p=None,
317 | max_new_tokens=500,
318 | pad_token_id=self.tokenizer.eos_token_id,)
319 | permutation = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=False)
320 |
321 | response = [int(match.group(1)) -1 for match in re.finditer(r"(\d+)", permutation)]
322 | response = [x for x in response if 0<= x < window_size]
323 | if len(response) == 0:
324 | return item
325 |
326 | new_response = []
327 | for x in response:
328 | if x not in new_response:
329 | new_response.append(x)
330 | response = new_response
331 |
332 | if len(response) < window_size:
333 | full_set = set(range(0, window_size))
334 | missing_number = full_set - set(response)
335 | response.extend(list(missing_number))
336 |
337 | candidates = copy.deepcopy(item['hits'][rank_start: rank_end])
338 | rerank_candidates = [candidates[x] for x in response]
339 | item['hits'][rank_start: rank_end] = rerank_candidates
340 |
341 | return item
342 |
343 | def sliding_windows(self, item=None, prompt_template=None, rank_start=0, rank_end=100, window_size=5, step=3):
344 | item = copy.deepcopy(item)
345 | end_pos = rank_end
346 | start_pos = rank_end - window_size
347 | while start_pos >= rank_start:
348 | start_pos = max(start_pos, rank_start)
349 | item = self.permutation_pipeline(item, start_pos, end_pos, prompt_template, window_size)
350 | end_pos = end_pos - step
351 | start_pos = start_pos - step
352 | return item
353 |
354 | @torch.no_grad()
355 | def rerank_listwise(self, dataloader, prompt_template, accelerator, window, stride):
356 | rerank_results = defaultdict(list)
357 | if accelerator is not None and type(dataloader) is DataLoader:
358 | dataloader = accelerator.prepare(dataloader)
359 |
360 | for _, ranking_data in enumerate(tqdm(dataloader, desc="Listwise Reranking", ncols=120)):
361 | batch_size = ranking_data["doc_ids"].size(dim=0)
362 | items = [{} for _ in range(batch_size)]
363 | doc_start = [0] * batch_size
364 | doc_end = []
365 | for i in range(batch_size):
366 | items[i]["query"] = ranking_data["query"][i]
367 | items[i]["query_id"] = ranking_data["query_id"][i].item()
368 | items[i]["fewshot_prompt"] = ranking_data["fewshot_prompt"][i]
369 | items[i]["hits"] = [{"document": doc, "docid": docid.item()}
370 | for doc, docid in zip(ranking_data["docs"][i], ranking_data["doc_ids"][i])]
371 | self.rng.shuffle(items[i]["hits"])
372 | doc_end.append(len(items[i]["hits"]))
373 | prompt_templates = [prompt_template] * batch_size
374 | windows = [window] * batch_size
375 | strides = [stride] * batch_size
376 | items = list(map(self.sliding_windows, items, prompt_templates, doc_start, doc_end, windows, strides))
377 | query_ids = ranking_data.pop("query_id")
378 | doc_ids = torch.tensor([[hit["docid"] for hit in item["hits"]] for item in items], device=self.device)
379 |
380 | if accelerator is not None:
381 | query_ids = accelerator.gather_for_metrics(query_ids)
382 | doc_ids = accelerator.gather_for_metrics(doc_ids)
383 | for query_id, doc_id in zip(query_ids.tolist(), doc_ids.tolist()):
384 | for doc_id_i in doc_id:
385 | ranking_result = RerankResult(doc_id=doc_id_i, doc_score=0)
386 | rerank_results[query_id].append(ranking_result)
387 |
388 | return dict(rerank_results)
389 |
390 |
391 | def get_model_and_tokenizer(model_args, accelerator=None, **kwargs):
392 | """Load model and tokenizer. Possibly load LoRA for the model."""
393 |
394 | from .args import ModelArgs
395 | model_args: ModelArgs
396 |
397 | model_args = asdict(model_args)
398 | model_args.update(**kwargs)
399 |
400 | model = LM(
401 | model_name_or_path=model_args["model_name_or_path"],
402 | tokenizer_name_or_path=model_args["tokenizer_name_or_path"],
403 | padding_side=model_args["padding_side"],
404 | dtype=model_args["dtype"],
405 | cache_dir=model_args["model_cache_dir"],
406 | device_map=model_args["device_map"],
407 | use_flash_attention_2=model_args["use_flash_attention_2"],
408 | access_token=model_args["access_token"],
409 | accelerator=accelerator
410 | )
411 |
412 | # load lora
413 | if model_args["lora"] is not None:
414 | from peft import PeftModel
415 | logger.info(f"loading lora from {model_args['lora']}...")
416 | model = PeftModel.from_pretrained(model, model_args["lora"])
417 | model = model.merge_and_unload()
418 |
419 | return model, model.tokenizer
420 |
--------------------------------------------------------------------------------
/evaluation/qdu-tasks/src/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pytz
3 | import torch
4 | import pathlib
5 | import json
6 | import string
7 | import numpy as np
8 | from datetime import datetime
9 | from dataclasses import dataclass
10 | from collections import defaultdict
11 | from typing import List, Any, Dict
12 |
13 |
14 | def makedirs(path):
15 | p = pathlib.Path(path)
16 | p.parent.mkdir(parents=True, exist_ok=True)
17 | return path
18 |
19 |
20 | def split_file_dir_name_ext(path):
21 | """Return the directory, name, and extension of a given file."""
22 | p = pathlib.Path(path)
23 | assert p.is_file()
24 | return p.parent, p.stem, p.suffix
25 |
26 |
27 | def get_max_length_in_nested_lists(lst):
28 | if len(lst) and isinstance(lst[0], list):
29 | lengths = []
30 | for elem in lst:
31 | length = get_max_length_in_nested_lists(elem)
32 | lengths.append(length)
33 | max_length = max(lengths)
34 | return max_length
35 | else:
36 | return len(lst)
37 |
38 |
39 | def pad_nested_lists(lst, max_length, padding_value, padding_side="right"):
40 | if isinstance(lst, list) and len(lst) and isinstance(lst[0], list):
41 | masks = []
42 | for i, elem in enumerate(lst):
43 | lst[i], mask = pad_nested_lists(elem, max_length, padding_value, padding_side)
44 | masks.append(mask)
45 | return lst, masks
46 | elif isinstance(lst, list):
47 | if padding_side == "right":
48 | mask = [1] * len(lst) + [0] * (max_length - len(lst))
49 | lst = lst + [padding_value for _ in range(max_length - len(lst))]
50 | return lst, mask
51 | else:
52 | mask = [0] * (max_length - len(lst)) + [1] * len(lst)
53 | lst = [padding_value for _ in range(max_length - len(lst))] + lst
54 | return lst, mask
55 | else:
56 | raise NotImplementedError(f"Unrecognized type {lst}")
57 |
58 | def mask_nested_lists(lst, mask_target, mask_value=0):
59 | if isinstance(lst[0], list):
60 | for i, elem in enumerate(lst):
61 | lst[i] = mask_nested_lists(elem, mask_target, mask_value)
62 | return lst
63 | else:
64 | return [x if x != mask_target else mask_value for x in lst]
65 |
66 | def are_elements_of_same_length(lst: List):
67 | if not isinstance(lst[0], list):
68 | return False
69 |
70 | length = len(lst[0])
71 | return all(len(x) == length if isinstance(x, list) else False for x in lst)
72 |
73 |
74 | def normalize_text(text, ignore_case=True, ignore_punctuation=True, ignore_space=True, ignore_number=False):
75 | if isinstance(text, str):
76 | text = [text]
77 | unpack = True
78 | else:
79 | unpack = False
80 | if ignore_case:
81 | text = np.char.lower(text)
82 | if ignore_punctuation:
83 | repl_table = string.punctuation.maketrans("", "", string.punctuation)
84 | text = np.char.translate(text, table=repl_table)
85 | if ignore_number:
86 | repl_table = string.digits.maketrans("", "", string.digits)
87 | text = np.char.translate(text, table=repl_table)
88 | if ignore_space:
89 | for i, words in enumerate(np.char.split(text)):
90 | text[i] = " ".join(words)
91 | if isinstance(text, np.ndarray):
92 | text = text.tolist()
93 | if unpack:
94 | text = text[0]
95 | return text
96 |
97 |
98 | class DatasetProcessFn:
99 | """Wrapper for any user-defined process function for huggingface datasets.
100 |
101 | 1. Process batched examples by looping the process function over them;
102 | 2. Gather returned examples if any data augmentation happens with augment=True;
103 | 3. Pass indices of examples inside the process function with _index keywords if they exist.
104 |
105 | The wrapped function should take in any needed columns and return a dict with 1 or more samples.
106 | """
107 | def __init__(self, augment=False):
108 | self.augment = augment
109 |
110 | def __call__(self, _process_fn):
111 | def process(*args):
112 | sample_or_batch_sample = args[0]
113 | if len(args) == 1:
114 | pass
115 | elif len(args) == 2:
116 | indices = args[1]
117 | # detach the slice so that _index will not be set in the original data
118 | sample_or_batch_sample = sample_or_batch_sample.copy()
119 | sample_or_batch_sample["_index"] = indices
120 | else:
121 | raise NotImplementedError(f"Found more than 2 arguments {args}!")
122 |
123 | keys = list(sample_or_batch_sample.keys())
124 | func_args = [sample_or_batch_sample[k] for k in keys]
125 |
126 | # FIXME: if all values in one sample are of the same length, this would fail
127 | if are_elements_of_same_length(func_args):
128 | outputs = defaultdict(list)
129 | for arg in zip(*func_args):
130 | # get each element in a batch
131 | kwargs = {keys[j]: arg[j] for j in range(len(arg))}
132 | output = _process_fn(**kwargs)
133 | if output is not None:
134 | for k, v in output.items():
135 | if self.augment:
136 | outputs[k].extend(v)
137 | else:
138 | outputs[k].append(v)
139 | else:
140 | outputs = _process_fn(**sample_or_batch_sample)
141 | if outputs is None:
142 | raise ValueError(f"Found None returned from process_fn. Make sure you set 'batched=True' when trying to augment/distract samples in the datasets!")
143 | return dict(outputs)
144 | return process
145 |
146 |
147 | @dataclass
148 | class DefaultDataCollator:
149 | """
150 | Data collator that can:
151 | 1. Dynamically pad all inputs received. The inputs must be dict of lists.
152 | 2. Add position_ids based on attention_mask if required.
153 | """
154 | tokenizer: Any = None
155 | attention_padding_value: int = 0
156 | label_padding_value: int = -100
157 | add_position_ids: bool = False
158 |
159 | def __call__(self, batch_elem: List) -> Dict[str, Any]:
160 | first_elem = batch_elem[0]
161 | return_batch = {}
162 |
163 | for key, value in first_elem.items():
164 | # HACK: any key containing attention_mask must be attention_mask
165 | # important to assign different pad token for different types of inputs
166 | if "attention_mask" in key:
167 | pad_token_id = self.attention_padding_value
168 | elif "label" in key:
169 | pad_token_id = self.label_padding_value
170 | else:
171 | pad_token_id = self.tokenizer.pad_token_id
172 |
173 | batch_value = [elem[key] for elem in batch_elem]
174 | # pad all lists and nested lists
175 | if isinstance(value, list):
176 | max_length = get_max_length_in_nested_lists(batch_value)
177 | batch_value, _ = pad_nested_lists(batch_value, max_length, pad_token_id, self.tokenizer.padding_side)
178 |
179 | try:
180 | return_batch[key] = torch.tensor(batch_value)
181 | except:
182 | # handle strings and None
183 | return_batch[key] = batch_value
184 |
185 | if "attention_mask" in key and self.add_position_ids:
186 | value = return_batch[key]
187 | position_ids = value.cumsum(-1) - 1
188 | position_ids = position_ids.masked_fill(value == 0, 0)
189 | return_batch[key.replace("attention_mask", "position_ids")] = position_ids
190 | return return_batch
191 |
192 |
193 | class FileLogger:
194 | def __init__(self, log_file) -> None:
195 | self.log_file = log_file
196 |
197 | def log(self, metrics, **kwargs):
198 | with open(self.log_file, "a+") as f:
199 | # get current time
200 | tz = pytz.timezone('Asia/Shanghai')
201 | time = f"{'Time': <10}: {json.dumps(datetime.now(tz).strftime('%Y-%m-%d, %H:%M:%S'), ensure_ascii=False)}\n"
202 | print(time)
203 | command = f"{'Command': <10}: {json.dumps(' '.join(sys.argv), ensure_ascii=False)}\n"
204 | print(command)
205 | metrics = f"{'Metrics': <10}: {json.dumps(metrics, ensure_ascii=False)}\n"
206 | msg = time + command
207 |
208 | for key, value in kwargs.items():
209 | x = f"{key: <10}: {json.dumps(value, ensure_ascii=False)}\n"
210 | print(x)
211 | msg += x
212 | msg += metrics
213 | print(metrics)
214 | f.write(str(msg) + "\n")
215 |
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/eval_sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Sampler
3 | import torch.distributed as dist
4 |
5 |
6 | class DistributedEvalSampler(Sampler):
7 | r"""
8 | DistributedEvalSampler is different from DistributedSampler.
9 | It does NOT add extra samples to make it evenly divisible.
10 | DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever.
11 | See this issue for details: https://github.com/pytorch/pytorch/issues/22584
12 | shuffle is disabled by default
13 | DistributedEvalSampler is for evaluation purpose where synchronization does not happen every epoch.
14 | Synchronization should be done outside the dataloader loop.
15 | Sampler that restricts data loading to a subset of the dataset.
16 | It is especially useful in conjunction with
17 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
18 | process can pass a :class`~torch.utils.data.DistributedSampler` instance as a
19 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
20 | original dataset that is exclusive to it.
21 | .. note::
22 | Dataset is assumed to be of constant size.
23 | Arguments:
24 | dataset: Dataset used for sampling.
25 | num_replicas (int, optional): Number of processes participating in
26 | distributed training. By default, :attr:`rank` is retrieved from the
27 | current distributed group.
28 | rank (int, optional): Rank of the current process within :attr:`num_replicas`.
29 | By default, :attr:`rank` is retrieved from the current distributed
30 | group.
31 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
32 | indices.
33 | seed (int, optional): random seed used to shuffle the sampler if
34 | :attr:`shuffle=True`. This number should be identical across all
35 | processes in the distributed group. Default: ``0``.
36 | .. warning::
37 | In distributed mode, calling the :meth`set_epoch(epoch) ` method at
38 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator
39 | is necessary to make shuffling work properly across multiple epochs. Otherwise,
40 | the same ordering will be always used.
41 | Example::
42 | >>> sampler = DistributedSampler(dataset) if is_distributed else None
43 | >>> loader = DataLoader(dataset, shuffle=(sampler is None),
44 | ... sampler=sampler)
45 | >>> for epoch in range(start_epoch, n_epochs):
46 | ... if is_distributed:
47 | ... sampler.set_epoch(epoch)
48 | ... train(loader)
49 | """
50 |
51 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, seed=0):
52 | if num_replicas is None:
53 | if not dist.is_available():
54 | raise RuntimeError("Requires distributed package to be available")
55 | num_replicas = dist.get_world_size()
56 | if rank is None:
57 | if not dist.is_available():
58 | raise RuntimeError("Requires distributed package to be available")
59 | rank = dist.get_rank()
60 | self.dataset = dataset
61 | self.num_replicas = num_replicas
62 | self.rank = rank
63 | self.epoch = 0
64 | # self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
65 | # self.total_size = self.num_samples * self.num_replicas
66 | self.total_size = len(self.dataset) # true value without extra samples
67 | indices = list(range(self.total_size))
68 | indices = indices[self.rank:self.total_size:self.num_replicas]
69 | self.num_samples = len(indices) # true value without extra samples
70 |
71 | self.shuffle = shuffle
72 | self.seed = seed
73 |
74 | def __iter__(self):
75 | if self.shuffle:
76 | # deterministically shuffle based on epoch and seed
77 | g = torch.Generator()
78 | g.manual_seed(self.seed + self.epoch)
79 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
80 | else:
81 | indices = list(range(len(self.dataset)))
82 |
83 |
84 | # # add extra samples to make it evenly divisible
85 | # indices += indices[:(self.total_size - len(indices))]
86 | # assert len(indices) == self.total_size
87 |
88 | # subsample
89 | indices = indices[self.rank:self.total_size:self.num_replicas]
90 | assert len(indices) == self.num_samples
91 |
92 | return iter(indices)
93 |
94 | def __len__(self):
95 | return self.num_samples
96 |
97 | def set_epoch(self, epoch):
98 | r"""
99 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
100 | use a different random ordering for each epoch. Otherwise, the next iteration of this
101 | sampler will yield the same ordering.
102 | Arguments:
103 | epoch (int): _epoch number.
104 | """
105 | self.epoch = epoch
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_dataset.py:
--------------------------------------------------------------------------------
1 | import linecache
2 | import json
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class InferenceDataset(Dataset):
8 | def __init__(self, filename, tokenizer, max_input_length):
9 | super(InferenceDataset, self).__init__()
10 | self._filename = filename
11 | self._tokenizer = tokenizer
12 | self._max_input_length = max_input_length
13 | with open(filename, "r", encoding="utf-8") as f:
14 | self._total_data = len(f.readlines())
15 |
16 | def __getitem__(self, idx):
17 | line = linecache.getline(self._filename, idx + 1)
18 | sample = json.loads(line)
19 |
20 | source = sample["prompt"]
21 | source_encode = self._tokenizer(source, padding="max_length", max_length=self._max_input_length, truncation=True)
22 |
23 | batch = {
24 | "input_ids": np.asarray(source_encode.input_ids),
25 | "attention_mask": np.asarray(source_encode.attention_mask),
26 | "input": sample["prompt"],
27 | "label": sample["completion"]
28 | }
29 |
30 | return batch
31 |
32 | def __len__(self):
33 | return self._total_data
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_qu_du.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | import numpy as np
5 | import torch
6 | import random
7 | import torch.multiprocessing as mp
8 | import torch.distributed as dist
9 | from transformers import AutoModelForCausalLM, AutoTokenizer
10 | from eval_sampling import DistributedEvalSampler
11 | from inference_dataset import InferenceDataset
12 | from torch.utils.data import DataLoader
13 | from inference_tasks.query_description import QueryDescription
14 | from inference_tasks.query_expansion import QueryExpansion
15 | from inference_tasks.query_suggestion import QuerySuggestion
16 | from inference_tasks.query_reformulation import QueryReformulation
17 | from inference_tasks.query_clarification import QueryClarification
18 | from inference_tasks.query_matching import QueryMatching
19 | from inference_tasks.summarization import Summarization
20 | from inference_tasks.fact_verification import FactVerification
21 | from inference_tasks.query_intent_classification import QueryIntentClassification
22 | from inference_tasks.reading_comprehension import ReadingComprehension
23 | from inference_tasks.query_subtopic_generation import QuerySubtopicGeneration
24 | from inference_tasks.conversational_qa import ConversationalQA
25 | import torch.nn.functional as F
26 | from datetime import datetime
27 | from tqdm import tqdm
28 | from tqdm.contrib import tzip
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--model_name_or_path", default="model/", type=str)
32 | parser.add_argument("--tokenizer_name", default="model/llama-2-7b-hf", type=str)
33 | parser.add_argument("--setting", default="in-domain", type=str)
34 | parser.add_argument("--n_shots", default="zero_shot", type=str)
35 | parser.add_argument("--max_input_len", default=1792, type=int)
36 | parser.add_argument("--max_output_len", default=256, type=int)
37 | parser.add_argument("--seed", default=0, type=int)
38 | parser.add_argument('--nodes', type=int, default=1)
39 | parser.add_argument('--gpus', type=int,default=-1, help='num gpus per node')
40 | parser.add_argument('--nr', type=int,default=0, help='ranking within the nodes')
41 | args = parser.parse_args()
42 |
43 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, legacy=True)
44 | tokenizer.pad_token = tokenizer.eos_token
45 | tokenizer.padding_side = "left"
46 | tokenizer.truncation_side = "left"
47 |
48 | def set_seed(seed=args.seed):
49 | random.seed(seed)
50 | os.environ['PYTHONHASHSEED'] = str(seed)
51 | np.random.seed(seed)
52 | torch.manual_seed(seed)
53 | torch.cuda.manual_seed(seed)
54 | torch.cuda.manual_seed_all(seed)
55 | torch.backends.cudnn.deterministic = True
56 | torch.backends.cudnn.benchmark = True
57 |
58 | def test_model(local_gpu_rank, args, tasks, label_lists, save_result=False):
59 | set_seed(args.seed)
60 | args.rank = args.nr * args.gpus + local_gpu_rank
61 | torch.cuda.set_device(local_gpu_rank)
62 | dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.rank)
63 |
64 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, use_cache=True, torch_dtype=torch.bfloat16)
65 | args.device = torch.device("cuda", local_gpu_rank)
66 | model.to(args.device)
67 | model.eval()
68 |
69 | current_time = datetime.now()
70 | time_string = current_time.strftime("%Y-%m-%d-%H-%M-%S")
71 |
72 | if args.rank == 0:
73 | task_result = []
74 | fw = open(f"result/{time_string}.csv", "w", encoding="utf-8")
75 | fw.write(f"Model,{args.model_name_or_path}\n")
76 | for task, label_list in tzip(tasks, label_lists, ncols=120, desc="# Tasks"):
77 | test_data = task.get_path()
78 |
79 | test_dataset = InferenceDataset(test_data, tokenizer, args.max_input_len)
80 | test_sampler = DistributedEvalSampler(test_dataset, shuffle=False)
81 | test_dataloader = DataLoader(test_dataset, batch_size=args.test_batch_size, num_workers=2, sampler=test_sampler)
82 |
83 | all_decode_result = []
84 | all_labels = []
85 | count = 0
86 | with torch.no_grad():
87 | if label_list == []:
88 | if args.rank == 0:
89 | test_dataloader = tqdm(test_dataloader, ncols=120, leave=False)
90 | for test_data in test_dataloader:
91 | outputs = model.generate(
92 | input_ids=test_data["input_ids"].to(args.device),
93 | attention_mask=test_data["attention_mask"].to(args.device),
94 | max_length=args.max_input_len + args.max_output_len,
95 | do_sample=False,
96 | pad_token_id=tokenizer.eos_token_id,
97 | )
98 | if outputs.size(1) < args.max_input_len + args.max_output_len:
99 | batch_pred_padding = torch.ones((outputs.size(0), args.max_input_len + args.max_output_len - outputs.size(1)), dtype=outputs.dtype).cuda() * 2
100 | outputs = torch.cat([outputs, batch_pred_padding], dim=1)
101 |
102 | batch_out_sentences = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
103 | batch_output = []
104 | for idx, output in enumerate(batch_out_sentences):
105 | output = output[len(test_data["input"][idx]):]
106 | batch_output.append(output.strip())
107 | all_decode_result.extend(batch_output)
108 | all_labels.extend(test_data["label"])
109 | if args.rank == 0:
110 | count += len(batch_output) * args.world_size
111 | else:
112 | if args.rank == 0:
113 | test_dataloader = tqdm(test_dataloader, ncols=120, leave=False)
114 | for test_data in test_dataloader:
115 | input_text = test_data["input"] # List[string] len=batch_size
116 | gold_label = test_data["label"] # List[string] len=batch_size
117 |
118 | new_reqs = []
119 | for i in input_text:
120 | for l in label_list:
121 | context_enc = tokenizer.encode(i)
122 | continuation_enc = tokenizer.encode(i + l)[len(context_enc):]
123 | key = (i, l)
124 | value = (context_enc, continuation_enc)
125 | new_reqs.append((key, value))
126 |
127 | input_texts = [x[0][0] + x[0][1] for x in new_reqs]
128 | inputs = tokenizer(
129 | input_texts,
130 | return_tensors='pt',
131 | padding="longest",
132 | max_length=args.max_input_len + args.max_output_len,
133 | truncation=True,
134 | return_attention_mask=True,
135 | )
136 |
137 | logits = model(inputs['input_ids'].to(model.device), attention_mask=inputs['attention_mask'].to(model.device))
138 | logits = F.log_softmax(logits[0], dim=-1).cpu()
139 |
140 | all_result = []
141 | for one_req, one_logits in zip(new_reqs, logits):
142 | key, value = one_req
143 | _, cont_toks = value
144 |
145 | # Slice to original seq length
146 | contlen = len(cont_toks)
147 | one_logits = one_logits[-contlen-1 : -1].unsqueeze(0) # [1, seq, vocab]
148 |
149 | # Check if per-token argmax is exactly equal to continuation
150 | greedy_tokens = one_logits.argmax(dim=-1)
151 | cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq]
152 | if len(greedy_tokens[0]) != len(cont_toks[0]):
153 | # 超出最大长度限制
154 | answer = -float('inf')
155 | else:
156 | # Obtain log-probs at the corresponding continuation token indices
157 | # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
158 | one_logits = torch.gather(one_logits, dim=2, index=cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
159 |
160 | # Answer: (log prob, is-exact-match)
161 | answer = float(one_logits.sum())
162 | all_result.append(answer)
163 | all_result = np.asarray(all_result)
164 | all_result = all_result.reshape(-1, len(label_list)) # bsz, 2
165 | gold_label_list = [x.split(", ") for x in gold_label]
166 | gold_idx = [[label_list.index(x) for x in y] for y in gold_label_list]
167 | all_result = np.argmax(all_result, axis=1)
168 |
169 | all_decode_result.extend(all_result)
170 | all_labels.extend(gold_idx)
171 | if args.rank == 0:
172 | count += len(all_result) * args.world_size
173 |
174 | dist.barrier()
175 | gather_data = [None for _ in range(args.world_size)]
176 | gather_label = [None for _ in range(args.world_size)]
177 | dist.all_gather_object(gather_data, all_decode_result)
178 | dist.all_gather_object(gather_label, all_labels)
179 | if args.rank == 0:
180 | preds = []
181 | labels = []
182 | for j in range(len(gather_data[0])):
183 | for i in range(len(gather_data)):
184 | if j < len(gather_data[i]):
185 | prediction = gather_data[i][j]
186 | label = gather_label[i][j]
187 | preds.append(prediction)
188 | labels.append(label)
189 | if save_result:
190 | current_time = datetime.now()
191 | save_time_string = current_time.strftime("%Y-%m-%d-%H-%M-%S")
192 | with open(f"result/{save_time_string}.{task._cluster}_{task._name}.result.txt", "w", encoding="utf-8") as fw:
193 | for p, l in zip(preds, labels):
194 | fw.write(json.dumps({"pred": p, "label": l}) + "\n")
195 | result = task.compute_metrics(preds, labels)
196 | tqdm.write(f"{task._cluster}_{task._name}:\t{json.dumps(result)}")
197 | task_result.append((f"{task._cluster}_{task._name}", result))
198 | for k in result.keys():
199 | fw.write(f"{task._cluster}_{task._name}, {k}, {result[k]}\n")
200 | fw.flush()
201 | if args.rank == 0:
202 | fw.close()
203 |
204 | if __name__ == '__main__':
205 | if args.gpus < 0:
206 | args.gpus = torch.cuda.device_count()
207 | args.world_size = args.nodes * args.gpus
208 | os.environ['MASTER_ADDR']='localhost'
209 | os.environ['MASTER_PORT']='8889'
210 | args.test_batch_size = 2
211 | tasks = []
212 | label_lists = []
213 |
214 | # query description
215 | task = QueryDescription(name="gov2", shot=args.n_shots, setting=args.setting)
216 | tasks.append(task)
217 | label_lists.append([])
218 | task = QueryDescription(name="trec_robust", shot=args.n_shots, setting=args.setting)
219 | tasks.append(task)
220 | label_lists.append([])
221 | task = QueryDescription(name="trec_covid", shot=args.n_shots, setting=args.setting)
222 | tasks.append(task)
223 | label_lists.append([])
224 | task = QueryDescription(name="fire", shot=args.n_shots, setting=args.setting)
225 | tasks.append(task)
226 | label_lists.append([])
227 |
228 | # query expansion
229 | task = QueryExpansion(name="gov2", shot=args.n_shots, setting=args.setting)
230 | tasks.append(task)
231 | label_lists.append([])
232 | task = QueryExpansion(name="trec_robust", shot=args.n_shots, setting=args.setting)
233 | tasks.append(task)
234 | label_lists.append([])
235 | task = QueryExpansion(name="trec_covid", shot=args.n_shots, setting=args.setting)
236 | tasks.append(task)
237 | label_lists.append([])
238 | task = QueryExpansion(name="fire", shot=args.n_shots, setting=args.setting)
239 | tasks.append(task)
240 | label_lists.append([])
241 | task = QueryExpansion(name="query2doc", shot=args.n_shots, setting=args.setting)
242 | tasks.append(task)
243 | label_lists.append([])
244 | task = QueryExpansion(name="trec_cast", shot=args.n_shots, setting=args.setting)
245 | tasks.append(task)
246 | label_lists.append([])
247 | task = QueryExpansion(name="trec_web", shot=args.n_shots, setting=args.setting)
248 | tasks.append(task)
249 | label_lists.append([])
250 |
251 | # query reformulation
252 | task = QueryReformulation(name="codec", shot=args.n_shots, setting=args.setting)
253 | tasks.append(task)
254 | label_lists.append([])
255 | task = QueryReformulation(name="qrecc", shot=args.n_shots, setting=args.setting)
256 | tasks.append(task)
257 | label_lists.append([])
258 | task = QueryReformulation(name="canard", shot=args.n_shots, setting=args.setting)
259 | tasks.append(task)
260 | label_lists.append([])
261 | task = QueryReformulation(name="trec_cast", shot=args.n_shots, setting=args.setting)
262 | tasks.append(task)
263 | label_lists.append([])
264 | task = QueryReformulation(name="gecor", shot=args.n_shots, setting=args.setting)
265 | tasks.append(task)
266 | label_lists.append([])
267 |
268 | # query clarification
269 | task = QueryClarification(name="mimics", shot=args.n_shots, setting=args.setting)
270 | tasks.append(task)
271 | label_lists.append([])
272 | task = QueryClarification(name="mimics_duo", shot=args.n_shots, setting=args.setting)
273 | tasks.append(task)
274 | label_lists.append([])
275 | task = QueryClarification(name="clariq_fkw", shot=args.n_shots, setting=args.setting)
276 | tasks.append(task)
277 | label_lists.append([])
278 | task = QueryClarification(name="raocq", shot=args.n_shots, setting=args.setting)
279 | tasks.append(task)
280 | label_lists.append([])
281 |
282 | # query subtopic generation
283 | task = QuerySubtopicGeneration(name="trec_web", shot=args.n_shots, setting=args.setting)
284 | tasks.append(task)
285 | label_lists.append([])
286 |
287 | # query suggestion
288 | task = QuerySuggestion(name="aol", shot=args.n_shots, setting=args.setting)
289 | tasks.append(task)
290 | label_lists.append([])
291 |
292 | # query matching
293 | task = QueryMatching(name="msrp", shot=args.n_shots, setting=args.setting)
294 | tasks.append(task)
295 | label_lists.append(["yes", "no"])
296 |
297 | # query intent classification
298 | task = QueryIntentClassification(name="mantis", shot=args.n_shots, setting=args.setting, multi_label=True)
299 | tasks.append(task)
300 | label_lists.append(["original question", "further details", "other", "information request", "potential answer", "positive feedback", "negative feedback", "greetings / gratitude", "follow up question"])
301 |
302 | task = QueryIntentClassification(name="orcas_i", shot=args.n_shots, setting=args.setting)
303 | tasks.append(task)
304 | label_lists.append(["factual", "abstain", "instrumental", "transactional", "navigational"])
305 | task = QueryIntentClassification(name="trec_web", shot=args.n_shots, setting=args.setting)
306 | tasks.append(task)
307 | label_lists.append(["faceted", "ambiguous", "navigational", "informational"])
308 |
309 | # fact verification
310 | task = FactVerification(name="fever", shot=args.n_shots, setting=args.setting)
311 | tasks.append(task)
312 | label_lists.append(["support", "refute"])
313 | task = FactVerification(name="climate_fever", shot=args.n_shots, setting=args.setting)
314 | tasks.append(task)
315 | label_lists.append(["support", "refute", "disputed", "not enough information"])
316 | task = FactVerification(name="scifact", shot=args.n_shots, setting=args.setting)
317 | tasks.append(task)
318 | label_lists.append(["support", "refute"])
319 |
320 | # conversational qa
321 | task = ConversationalQA(name="coqa", shot=args.n_shots, setting=args.setting)
322 | tasks.append(task)
323 | label_lists.append([])
324 | task = ConversationalQA(name="quac", shot=args.n_shots, setting=args.setting)
325 | tasks.append(task)
326 | label_lists.append([])
327 |
328 | # summarization
329 | task = Summarization(name="cnndm", shot=args.n_shots, setting=args.setting)
330 | tasks.append(task)
331 | label_lists.append([])
332 | task = Summarization(name="xsum", shot=args.n_shots, setting=args.setting)
333 | tasks.append(task)
334 | label_lists.append([])
335 | task = Summarization(name="wikisum", shot=args.n_shots, setting=args.setting)
336 | tasks.append(task)
337 | label_lists.append([])
338 | task = Summarization(name="multinews", shot=args.n_shots, setting=args.setting)
339 | tasks.append(task)
340 | label_lists.append([])
341 |
342 | # reading comprehension
343 | task = ReadingComprehension(name="squad", shot=args.n_shots, setting=args.setting)
344 | tasks.append(task)
345 | label_lists.append([])
346 | task = ReadingComprehension(name="hotpot_qa", shot=args.n_shots, setting=args.setting)
347 | tasks.append(task)
348 | label_lists.append([])
349 | task = ReadingComprehension(name="ms_marco", shot=args.n_shots, setting=args.setting)
350 | tasks.append(task)
351 | label_lists.append([])
352 | task = ReadingComprehension(name="boolq", shot=args.n_shots, setting=args.setting)
353 | tasks.append(task)
354 | label_lists.append(["true", "false"])
355 | task = ReadingComprehension(name="webglm_qa", shot=args.n_shots, setting=args.setting)
356 | tasks.append(task)
357 | label_lists.append([])
358 | task = ReadingComprehension(name="trivia_qa", shot=args.n_shots, setting=args.setting)
359 | tasks.append(task)
360 | label_lists.append([])
361 |
362 | assert len(tasks) == len(label_lists)
363 | mp.spawn(test_model, nprocs=args.gpus, args=(args, tasks, label_lists, False))
364 |
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/conversational_qa.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 |
4 | class ConversationalQA:
5 | def __init__(self, name, shot, setting):
6 | self._cluster = "conversational_qa"
7 | self._name = name
8 | self._shot = shot
9 | self._setting = setting
10 |
11 | def get_path(self):
12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
13 |
14 | def compute_metrics(self, preds, labels):
15 | def normalize_results(text):
16 | pattern = r"(\d+\.|\[\d+\]|\(\d+\))\s*(.*?)\s*(?=\d\.|\[\d\]|\(\d\)|$)"
17 | extracted_text = re.findall(pattern, text)
18 | return extracted_text
19 |
20 | all_acc = []
21 | for prediction, label in zip(preds, labels):
22 | extracted_answer = normalize_results(label)
23 | extracted_prediction = normalize_results(prediction)
24 | if extracted_answer == []:
25 | extracted_answer = label
26 | answer_dict = {0: extracted_answer}
27 | else:
28 | answer_dict = {match[0]: match[1] for match in extracted_answer}
29 | if extracted_prediction == []:
30 | extracted_prediction = prediction
31 | extracted_dict = {0: extracted_prediction}
32 | else:
33 | extracted_dict = {match[0]: match[1] for match in extracted_prediction}
34 |
35 | acc = 0
36 | for k in extracted_dict.keys():
37 | if k not in answer_dict:
38 | continue
39 | if extracted_dict[k] == answer_dict[k]:
40 | acc += 1
41 | acc = acc / len(answer_dict.keys())
42 | all_acc.append(acc)
43 | all_acc = np.asarray(all_acc)
44 | results = {
45 | "Acc": np.mean(all_acc)
46 | }
47 | return results
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/fact_verification.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import f1_score, accuracy_score
3 |
4 | class FactVerification():
5 | def __init__(self, name, shot, setting):
6 | self._cluster = "fact_verification"
7 | self._name = name
8 | self._shot = shot
9 | self._setting = setting
10 |
11 | def get_path(self):
12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
13 |
14 | def compute_metrics(self, preds, labels):
15 | labels = np.asarray([x[0] for x in labels])
16 | acc = accuracy_score(labels, preds)
17 | f1 = f1_score(labels, preds, average='weighted')
18 |
19 | results = {
20 | "Acc": acc,
21 | "F1": f1
22 | }
23 |
24 | return results
25 |
26 |
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/query_clarification.py:
--------------------------------------------------------------------------------
1 | from rouge_score import rouge_scorer
2 | from nltk.translate import bleu_score
3 | from nltk.translate.bleu_score import corpus_bleu
4 | import numpy as np
5 | import re
6 |
7 | class QueryClarification:
8 | def __init__(self, name, shot, setting):
9 | self._cluster = "query_clarification"
10 | self._name = name
11 | self._shot = shot
12 | self._setting = setting
13 |
14 | def get_path(self):
15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
16 |
17 | def compute_metrics(self, preds, labels):
18 | def normalize_results(text):
19 | pattern = r"(?:\d\.|\[\d\]|\(\d\))\s*(.*?)\s*(?=\d\.|\[\d\]|\(\d\)|$)"
20 | extracted_text = re.findall(pattern, text)
21 | return extracted_text
22 |
23 | if self._name == "clariq_fkw" or self._name == "raocq":
24 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
25 | all_rouge_l = []
26 | split_labels = [[x.split()] for x in labels]
27 | split_preds = [x.split() for x in preds]
28 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
29 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
30 | for prediction, label in zip(preds, labels):
31 | scores = result_rouge_scorer.score(label, prediction)
32 | rougel = scores['rougeL'].fmeasure
33 | all_rouge_l.append(rougel)
34 | avg_rouge_l = np.mean(np.asarray(all_rouge_l))
35 | results = {
36 | "BLEU-1": bleu1,
37 | "BLEU-2": bleu2,
38 | "ROUGE-L": avg_rouge_l,
39 | }
40 | return results
41 | else:
42 | all_p, all_r, all_f1 = [], [], []
43 | for prediction, label in zip(preds, labels):
44 | label = label.split(" <=SEP=> ")
45 | extracted_prediction = normalize_results(prediction)
46 | if extracted_prediction == []:
47 | extracted_prediction = prediction
48 | TP = len(set(extracted_prediction).intersection(label))
49 | FP = len(set(extracted_prediction) - set(label))
50 | FN = len(set(label) - set(extracted_prediction))
51 | precision = TP / (TP + FP) if (TP + FP) != 0 else 0
52 | recall = TP / (TP + FN) if (TP + FN) != 0 else 0
53 | f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
54 | all_p.append(precision)
55 | all_r.append(recall)
56 | all_f1.append(f1)
57 | all_p = np.asarray(all_p)
58 | all_r = np.asarray(all_r)
59 | all_f1 = np.asarray(all_f1)
60 | results = {
61 | "precision": np.sum(all_p) / len(all_p),
62 | "recall": np.sum(all_r) / len(all_r),
63 | "f1": np.sum(all_f1) / len(all_f1),
64 | }
65 | return results
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/query_description.py:
--------------------------------------------------------------------------------
1 | from rouge_score import rouge_scorer
2 | import numpy as np
3 | from nltk.translate.bleu_score import corpus_bleu
4 | from nltk.translate import bleu_score
5 |
6 | class QueryDescription():
7 | def __init__(self, name, shot, setting):
8 | self._cluster = "query_description"
9 | self._name = name
10 | self._shot = shot
11 | self._setting = setting
12 |
13 |
14 | def get_path(self):
15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
16 |
17 | def compute_metrics(self, preds, labels):
18 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
19 | all_rouge_l = []
20 | split_labels = [[x.split()] for x in labels]
21 | split_preds = [x.split() for x in preds]
22 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
23 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
24 | for prediction, label in zip(preds, labels):
25 | scores = result_rouge_scorer.score(label, prediction)
26 | rougel = scores['rougeL'].fmeasure
27 | all_rouge_l.append(rougel)
28 | avg_rouge_l = np.mean(np.asarray(all_rouge_l))
29 | results = {
30 | "BLEU-1": bleu1,
31 | "BLEU-2": bleu2,
32 | "ROUGE-L": avg_rouge_l,
33 | }
34 | return results
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/query_expansion.py:
--------------------------------------------------------------------------------
1 | from rouge_score import rouge_scorer
2 | import numpy as np
3 | from nltk.translate.bleu_score import corpus_bleu
4 | from nltk.translate import bleu_score
5 |
6 | class QueryExpansion():
7 | def __init__(self, name, shot, setting):
8 | self._cluster = "query_expansion"
9 | self._name = name
10 | self._shot = shot
11 | self._setting = setting
12 |
13 |
14 | def get_path(self):
15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
16 |
17 | def compute_metrics(self, preds, labels):
18 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
19 | all_rouge_l = []
20 | split_labels = [[x.split()] for x in labels]
21 | split_preds = [x.split() for x in preds]
22 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
23 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
24 | for prediction, label in zip(preds, labels):
25 | scores = result_rouge_scorer.score(label, prediction)
26 | rougel = scores['rougeL'].fmeasure
27 | all_rouge_l.append(rougel)
28 | avg_rouge_l = np.mean(np.asarray(all_rouge_l))
29 | results = {
30 | "BLEU-1": bleu1,
31 | "BLEU-2": bleu2,
32 | "ROUGE-L": avg_rouge_l,
33 | }
34 | return results
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/query_intent_classification.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import f1_score, accuracy_score
2 |
3 | class QueryIntentClassification():
4 | def __init__(self, name, shot, setting, multi_label=False):
5 | self._cluster = "query_intent_classification"
6 | self._name = name
7 | self._shot = shot
8 | self._setting = setting
9 | self._multi_label = multi_label
10 |
11 | def get_path(self):
12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
13 |
14 | def compute_metrics(self, preds, labels):
15 | # accuracy = np.sum(np.asarray(preds) == np.asarray(labels)) / len(preds)
16 | if self._multi_label:
17 | p = 0
18 | for pred, label in zip(preds, labels):
19 | if pred in label:
20 | p += 1
21 | results = {
22 | "P@1": p / len(preds)
23 | }
24 | else:
25 | acc = accuracy_score(labels, preds)
26 | f1 = f1_score(labels, preds, average='weighted')
27 | results = {
28 | "Acc": acc,
29 | "F1": f1
30 | }
31 |
32 | return results
33 |
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/query_matching.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import f1_score, accuracy_score
3 |
4 | class QueryMatching():
5 | def __init__(self, name, shot, setting):
6 | self._cluster = "query_matching"
7 | self._name = name
8 | self._shot = shot
9 | self._setting = setting
10 |
11 | def get_path(self):
12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
13 |
14 | def compute_metrics(self, preds, labels):
15 | # accuracy = np.sum(np.asarray(preds) == np.asarray(labels)) / len(preds)
16 | labels = np.asarray([x[0] for x in labels])
17 | acc = accuracy_score(labels, preds)
18 | f1 = f1_score(labels, preds, average='weighted')
19 |
20 | results = {
21 | "Acc": acc,
22 | "F1": f1
23 | }
24 |
25 | return results
26 |
27 |
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/query_reformulation.py:
--------------------------------------------------------------------------------
1 | from rouge_score import rouge_scorer
2 | import numpy as np
3 | from nltk.translate.bleu_score import corpus_bleu
4 | from nltk.translate import bleu_score
5 | import re
6 |
7 | class QueryReformulation():
8 | def __init__(self, name, shot, setting):
9 | self._cluster = "query_reformulation"
10 | self._name = name
11 | self._shot = shot
12 | self._setting = setting
13 |
14 | def get_path(self):
15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
16 |
17 | def compute_metrics(self, preds, labels):
18 | def normalize_results(text):
19 | pattern = r"(?:\d\.|\[\d\]|\(\d\))\s*(.*?)\s*(?=\d\.|\[\d\]|\(\d\)|$)"
20 | extracted_text = re.findall(pattern, text)
21 | return extracted_text
22 |
23 | if self._name == "codec":
24 | all_p, all_r, all_f1 = [], [], []
25 | for prediction, label in zip(preds, labels):
26 | label = normalize_results(label)
27 | extracted_prediction = normalize_results(prediction)
28 | if extracted_prediction == []:
29 | extracted_prediction = prediction
30 | TP = len(set(extracted_prediction).intersection(label))
31 | FP = len(set(extracted_prediction) - set(label))
32 | FN = len(set(label) - set(extracted_prediction))
33 | precision = TP / (TP + FP) if (TP + FP) != 0 else 0
34 | recall = TP / (TP + FN) if (TP + FN) != 0 else 0
35 | f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
36 | all_p.append(precision)
37 | all_r.append(recall)
38 | all_f1.append(f1)
39 | all_p = np.asarray(all_p)
40 | all_r = np.asarray(all_r)
41 | all_f1 = np.asarray(all_f1)
42 | results = {
43 | "precision": np.mean(all_p),
44 | "recall": np.mean(all_r),
45 | "f1": np.mean(all_f1),
46 | }
47 | else:
48 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
49 | all_rouge_l = []
50 |
51 | split_labels = [[x.split()] for x in labels]
52 | split_preds = [x.split() for x in preds]
53 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
54 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
55 |
56 | for prediction, label in zip(preds, labels):
57 | scores = result_rouge_scorer.score(label, prediction)
58 | rougel = scores['rougeL'].fmeasure
59 | all_rouge_l.append(rougel)
60 | avg_rouge_l = np.mean(np.asarray(all_rouge_l))
61 | results = {
62 | "BLEU-1": bleu1,
63 | "BLEU-2": bleu2,
64 | "ROUGE-L": avg_rouge_l,
65 | }
66 | return results
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/query_subtopic_generation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 |
4 | class QuerySubtopicGeneration():
5 | def __init__(self, name, shot, setting):
6 | self._cluster = "query_subtopic_generation"
7 | self._name = name
8 | self._shot = shot
9 | self._setting = setting
10 |
11 | def get_path(self):
12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
13 |
14 | def compute_metrics(self, preds, labels):
15 | def normalize_results(text):
16 | pattern = r"(?:\d\.|\[\d\]|\(\d\))\s*(.*?)\s*(?=\d\.|\[\d\]|\(\d\)|$)"
17 | extracted_text = re.findall(pattern, text)
18 | return extracted_text
19 |
20 | all_p, all_r, all_f1 = [], [], []
21 | for prediction, label in zip(preds, labels):
22 | label = normalize_results(label)
23 | extracted_prediction = normalize_results(prediction)
24 | if extracted_prediction == []:
25 | extracted_prediction = prediction
26 | TP = len(set(extracted_prediction).intersection(label))
27 | FP = len(set(extracted_prediction) - set(label))
28 | FN = len(set(label) - set(extracted_prediction))
29 | precision = TP / (TP + FP) if (TP + FP) != 0 else 0
30 | recall = TP / (TP + FN) if (TP + FN) != 0 else 0
31 | f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
32 | all_p.append(precision)
33 | all_r.append(recall)
34 | all_f1.append(f1)
35 | all_p = np.asarray(all_p)
36 | all_r = np.asarray(all_r)
37 | all_f1 = np.asarray(all_f1)
38 | results = {
39 | "precision": np.sum(all_p) / len(all_p),
40 | "recall": np.sum(all_r) / len(all_r),
41 | "f1": np.sum(all_f1) / len(all_f1),
42 | }
43 | return results
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/query_suggestion.py:
--------------------------------------------------------------------------------
1 | from rouge_score import rouge_scorer
2 | import numpy as np
3 | from nltk.translate.bleu_score import corpus_bleu
4 | from nltk.translate import bleu_score
5 |
6 | class QuerySuggestion():
7 | def __init__(self, name, shot, setting):
8 | self._cluster = "query_suggestion"
9 | self._name = name
10 | self._shot = shot
11 | self._setting = setting
12 |
13 |
14 | def get_path(self):
15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
16 |
17 | def compute_metrics(self, preds, labels):
18 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
19 | all_rouge_l = []
20 | split_labels = [[x.split()] for x in labels]
21 | split_preds = [x.split() for x in preds]
22 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
23 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
24 | for prediction, label in zip(preds, labels):
25 | scores = result_rouge_scorer.score(label, prediction)
26 | rougel = scores['rougeL'].fmeasure
27 | all_rouge_l.append(rougel)
28 | avg_rouge_l = np.mean(np.asarray(all_rouge_l))
29 | results = {
30 | "BLEU-1": bleu1,
31 | "BLEU-2": bleu2,
32 | "ROUGE-L": avg_rouge_l,
33 | }
34 | return results
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/reading_comprehension.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | import string
4 | from collections import Counter
5 | from typing import List
6 | import numpy as np
7 | from rouge_score import rouge_scorer
8 | from sklearn.metrics import f1_score, accuracy_score
9 | from nltk.translate.bleu_score import corpus_bleu
10 | from nltk.translate import bleu_score
11 |
12 |
13 | class ReadingComprehension():
14 | def __init__(self, name, shot, setting):
15 | self._cluster = "reading_comprehension"
16 | self._name = name
17 | self._shot = shot
18 | self._setting = setting
19 |
20 | def get_path(self):
21 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
22 |
23 | def compute_metrics(self, preds, labels):
24 |
25 | def normalize_text(s):
26 |
27 | def remove_articles(text):
28 | return re.sub(r'\b(a|an|the)\b', ' ', text)
29 |
30 | def white_space_fix(text):
31 | return ' '.join(text.split())
32 |
33 | def remove_punc(text):
34 | exclude = set(string.punctuation)
35 | return ''.join(ch for ch in text if ch not in exclude)
36 |
37 | def lower(text):
38 | return text.lower()
39 |
40 | return white_space_fix(remove_articles(remove_punc(lower(s))))
41 |
42 |
43 | def calc_exact_match(text: str, answers: List[str]) -> bool:
44 | """Check if prediction is exactly the same as any of the answers."""
45 | norm_text = normalize_text(text)
46 | norm_answers = [normalize_text(ans) for ans in answers]
47 | return max([(norm_text == norm_ans) for norm_ans in norm_answers])
48 |
49 | def calc_soft_exact_match(text: str, answers: List[str]) -> bool:
50 | norm_text = normalize_text(text)
51 | norm_answers = [normalize_text(ans) for ans in answers]
52 | return max([(norm_ans in norm_text) for norm_ans in norm_answers])
53 |
54 | def calc_unigram_f1(text: str, answers: List[str]) -> float:
55 | """Calculate unigram f1 score between the text and reference answers."""
56 | norm_pred = normalize_text(text)
57 | norm_answers = [normalize_text(ans) for ans in answers]
58 | common_tokens = [
59 | Counter(norm_pred) & Counter(norm_ans) for norm_ans in norm_answers
60 | ]
61 | num_same = [sum(common.values()) for common in common_tokens]
62 |
63 | score_list = []
64 | for i, num in enumerate(num_same):
65 | if num == 0:
66 | score_list.append(0.0)
67 | else:
68 | p = 1.0 * num / len(norm_pred)
69 | r = 1.0 * num / len(norm_answers[i])
70 | f1 = 2 * p * r / (p + r)
71 | score_list.append(f1)
72 | return max(score_list)
73 |
74 | if self._name == "boolq":
75 | # accuracy = np.sum(np.asarray(preds) == np.asarray(labels)) / len(preds)
76 | labels = np.asarray([x[0] for x in labels])
77 | acc = accuracy_score(labels, preds)
78 | f1 = f1_score(labels, preds, average='weighted')
79 |
80 | results = {
81 | "Acc": acc,
82 | "F1": f1
83 | }
84 |
85 | elif self._name == "webglm_qa":
86 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
87 | all_rouge_l = []
88 | split_labels = [[x.split()] for x in labels]
89 | split_preds = [x.split() for x in preds]
90 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
91 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4)
92 | for prediction, label in zip(preds, labels):
93 | scores = result_rouge_scorer.score(label, prediction)
94 | rougel = scores['rougeL'].fmeasure
95 | all_rouge_l.append(rougel)
96 | avg_rouge_l = np.mean(np.asarray(all_rouge_l))
97 | results = {
98 | "BLEU-1": bleu1,
99 | "BLEU-2": bleu2,
100 | "ROUGE-L": avg_rouge_l,
101 | }
102 | return results
103 | else:
104 | em_scores = [calc_exact_match(pred, label.split(" <=SEP=> ")) for pred, label in zip(preds, labels)]
105 | f1_scores = [calc_unigram_f1(pred, label.split(" <=SEP=> ")) for pred, label in zip(preds, labels)]
106 | soft_em_scors = [calc_soft_exact_match(pred, label.split(" <=SEP=> ")) for pred, label in zip(preds, labels)]
107 |
108 | results = {
109 | # "EM": sum(em_scores) / len(em_scores),
110 | "F1": sum(f1_scores) / len(f1_scores),
111 | # "Soft-EM": sum(soft_em_scors) / len(soft_em_scors),
112 | }
113 |
114 | return results
--------------------------------------------------------------------------------
/evaluation/qu-du-tasks/inference_tasks/summarization.py:
--------------------------------------------------------------------------------
1 | from rouge_score import rouge_scorer
2 | import numpy as np
3 |
4 | class Summarization():
5 | def __init__(self, name, shot, setting):
6 | self._cluster = "summarization"
7 | self._name = name
8 | self._shot = shot
9 | self._setting = setting
10 |
11 |
12 | def get_path(self):
13 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl"
14 |
15 | def compute_metrics(self, preds, labels):
16 | result_rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
17 | all_rouge_l, all_rouge_1, all_rouge_2 = [], [], []
18 | for prediction, label in zip(preds, labels):
19 | scores = result_rouge_scorer.score(label, prediction)
20 | rouge1 = scores['rouge1'].fmeasure
21 | rouge2 = scores['rouge2'].fmeasure
22 | rougel = scores['rougeL'].fmeasure
23 | all_rouge_1.append(rouge1)
24 | all_rouge_2.append(rouge2)
25 | all_rouge_l.append(rougel)
26 | avg_rouge_l = np.mean(np.asarray(all_rouge_l))
27 | avg_rouge_1 = np.mean(np.asarray(all_rouge_1))
28 | avg_rouge_2 = np.mean(np.asarray(all_rouge_2))
29 | results = {
30 | "ROUGE-1": avg_rouge_1,
31 | "ROUGE-2": avg_rouge_2,
32 | "ROUGE-L": avg_rouge_l,
33 | }
34 | return results
--------------------------------------------------------------------------------
/evaluation/readme.md:
--------------------------------------------------------------------------------
1 | ## Required packages
2 | ```
3 | torch 2.0.0
4 | transformers 4.36.2
5 | numpy 1.26.3
6 | tqdm 4.66.1
7 | scikit-learn 1.4.0
8 | rouge_score 0.1.2
9 | nltk 3.8.1
10 | accelerate 0.26.1
11 | ```
12 |
13 | ## For query understanding tasks and document understanding tasks (qu-du-tasks)
14 | This evaluation script use pytorch DDP for text generation.
15 |
16 | 1. Download [test data](https://huggingface.co/datasets/yutaozhu94/INTERS/tree/main/test-qu-du-zero-shot) and save it to ``data/in-domain/zero_shot/``. The directory structure is like below:
17 | ```
18 | qu-du-tasks
19 | ├── eval_sampling.py
20 | ├── inference_dataset.py
21 | ├── inference_qu_du.py
22 | ├── inference_tasks
23 | │ ├── conversational_qa.py
24 | │ ├── fact_verification.py
25 | │ └── ...
26 | └── data
27 | └── in-domain
28 | └── zero-shot
29 | ├── conversational_qa_coqa.zero_shot.test.jsonl
30 | ├── conversational_qa_quac.zero_shot.test.jsonl
31 | ├── fact_verification_climate_fever.zero_shot.test.jsonl
32 | ├── fact_verification_fever.zero_shot.test.jsonl
33 | ├── fact_verification_scifact.zero_shot.test.jsonl
34 | └── ...
35 | ```
36 | 2. If you choose to place the test files in other directories, you can modify the path in each task file under ``inference_tasks`` directory (in ``get_path()`` function).
37 |
38 | 3. Run evaluation as
39 | ```
40 | TOKENIZERS_PARALLELISM=True python3 inference_qu_du.py \
41 | --model_name_or_path your/model/path \
42 | --tokenizer_name your/tokenizer/path \
43 | --setting in-domain \
44 | --n_shots zero_shot
45 | ```
46 |
47 | ## For query-document relationship understanding tasks (qdu-tasks)
48 | 1. Download [test data](https://huggingface.co/datasets/yutaozhu94/INTERS/tree/main/test-qdu) and save it to ``data/``. The directory structure is like below:
49 | ```
50 | qdu-tasks
51 | ├── cqa.sh
52 | ├── eval_rank.py
53 | ├── postprocess_cqa.py
54 | ├── run_eval.sh
55 | └── data
56 | ├── cqadupstack
57 | │ ├── android
58 | │ │ └── test.pt.key.do-not-overwrite.json
59 | │ ├── english
60 | │ │ └── test.pt.key.do-not-overwrite.json
61 | │ └── ...
62 | ├── arguana.bm25.100.jsonl
63 | ├── climate_fever.bm25.100.jsonl
64 | └── ...
65 | ```
66 | 1. For datasets other than cqadupstack, modify the paths in ``run_eval.sh``, then run the script
67 | ```
68 | MODEL_PATH="your/model/path"
69 | TOKENIZER_PATH="your/tokenizer/path"
70 | RESULT_PATH="your/result/path"
71 | EVAL_DATA_PATH="data"
72 |
73 | -----------------------
74 | bash run_eval.sh
75 | ```
76 | 2. For cqadupstack dataset, modify the paths in ``cqa.sh``, then run the script
77 | ```
78 | MODEL_PATH="your/model/path"
79 | TOKENIZER_PATH="your/tokenizer/path"
80 | RESULT_PATH="your/result/path"
81 |
82 | -----------------------
83 | bash cqa.sh
84 | ```
85 | 3. This script supports testing pointwise/pairwise/listwise methods for reranking. Modify the parameter of ``eval_rerank.py`` in ``run_eval.sh`` or ``cqa.sh``
86 | ```
87 | # pointwise: (default)
88 | --rerank_method pointwise
89 |
90 | # pairwise:
91 | --rerank_method pairwise
92 |
93 | # listwise:
94 | --rerank_method listwise \
95 | --listwise_window 5 \
96 | --listwise_stride 5
97 | ```
98 |
--------------------------------------------------------------------------------
/img/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/dataset.png
--------------------------------------------------------------------------------
/img/in-domain-google.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/in-domain-google.png
--------------------------------------------------------------------------------
/img/intro.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/intro.jpg
--------------------------------------------------------------------------------
/img/logo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/logo1.jpg
--------------------------------------------------------------------------------
/img/process.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/process.jpg
--------------------------------------------------------------------------------