├── LICENSE ├── README.md ├── evaluation ├── qdu-tasks │ ├── cqa.sh │ ├── eval_rerank.py │ ├── postprocess_cqa.py │ ├── run_eval.sh │ └── src │ │ ├── __init__.py │ │ ├── args.py │ │ ├── metrics.py │ │ ├── modeling.py │ │ └── utils.py ├── qu-du-tasks │ ├── eval_sampling.py │ ├── inference_dataset.py │ ├── inference_qu_du.py │ └── inference_tasks │ │ ├── conversational_qa.py │ │ ├── fact_verification.py │ │ ├── query_clarification.py │ │ ├── query_description.py │ │ ├── query_expansion.py │ │ ├── query_intent_classification.py │ │ ├── query_matching.py │ │ ├── query_reformulation.py │ │ ├── query_subtopic_generation.py │ │ ├── query_suggestion.py │ │ ├── reading_comprehension.py │ │ └── summarization.py └── readme.md ├── img ├── dataset.png ├── in-domain-google.png ├── intro.jpg ├── logo1.jpg └── process.jpg └── instruct_templates.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yutao ZHU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | 6 | ## INTERS: Unlocking the Power of Large Language Models in Search with Instruction Tuning 7 |

8 | 9 | license 10 | 11 |

12 | 13 | **Authors**: Yutao Zhu, Peitian Zhang, Chenghao Zhang, Yifei Chen, Binyu Xie, Zhicheng Dou, Zheng Liu, and Ji-Rong Wen 14 | 15 |

16 | 📃 Paper 17 | • 18 | 📚 Dataset 19 |

20 |

21 | 🤗 HuggingFace Model List 22 |

23 | 24 | | Model | Backbone Model | 25 | |:---------------------------------------------------------------------------------|:------------------------------------------------------------------------| 26 | | [INTERS-LLaMA-7b-Chat](https://huggingface.co/yutaozhu94/INTERS-LLaMA-7b-chat) | [LLaMA-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | 27 | | [INTERS-LLaMA-7b-Base](https://huggingface.co/yutaozhu94/INTERS-LLaMA-7b-base) | [LLaMA-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 28 | | [INTERS-Mistral-7b](https://huggingface.co/yutaozhu94/INTERS-Mistral-7b) | [Mistral-7b](https://huggingface.co/mistralai/Mistral-7B-v0.1) | 29 | | [INTERS-Minima-3b](https://huggingface.co/yutaozhu94/INTERS-Minima-3b) | [Minima-2-3b](https://huggingface.co/GeneZC/MiniMA-2-3B) | 30 | | [INTERS-Falcon-1b](https://huggingface.co/yutaozhu94/INTERS-Falcon-1b) | [Falcon-rw-1b](https://huggingface.co/tiiuae/falcon-rw-1b) | 31 | 32 | ## News 33 | - May, 2024: We are happy that INTERS has been accepted by ACL 2024 main conference! 34 | - Feb, 2024: We have released the dataset, instruction templates, fine-tuned models, and evaluation scripts. 35 | 36 | ## Introduction 37 |
38 | 39 |
40 | 41 | Large language models (LLMs) have demonstrated impressive capabilities in various natural language processing tasks. Despite this, their application to information retrieval (IR) tasks is still challenging due to the infrequent occurrence of many IR-specific concepts in natural language. While prompt-based methods can provide task descriptions to LLMs, they often fall short in facilitating a comprehensive understanding and execution of IR tasks, thereby limiting LLMs' applicability. To address this gap, in this work, we explore the potential of instruction tuning to enhance LLMs' proficiency in IR tasks. We introduce a novel instruction tuning dataset, INTERS, encompassing 20 tasks across three fundamental IR categories: query understanding, document understanding, and query-document relationship understanding. The data are derived from 43 distinct datasets with manually written templates. Our empirical results reveal that INTERS significantly boosts the performance of various publicly available LLMs, such as LLaMA, Mistral, and Phi, in IR tasks. Furthermore, we conduct extensive experiments to analyze the effects of instruction design, template diversity, few-shot demonstrations, and the volume of instructions on performance. 42 | 43 | ## Tasks & Datasets 44 | We consider tasks under the categories of query understanding, document understanding, and query-document understanding. Our dataset consists of 20 tasks derived from 43 datasets. All tasks and datasets we used are shown in the figure below. 45 |
46 | 47 |
48 | 49 | ## Dataset Construction 50 |
51 | 52 |
53 | 54 | ## General Performance 55 |
56 | 57 |
58 | 59 | ## Zero-shot Evaluation 60 | The evaluation script is under the ``evaluation`` directory. 61 | 62 | ### Required packages 63 | ``` 64 | torch 2.0.0 65 | transformers 4.36.2 66 | numpy 1.26.3 67 | tqdm 4.66.1 68 | scikit-learn 1.4.0 69 | rouge_score 0.1.2 70 | nltk 3.8.1 71 | accelerate 0.26.1 72 | ``` 73 | 74 | ### For query understanding tasks and document understanding tasks (qu-du-tasks) 75 | This evaluation script use pytorch DDP for text generation. 76 | 77 | 1. Download [test data](https://huggingface.co/datasets/yutaozhu94/INTERS/tree/main/test-qu-du-zero-shot) and save it to ``data/in-domain/zero_shot/``. The directory structure is like below: 78 | ``` 79 | qu-du-tasks 80 | ├── eval_sampling.py 81 | ├── inference_dataset.py 82 | ├── inference_qu_du.py 83 | ├── inference_tasks 84 | │ ├── conversational_qa.py 85 | │ ├── fact_verification.py 86 | │ └── ... 87 | └── data 88 | └── in-domain 89 | └── zero-shot 90 | ├── conversational_qa_coqa.zero_shot.test.jsonl 91 | ├── conversational_qa_quac.zero_shot.test.jsonl 92 | ├── fact_verification_climate_fever.zero_shot.test.jsonl 93 | ├── fact_verification_fever.zero_shot.test.jsonl 94 | ├── fact_verification_scifact.zero_shot.test.jsonl 95 | └── ... 96 | ``` 97 | 2. If you choose to place the test files in other directories, you can modify the path in each task file under ``inference_tasks`` directory (in ``get_path()`` function). 98 | 99 | 3. Run evaluation as 100 | ``` 101 | TOKENIZERS_PARALLELISM=True python3 inference_qu_du.py \ 102 | --model_name_or_path your/model/path \ 103 | --tokenizer_name your/tokenizer/path \ 104 | --setting in-domain \ 105 | --n_shots zero_shot 106 | ``` 107 | 108 | ### For query-document relationship understanding tasks (qdu-tasks) 109 | 1. Download [test data](https://huggingface.co/datasets/yutaozhu94/INTERS/tree/main/test-qdu) and save it to ``data/``. The directory structure is like below: 110 | ``` 111 | qdu-tasks 112 | ├── cqa.sh 113 | ├── eval_rank.py 114 | ├── postprocess_cqa.py 115 | ├── run_eval.sh 116 | └── data 117 | ├── cqadupstack 118 | │ ├── android 119 | │ │ └── test.pt.key.do-not-overwrite.json 120 | │ ├── english 121 | │ │ └── test.pt.key.do-not-overwrite.json 122 | │ └── ... 123 | ├── arguana.bm25.100.jsonl 124 | ├── climate_fever.bm25.100.jsonl 125 | └── ... 126 | ``` 127 | 1. For datasets other than cqadupstack, modify the paths in ``run_eval.sh``, then run the script 128 | ``` 129 | MODEL_PATH="your/model/path" 130 | TOKENIZER_PATH="your/tokenizer/path" 131 | RESULT_PATH="your/result/path" 132 | EVAL_DATA_PATH="data" 133 | 134 | ----------------------- 135 | bash run_eval.sh 136 | ``` 137 | 2. For cqadupstack dataset, modify the paths in ``cqa.sh``, then run the script 138 | ``` 139 | MODEL_PATH="your/model/path" 140 | TOKENIZER_PATH="your/tokenizer/path" 141 | RESULT_PATH="your/result/path" 142 | 143 | ----------------------- 144 | bash cqa.sh 145 | ``` 146 | 3. This script supports testing pointwise/pairwise/listwise methods for reranking. Modify the parameter of ``eval_rerank.py`` in ``run_eval.sh`` or ``cqa.sh`` 147 | ``` 148 | # pointwise: (default) 149 | --rerank_method pointwise 150 | 151 | # pairwise: 152 | --rerank_method pairwise 153 | 154 | # listwise: 155 | --rerank_method listwise \ 156 | --listwise_window 5 \ 157 | --listwise_stride 5 158 | ``` 159 | 160 | ## Citation 161 | Please kindly cite our paper if it helps your research: 162 | ```BibTex 163 | @inproceedings{INTERS, 164 | author = {Yutao Zhu and 165 | Peitian Zhang and 166 | Chenghao Zhang and 167 | Yifei Chen and 168 | Binyu Xie and 169 | Zheng Liu and 170 | Ji{-}Rong Wen and 171 | Zhicheng Dou}, 172 | editor = {Lun{-}Wei Ku and 173 | Andre Martins and 174 | Vivek Srikumar}, 175 | title = {{INTERS:} Unlocking the Power of Large Language Models in Search with 176 | Instruction Tuning}, 177 | booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational 178 | Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, 179 | August 11-16, 2024}, 180 | pages = {2782--2809}, 181 | publisher = {Association for Computational Linguistics}, 182 | year = {2024}, 183 | url = {https://doi.org/10.18653/v1/2024.acl-long.154}, 184 | doi = {10.18653/V1/2024.ACL-LONG.154}, 185 | } 186 | ``` 187 | -------------------------------------------------------------------------------- /evaluation/qdu-tasks/cqa.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH="your/model/path" 2 | TOKENIZER_PATH="your/tokenizer/path" 3 | RESULT_PATH="your/result/path" 4 | 5 | RESULT_PATH=${RESULT_PATH}/cqadupstack 6 | EVAL_DATA_PATH="data" 7 | 8 | mkdir ${RESULT_PATH} 9 | 10 | # to gather metrics from all sub-datasets 11 | TMP_PATH=${RESULT_PATH}/tmp.log 12 | 13 | ############# MODIFY PATH HERE ############# 14 | # containing all sub-dataset folders 15 | CQA_ROOT=${EVAL_DATA_PATH}/cqadupstack/ 16 | ################################################ 17 | 18 | COUNTER=0 19 | for dataset in $CQA_ROOT/* 20 | do 21 | 22 | # get fewshot data files 23 | fewshot_data=($dataset/*train.pt.neg.do-not-overwrite.fewshot*) 24 | fewshot_data=${fewshot_data[0]} 25 | 26 | eval_data="$dataset/test.pt.key.do-not-overwrite.json" 27 | 28 | ############# MODIFY COMMANDS HERE ############# 29 | outputString=`torchrun --nproc_per_node 8 eval_rerank.py \ 30 | --eval_data $eval_data \ 31 | --output_dir $RESULT_PATH \ 32 | --model_name_or_path $MODEL_PATH \ 33 | --tokenizer_name_or_path $TOKENIZER_PATH \ 34 | --hits 10 \ 35 | --rerank_method pointwise \ 36 | --dataset_name cqadupstack \ 37 | --batch_size 4` 38 | 39 | # to add 1-shot 40 | # --fewshot_data $fewshot_data \ 41 | # --shots 1 42 | ################################################ 43 | 44 | if [[ $COUNTER == 0 ]] 45 | then 46 | echo $outputString > $TMP_PATH 47 | else 48 | echo $outputString >> $TMP_PATH 49 | fi 50 | 51 | COUNTER=$[$COUNTER +1] 52 | done 53 | 54 | python postprocess_cqa.py -t $TMP_PATH -------------------------------------------------------------------------------- /evaluation/qdu-tasks/eval_rerank.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datasets 3 | import numpy as np 4 | from typing import List 5 | from dataclasses import dataclass, field, asdict 6 | from torch.utils.data import DataLoader 7 | from collections import defaultdict 8 | from accelerate import Accelerator 9 | from transformers import HfArgumentParser 10 | from transformers.utils import logging 11 | 12 | from src import ModelArgs, Metric, DatasetProcessFn, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs 13 | 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | 18 | @dataclass 19 | class Args(ModelArgs): 20 | eval_data: str = field( 21 | default=None, 22 | metadata={'help': 'The evaluation json data path.'} 23 | ) 24 | fewshot_data: str = field( 25 | default=None, 26 | metadata={'help': 'The fewshot json data path.'} 27 | ) 28 | output_dir: str = field( 29 | default="data/results/rerank", 30 | metadata={'help': 'Output directory for results and logs.'} 31 | ) 32 | batch_size: int = field( 33 | default=16, 34 | metadata={'help': 'Evaluation batch size.'} 35 | ) 36 | 37 | query_max_length: int = field( 38 | default=64, 39 | metadata={'help': 'How many tokens at maximum?'} 40 | ) 41 | doc_max_length: int = field( 42 | default=512, 43 | metadata={'help': 'How many tokens at maximum?'} 44 | ) 45 | 46 | with_description: bool = field( 47 | default=True, 48 | metadata={'help': "Whether to add task description"} 49 | ) 50 | rerank_method: str = field( 51 | default="pointwise", 52 | metadata={'help': 'How to evaluate reranking? {pointwise, pairwise, listwise, no}'} 53 | ) 54 | dataset_name: str = field( 55 | default="msmarco", 56 | metadata={'help': 'Select one from [msmarco | trec_covid | nfcorpus | nq | fiqa | hotpot_qa | arguana | touche | cqadupstack | quora | dbpdia | scidocs | climate_fever | fever | scifact]'}, 57 | ) 58 | hits: int = field( 59 | default=100, 60 | metadata={'help': 'How many candidates to rerank?'} 61 | ) 62 | shots: int = field( 63 | default=None, 64 | metadata={'help': 'How many shots to use for fewshot testing?'} 65 | ) 66 | 67 | metrics: List[str] = field( 68 | default_factory=lambda: ["mrr", "recall", "ndcg"], 69 | metadata={'help': 'List of metrics. {rouge, acc}'} 70 | ) 71 | cutoffs: List[int] = field( 72 | default_factory=lambda: [1, 5, 10], 73 | metadata={'help': 'Cutoffs to evaluate retrieval metrics.'} 74 | ) 75 | listwise_window: int = field( 76 | default=10, 77 | metadata={'help': 'How long is the window in listwise?'} 78 | ) 79 | listwise_stride: int = field( 80 | default=5, 81 | metadata={'help': 'How long is the step in listwise?'} 82 | ) 83 | 84 | 85 | TASK_DESCRIPTION = "In the reranking task, search engines must understand the relationship between the user's query, which may be keywords or a sentence, and the potential documents. The goal is to ensure that the most relevant documents, those that best cover the user's information needs, are ranked highest. This requires a nuanced understanding of both the query's intent and the content of the documents." 86 | 87 | PROMPT_TEMPLATE = { 88 | "pointwise": { 89 | "msmarco": "Assess the relevance between the provided document:\n{text}\nand the query: \"{query}\". Respond with 'Yes' if the document is relevant to the query or 'No' if not.", 90 | "trec_covid": "Document: {text}\n\nQuery: {query}\n\nAssess the relevance of the provided document in the context of the COVID-19-related query. Answer 'Yes' if the document explicitly addresses or pertains to the query, or 'No' if it is unrelated.", 91 | "nfcorpus": "Assess the relevance of the medical document:\n{text}\nin relation to the search query \"{query}\". Determine if the document is relevant by responding with 'Yes' for relevance or 'No' for irrelevance.", 92 | "nq": "Review the content of the document:\n{text}\nand ascertain its relevance to the topic: \"{query}\". Provide your determination by responding with either 'Yes' for relevance or 'No' for irrelevance.", 93 | "fiqa": "Evaluate the financial document:\n{text}\nin the context of the query: \"{query}\". Determine if the document is relevant to the query and provide your judgment with 'Yes' for relevance or 'No' for irrelevance.", 94 | "hotpot_qa": "Analyze the relationship between the document:\n{text}\nand the query: \"{query}\". Offer a definitive response by indicating whether they are relevant, and reply with either 'Yes' for relevance or 'No' for irrelevance.", 95 | "arguana": "Document:\n\n{text}\n\nQuery:\n\n{query}\n\nDetermine their relevance and return the judgment about the document's relevance to the query by responding with either 'Yes' for relevance or 'No' for irrelevance.", 96 | "touche": "Evaluate the relevance between the given document on controversial subjects:\n{text}\nand the query: \"{query}\", and provide a clear 'Yes' for relevance or 'No' for irrelevance judgment regarding their relevance.", 97 | "cqadupstack": "Evaluate the relevance between the given document on controversial subjects:\n{text}\nand the query: \"{query}\", and provide a clear 'Yes' for relevance or 'No' for irrelevance judgment regarding their relevance.", 98 | "quora": "Determine the relevance between the document:\n{text}\nand the query: {query}. Judge whether they are related by responding with 'Yes' for relevance or 'No' for irrelevance.", 99 | "dbpedia": "Determine the relevance of the document:\n{text}\nto the query: \"{query}\". Conclude with 'Yes' for relevance or 'No' for irrelevance.", 100 | "scidocs": "Determine the correlation between the provided literature-related document:\n{text}\nand the query: \"{query}\". Conclude if they are closely connected with a 'Yes' for relevance or 'No' for irrelevance.", 101 | "climate_fever": "Analyze the correlation between this document on the topic of climate change:\n{text}\nand the following query: {query}. Determine if the document is relevant to the query. Respond with 'Yes' for relevance or 'No' for irrelevance.", 102 | "fever": "Assess the relationship between the given document:\n{text}\nand the query: \"{query}\" to determine if the document is relevant. Answer with 'Yes' for relevance or 'No' for irrelevance.", 103 | "scifact": "Assess the relevance between the scientific document:\n{text}\nand query: \"{query}\", and state 'Yes' or 'No' to reflect the relevance judgment. Answer 'Yes' for relevance and 'No' for irrelevance." 104 | }, 105 | "pairwise": { 106 | "msmarco": "Consider a query \"{query}\" alongside two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nDecide which document is more relevant to the given query by providing the corresponding document identifier.", 107 | "trec_covid": "Given a query: \"{query}\" and two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nEach marked with a unique identifier and related to COVID-19, assess and identify which document is more closely related to the query. Respond with the identifier of the more relevant document.", 108 | "nfcorpus": "Assess which of the two medical field documents is more relevant to the provided query \"{query}\". Each document is presented with a unique identifier.\nDocuments:\n\n[1] {doc1}\n\n[2] {doc2}\n\nIdentify and return the identifier of the document that best aligns with the query.", 109 | "nq": "Evaluate the relevance of the provided query \"{query}\" to a pair of documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach identified separately. Determine which document is more relevant to the query and return the identifier of the more relevant document.", 110 | "fiqa": "Evaluate the relevance of the query: \"{query}\" and the pair of financial documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach assigned a unique identifier. Determine which document is more relevant to the provided query and specify its document identifier.", 111 | "hotpot_qa": "Compare the relevance of two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nto the provided query: \"{query}\". Identify the document with the higher relevance by specifying its identifier.", 112 | "arguana": "Evaluate the relevance of two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nto the provided query \"{query}\". Express the identifier of the document that is more relevant to the query.", 113 | "touche": "Given a query: \"{query}\" and two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach with its unique identifier, determine which document is more relevant to the given query by providing the document identifier.", 114 | "cqadupstack": "Compare the relevance of two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nto the query: \"{query}\", and specify the document identifier that has higher relevance.", 115 | "quora": "With a given query: \"{query}\" and two unique documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nIdentify the document that aligns more closely with the query by stating its identifier.", 116 | "dbpedia": "Given a query: \"{query}\" and two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach identified by a distinct number. Determine the document identifier for the one more relevant to the provided query.", 117 | "scidocs": "Given a query:\"{query}\" and two literature-related documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach with a unique identifier, identify which document is more relevant to the query by specifying its identifier.", 118 | "climate_fever": "Given this query:\"{query}\" and two climate change documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\neach with a unique identifier, identify which document aligns more closely with the query by indicating the document's identifier.", 119 | "fever": "Evaluate two documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nagainst a given query: \"{query}\" and identify the document that is more relevant by its identifier.", 120 | "scifact": "Analyze the relevance of two scientific documents:\n\n[1] {doc1}\n\n[2] {doc2}\n\nto the query: \"{query}\" and identify the more relevant document by its identifier." 121 | }, 122 | "listwise": { 123 | "msmarco": "Here are {num} documents:\n{docs}\nand a query \"{query}\", and each document is indicated by a number identifier. Please sort the documents in an order based on their relevance to the above query by returning an identifier list. Be careful to sort documents in order of their relevance to the query from highest to lowest. Result: ", 124 | "trec_covid": "Given {num} documents:\n{docs}\neach pertaining to COVID-19, and a specific query: \"{query}\", organize the documents in order of relevance to the query. List the identifiers of these documents starting from the most relevant to the least relevant. Result: ", 125 | "nfcorpus": "Arrange the given {num} medical field documents:\n{docs}\nin order of relevance to the specified query \"{query}\", with the most relevant document at the top and the least relevant at the bottom. Use their unique number identifiers to indicate the sequence of relevance. Result: ", 126 | "nq": "Arrange the provided {num} documents:\n{docs}\nin accordance with their relevance to the specified query \"{query}\". Utilize number identifiers to denote the sequence of relevance, with the most relevant document at the top and the least relevant document at the end. Result: ", 127 | "fiqa": "Assess the relevance of the query: \"{query}\" to a set of {num} financial documents:\n{docs}\nGenerate a list of document identifiers, arranging them from the most relevant to the least relevant in relation to the query, and return the identifiers list. Result: ", 128 | "hotpot_qa": "Rank the following {num} documents in descending order of relevance to the query: \"{query}\"\n{docs}\nProvide the list of identifiers. Result: ", 129 | "arguana": "Rank the {num} documents:\n{docs}\n in descending order of relevance to the provided query: \"{query}\". Return the list of identifiers associated with each document in order. Result:", 130 | "touche": "Rerank the provided {num} documents on a controversial topic:\n{docs}\nbased on their relevance to the query \"{query}\". Return a list of identifiers in the order of descending relevance. Result: ", 131 | "cqadupstack": "Rerank the provided {num} documents on a controversial topic:\n{docs}\nbased on their relevance to the query \"{query}\". Return a list of identifiers in the order of descending relevance. Result: ", 132 | "quora": "Rank {num} documents:\n{docs}\nby their relevance to the query: \"{query}\" in descending order. List their identifiers. Result: ", 133 | "dbpedia": "Rank {num} documents:\n{docs}\nbased on their relevance to the specified query \"{query}\". Return the list of identifiers in descending order. Result: ", 134 | "scidocs": "Rank {num} documents:\n{docs}\nin order of their relevance to the query: {query}. List the identifiers starting with the most relevant document. Result: ", 135 | "climate_fever": "Rerank these {num} climate change documents:\n{docs}\nby their relevance to the query: \"{query}\" and list the identifiers. Rank from most to least relevant. Result: ", 136 | "fever": "Rank a set of {num} documents:\n{docs}\nby their relevance to a query: \"{query}\". List their identifiers in order of decreasing relevance. Result: ", 137 | "scifact": "Order the provided {num} scientific documents:\n{docs}\n based on their relevance to the query: \"{query}\", from most to least relevant. Only output a list of identifiers. Result: " 138 | # "Here is a query:\n\n{query}\n\nHere are {num} documents, and each document has its own identifier:\n\n{docs}\n\nRank the {num} documents above based on their relevance to the given query. The documents should be listed in descending order using identifiers. The most relevant documents should be listed first. The output format should be [] > [], e.g., [1] > [2]. Each two identifiers are separated by ' > '. Only response the ranking results, do not say any word or explain.", 139 | }, 140 | "no": defaultdict(lambda: ""), 141 | } 142 | 143 | 144 | def truncate(text, tokenizer, max_length): 145 | if tokenizer is not None: 146 | return tokenizer.decode(tokenizer.encode(text, add_special_tokens=False, max_length=max_length, truncation=True)) 147 | else: 148 | return text 149 | 150 | 151 | def process_rerank(tokenizer, rerank_method, prompt_template, query_max_length=64, doc_max_length=512, hits=10, fewshot_data=None, shots=None, cache_dir=None): 152 | if fewshot_data is not None: 153 | fewshot_data = datasets.load_dataset("json", data_files=fewshot_data, split="train", cache_dir=cache_dir) 154 | rng = np.random.default_rng(42) 155 | indices = rng.choice(range(len(fewshot_data)), size=shots, replace=False).tolist() 156 | 157 | fewshot_prompt = [] 158 | 159 | for index in indices: 160 | item = fewshot_data[index] 161 | # NOTE: do not use query, pos, neg here, because they are the arguments for the _process function 162 | fs_query = item["query"] 163 | fs_pos = item["pos"] 164 | fs_neg = item["neg"] 165 | 166 | fs_query = truncate(fs_query, tokenizer, query_max_length) 167 | 168 | if rerank_method == "pointwise": 169 | # sample 1 candidate from the union of pos and neg 170 | candidate = fs_pos + fs_neg 171 | candidate_range = range(len(candidate)) 172 | candidate_index = rng.choice(candidate_range).tolist() 173 | candidate = truncate(candidate[candidate_index], tokenizer, doc_max_length) 174 | prompt = prompt_template.format(query=fs_query, text=candidate) 175 | 176 | if candidate_index < len(fs_pos): 177 | # sampled from pos 178 | prompt += " Yes." 179 | else: 180 | prompt += " No." 181 | 182 | elif rerank_method == "pairwise": 183 | # sample 1 positive and 1 negative for comparison 184 | fs_pos = rng.choice(fs_pos).tolist() 185 | fs_neg = rng.choice(fs_neg).tolist() 186 | 187 | reverse = rng.choice([True, False]) 188 | if reverse: 189 | prompt = prompt_template.format(query=fs_query, doc1=fs_neg, doc2=fs_pos) 190 | prompt += "[2]" 191 | else: 192 | prompt = prompt_template.format(query=fs_query, doc1=fs_pos, doc2=fs_neg) 193 | prompt += "[1]" 194 | 195 | elif rerank_method == "listwise": 196 | len_pos = len(fs_pos) 197 | len_neg = len(fs_neg) 198 | len_all = len_pos + len_neg 199 | all_documents = fs_pos + fs_neg 200 | result = list(range(1, len_all + 1)) 201 | rng.shuffle(result) 202 | idx_document = {idx: document for idx, document in zip(result, all_documents)} 203 | idx_document = dict(sorted(idx_document.items())) 204 | result = [f"[{num}]" for num in result] 205 | result = " > ".join(result) 206 | docs = "" 207 | rank = 0 208 | for value in idx_document.values(): 209 | rank + 1 210 | docs += f"[{rank}] " + value + "\n" 211 | prompt = prompt_template.format(query=fs_query, num=rank, docs=docs) 212 | prompt += " " + result 213 | 214 | else: 215 | raise NotImplementedError(f"Rerank method {rerank_method} not implemented for fewshot!") 216 | 217 | fewshot_prompt.append(prompt) 218 | 219 | fewshot_prompt = "\n\n".join(fewshot_prompt) + "\n\n" 220 | 221 | else: 222 | fewshot_prompt = None 223 | 224 | 225 | @DatasetProcessFn(augment=True) 226 | def _process_pointwise(query, pos=None, neg=None, pos_index=None, neg_index=None, pos_score=None, key=None, key_index=None, query_id=None, _index=None, **kwds): 227 | outputs = defaultdict(list) 228 | if pos_score is None: 229 | pos_score = [1 for _ in pos] 230 | 231 | # rerank positive and negative documents when there are no pre-defined candidates 232 | if neg is not None: 233 | key = pos + neg 234 | key_index = pos_index + neg_index 235 | 236 | if query_id is None: 237 | assert _index is not None, f"Make sure to set with_indices=True when there is no given query_id!" 238 | query_id = _index 239 | 240 | # truncate query 241 | query = truncate(query, tokenizer, query_max_length) 242 | # only rerank the top-hits candidates 243 | key = key[:hits] 244 | key_index = key_index[:hits] 245 | 246 | for doc_id, doc_text in zip(key_index, key): 247 | # truncate document 248 | doc_text = truncate(doc_text, tokenizer, doc_max_length) 249 | 250 | prompt = prompt_template.format(query=query, text=doc_text) 251 | # NOTE: prepend fewshot prompt 252 | if fewshot_prompt is not None: 253 | prompt = fewshot_prompt + prompt 254 | 255 | output = tokenizer(prompt) 256 | 257 | output["query_id"] = query_id 258 | output["doc_id"] = doc_id 259 | 260 | for k, v in output.items(): 261 | outputs[k].append(v) 262 | return outputs 263 | 264 | 265 | @DatasetProcessFn() 266 | def _process_other(query, pos=None, neg=None, pos_index=None, neg_index=None, pos_score=None, key=None, key_index=None, query_id=None, _index=None, **kwds): 267 | if pos_score is None: 268 | pos_score = [1 for _ in pos] 269 | 270 | # rerank positive and negative documents when there are no pre-defined candidates 271 | if neg is not None: 272 | key = pos + neg 273 | key_index = pos_index + neg_index 274 | 275 | if query_id is None: 276 | assert _index is not None, f"Make sure to set with_indices=True when there is no given query_id!" 277 | query_id = _index 278 | 279 | # truncate query 280 | query = truncate(query, tokenizer, query_max_length) 281 | 282 | if len(key) < hits: 283 | return None 284 | 285 | # only rerank the top-hits candidates 286 | key = key[:hits] 287 | key_index = key_index[:hits] 288 | for i, k in enumerate(key): 289 | key[i] = truncate(k, tokenizer, doc_max_length) 290 | 291 | outputs = { 292 | "query": query, 293 | "query_id": query_id, 294 | "docs": key, 295 | "doc_ids": key_index 296 | } 297 | 298 | # always add fewshot prompt 299 | outputs["fewshot_prompt"] = fewshot_prompt 300 | return outputs 301 | 302 | if rerank_method == "pointwise": 303 | return _process_pointwise 304 | else: 305 | return _process_other 306 | 307 | 308 | def main(): 309 | parser = HfArgumentParser([Args]) 310 | args: Args = parser.parse_args_into_dataclasses()[0] 311 | 312 | # NOTE: if skip, we just use CPU to load model 313 | # import json 314 | # with open("/share/peitian/Data/Datasets/searchgpt/qrels.train.pt.neg.do-not-overwrite.fewshot.jsonl", 'r', encoding='utf-8') as fr: 315 | # lines = [json.loads(line.strip()) for line in fr.readlines()[:100]] 316 | # with open("/share/yutao/yifei/data/test_fewshot_100.jsonl", 'w', encoding='utf-8') as fw: 317 | # for line in lines: 318 | # json.dump(line, fw) 319 | # fw.write('\n') 320 | # with open("/share/peitian/Data/Datasets/searchgpt/qrels.dev.pt.key.do-not-overwrite.jsonl", 'r', encoding='utf-8') as fr: 321 | # lines = [json.loads(line.strip()) for line in fr.readlines()[:100]] 322 | # with open("/share/yutao/yifei/data/test_100.jsonl", 'w', encoding='utf-8') as fw: 323 | # for line in lines: 324 | # json.dump(line, fw) 325 | # fw.write('\n') 326 | # return 327 | 328 | accelerator = Accelerator(cpu=args.cpu or args.rerank_method == "no") 329 | 330 | if args.rerank_method == "no": 331 | model = None 332 | tokenizer = None 333 | logger.info(f"directly evaluating {args.eval_data}...") 334 | else: 335 | model, tokenizer = get_model_and_tokenizer(args, accelerator=accelerator) 336 | 337 | with accelerator.main_process_first(): 338 | prompt_template = PROMPT_TEMPLATE[args.rerank_method][args.dataset_name] 339 | if args.with_description: 340 | prompt_template = TASK_DESCRIPTION + " " + prompt_template 341 | process_fn = process_rerank( 342 | tokenizer=tokenizer, 343 | rerank_method=args.rerank_method, 344 | prompt_template=prompt_template, 345 | query_max_length=args.query_max_length, 346 | doc_max_length=args.doc_max_length, 347 | hits=args.hits, 348 | fewshot_data=args.fewshot_data, 349 | shots=args.shots, 350 | cache_dir=args.dataset_cache_dir, 351 | ) 352 | dataset = datasets.load_dataset("json", data_files=args.eval_data, cache_dir=args.dataset_cache_dir, split="train") 353 | dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, with_indices=True) 354 | 355 | if args.rerank_method == "no": 356 | # directly compose rerank results from retrieval results 357 | from src import RerankResult 358 | results = {} 359 | for x in dataset: 360 | query_id = x["query_id"] 361 | results[query_id] = [RerankResult(doc_id, 0) for doc_id in x["doc_ids"]] 362 | 363 | else: 364 | data_collator = DefaultDataCollator(tokenizer=tokenizer) 365 | dataloader = DataLoader( 366 | dataset, 367 | batch_size=args.batch_size, 368 | collate_fn=data_collator, 369 | pin_memory=model.device.index is not None, 370 | ) 371 | 372 | if args.rerank_method == "pointwise": 373 | results = model.rerank_pointwise(dataloader, accelerator=accelerator) 374 | elif args.rerank_method == "pairwise": 375 | results = model.rerank_pairwise(dataloader, prompt_template=prompt_template, accelerator=accelerator) 376 | elif args.rerank_method == "listwise": 377 | results = model.rerank_listwise(dataloader, prompt_template=prompt_template, accelerator=accelerator, window=args.listwise_window, stride=args.listwise_stride) 378 | else: 379 | raise NotImplementedError(f"Rerank method {args.rerank_method} not implemented!") 380 | 381 | if accelerator.process_index == 0: 382 | result_path = Metric._get_save_path(args.eval_data, args.output_dir) 383 | Metric._save_rerank_result(results, result_path, eval_data=args.eval_data) 384 | metrics = Metric.get_metric_fn(metric_names=args.metrics, eval_data=args.eval_data, cutoffs=args.cutoffs)(results) 385 | 386 | file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log"))) 387 | file_logger.log(metrics, Args=asdict(args)) 388 | 389 | 390 | if __name__ == "__main__": 391 | main() -------------------------------------------------------------------------------- /evaluation/qdu-tasks/postprocess_cqa.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import argparse 4 | from collections import defaultdict 5 | 6 | def main(args): 7 | all_results = defaultdict(list) 8 | with open(args.tmp_path, encoding="utf-8") as f: 9 | for line in f: 10 | result = re.search("Metrics : (\{.*\})", line).group(1) 11 | result = json.loads(result) 12 | for k, v in result.items(): 13 | all_results[k].append(v) 14 | for k, v in all_results.items(): 15 | all_results[k] = sum(v) / len(v) 16 | print(dict(all_results)) 17 | 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("-t", "--tmp_path", default="data/results/cqadupstack/tmp.log") 22 | args = parser.parse_args() 23 | 24 | main(args) 25 | -------------------------------------------------------------------------------- /evaluation/qdu-tasks/run_eval.sh: -------------------------------------------------------------------------------- 1 | GPU_NUM=8 2 | MODEL_PATH="your/model/path" 3 | TOKENIZER_PATH="your/tokenizer/path" 4 | RESULT_PATH="your/result/path" 5 | EVAL_DATA_PATH="data" 6 | 7 | torchrun --nproc_per_node 8 eval_rerank.py \ 8 | --eval_data ${EVAL_DATA_PATH}/msmarco.bm25.100.jsonl \ 9 | --output_dir ${RESULT_PATH} \ 10 | --model_name_or_path ${MODEL_PATH} \ 11 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 12 | --dataset_cache_dir hf_cache/dataset/ \ 13 | --use_flash_attention_2 False \ 14 | --max_length 2048 \ 15 | --batch_size 4 \ 16 | --with_description True \ 17 | --dataset_name msmarco \ 18 | 19 | torchrun --nproc_per_node 8 eval_rerank.py \ 20 | --eval_data ${EVAL_DATA_PATH}/touche.bm25.100.jsonl \ 21 | --output_dir ${RESULT_PATH} \ 22 | --model_name_or_path ${MODEL_PATH} \ 23 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 24 | --dataset_cache_dir hf_cache/dataset/ \ 25 | --use_flash_attention_2 False \ 26 | --max_length 2048 \ 27 | --batch_size 4 \ 28 | --with_description True \ 29 | --dataset_name touche 30 | 31 | torchrun --nproc_per_node 8 eval_rerank.py \ 32 | --eval_data ${EVAL_DATA_PATH}/arguana.bm25.100.jsonl \ 33 | --output_dir ${RESULT_PATH} \ 34 | --model_name_or_path ${MODEL_PATH} \ 35 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 36 | --dataset_cache_dir hf_cache/dataset/ \ 37 | --use_flash_attention_2 False \ 38 | --max_length 2048 \ 39 | --batch_size 4 \ 40 | --with_description True \ 41 | --dataset_name arguana 42 | 43 | torchrun --nproc_per_node 8 eval_rerank.py \ 44 | --eval_data ${EVAL_DATA_PATH}/trec_covid.bm25.100.jsonl \ 45 | --output_dir ${RESULT_PATH} \ 46 | --model_name_or_path ${MODEL_PATH} \ 47 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 48 | --dataset_cache_dir hf_cache/dataset/ \ 49 | --use_flash_attention_2 False \ 50 | --max_length 2048 \ 51 | --batch_size 4 \ 52 | --with_description True \ 53 | --dataset_name trec_covid 54 | 55 | torchrun --nproc_per_node 8 eval_rerank.py \ 56 | --eval_data ${EVAL_DATA_PATH}/nfcorpus.bm25.100.jsonl \ 57 | --output_dir ${RESULT_PATH} \ 58 | --model_name_or_path ${MODEL_PATH} \ 59 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 60 | --dataset_cache_dir hf_cache/dataset/ \ 61 | --use_flash_attention_2 False \ 62 | --max_length 2048 \ 63 | --batch_size 4 \ 64 | --with_description True \ 65 | --dataset_name nfcorpus 66 | 67 | torchrun --nproc_per_node 8 eval_rerank.py \ 68 | --eval_data ${EVAL_DATA_PATH}/scidocs.bm25.100.jsonl \ 69 | --output_dir ${RESULT_PATH} \ 70 | --model_name_or_path ${MODEL_PATH} \ 71 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 72 | --dataset_cache_dir hf_cache/dataset/ \ 73 | --use_flash_attention_2 False \ 74 | --max_length 2048 \ 75 | --batch_size 4 \ 76 | --with_description True \ 77 | --dataset_name scidocs 78 | 79 | torchrun --nproc_per_node 8 eval_rerank.py \ 80 | --eval_data ${EVAL_DATA_PATH}/quora.bm25.100.jsonl \ 81 | --output_dir ${RESULT_PATH} \ 82 | --model_name_or_path ${MODEL_PATH} \ 83 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 84 | --dataset_cache_dir hf_cache/dataset/ \ 85 | --use_flash_attention_2 False \ 86 | --max_length 2048 \ 87 | --batch_size 4 \ 88 | --with_description True \ 89 | --dataset_name quora 90 | 91 | torchrun --nproc_per_node 8 eval_rerank.py \ 92 | --eval_data ${EVAL_DATA_PATH}/dbpedia.bm25.100.jsonl \ 93 | --output_dir ${RESULT_PATH} \ 94 | --model_name_or_path ${MODEL_PATH} \ 95 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 96 | --dataset_cache_dir hf_cache/dataset/ \ 97 | --use_flash_attention_2 False \ 98 | --max_length 2048 \ 99 | --batch_size 4 \ 100 | --with_description True\ 101 | --rerank_method listwise \ 102 | --listwise_window 5 \ 103 | --listwise_stride 5 \ 104 | --dataset_name dbpedia 105 | 106 | torchrun --nproc_per_node 8 eval_rerank.py \ 107 | --eval_data ${EVAL_DATA_PATH}/fever.bm25.100.jsonl \ 108 | --output_dir ${RESULT_PATH} \ 109 | --model_name_or_path ${MODEL_PATH} \ 110 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 111 | --dataset_cache_dir hf_cache/dataset/ \ 112 | --use_flash_attention_2 False \ 113 | --max_length 2048 \ 114 | --batch_size 4 \ 115 | --with_description True \ 116 | --dataset_name fever 117 | 118 | torchrun --nproc_per_node 8 eval_rerank.py \ 119 | --eval_data ${EVAL_DATA_PATH}/climate_fever.bm25.100.jsonl \ 120 | --output_dir ${RESULT_PATH} \ 121 | --model_name_or_path ${MODEL_PATH} \ 122 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 123 | --dataset_cache_dir hf_cache/dataset/ \ 124 | --use_flash_attention_2 False \ 125 | --max_length 2048 \ 126 | --batch_size 4 \ 127 | --with_description True \ 128 | --dataset_name climate_fever 129 | 130 | torchrun --nproc_per_node 8 eval_rerank.py \ 131 | --eval_data ${EVAL_DATA_PATH}/scifact.bm25.100.jsonl \ 132 | --output_dir ${RESULT_PATH} \ 133 | --model_name_or_path ${MODEL_PATH} \ 134 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 135 | --dataset_cache_dir hf_cache/dataset/ \ 136 | --use_flash_attention_2 False \ 137 | --max_length 2048 \ 138 | --batch_size 4 \ 139 | --with_description True \ 140 | --dataset_name scifact 141 | 142 | torchrun --nproc_per_node 8 eval_rerank.py \ 143 | --eval_data ${EVAL_DATA_PATH}/nq.bm25.100.jsonl \ 144 | --output_dir ${RESULT_PATH} \ 145 | --model_name_or_path ${MODEL_PATH} \ 146 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 147 | --dataset_cache_dir hf_cache/dataset/ \ 148 | --use_flash_attention_2 False \ 149 | --max_length 2048 \ 150 | --batch_size 4 \ 151 | --with_description True \ 152 | --dataset_name nq 153 | 154 | torchrun --nproc_per_node 8 eval_rerank.py \ 155 | --eval_data ${EVAL_DATA_PATH}/fiqa.bm25.100.jsonl \ 156 | --output_dir ${RESULT_PATH} \ 157 | --model_name_or_path ${MODEL_PATH} \ 158 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 159 | --dataset_cache_dir hf_cache/dataset/ \ 160 | --use_flash_attention_2 False \ 161 | --max_length 2048 \ 162 | --batch_size 4 \ 163 | --with_description True \ 164 | --dataset_name fiqa 165 | 166 | torchrun --nproc_per_node 8 eval_rerank.py \ 167 | --eval_data ${EVAL_DATA_PATH}/hotpot_qa.bm25.100.jsonl \ 168 | --output_dir ${RESULT_PATH} \ 169 | --model_name_or_path ${MODEL_PATH} \ 170 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 171 | --dataset_cache_dir hf_cache/dataset/ \ 172 | --use_flash_attention_2 False \ 173 | --max_length 2048 \ 174 | --batch_size 4 \ 175 | --with_description True \ 176 | --dataset_name hotpot_qa -------------------------------------------------------------------------------- /evaluation/qdu-tasks/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling import get_model_and_tokenizer, RerankResult 2 | from .args import ModelArgs 3 | from .utils import FileLogger, DatasetProcessFn, DefaultDataCollator, makedirs, split_file_dir_name_ext, mask_nested_lists 4 | from .metrics import Metric 5 | 6 | import logging 7 | logging.basicConfig( 8 | level=logging.INFO, 9 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 10 | datefmt="%m/%d/%Y %H:%M:%S", 11 | ) -------------------------------------------------------------------------------- /evaluation/qdu-tasks/src/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List, Union 4 | 5 | 6 | @dataclass 7 | class ModelArgs: 8 | model_cache_dir: Optional[str] = field( 9 | # default=None, 10 | default="/share/LMs", 11 | metadata={'help': 'Default path to save language models.'} 12 | ) 13 | dataset_cache_dir: Optional[str] = field( 14 | # default=None, 15 | default="/share/peitian/Data/Datasets/huggingface", 16 | metadata={'help': 'Default path to save huggingface datasets.'} 17 | ) 18 | eval_data: Optional[str] = field( 19 | default=None, 20 | metadata={'help': 'Evaluation json file.'}, 21 | ) 22 | 23 | model_name_or_path: str = field( 24 | default='meta-llama/Llama-2-7b-chat-hf', 25 | metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'} 26 | ) 27 | tokenizer_name_or_path: str = field( 28 | default='meta-llama/Llama-2-7b-chat-hf', 29 | metadata={'help': 'Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models'} 30 | ) 31 | padding_side: str = field( 32 | default="left", 33 | metadata={'help': 'Tokenizer padding side.'} 34 | ) 35 | access_token: Optional[str] = field( 36 | default=None, 37 | metadata={'help': 'Huggingface access token.'} 38 | ) 39 | max_length: int = field( 40 | default=2048, 41 | metadata={'help': 'How many tokens at maximum for each input?'}, 42 | ) 43 | 44 | lora: Optional[str] = field( 45 | default=None, 46 | metadata={'help': 'LoRA ID.'}, 47 | ) 48 | 49 | dtype: str = field( 50 | default="bf16", 51 | metadata={'help': 'Data type for embeddings.'} 52 | ) 53 | device_map: Optional[str] = field( 54 | default=None, 55 | metadata={'help': 'Device map for loading the model. Set to auto to load across devices.'} 56 | ) 57 | use_flash_attention_2: bool = field( 58 | default=True, 59 | metadata={'help': 'Use flash attention?'} 60 | ) 61 | cpu: bool = field( 62 | default=False, 63 | metadata={'help': 'Use cpu?'} 64 | ) 65 | 66 | metrics: List[str] = field( 67 | default_factory=lambda: [], 68 | metadata={'help': 'List of metrics. {rouge, acc}'} 69 | ) 70 | -------------------------------------------------------------------------------- /evaluation/qdu-tasks/src/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import inspect 5 | import numpy as np 6 | from tqdm import tqdm 7 | from typing import List, Dict, Tuple 8 | from .modeling import RerankResult 9 | from .utils import makedirs, split_file_dir_name_ext, normalize_text 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class Metric: 15 | """Class for computing metrics and some post-processings.""" 16 | @classmethod 17 | def get_metric_fn(cls, metric_names, **kwds): 18 | assert isinstance(metric_names, list) or isinstance(metric_names, tuple), "You must pass metric_names in a list or tuple!" 19 | all_metrics = {} 20 | # get all methods 21 | all_implemented_fns = [x[0] for x in inspect.getmembers(cls, predicate=inspect.isfunction) if not x[0].startswith("_")] 22 | 23 | def compute_metrics(*args, **kwargs): 24 | for metric_name in metric_names: 25 | # call corresponding method 26 | if metric_name in all_implemented_fns: 27 | metric_fn = getattr(cls, metric_name) 28 | metric = metric_fn(**kwds)(*args, **kwargs) 29 | # NOTE: some metric_fn are only used for post-processing and saving results, which return None by default 30 | if metric is not None: 31 | all_metrics.update(metric) 32 | else: 33 | raise NotImplementedError(f"Metric {metric_name} not implemented!") 34 | return all_metrics 35 | return compute_metrics 36 | 37 | @staticmethod 38 | def _get_save_path(eval_data, output_dir=None, field="result", save_name=None): 39 | """ 40 | if output_dir is None: 41 | -> {eval_data_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext} 42 | else: 43 | -> {output_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext} 44 | """ 45 | eval_data_dir, eval_data_name, eval_data_ext = split_file_dir_name_ext(eval_data) 46 | if output_dir is None: 47 | output_dir = eval_data_dir 48 | fields = [eval_data_name, field] 49 | if save_name is not None: 50 | fields.append(save_name) 51 | save_path = os.path.join(output_dir, ".".join(fields) + eval_data_ext) 52 | makedirs(save_path) 53 | return save_path 54 | 55 | @staticmethod 56 | def _save_generation_result(result, path, eval_data=None): 57 | if eval_data is not None: 58 | items = {} 59 | with open(eval_data, encoding="utf-8") as f: 60 | for i, line in enumerate(f): 61 | item = json.loads(line) 62 | if "query_id" in eval_data: 63 | index = item["query_id"] 64 | else: 65 | index = i 66 | items[index] = item 67 | with open(path, "w") as f: 68 | for index, pred in result.items(): 69 | res = {"query_id": index} 70 | if eval_data is not None: 71 | item = items[index] 72 | res.update(item) 73 | res["pred"] = pred 74 | f.write(json.dumps(res, ensure_ascii=False) + "\n") 75 | 76 | @staticmethod 77 | def _save_rerank_result(results: Dict[int, List[RerankResult]], path, eval_data=None): 78 | if eval_data is not None: 79 | items = {} 80 | with open(eval_data, encoding="utf-8") as f: 81 | for i, line in enumerate(f): 82 | full_item = json.loads(line) 83 | # we only save the index and score of positive samples 84 | item = { 85 | "pos_index": full_item["pos_index"], 86 | "pos_score": full_item["pos_score"], 87 | } 88 | if "query_id" in eval_data: 89 | index = full_item["query_id"] 90 | else: 91 | index = i 92 | items[index] = item 93 | 94 | with open(path, "w") as f: 95 | for index, preds in results.items(): 96 | item = { 97 | "query_id": index, 98 | "doc_index": [x.doc_id for x in preds], 99 | "doc_score": [x.doc_score for x in preds], 100 | } 101 | if eval_data is not None: 102 | data_item = items[index] 103 | item.update(data_item) 104 | f.write(json.dumps(item, ensure_ascii=False) + "\n") 105 | 106 | @staticmethod 107 | def _prepare_label_for_generation(eval_data): 108 | labels = {} 109 | with open(eval_data) as f: 110 | for i, line in enumerate(f): 111 | item = json.loads(line) 112 | if "query_id" in item: 113 | index = item["query_id"] 114 | else: 115 | index = i 116 | label = item["completion"] 117 | labels[index] = label 118 | return labels 119 | 120 | @staticmethod 121 | def _prepare_label_for_retrieval(eval_data): 122 | labels = {} 123 | with open(eval_data) as f: 124 | for i, line in enumerate(f): 125 | item = json.loads(line) 126 | if "query_id" in item: 127 | query_id = item["query_id"] 128 | else: 129 | query_id = i 130 | # save positive indices and their scores (for computing ndcg) 131 | labels[query_id] = (item["pos_index"], item["pos_score"]) 132 | return labels 133 | 134 | @staticmethod 135 | def _prepare_pred_for_retrieval(preds:List[RerankResult]) -> List[RerankResult]: 136 | valid_preds = [x for x in preds if x.doc_id > -1] 137 | return valid_preds 138 | 139 | @staticmethod 140 | def mrr(eval_data=None, cutoffs=[10], **kwds): 141 | if eval_data is not None: 142 | data_labels = Metric._prepare_label_for_retrieval(eval_data) 143 | 144 | def compute_metric(results:Dict[int, List[RerankResult]], labels:[Dict[int, Tuple[List[int], List[int]]]]=None, **kwargs): 145 | if labels is None: 146 | labels = data_labels 147 | 148 | mrrs = np.zeros(len(cutoffs)) 149 | counts = 0 150 | 151 | for query_id, preds in results.items(): 152 | pos_indices, pos_scores = labels[query_id] 153 | # remove irrelevant documents 154 | pos_indices = [pos_index for i, pos_index in enumerate(pos_indices) if pos_scores[i] > 0] 155 | if len(pos_indices) == 0: 156 | continue 157 | 158 | preds = Metric._prepare_pred_for_retrieval(preds) 159 | jump = False 160 | counts += 1 161 | 162 | for i, pred in enumerate(preds, 1): 163 | if pred.doc_id in pos_indices: 164 | for k, cutoff in enumerate(cutoffs): 165 | if i <= cutoff: 166 | mrrs[k] += 1 / i 167 | jump = True 168 | if jump: 169 | break 170 | 171 | mrrs /= counts 172 | 173 | metric = {} 174 | for i, cutoff in enumerate(cutoffs): 175 | mrr = mrrs[i] 176 | metric[f"mrr@{cutoff}"] = mrr 177 | 178 | return metric 179 | return compute_metric 180 | 181 | @staticmethod 182 | def recall(eval_data=None, cutoffs=[10], **kwds): 183 | if eval_data is not None: 184 | data_labels = Metric._prepare_label_for_retrieval(eval_data) 185 | 186 | def compute_metric(results:Dict[int, List[RerankResult]], labels:[Dict[int, Tuple[List[int], List[int]]]]=None, **kwargs): 187 | if labels is None: 188 | labels = data_labels 189 | 190 | recalls = np.zeros(len(cutoffs)) 191 | counts = 0 192 | 193 | for query_id, preds in results.items(): 194 | pos_indices, pos_scores = labels[query_id] 195 | # remove irrelevant documents 196 | pos_indices = [pos_index for i, pos_index in enumerate(pos_indices) if pos_scores[i] > 0] 197 | if len(pos_indices) == 0: 198 | continue 199 | 200 | preds = Metric._prepare_pred_for_retrieval(preds) 201 | preds_indices = [x.doc_id for x in preds] 202 | counts += 1 203 | 204 | for k, cutoff in enumerate(cutoffs): 205 | recall = np.intersect1d(pos_indices, preds_indices[:cutoff]) 206 | recalls[k] += len(recall) / len(pos_indices) 207 | 208 | recalls /= counts 209 | 210 | metric = {} 211 | for i, cutoff in enumerate(cutoffs): 212 | recall = recalls[i] 213 | metric[f"recall@{cutoff}"] = recall 214 | 215 | return metric 216 | return compute_metric 217 | 218 | @staticmethod 219 | def ndcg(eval_data=None, cutoffs=[10], **kwds): 220 | if eval_data is not None: 221 | data_labels = Metric._prepare_label_for_retrieval(eval_data) 222 | 223 | def compute_metric(results:Dict[int, List[RerankResult]], labels:[Dict[int, Tuple[List[int], List[int]]]]=None, **kwargs): 224 | if labels is None: 225 | labels = data_labels 226 | 227 | ndcgs = np.zeros(len(cutoffs)) 228 | counts = 0 229 | 230 | for query_id, preds in results.items(): 231 | pos_indices, pos_scores = labels[query_id] 232 | preds = Metric._prepare_pred_for_retrieval(preds) 233 | 234 | pos_indices_to_scores = {k: v for k, v in zip(pos_indices, pos_scores)} 235 | if len(pos_indices_to_scores) == 0: 236 | continue 237 | 238 | dcg = np.zeros(len(cutoffs)) 239 | idcg = np.zeros(len(cutoffs)) 240 | counts += 1 241 | 242 | for i, pred in enumerate(preds, 1): 243 | if pred.doc_id in pos_indices: 244 | for k, cutoff in enumerate(cutoffs): 245 | if i <= cutoff: 246 | # get the relevance score of the pred 247 | dcg[k] += (2 ** pos_indices_to_scores[pred.doc_id] - 1) / np.log2(i + 1) 248 | 249 | # descendingly sort positives to acquire the ideal ranking 250 | ideal_ranking = sorted(pos_scores, reverse=True) 251 | for j, y in enumerate(ideal_ranking, 1): 252 | for k, cutoff in enumerate(cutoffs): 253 | if j <= cutoff: 254 | idcg[k] += (2 ** y - 1) / np.log2(j + 1) 255 | 256 | ndcgs += dcg / idcg 257 | 258 | ndcgs /= counts 259 | 260 | metric = {} 261 | for i, cutoff in enumerate(cutoffs): 262 | ndcg = ndcgs[i] 263 | metric[f"ndcg@{cutoff}"] = ndcg 264 | return metric 265 | return compute_metric 266 | -------------------------------------------------------------------------------- /evaluation/qdu-tasks/src/modeling.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import copy 4 | import time 5 | import random 6 | import numpy as np 7 | import torch.nn as nn 8 | from transformers import ( 9 | AutoTokenizer, 10 | AutoConfig, 11 | AutoModelForCausalLM, 12 | AutoModelForSeq2SeqLM, 13 | logging 14 | ) 15 | from torch.utils.data import DataLoader 16 | from accelerate import Accelerator 17 | from typing import Mapping, Tuple, List, Optional 18 | from tqdm import tqdm 19 | from collections import defaultdict 20 | from dataclasses import dataclass, asdict 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | @dataclass 26 | class RerankResult: 27 | doc_id: int 28 | doc_score: int 29 | 30 | 31 | class LM(nn.Module): 32 | def __init__(self, model_name_or_path, tokenizer_name_or_path, padding_side="left", dtype="bf16", device_map=None, use_flash_attention_2=False, access_token=None, cache_dir="/share/LMs", accelerator: Accelerator=None) -> None: 33 | super().__init__() 34 | 35 | logger.info(f"loading tokenizer from {tokenizer_name_or_path}...") 36 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, cache_dir=cache_dir, padding_side=padding_side, trust_remote_code=True) 37 | if tokenizer.pad_token is None: 38 | if tokenizer.eos_token is None: 39 | pad_token = "<|endoftext|>" 40 | else: 41 | pad_token = tokenizer.eos_token 42 | tokenizer.pad_token = pad_token 43 | 44 | if dtype == "bf16": 45 | dtype = torch.bfloat16 46 | elif dtype == "fp16": 47 | dtype = torch.float16 48 | else: 49 | dtype = torch.float32 50 | 51 | if device_map is None: 52 | if accelerator is not None: 53 | device_map = {"": accelerator.device} 54 | else: 55 | device_map = {"": "cpu"} 56 | 57 | logger.info(f"loading model from {model_name_or_path}...") 58 | config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) 59 | if config.is_encoder_decoder: 60 | model = AutoModelForSeq2SeqLM.from_pretrained( 61 | model_name_or_path, 62 | cache_dir=cache_dir, 63 | torch_dtype=dtype, 64 | trust_remote_code=True, 65 | device_map=device_map, 66 | token=access_token, 67 | ) 68 | else: 69 | model = AutoModelForCausalLM.from_pretrained( 70 | model_name_or_path, 71 | cache_dir=cache_dir, 72 | torch_dtype=dtype, 73 | trust_remote_code=True, 74 | device_map=device_map, 75 | use_flash_attention_2=use_flash_attention_2, 76 | token=access_token, 77 | ) 78 | 79 | self.config = model.config 80 | self.tokenizer = tokenizer 81 | 82 | if accelerator is not None: 83 | self.model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True) 84 | else: 85 | self.model = model 86 | 87 | self.rng = np.random.default_rng(42) 88 | self.eval() 89 | 90 | @property 91 | def device(self): 92 | return self.model.device 93 | 94 | def prepare_inputs_for_generation(self, *args, **kwargs): 95 | return self.model.prepare_inputs_for_generation(*args, **kwargs) 96 | 97 | def _reorder_cache(self, *args, **kwargs): 98 | return self.model._reorder_cache(*args, **kwargs) 99 | 100 | def _move_to_device(self, data): 101 | """ 102 | Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. 103 | """ 104 | if isinstance(data, Mapping): 105 | return type(data)({k: self._move_to_device(v) for k, v in data.items()}) 106 | elif isinstance(data, (tuple, list)): 107 | return type(data)(self._move_to_device(v) for v in data) 108 | elif isinstance(data, torch.Tensor): 109 | kwargs = {"device": self.device} 110 | return data.to(**kwargs) 111 | else: 112 | return data 113 | 114 | def forward(self, *args, **kwargs): 115 | return self.model(*args, **kwargs) 116 | 117 | @torch.no_grad() 118 | def generate(self, return_new_tokens_only=True, decode=True, accelerator:Optional[Accelerator]=None, **inputs): 119 | outputs = self.model.generate(**inputs) 120 | 121 | if return_new_tokens_only: 122 | if self.model.config.is_encoder_decoder: 123 | if "decoder_input_ids" in inputs: 124 | start_idx = inputs["decoder_input_ids"].shape[1] + 1 125 | else: 126 | start_idx = 1 127 | else: 128 | start_idx = inputs["input_ids"].shape[1] 129 | outputs = outputs[:, start_idx:] 130 | 131 | if accelerator is not None: 132 | # must be contiguous 133 | outputs = outputs.contiguous() 134 | outputs = accelerator.pad_across_processes(outputs, pad_index=self.tokenizer.pad_token_id, dim=1) 135 | outputs = accelerator.gather_for_metrics(outputs) 136 | 137 | outputs = outputs.tolist() 138 | if decode: 139 | outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 140 | return outputs 141 | 142 | @torch.no_grad() 143 | def rerank_pointwise(self, dataloader, accelerator): 144 | """ 145 | Args: 146 | dataloader: return a batch of input_ids and attention_mask, each of which is a query-doc pair. 147 | """ 148 | if accelerator is not None and type(dataloader) is DataLoader: 149 | dataloader = accelerator.prepare(dataloader) 150 | 151 | is_encoder_decoder = self.model.config.is_encoder_decoder 152 | yes_id = self.tokenizer.encode("Yes", add_special_tokens=False)[0] 153 | no_id = self.tokenizer.encode("No", add_special_tokens=False)[0] 154 | 155 | rerank_results = defaultdict(list) 156 | for i, x in enumerate(tqdm(dataloader, desc="Pointwise Reranking", ncols=120)): 157 | query_ids = x.pop("query_id") 158 | doc_ids = x.pop("doc_id") 159 | 160 | if is_encoder_decoder: 161 | raise NotImplementedError 162 | 163 | else: 164 | logits = self.model(**x).logits[:, -1] # batch_size, vocab_size 165 | 166 | yes_and_no_logits = logits[:, [yes_id, no_id]] # batch_size, 2 167 | # NOTE: normalize so that different documents are comparable 168 | yes_and_no_logits = torch.softmax(yes_and_no_logits, dim=1) 169 | doc_scores = yes_and_no_logits[:, 0] # batch_size 170 | 171 | # gather outputs across devices 172 | if accelerator is not None: 173 | query_ids = accelerator.gather_for_metrics(query_ids) 174 | doc_ids = accelerator.gather_for_metrics(doc_ids) 175 | doc_scores = accelerator.gather_for_metrics(doc_scores) 176 | 177 | for query_id, doc_id, doc_score in zip(query_ids.tolist(), doc_ids.tolist(), doc_scores.tolist()): 178 | rerank_result = RerankResult(doc_id=doc_id, doc_score=doc_score) 179 | rerank_results[query_id].append(rerank_result) 180 | 181 | # sort candidates of each query 182 | for query_id, res in rerank_results.items(): 183 | sorted_res = sorted(res, key=lambda x: x.doc_score, reverse=True) 184 | rerank_results[query_id] = sorted_res 185 | 186 | return dict(rerank_results) 187 | 188 | def compare(self, query: str, docs: List, prompt_template: str, fewshot_prompt: Optional[str]=None): 189 | doc1, doc2 = docs[0], docs[1] 190 | input_texts = [prompt_template.format(query=query, doc1=doc1, doc2=doc2) + "[1]", prompt_template.format(query=query, doc1=doc1, doc2=doc2) + "[2]", 191 | prompt_template.format(query=query, doc1=doc2, doc2=doc1) + "[1]", prompt_template.format(query=query, doc1=doc2, doc2=doc1) + "[2]"] 192 | 193 | # NOTE: add fewshot prompt 194 | if fewshot_prompt is not None: 195 | input_texts = [fewshot_prompt + x for x in input_texts] 196 | 197 | inputs = self.tokenizer(input_texts, return_tensors="pt") 198 | #print("input: ", inputs.device()) 199 | 200 | target_texts = ["[1]", "[2]", "[1]", "[2]"] 201 | targets = self.tokenizer(target_texts, add_special_tokens=False)["input_ids"] 202 | targets_length = [len(x) for x in targets] 203 | 204 | labels = inputs["input_ids"].clone() 205 | for i, label in enumerate(labels): 206 | labels[i, :-targets_length[i]] = -100 207 | # inputs["labels"] = labels 208 | inputs = inputs.to(self.device) 209 | 210 | outputs = self.model(**inputs) 211 | logits = outputs["logits"].detach() 212 | logits = logits[:, :-1, :].contiguous() 213 | labels = labels[:, 1:].contiguous() 214 | labels = labels.to(logits.device) 215 | batch_size = logits.shape[0] 216 | token_loss = torch.nn.functional.cross_entropy( 217 | logits.flatten(0, 1), 218 | labels.reshape(-1), 219 | reduction="none" 220 | ).reshape(batch_size, -1) 221 | 222 | valid_token_num = (labels != -100).sum(-1) 223 | 224 | token_prob = - token_loss.sum(-1) / valid_token_num 225 | 226 | if token_prob[0] > token_prob[1] and token_prob[2] < token_prob[3]: 227 | return f'change' 228 | return f'not change' 229 | 230 | @torch.no_grad() 231 | def rerank_pairwise(self, dataloader, prompt_template, accelerator): 232 | rerank_results = defaultdict(list) 233 | if accelerator is not None and type(dataloader) is DataLoader: 234 | dataloader = accelerator.prepare(dataloader) 235 | 236 | for _, ranking_result in enumerate(tqdm(dataloader, desc="Pairwise Reranking", ncols=120)): 237 | k = ranking_result["doc_ids"].size(dim=1) 238 | last_end = k - 1 239 | query = ranking_result["query"] 240 | batch_size = ranking_result["doc_ids"].size(dim=0) 241 | 242 | ranking_result["doc_ids"] = ranking_result["doc_ids"].tolist() 243 | 244 | fewshot_prompts = ranking_result["fewshot_prompt"] 245 | 246 | for i in range(batch_size): 247 | # NOTE: convert doc_ids to list 248 | pairs = list(zip(ranking_result["docs"][i], ranking_result["doc_ids"][i])) 249 | self.rng.shuffle(pairs) 250 | 251 | shuffled_docs, shuffled_doc_ids = zip(*pairs) 252 | shuffled_docs = list(shuffled_docs) 253 | shuffled_doc_ids = list(shuffled_doc_ids) 254 | ranking_result["docs"][i] = shuffled_docs 255 | ranking_result["doc_ids"][i] = shuffled_doc_ids 256 | 257 | if fewshot_prompts is not None: 258 | fewshot_prompt = fewshot_prompts[i] 259 | else: 260 | fewshot_prompt = None 261 | 262 | for j in range(k): 263 | current_ind = last_end 264 | is_change = False 265 | while True: 266 | if current_ind <= j: 267 | break 268 | doc1 = ranking_result["docs"][i][current_ind] 269 | doc2 = ranking_result["docs"][i][current_ind - 1] 270 | output = self.compare(query[i], [doc1, doc2], prompt_template=prompt_template, fewshot_prompt=fewshot_prompt) 271 | if output == 'change': 272 | ranking_result["docs"][i][current_ind - 1], ranking_result["docs"][i][current_ind] = ranking_result["docs"][i][current_ind], ranking_result["docs"][i][current_ind - 1] 273 | ranking_result["doc_ids"][i][current_ind - 1], ranking_result["doc_ids"][i][current_ind] = ranking_result["doc_ids"][i][current_ind], ranking_result["doc_ids"][i][current_ind - 1] 274 | if not is_change: 275 | is_change = True 276 | if last_end != k - 1: # skip unchanged pairs at the bottom 277 | last_end += 1 278 | if not is_change: 279 | last_end -= 1 280 | current_ind -= 1 281 | query_ids = ranking_result.pop("query_id") 282 | doc_ids = torch.tensor(ranking_result.pop("doc_ids"), device=self.device) 283 | if accelerator is not None: 284 | query_ids = accelerator.gather_for_metrics(query_ids) 285 | doc_ids = accelerator.gather_for_metrics(doc_ids) 286 | for query_id, doc_id in zip(query_ids.tolist(), doc_ids.tolist()): 287 | for doc_id_i in doc_id: 288 | ranking_result = RerankResult(doc_id=doc_id_i, doc_score=0) 289 | rerank_results[query_id].append(ranking_result) 290 | return dict(rerank_results) 291 | 292 | def permutation_pipeline(self, item=None, rank_start=0, rank_end=100, prompt_template=None, window_size=5): 293 | f_query = item["query"] 294 | num = len(item['hits'][rank_start: rank_end]) 295 | fewshot_prompt = item["fewshot_prompt"] 296 | 297 | rank = 0 298 | docs = "" 299 | for hit in item['hits'][rank_start: rank_end]: 300 | rank += 1 301 | if isinstance(hit['document'], str): 302 | document = hit['document'].strip() 303 | else: 304 | print(item['hits']) 305 | raise ValueError("document should be a string") 306 | docs += f"[{rank}] " + document + "\n" 307 | messages = prompt_template.format(query=f_query, num=num, docs=docs) 308 | # messages = prompt_template.format(query=query, num=rank, docs=docs) 309 | if fewshot_prompt is not None: 310 | messages = fewshot_prompt + messages 311 | 312 | input_ids = self.tokenizer(messages, return_tensors="pt", padding='longest', truncation=False).input_ids.to(self.model.device) 313 | output_ids = self.model.generate(input_ids, 314 | do_sample=False, 315 | temperature=0.0, 316 | top_p=None, 317 | max_new_tokens=500, 318 | pad_token_id=self.tokenizer.eos_token_id,) 319 | permutation = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=False) 320 | 321 | response = [int(match.group(1)) -1 for match in re.finditer(r"(\d+)", permutation)] 322 | response = [x for x in response if 0<= x < window_size] 323 | if len(response) == 0: 324 | return item 325 | 326 | new_response = [] 327 | for x in response: 328 | if x not in new_response: 329 | new_response.append(x) 330 | response = new_response 331 | 332 | if len(response) < window_size: 333 | full_set = set(range(0, window_size)) 334 | missing_number = full_set - set(response) 335 | response.extend(list(missing_number)) 336 | 337 | candidates = copy.deepcopy(item['hits'][rank_start: rank_end]) 338 | rerank_candidates = [candidates[x] for x in response] 339 | item['hits'][rank_start: rank_end] = rerank_candidates 340 | 341 | return item 342 | 343 | def sliding_windows(self, item=None, prompt_template=None, rank_start=0, rank_end=100, window_size=5, step=3): 344 | item = copy.deepcopy(item) 345 | end_pos = rank_end 346 | start_pos = rank_end - window_size 347 | while start_pos >= rank_start: 348 | start_pos = max(start_pos, rank_start) 349 | item = self.permutation_pipeline(item, start_pos, end_pos, prompt_template, window_size) 350 | end_pos = end_pos - step 351 | start_pos = start_pos - step 352 | return item 353 | 354 | @torch.no_grad() 355 | def rerank_listwise(self, dataloader, prompt_template, accelerator, window, stride): 356 | rerank_results = defaultdict(list) 357 | if accelerator is not None and type(dataloader) is DataLoader: 358 | dataloader = accelerator.prepare(dataloader) 359 | 360 | for _, ranking_data in enumerate(tqdm(dataloader, desc="Listwise Reranking", ncols=120)): 361 | batch_size = ranking_data["doc_ids"].size(dim=0) 362 | items = [{} for _ in range(batch_size)] 363 | doc_start = [0] * batch_size 364 | doc_end = [] 365 | for i in range(batch_size): 366 | items[i]["query"] = ranking_data["query"][i] 367 | items[i]["query_id"] = ranking_data["query_id"][i].item() 368 | items[i]["fewshot_prompt"] = ranking_data["fewshot_prompt"][i] 369 | items[i]["hits"] = [{"document": doc, "docid": docid.item()} 370 | for doc, docid in zip(ranking_data["docs"][i], ranking_data["doc_ids"][i])] 371 | self.rng.shuffle(items[i]["hits"]) 372 | doc_end.append(len(items[i]["hits"])) 373 | prompt_templates = [prompt_template] * batch_size 374 | windows = [window] * batch_size 375 | strides = [stride] * batch_size 376 | items = list(map(self.sliding_windows, items, prompt_templates, doc_start, doc_end, windows, strides)) 377 | query_ids = ranking_data.pop("query_id") 378 | doc_ids = torch.tensor([[hit["docid"] for hit in item["hits"]] for item in items], device=self.device) 379 | 380 | if accelerator is not None: 381 | query_ids = accelerator.gather_for_metrics(query_ids) 382 | doc_ids = accelerator.gather_for_metrics(doc_ids) 383 | for query_id, doc_id in zip(query_ids.tolist(), doc_ids.tolist()): 384 | for doc_id_i in doc_id: 385 | ranking_result = RerankResult(doc_id=doc_id_i, doc_score=0) 386 | rerank_results[query_id].append(ranking_result) 387 | 388 | return dict(rerank_results) 389 | 390 | 391 | def get_model_and_tokenizer(model_args, accelerator=None, **kwargs): 392 | """Load model and tokenizer. Possibly load LoRA for the model.""" 393 | 394 | from .args import ModelArgs 395 | model_args: ModelArgs 396 | 397 | model_args = asdict(model_args) 398 | model_args.update(**kwargs) 399 | 400 | model = LM( 401 | model_name_or_path=model_args["model_name_or_path"], 402 | tokenizer_name_or_path=model_args["tokenizer_name_or_path"], 403 | padding_side=model_args["padding_side"], 404 | dtype=model_args["dtype"], 405 | cache_dir=model_args["model_cache_dir"], 406 | device_map=model_args["device_map"], 407 | use_flash_attention_2=model_args["use_flash_attention_2"], 408 | access_token=model_args["access_token"], 409 | accelerator=accelerator 410 | ) 411 | 412 | # load lora 413 | if model_args["lora"] is not None: 414 | from peft import PeftModel 415 | logger.info(f"loading lora from {model_args['lora']}...") 416 | model = PeftModel.from_pretrained(model, model_args["lora"]) 417 | model = model.merge_and_unload() 418 | 419 | return model, model.tokenizer 420 | -------------------------------------------------------------------------------- /evaluation/qdu-tasks/src/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pytz 3 | import torch 4 | import pathlib 5 | import json 6 | import string 7 | import numpy as np 8 | from datetime import datetime 9 | from dataclasses import dataclass 10 | from collections import defaultdict 11 | from typing import List, Any, Dict 12 | 13 | 14 | def makedirs(path): 15 | p = pathlib.Path(path) 16 | p.parent.mkdir(parents=True, exist_ok=True) 17 | return path 18 | 19 | 20 | def split_file_dir_name_ext(path): 21 | """Return the directory, name, and extension of a given file.""" 22 | p = pathlib.Path(path) 23 | assert p.is_file() 24 | return p.parent, p.stem, p.suffix 25 | 26 | 27 | def get_max_length_in_nested_lists(lst): 28 | if len(lst) and isinstance(lst[0], list): 29 | lengths = [] 30 | for elem in lst: 31 | length = get_max_length_in_nested_lists(elem) 32 | lengths.append(length) 33 | max_length = max(lengths) 34 | return max_length 35 | else: 36 | return len(lst) 37 | 38 | 39 | def pad_nested_lists(lst, max_length, padding_value, padding_side="right"): 40 | if isinstance(lst, list) and len(lst) and isinstance(lst[0], list): 41 | masks = [] 42 | for i, elem in enumerate(lst): 43 | lst[i], mask = pad_nested_lists(elem, max_length, padding_value, padding_side) 44 | masks.append(mask) 45 | return lst, masks 46 | elif isinstance(lst, list): 47 | if padding_side == "right": 48 | mask = [1] * len(lst) + [0] * (max_length - len(lst)) 49 | lst = lst + [padding_value for _ in range(max_length - len(lst))] 50 | return lst, mask 51 | else: 52 | mask = [0] * (max_length - len(lst)) + [1] * len(lst) 53 | lst = [padding_value for _ in range(max_length - len(lst))] + lst 54 | return lst, mask 55 | else: 56 | raise NotImplementedError(f"Unrecognized type {lst}") 57 | 58 | def mask_nested_lists(lst, mask_target, mask_value=0): 59 | if isinstance(lst[0], list): 60 | for i, elem in enumerate(lst): 61 | lst[i] = mask_nested_lists(elem, mask_target, mask_value) 62 | return lst 63 | else: 64 | return [x if x != mask_target else mask_value for x in lst] 65 | 66 | def are_elements_of_same_length(lst: List): 67 | if not isinstance(lst[0], list): 68 | return False 69 | 70 | length = len(lst[0]) 71 | return all(len(x) == length if isinstance(x, list) else False for x in lst) 72 | 73 | 74 | def normalize_text(text, ignore_case=True, ignore_punctuation=True, ignore_space=True, ignore_number=False): 75 | if isinstance(text, str): 76 | text = [text] 77 | unpack = True 78 | else: 79 | unpack = False 80 | if ignore_case: 81 | text = np.char.lower(text) 82 | if ignore_punctuation: 83 | repl_table = string.punctuation.maketrans("", "", string.punctuation) 84 | text = np.char.translate(text, table=repl_table) 85 | if ignore_number: 86 | repl_table = string.digits.maketrans("", "", string.digits) 87 | text = np.char.translate(text, table=repl_table) 88 | if ignore_space: 89 | for i, words in enumerate(np.char.split(text)): 90 | text[i] = " ".join(words) 91 | if isinstance(text, np.ndarray): 92 | text = text.tolist() 93 | if unpack: 94 | text = text[0] 95 | return text 96 | 97 | 98 | class DatasetProcessFn: 99 | """Wrapper for any user-defined process function for huggingface datasets. 100 | 101 | 1. Process batched examples by looping the process function over them; 102 | 2. Gather returned examples if any data augmentation happens with augment=True; 103 | 3. Pass indices of examples inside the process function with _index keywords if they exist. 104 | 105 | The wrapped function should take in any needed columns and return a dict with 1 or more samples. 106 | """ 107 | def __init__(self, augment=False): 108 | self.augment = augment 109 | 110 | def __call__(self, _process_fn): 111 | def process(*args): 112 | sample_or_batch_sample = args[0] 113 | if len(args) == 1: 114 | pass 115 | elif len(args) == 2: 116 | indices = args[1] 117 | # detach the slice so that _index will not be set in the original data 118 | sample_or_batch_sample = sample_or_batch_sample.copy() 119 | sample_or_batch_sample["_index"] = indices 120 | else: 121 | raise NotImplementedError(f"Found more than 2 arguments {args}!") 122 | 123 | keys = list(sample_or_batch_sample.keys()) 124 | func_args = [sample_or_batch_sample[k] for k in keys] 125 | 126 | # FIXME: if all values in one sample are of the same length, this would fail 127 | if are_elements_of_same_length(func_args): 128 | outputs = defaultdict(list) 129 | for arg in zip(*func_args): 130 | # get each element in a batch 131 | kwargs = {keys[j]: arg[j] for j in range(len(arg))} 132 | output = _process_fn(**kwargs) 133 | if output is not None: 134 | for k, v in output.items(): 135 | if self.augment: 136 | outputs[k].extend(v) 137 | else: 138 | outputs[k].append(v) 139 | else: 140 | outputs = _process_fn(**sample_or_batch_sample) 141 | if outputs is None: 142 | raise ValueError(f"Found None returned from process_fn. Make sure you set 'batched=True' when trying to augment/distract samples in the datasets!") 143 | return dict(outputs) 144 | return process 145 | 146 | 147 | @dataclass 148 | class DefaultDataCollator: 149 | """ 150 | Data collator that can: 151 | 1. Dynamically pad all inputs received. The inputs must be dict of lists. 152 | 2. Add position_ids based on attention_mask if required. 153 | """ 154 | tokenizer: Any = None 155 | attention_padding_value: int = 0 156 | label_padding_value: int = -100 157 | add_position_ids: bool = False 158 | 159 | def __call__(self, batch_elem: List) -> Dict[str, Any]: 160 | first_elem = batch_elem[0] 161 | return_batch = {} 162 | 163 | for key, value in first_elem.items(): 164 | # HACK: any key containing attention_mask must be attention_mask 165 | # important to assign different pad token for different types of inputs 166 | if "attention_mask" in key: 167 | pad_token_id = self.attention_padding_value 168 | elif "label" in key: 169 | pad_token_id = self.label_padding_value 170 | else: 171 | pad_token_id = self.tokenizer.pad_token_id 172 | 173 | batch_value = [elem[key] for elem in batch_elem] 174 | # pad all lists and nested lists 175 | if isinstance(value, list): 176 | max_length = get_max_length_in_nested_lists(batch_value) 177 | batch_value, _ = pad_nested_lists(batch_value, max_length, pad_token_id, self.tokenizer.padding_side) 178 | 179 | try: 180 | return_batch[key] = torch.tensor(batch_value) 181 | except: 182 | # handle strings and None 183 | return_batch[key] = batch_value 184 | 185 | if "attention_mask" in key and self.add_position_ids: 186 | value = return_batch[key] 187 | position_ids = value.cumsum(-1) - 1 188 | position_ids = position_ids.masked_fill(value == 0, 0) 189 | return_batch[key.replace("attention_mask", "position_ids")] = position_ids 190 | return return_batch 191 | 192 | 193 | class FileLogger: 194 | def __init__(self, log_file) -> None: 195 | self.log_file = log_file 196 | 197 | def log(self, metrics, **kwargs): 198 | with open(self.log_file, "a+") as f: 199 | # get current time 200 | tz = pytz.timezone('Asia/Shanghai') 201 | time = f"{'Time': <10}: {json.dumps(datetime.now(tz).strftime('%Y-%m-%d, %H:%M:%S'), ensure_ascii=False)}\n" 202 | print(time) 203 | command = f"{'Command': <10}: {json.dumps(' '.join(sys.argv), ensure_ascii=False)}\n" 204 | print(command) 205 | metrics = f"{'Metrics': <10}: {json.dumps(metrics, ensure_ascii=False)}\n" 206 | msg = time + command 207 | 208 | for key, value in kwargs.items(): 209 | x = f"{key: <10}: {json.dumps(value, ensure_ascii=False)}\n" 210 | print(x) 211 | msg += x 212 | msg += metrics 213 | print(metrics) 214 | f.write(str(msg) + "\n") 215 | -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/eval_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Sampler 3 | import torch.distributed as dist 4 | 5 | 6 | class DistributedEvalSampler(Sampler): 7 | r""" 8 | DistributedEvalSampler is different from DistributedSampler. 9 | It does NOT add extra samples to make it evenly divisible. 10 | DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. 11 | See this issue for details: https://github.com/pytorch/pytorch/issues/22584 12 | shuffle is disabled by default 13 | DistributedEvalSampler is for evaluation purpose where synchronization does not happen every epoch. 14 | Synchronization should be done outside the dataloader loop. 15 | Sampler that restricts data loading to a subset of the dataset. 16 | It is especially useful in conjunction with 17 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 18 | process can pass a :class`~torch.utils.data.DistributedSampler` instance as a 19 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 20 | original dataset that is exclusive to it. 21 | .. note:: 22 | Dataset is assumed to be of constant size. 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (int, optional): Number of processes participating in 26 | distributed training. By default, :attr:`rank` is retrieved from the 27 | current distributed group. 28 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 29 | By default, :attr:`rank` is retrieved from the current distributed 30 | group. 31 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 32 | indices. 33 | seed (int, optional): random seed used to shuffle the sampler if 34 | :attr:`shuffle=True`. This number should be identical across all 35 | processes in the distributed group. Default: ``0``. 36 | .. warning:: 37 | In distributed mode, calling the :meth`set_epoch(epoch) ` method at 38 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 39 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 40 | the same ordering will be always used. 41 | Example:: 42 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 43 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 44 | ... sampler=sampler) 45 | >>> for epoch in range(start_epoch, n_epochs): 46 | ... if is_distributed: 47 | ... sampler.set_epoch(epoch) 48 | ... train(loader) 49 | """ 50 | 51 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, seed=0): 52 | if num_replicas is None: 53 | if not dist.is_available(): 54 | raise RuntimeError("Requires distributed package to be available") 55 | num_replicas = dist.get_world_size() 56 | if rank is None: 57 | if not dist.is_available(): 58 | raise RuntimeError("Requires distributed package to be available") 59 | rank = dist.get_rank() 60 | self.dataset = dataset 61 | self.num_replicas = num_replicas 62 | self.rank = rank 63 | self.epoch = 0 64 | # self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 65 | # self.total_size = self.num_samples * self.num_replicas 66 | self.total_size = len(self.dataset) # true value without extra samples 67 | indices = list(range(self.total_size)) 68 | indices = indices[self.rank:self.total_size:self.num_replicas] 69 | self.num_samples = len(indices) # true value without extra samples 70 | 71 | self.shuffle = shuffle 72 | self.seed = seed 73 | 74 | def __iter__(self): 75 | if self.shuffle: 76 | # deterministically shuffle based on epoch and seed 77 | g = torch.Generator() 78 | g.manual_seed(self.seed + self.epoch) 79 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 80 | else: 81 | indices = list(range(len(self.dataset))) 82 | 83 | 84 | # # add extra samples to make it evenly divisible 85 | # indices += indices[:(self.total_size - len(indices))] 86 | # assert len(indices) == self.total_size 87 | 88 | # subsample 89 | indices = indices[self.rank:self.total_size:self.num_replicas] 90 | assert len(indices) == self.num_samples 91 | 92 | return iter(indices) 93 | 94 | def __len__(self): 95 | return self.num_samples 96 | 97 | def set_epoch(self, epoch): 98 | r""" 99 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 100 | use a different random ordering for each epoch. Otherwise, the next iteration of this 101 | sampler will yield the same ordering. 102 | Arguments: 103 | epoch (int): _epoch number. 104 | """ 105 | self.epoch = epoch -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_dataset.py: -------------------------------------------------------------------------------- 1 | import linecache 2 | import json 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class InferenceDataset(Dataset): 8 | def __init__(self, filename, tokenizer, max_input_length): 9 | super(InferenceDataset, self).__init__() 10 | self._filename = filename 11 | self._tokenizer = tokenizer 12 | self._max_input_length = max_input_length 13 | with open(filename, "r", encoding="utf-8") as f: 14 | self._total_data = len(f.readlines()) 15 | 16 | def __getitem__(self, idx): 17 | line = linecache.getline(self._filename, idx + 1) 18 | sample = json.loads(line) 19 | 20 | source = sample["prompt"] 21 | source_encode = self._tokenizer(source, padding="max_length", max_length=self._max_input_length, truncation=True) 22 | 23 | batch = { 24 | "input_ids": np.asarray(source_encode.input_ids), 25 | "attention_mask": np.asarray(source_encode.attention_mask), 26 | "input": sample["prompt"], 27 | "label": sample["completion"] 28 | } 29 | 30 | return batch 31 | 32 | def __len__(self): 33 | return self._total_data -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_qu_du.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import numpy as np 5 | import torch 6 | import random 7 | import torch.multiprocessing as mp 8 | import torch.distributed as dist 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | from eval_sampling import DistributedEvalSampler 11 | from inference_dataset import InferenceDataset 12 | from torch.utils.data import DataLoader 13 | from inference_tasks.query_description import QueryDescription 14 | from inference_tasks.query_expansion import QueryExpansion 15 | from inference_tasks.query_suggestion import QuerySuggestion 16 | from inference_tasks.query_reformulation import QueryReformulation 17 | from inference_tasks.query_clarification import QueryClarification 18 | from inference_tasks.query_matching import QueryMatching 19 | from inference_tasks.summarization import Summarization 20 | from inference_tasks.fact_verification import FactVerification 21 | from inference_tasks.query_intent_classification import QueryIntentClassification 22 | from inference_tasks.reading_comprehension import ReadingComprehension 23 | from inference_tasks.query_subtopic_generation import QuerySubtopicGeneration 24 | from inference_tasks.conversational_qa import ConversationalQA 25 | import torch.nn.functional as F 26 | from datetime import datetime 27 | from tqdm import tqdm 28 | from tqdm.contrib import tzip 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--model_name_or_path", default="model/", type=str) 32 | parser.add_argument("--tokenizer_name", default="model/llama-2-7b-hf", type=str) 33 | parser.add_argument("--setting", default="in-domain", type=str) 34 | parser.add_argument("--n_shots", default="zero_shot", type=str) 35 | parser.add_argument("--max_input_len", default=1792, type=int) 36 | parser.add_argument("--max_output_len", default=256, type=int) 37 | parser.add_argument("--seed", default=0, type=int) 38 | parser.add_argument('--nodes', type=int, default=1) 39 | parser.add_argument('--gpus', type=int,default=-1, help='num gpus per node') 40 | parser.add_argument('--nr', type=int,default=0, help='ranking within the nodes') 41 | args = parser.parse_args() 42 | 43 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, legacy=True) 44 | tokenizer.pad_token = tokenizer.eos_token 45 | tokenizer.padding_side = "left" 46 | tokenizer.truncation_side = "left" 47 | 48 | def set_seed(seed=args.seed): 49 | random.seed(seed) 50 | os.environ['PYTHONHASHSEED'] = str(seed) 51 | np.random.seed(seed) 52 | torch.manual_seed(seed) 53 | torch.cuda.manual_seed(seed) 54 | torch.cuda.manual_seed_all(seed) 55 | torch.backends.cudnn.deterministic = True 56 | torch.backends.cudnn.benchmark = True 57 | 58 | def test_model(local_gpu_rank, args, tasks, label_lists, save_result=False): 59 | set_seed(args.seed) 60 | args.rank = args.nr * args.gpus + local_gpu_rank 61 | torch.cuda.set_device(local_gpu_rank) 62 | dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.rank) 63 | 64 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, use_cache=True, torch_dtype=torch.bfloat16) 65 | args.device = torch.device("cuda", local_gpu_rank) 66 | model.to(args.device) 67 | model.eval() 68 | 69 | current_time = datetime.now() 70 | time_string = current_time.strftime("%Y-%m-%d-%H-%M-%S") 71 | 72 | if args.rank == 0: 73 | task_result = [] 74 | fw = open(f"result/{time_string}.csv", "w", encoding="utf-8") 75 | fw.write(f"Model,{args.model_name_or_path}\n") 76 | for task, label_list in tzip(tasks, label_lists, ncols=120, desc="# Tasks"): 77 | test_data = task.get_path() 78 | 79 | test_dataset = InferenceDataset(test_data, tokenizer, args.max_input_len) 80 | test_sampler = DistributedEvalSampler(test_dataset, shuffle=False) 81 | test_dataloader = DataLoader(test_dataset, batch_size=args.test_batch_size, num_workers=2, sampler=test_sampler) 82 | 83 | all_decode_result = [] 84 | all_labels = [] 85 | count = 0 86 | with torch.no_grad(): 87 | if label_list == []: 88 | if args.rank == 0: 89 | test_dataloader = tqdm(test_dataloader, ncols=120, leave=False) 90 | for test_data in test_dataloader: 91 | outputs = model.generate( 92 | input_ids=test_data["input_ids"].to(args.device), 93 | attention_mask=test_data["attention_mask"].to(args.device), 94 | max_length=args.max_input_len + args.max_output_len, 95 | do_sample=False, 96 | pad_token_id=tokenizer.eos_token_id, 97 | ) 98 | if outputs.size(1) < args.max_input_len + args.max_output_len: 99 | batch_pred_padding = torch.ones((outputs.size(0), args.max_input_len + args.max_output_len - outputs.size(1)), dtype=outputs.dtype).cuda() * 2 100 | outputs = torch.cat([outputs, batch_pred_padding], dim=1) 101 | 102 | batch_out_sentences = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False) 103 | batch_output = [] 104 | for idx, output in enumerate(batch_out_sentences): 105 | output = output[len(test_data["input"][idx]):] 106 | batch_output.append(output.strip()) 107 | all_decode_result.extend(batch_output) 108 | all_labels.extend(test_data["label"]) 109 | if args.rank == 0: 110 | count += len(batch_output) * args.world_size 111 | else: 112 | if args.rank == 0: 113 | test_dataloader = tqdm(test_dataloader, ncols=120, leave=False) 114 | for test_data in test_dataloader: 115 | input_text = test_data["input"] # List[string] len=batch_size 116 | gold_label = test_data["label"] # List[string] len=batch_size 117 | 118 | new_reqs = [] 119 | for i in input_text: 120 | for l in label_list: 121 | context_enc = tokenizer.encode(i) 122 | continuation_enc = tokenizer.encode(i + l)[len(context_enc):] 123 | key = (i, l) 124 | value = (context_enc, continuation_enc) 125 | new_reqs.append((key, value)) 126 | 127 | input_texts = [x[0][0] + x[0][1] for x in new_reqs] 128 | inputs = tokenizer( 129 | input_texts, 130 | return_tensors='pt', 131 | padding="longest", 132 | max_length=args.max_input_len + args.max_output_len, 133 | truncation=True, 134 | return_attention_mask=True, 135 | ) 136 | 137 | logits = model(inputs['input_ids'].to(model.device), attention_mask=inputs['attention_mask'].to(model.device)) 138 | logits = F.log_softmax(logits[0], dim=-1).cpu() 139 | 140 | all_result = [] 141 | for one_req, one_logits in zip(new_reqs, logits): 142 | key, value = one_req 143 | _, cont_toks = value 144 | 145 | # Slice to original seq length 146 | contlen = len(cont_toks) 147 | one_logits = one_logits[-contlen-1 : -1].unsqueeze(0) # [1, seq, vocab] 148 | 149 | # Check if per-token argmax is exactly equal to continuation 150 | greedy_tokens = one_logits.argmax(dim=-1) 151 | cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] 152 | if len(greedy_tokens[0]) != len(cont_toks[0]): 153 | # 超出最大长度限制 154 | answer = -float('inf') 155 | else: 156 | # Obtain log-probs at the corresponding continuation token indices 157 | # last_token_slice = logits[:, -1, :].squeeze(0).tolist() 158 | one_logits = torch.gather(one_logits, dim=2, index=cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] 159 | 160 | # Answer: (log prob, is-exact-match) 161 | answer = float(one_logits.sum()) 162 | all_result.append(answer) 163 | all_result = np.asarray(all_result) 164 | all_result = all_result.reshape(-1, len(label_list)) # bsz, 2 165 | gold_label_list = [x.split(", ") for x in gold_label] 166 | gold_idx = [[label_list.index(x) for x in y] for y in gold_label_list] 167 | all_result = np.argmax(all_result, axis=1) 168 | 169 | all_decode_result.extend(all_result) 170 | all_labels.extend(gold_idx) 171 | if args.rank == 0: 172 | count += len(all_result) * args.world_size 173 | 174 | dist.barrier() 175 | gather_data = [None for _ in range(args.world_size)] 176 | gather_label = [None for _ in range(args.world_size)] 177 | dist.all_gather_object(gather_data, all_decode_result) 178 | dist.all_gather_object(gather_label, all_labels) 179 | if args.rank == 0: 180 | preds = [] 181 | labels = [] 182 | for j in range(len(gather_data[0])): 183 | for i in range(len(gather_data)): 184 | if j < len(gather_data[i]): 185 | prediction = gather_data[i][j] 186 | label = gather_label[i][j] 187 | preds.append(prediction) 188 | labels.append(label) 189 | if save_result: 190 | current_time = datetime.now() 191 | save_time_string = current_time.strftime("%Y-%m-%d-%H-%M-%S") 192 | with open(f"result/{save_time_string}.{task._cluster}_{task._name}.result.txt", "w", encoding="utf-8") as fw: 193 | for p, l in zip(preds, labels): 194 | fw.write(json.dumps({"pred": p, "label": l}) + "\n") 195 | result = task.compute_metrics(preds, labels) 196 | tqdm.write(f"{task._cluster}_{task._name}:\t{json.dumps(result)}") 197 | task_result.append((f"{task._cluster}_{task._name}", result)) 198 | for k in result.keys(): 199 | fw.write(f"{task._cluster}_{task._name}, {k}, {result[k]}\n") 200 | fw.flush() 201 | if args.rank == 0: 202 | fw.close() 203 | 204 | if __name__ == '__main__': 205 | if args.gpus < 0: 206 | args.gpus = torch.cuda.device_count() 207 | args.world_size = args.nodes * args.gpus 208 | os.environ['MASTER_ADDR']='localhost' 209 | os.environ['MASTER_PORT']='8889' 210 | args.test_batch_size = 2 211 | tasks = [] 212 | label_lists = [] 213 | 214 | # query description 215 | task = QueryDescription(name="gov2", shot=args.n_shots, setting=args.setting) 216 | tasks.append(task) 217 | label_lists.append([]) 218 | task = QueryDescription(name="trec_robust", shot=args.n_shots, setting=args.setting) 219 | tasks.append(task) 220 | label_lists.append([]) 221 | task = QueryDescription(name="trec_covid", shot=args.n_shots, setting=args.setting) 222 | tasks.append(task) 223 | label_lists.append([]) 224 | task = QueryDescription(name="fire", shot=args.n_shots, setting=args.setting) 225 | tasks.append(task) 226 | label_lists.append([]) 227 | 228 | # query expansion 229 | task = QueryExpansion(name="gov2", shot=args.n_shots, setting=args.setting) 230 | tasks.append(task) 231 | label_lists.append([]) 232 | task = QueryExpansion(name="trec_robust", shot=args.n_shots, setting=args.setting) 233 | tasks.append(task) 234 | label_lists.append([]) 235 | task = QueryExpansion(name="trec_covid", shot=args.n_shots, setting=args.setting) 236 | tasks.append(task) 237 | label_lists.append([]) 238 | task = QueryExpansion(name="fire", shot=args.n_shots, setting=args.setting) 239 | tasks.append(task) 240 | label_lists.append([]) 241 | task = QueryExpansion(name="query2doc", shot=args.n_shots, setting=args.setting) 242 | tasks.append(task) 243 | label_lists.append([]) 244 | task = QueryExpansion(name="trec_cast", shot=args.n_shots, setting=args.setting) 245 | tasks.append(task) 246 | label_lists.append([]) 247 | task = QueryExpansion(name="trec_web", shot=args.n_shots, setting=args.setting) 248 | tasks.append(task) 249 | label_lists.append([]) 250 | 251 | # query reformulation 252 | task = QueryReformulation(name="codec", shot=args.n_shots, setting=args.setting) 253 | tasks.append(task) 254 | label_lists.append([]) 255 | task = QueryReformulation(name="qrecc", shot=args.n_shots, setting=args.setting) 256 | tasks.append(task) 257 | label_lists.append([]) 258 | task = QueryReformulation(name="canard", shot=args.n_shots, setting=args.setting) 259 | tasks.append(task) 260 | label_lists.append([]) 261 | task = QueryReformulation(name="trec_cast", shot=args.n_shots, setting=args.setting) 262 | tasks.append(task) 263 | label_lists.append([]) 264 | task = QueryReformulation(name="gecor", shot=args.n_shots, setting=args.setting) 265 | tasks.append(task) 266 | label_lists.append([]) 267 | 268 | # query clarification 269 | task = QueryClarification(name="mimics", shot=args.n_shots, setting=args.setting) 270 | tasks.append(task) 271 | label_lists.append([]) 272 | task = QueryClarification(name="mimics_duo", shot=args.n_shots, setting=args.setting) 273 | tasks.append(task) 274 | label_lists.append([]) 275 | task = QueryClarification(name="clariq_fkw", shot=args.n_shots, setting=args.setting) 276 | tasks.append(task) 277 | label_lists.append([]) 278 | task = QueryClarification(name="raocq", shot=args.n_shots, setting=args.setting) 279 | tasks.append(task) 280 | label_lists.append([]) 281 | 282 | # query subtopic generation 283 | task = QuerySubtopicGeneration(name="trec_web", shot=args.n_shots, setting=args.setting) 284 | tasks.append(task) 285 | label_lists.append([]) 286 | 287 | # query suggestion 288 | task = QuerySuggestion(name="aol", shot=args.n_shots, setting=args.setting) 289 | tasks.append(task) 290 | label_lists.append([]) 291 | 292 | # query matching 293 | task = QueryMatching(name="msrp", shot=args.n_shots, setting=args.setting) 294 | tasks.append(task) 295 | label_lists.append(["yes", "no"]) 296 | 297 | # query intent classification 298 | task = QueryIntentClassification(name="mantis", shot=args.n_shots, setting=args.setting, multi_label=True) 299 | tasks.append(task) 300 | label_lists.append(["original question", "further details", "other", "information request", "potential answer", "positive feedback", "negative feedback", "greetings / gratitude", "follow up question"]) 301 | 302 | task = QueryIntentClassification(name="orcas_i", shot=args.n_shots, setting=args.setting) 303 | tasks.append(task) 304 | label_lists.append(["factual", "abstain", "instrumental", "transactional", "navigational"]) 305 | task = QueryIntentClassification(name="trec_web", shot=args.n_shots, setting=args.setting) 306 | tasks.append(task) 307 | label_lists.append(["faceted", "ambiguous", "navigational", "informational"]) 308 | 309 | # fact verification 310 | task = FactVerification(name="fever", shot=args.n_shots, setting=args.setting) 311 | tasks.append(task) 312 | label_lists.append(["support", "refute"]) 313 | task = FactVerification(name="climate_fever", shot=args.n_shots, setting=args.setting) 314 | tasks.append(task) 315 | label_lists.append(["support", "refute", "disputed", "not enough information"]) 316 | task = FactVerification(name="scifact", shot=args.n_shots, setting=args.setting) 317 | tasks.append(task) 318 | label_lists.append(["support", "refute"]) 319 | 320 | # conversational qa 321 | task = ConversationalQA(name="coqa", shot=args.n_shots, setting=args.setting) 322 | tasks.append(task) 323 | label_lists.append([]) 324 | task = ConversationalQA(name="quac", shot=args.n_shots, setting=args.setting) 325 | tasks.append(task) 326 | label_lists.append([]) 327 | 328 | # summarization 329 | task = Summarization(name="cnndm", shot=args.n_shots, setting=args.setting) 330 | tasks.append(task) 331 | label_lists.append([]) 332 | task = Summarization(name="xsum", shot=args.n_shots, setting=args.setting) 333 | tasks.append(task) 334 | label_lists.append([]) 335 | task = Summarization(name="wikisum", shot=args.n_shots, setting=args.setting) 336 | tasks.append(task) 337 | label_lists.append([]) 338 | task = Summarization(name="multinews", shot=args.n_shots, setting=args.setting) 339 | tasks.append(task) 340 | label_lists.append([]) 341 | 342 | # reading comprehension 343 | task = ReadingComprehension(name="squad", shot=args.n_shots, setting=args.setting) 344 | tasks.append(task) 345 | label_lists.append([]) 346 | task = ReadingComprehension(name="hotpot_qa", shot=args.n_shots, setting=args.setting) 347 | tasks.append(task) 348 | label_lists.append([]) 349 | task = ReadingComprehension(name="ms_marco", shot=args.n_shots, setting=args.setting) 350 | tasks.append(task) 351 | label_lists.append([]) 352 | task = ReadingComprehension(name="boolq", shot=args.n_shots, setting=args.setting) 353 | tasks.append(task) 354 | label_lists.append(["true", "false"]) 355 | task = ReadingComprehension(name="webglm_qa", shot=args.n_shots, setting=args.setting) 356 | tasks.append(task) 357 | label_lists.append([]) 358 | task = ReadingComprehension(name="trivia_qa", shot=args.n_shots, setting=args.setting) 359 | tasks.append(task) 360 | label_lists.append([]) 361 | 362 | assert len(tasks) == len(label_lists) 363 | mp.spawn(test_model, nprocs=args.gpus, args=(args, tasks, label_lists, False)) 364 | -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/conversational_qa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | class ConversationalQA: 5 | def __init__(self, name, shot, setting): 6 | self._cluster = "conversational_qa" 7 | self._name = name 8 | self._shot = shot 9 | self._setting = setting 10 | 11 | def get_path(self): 12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 13 | 14 | def compute_metrics(self, preds, labels): 15 | def normalize_results(text): 16 | pattern = r"(\d+\.|\[\d+\]|\(\d+\))\s*(.*?)\s*(?=\d\.|\[\d\]|\(\d\)|$)" 17 | extracted_text = re.findall(pattern, text) 18 | return extracted_text 19 | 20 | all_acc = [] 21 | for prediction, label in zip(preds, labels): 22 | extracted_answer = normalize_results(label) 23 | extracted_prediction = normalize_results(prediction) 24 | if extracted_answer == []: 25 | extracted_answer = label 26 | answer_dict = {0: extracted_answer} 27 | else: 28 | answer_dict = {match[0]: match[1] for match in extracted_answer} 29 | if extracted_prediction == []: 30 | extracted_prediction = prediction 31 | extracted_dict = {0: extracted_prediction} 32 | else: 33 | extracted_dict = {match[0]: match[1] for match in extracted_prediction} 34 | 35 | acc = 0 36 | for k in extracted_dict.keys(): 37 | if k not in answer_dict: 38 | continue 39 | if extracted_dict[k] == answer_dict[k]: 40 | acc += 1 41 | acc = acc / len(answer_dict.keys()) 42 | all_acc.append(acc) 43 | all_acc = np.asarray(all_acc) 44 | results = { 45 | "Acc": np.mean(all_acc) 46 | } 47 | return results -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/fact_verification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import f1_score, accuracy_score 3 | 4 | class FactVerification(): 5 | def __init__(self, name, shot, setting): 6 | self._cluster = "fact_verification" 7 | self._name = name 8 | self._shot = shot 9 | self._setting = setting 10 | 11 | def get_path(self): 12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 13 | 14 | def compute_metrics(self, preds, labels): 15 | labels = np.asarray([x[0] for x in labels]) 16 | acc = accuracy_score(labels, preds) 17 | f1 = f1_score(labels, preds, average='weighted') 18 | 19 | results = { 20 | "Acc": acc, 21 | "F1": f1 22 | } 23 | 24 | return results 25 | 26 | -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/query_clarification.py: -------------------------------------------------------------------------------- 1 | from rouge_score import rouge_scorer 2 | from nltk.translate import bleu_score 3 | from nltk.translate.bleu_score import corpus_bleu 4 | import numpy as np 5 | import re 6 | 7 | class QueryClarification: 8 | def __init__(self, name, shot, setting): 9 | self._cluster = "query_clarification" 10 | self._name = name 11 | self._shot = shot 12 | self._setting = setting 13 | 14 | def get_path(self): 15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 16 | 17 | def compute_metrics(self, preds, labels): 18 | def normalize_results(text): 19 | pattern = r"(?:\d\.|\[\d\]|\(\d\))\s*(.*?)\s*(?=\d\.|\[\d\]|\(\d\)|$)" 20 | extracted_text = re.findall(pattern, text) 21 | return extracted_text 22 | 23 | if self._name == "clariq_fkw" or self._name == "raocq": 24 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) 25 | all_rouge_l = [] 26 | split_labels = [[x.split()] for x in labels] 27 | split_preds = [x.split() for x in preds] 28 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 29 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 30 | for prediction, label in zip(preds, labels): 31 | scores = result_rouge_scorer.score(label, prediction) 32 | rougel = scores['rougeL'].fmeasure 33 | all_rouge_l.append(rougel) 34 | avg_rouge_l = np.mean(np.asarray(all_rouge_l)) 35 | results = { 36 | "BLEU-1": bleu1, 37 | "BLEU-2": bleu2, 38 | "ROUGE-L": avg_rouge_l, 39 | } 40 | return results 41 | else: 42 | all_p, all_r, all_f1 = [], [], [] 43 | for prediction, label in zip(preds, labels): 44 | label = label.split(" <=SEP=> ") 45 | extracted_prediction = normalize_results(prediction) 46 | if extracted_prediction == []: 47 | extracted_prediction = prediction 48 | TP = len(set(extracted_prediction).intersection(label)) 49 | FP = len(set(extracted_prediction) - set(label)) 50 | FN = len(set(label) - set(extracted_prediction)) 51 | precision = TP / (TP + FP) if (TP + FP) != 0 else 0 52 | recall = TP / (TP + FN) if (TP + FN) != 0 else 0 53 | f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0 54 | all_p.append(precision) 55 | all_r.append(recall) 56 | all_f1.append(f1) 57 | all_p = np.asarray(all_p) 58 | all_r = np.asarray(all_r) 59 | all_f1 = np.asarray(all_f1) 60 | results = { 61 | "precision": np.sum(all_p) / len(all_p), 62 | "recall": np.sum(all_r) / len(all_r), 63 | "f1": np.sum(all_f1) / len(all_f1), 64 | } 65 | return results -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/query_description.py: -------------------------------------------------------------------------------- 1 | from rouge_score import rouge_scorer 2 | import numpy as np 3 | from nltk.translate.bleu_score import corpus_bleu 4 | from nltk.translate import bleu_score 5 | 6 | class QueryDescription(): 7 | def __init__(self, name, shot, setting): 8 | self._cluster = "query_description" 9 | self._name = name 10 | self._shot = shot 11 | self._setting = setting 12 | 13 | 14 | def get_path(self): 15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 16 | 17 | def compute_metrics(self, preds, labels): 18 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) 19 | all_rouge_l = [] 20 | split_labels = [[x.split()] for x in labels] 21 | split_preds = [x.split() for x in preds] 22 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 23 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 24 | for prediction, label in zip(preds, labels): 25 | scores = result_rouge_scorer.score(label, prediction) 26 | rougel = scores['rougeL'].fmeasure 27 | all_rouge_l.append(rougel) 28 | avg_rouge_l = np.mean(np.asarray(all_rouge_l)) 29 | results = { 30 | "BLEU-1": bleu1, 31 | "BLEU-2": bleu2, 32 | "ROUGE-L": avg_rouge_l, 33 | } 34 | return results -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/query_expansion.py: -------------------------------------------------------------------------------- 1 | from rouge_score import rouge_scorer 2 | import numpy as np 3 | from nltk.translate.bleu_score import corpus_bleu 4 | from nltk.translate import bleu_score 5 | 6 | class QueryExpansion(): 7 | def __init__(self, name, shot, setting): 8 | self._cluster = "query_expansion" 9 | self._name = name 10 | self._shot = shot 11 | self._setting = setting 12 | 13 | 14 | def get_path(self): 15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 16 | 17 | def compute_metrics(self, preds, labels): 18 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) 19 | all_rouge_l = [] 20 | split_labels = [[x.split()] for x in labels] 21 | split_preds = [x.split() for x in preds] 22 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 23 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 24 | for prediction, label in zip(preds, labels): 25 | scores = result_rouge_scorer.score(label, prediction) 26 | rougel = scores['rougeL'].fmeasure 27 | all_rouge_l.append(rougel) 28 | avg_rouge_l = np.mean(np.asarray(all_rouge_l)) 29 | results = { 30 | "BLEU-1": bleu1, 31 | "BLEU-2": bleu2, 32 | "ROUGE-L": avg_rouge_l, 33 | } 34 | return results -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/query_intent_classification.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import f1_score, accuracy_score 2 | 3 | class QueryIntentClassification(): 4 | def __init__(self, name, shot, setting, multi_label=False): 5 | self._cluster = "query_intent_classification" 6 | self._name = name 7 | self._shot = shot 8 | self._setting = setting 9 | self._multi_label = multi_label 10 | 11 | def get_path(self): 12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 13 | 14 | def compute_metrics(self, preds, labels): 15 | # accuracy = np.sum(np.asarray(preds) == np.asarray(labels)) / len(preds) 16 | if self._multi_label: 17 | p = 0 18 | for pred, label in zip(preds, labels): 19 | if pred in label: 20 | p += 1 21 | results = { 22 | "P@1": p / len(preds) 23 | } 24 | else: 25 | acc = accuracy_score(labels, preds) 26 | f1 = f1_score(labels, preds, average='weighted') 27 | results = { 28 | "Acc": acc, 29 | "F1": f1 30 | } 31 | 32 | return results 33 | -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/query_matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import f1_score, accuracy_score 3 | 4 | class QueryMatching(): 5 | def __init__(self, name, shot, setting): 6 | self._cluster = "query_matching" 7 | self._name = name 8 | self._shot = shot 9 | self._setting = setting 10 | 11 | def get_path(self): 12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 13 | 14 | def compute_metrics(self, preds, labels): 15 | # accuracy = np.sum(np.asarray(preds) == np.asarray(labels)) / len(preds) 16 | labels = np.asarray([x[0] for x in labels]) 17 | acc = accuracy_score(labels, preds) 18 | f1 = f1_score(labels, preds, average='weighted') 19 | 20 | results = { 21 | "Acc": acc, 22 | "F1": f1 23 | } 24 | 25 | return results 26 | 27 | -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/query_reformulation.py: -------------------------------------------------------------------------------- 1 | from rouge_score import rouge_scorer 2 | import numpy as np 3 | from nltk.translate.bleu_score import corpus_bleu 4 | from nltk.translate import bleu_score 5 | import re 6 | 7 | class QueryReformulation(): 8 | def __init__(self, name, shot, setting): 9 | self._cluster = "query_reformulation" 10 | self._name = name 11 | self._shot = shot 12 | self._setting = setting 13 | 14 | def get_path(self): 15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 16 | 17 | def compute_metrics(self, preds, labels): 18 | def normalize_results(text): 19 | pattern = r"(?:\d\.|\[\d\]|\(\d\))\s*(.*?)\s*(?=\d\.|\[\d\]|\(\d\)|$)" 20 | extracted_text = re.findall(pattern, text) 21 | return extracted_text 22 | 23 | if self._name == "codec": 24 | all_p, all_r, all_f1 = [], [], [] 25 | for prediction, label in zip(preds, labels): 26 | label = normalize_results(label) 27 | extracted_prediction = normalize_results(prediction) 28 | if extracted_prediction == []: 29 | extracted_prediction = prediction 30 | TP = len(set(extracted_prediction).intersection(label)) 31 | FP = len(set(extracted_prediction) - set(label)) 32 | FN = len(set(label) - set(extracted_prediction)) 33 | precision = TP / (TP + FP) if (TP + FP) != 0 else 0 34 | recall = TP / (TP + FN) if (TP + FN) != 0 else 0 35 | f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0 36 | all_p.append(precision) 37 | all_r.append(recall) 38 | all_f1.append(f1) 39 | all_p = np.asarray(all_p) 40 | all_r = np.asarray(all_r) 41 | all_f1 = np.asarray(all_f1) 42 | results = { 43 | "precision": np.mean(all_p), 44 | "recall": np.mean(all_r), 45 | "f1": np.mean(all_f1), 46 | } 47 | else: 48 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) 49 | all_rouge_l = [] 50 | 51 | split_labels = [[x.split()] for x in labels] 52 | split_preds = [x.split() for x in preds] 53 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 54 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 55 | 56 | for prediction, label in zip(preds, labels): 57 | scores = result_rouge_scorer.score(label, prediction) 58 | rougel = scores['rougeL'].fmeasure 59 | all_rouge_l.append(rougel) 60 | avg_rouge_l = np.mean(np.asarray(all_rouge_l)) 61 | results = { 62 | "BLEU-1": bleu1, 63 | "BLEU-2": bleu2, 64 | "ROUGE-L": avg_rouge_l, 65 | } 66 | return results -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/query_subtopic_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | class QuerySubtopicGeneration(): 5 | def __init__(self, name, shot, setting): 6 | self._cluster = "query_subtopic_generation" 7 | self._name = name 8 | self._shot = shot 9 | self._setting = setting 10 | 11 | def get_path(self): 12 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 13 | 14 | def compute_metrics(self, preds, labels): 15 | def normalize_results(text): 16 | pattern = r"(?:\d\.|\[\d\]|\(\d\))\s*(.*?)\s*(?=\d\.|\[\d\]|\(\d\)|$)" 17 | extracted_text = re.findall(pattern, text) 18 | return extracted_text 19 | 20 | all_p, all_r, all_f1 = [], [], [] 21 | for prediction, label in zip(preds, labels): 22 | label = normalize_results(label) 23 | extracted_prediction = normalize_results(prediction) 24 | if extracted_prediction == []: 25 | extracted_prediction = prediction 26 | TP = len(set(extracted_prediction).intersection(label)) 27 | FP = len(set(extracted_prediction) - set(label)) 28 | FN = len(set(label) - set(extracted_prediction)) 29 | precision = TP / (TP + FP) if (TP + FP) != 0 else 0 30 | recall = TP / (TP + FN) if (TP + FN) != 0 else 0 31 | f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0 32 | all_p.append(precision) 33 | all_r.append(recall) 34 | all_f1.append(f1) 35 | all_p = np.asarray(all_p) 36 | all_r = np.asarray(all_r) 37 | all_f1 = np.asarray(all_f1) 38 | results = { 39 | "precision": np.sum(all_p) / len(all_p), 40 | "recall": np.sum(all_r) / len(all_r), 41 | "f1": np.sum(all_f1) / len(all_f1), 42 | } 43 | return results -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/query_suggestion.py: -------------------------------------------------------------------------------- 1 | from rouge_score import rouge_scorer 2 | import numpy as np 3 | from nltk.translate.bleu_score import corpus_bleu 4 | from nltk.translate import bleu_score 5 | 6 | class QuerySuggestion(): 7 | def __init__(self, name, shot, setting): 8 | self._cluster = "query_suggestion" 9 | self._name = name 10 | self._shot = shot 11 | self._setting = setting 12 | 13 | 14 | def get_path(self): 15 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 16 | 17 | def compute_metrics(self, preds, labels): 18 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) 19 | all_rouge_l = [] 20 | split_labels = [[x.split()] for x in labels] 21 | split_preds = [x.split() for x in preds] 22 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 23 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 24 | for prediction, label in zip(preds, labels): 25 | scores = result_rouge_scorer.score(label, prediction) 26 | rougel = scores['rougeL'].fmeasure 27 | all_rouge_l.append(rougel) 28 | avg_rouge_l = np.mean(np.asarray(all_rouge_l)) 29 | results = { 30 | "BLEU-1": bleu1, 31 | "BLEU-2": bleu2, 32 | "ROUGE-L": avg_rouge_l, 33 | } 34 | return results -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/reading_comprehension.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import string 4 | from collections import Counter 5 | from typing import List 6 | import numpy as np 7 | from rouge_score import rouge_scorer 8 | from sklearn.metrics import f1_score, accuracy_score 9 | from nltk.translate.bleu_score import corpus_bleu 10 | from nltk.translate import bleu_score 11 | 12 | 13 | class ReadingComprehension(): 14 | def __init__(self, name, shot, setting): 15 | self._cluster = "reading_comprehension" 16 | self._name = name 17 | self._shot = shot 18 | self._setting = setting 19 | 20 | def get_path(self): 21 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 22 | 23 | def compute_metrics(self, preds, labels): 24 | 25 | def normalize_text(s): 26 | 27 | def remove_articles(text): 28 | return re.sub(r'\b(a|an|the)\b', ' ', text) 29 | 30 | def white_space_fix(text): 31 | return ' '.join(text.split()) 32 | 33 | def remove_punc(text): 34 | exclude = set(string.punctuation) 35 | return ''.join(ch for ch in text if ch not in exclude) 36 | 37 | def lower(text): 38 | return text.lower() 39 | 40 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 41 | 42 | 43 | def calc_exact_match(text: str, answers: List[str]) -> bool: 44 | """Check if prediction is exactly the same as any of the answers.""" 45 | norm_text = normalize_text(text) 46 | norm_answers = [normalize_text(ans) for ans in answers] 47 | return max([(norm_text == norm_ans) for norm_ans in norm_answers]) 48 | 49 | def calc_soft_exact_match(text: str, answers: List[str]) -> bool: 50 | norm_text = normalize_text(text) 51 | norm_answers = [normalize_text(ans) for ans in answers] 52 | return max([(norm_ans in norm_text) for norm_ans in norm_answers]) 53 | 54 | def calc_unigram_f1(text: str, answers: List[str]) -> float: 55 | """Calculate unigram f1 score between the text and reference answers.""" 56 | norm_pred = normalize_text(text) 57 | norm_answers = [normalize_text(ans) for ans in answers] 58 | common_tokens = [ 59 | Counter(norm_pred) & Counter(norm_ans) for norm_ans in norm_answers 60 | ] 61 | num_same = [sum(common.values()) for common in common_tokens] 62 | 63 | score_list = [] 64 | for i, num in enumerate(num_same): 65 | if num == 0: 66 | score_list.append(0.0) 67 | else: 68 | p = 1.0 * num / len(norm_pred) 69 | r = 1.0 * num / len(norm_answers[i]) 70 | f1 = 2 * p * r / (p + r) 71 | score_list.append(f1) 72 | return max(score_list) 73 | 74 | if self._name == "boolq": 75 | # accuracy = np.sum(np.asarray(preds) == np.asarray(labels)) / len(preds) 76 | labels = np.asarray([x[0] for x in labels]) 77 | acc = accuracy_score(labels, preds) 78 | f1 = f1_score(labels, preds, average='weighted') 79 | 80 | results = { 81 | "Acc": acc, 82 | "F1": f1 83 | } 84 | 85 | elif self._name == "webglm_qa": 86 | result_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) 87 | all_rouge_l = [] 88 | split_labels = [[x.split()] for x in labels] 89 | split_preds = [x.split() for x in preds] 90 | bleu1 = corpus_bleu(split_labels, split_preds, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 91 | bleu2 = corpus_bleu(split_labels, split_preds, weights=(0.5, 0.5, 0, 0), smoothing_function=bleu_score.SmoothingFunction().method4) 92 | for prediction, label in zip(preds, labels): 93 | scores = result_rouge_scorer.score(label, prediction) 94 | rougel = scores['rougeL'].fmeasure 95 | all_rouge_l.append(rougel) 96 | avg_rouge_l = np.mean(np.asarray(all_rouge_l)) 97 | results = { 98 | "BLEU-1": bleu1, 99 | "BLEU-2": bleu2, 100 | "ROUGE-L": avg_rouge_l, 101 | } 102 | return results 103 | else: 104 | em_scores = [calc_exact_match(pred, label.split(" <=SEP=> ")) for pred, label in zip(preds, labels)] 105 | f1_scores = [calc_unigram_f1(pred, label.split(" <=SEP=> ")) for pred, label in zip(preds, labels)] 106 | soft_em_scors = [calc_soft_exact_match(pred, label.split(" <=SEP=> ")) for pred, label in zip(preds, labels)] 107 | 108 | results = { 109 | # "EM": sum(em_scores) / len(em_scores), 110 | "F1": sum(f1_scores) / len(f1_scores), 111 | # "Soft-EM": sum(soft_em_scors) / len(soft_em_scors), 112 | } 113 | 114 | return results -------------------------------------------------------------------------------- /evaluation/qu-du-tasks/inference_tasks/summarization.py: -------------------------------------------------------------------------------- 1 | from rouge_score import rouge_scorer 2 | import numpy as np 3 | 4 | class Summarization(): 5 | def __init__(self, name, shot, setting): 6 | self._cluster = "summarization" 7 | self._name = name 8 | self._shot = shot 9 | self._setting = setting 10 | 11 | 12 | def get_path(self): 13 | return f"data/{self._setting}/{self._shot}/{self._cluster}_{self._name}.{self._shot}.test.jsonl" 14 | 15 | def compute_metrics(self, preds, labels): 16 | result_rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) 17 | all_rouge_l, all_rouge_1, all_rouge_2 = [], [], [] 18 | for prediction, label in zip(preds, labels): 19 | scores = result_rouge_scorer.score(label, prediction) 20 | rouge1 = scores['rouge1'].fmeasure 21 | rouge2 = scores['rouge2'].fmeasure 22 | rougel = scores['rougeL'].fmeasure 23 | all_rouge_1.append(rouge1) 24 | all_rouge_2.append(rouge2) 25 | all_rouge_l.append(rougel) 26 | avg_rouge_l = np.mean(np.asarray(all_rouge_l)) 27 | avg_rouge_1 = np.mean(np.asarray(all_rouge_1)) 28 | avg_rouge_2 = np.mean(np.asarray(all_rouge_2)) 29 | results = { 30 | "ROUGE-1": avg_rouge_1, 31 | "ROUGE-2": avg_rouge_2, 32 | "ROUGE-L": avg_rouge_l, 33 | } 34 | return results -------------------------------------------------------------------------------- /evaluation/readme.md: -------------------------------------------------------------------------------- 1 | ## Required packages 2 | ``` 3 | torch 2.0.0 4 | transformers 4.36.2 5 | numpy 1.26.3 6 | tqdm 4.66.1 7 | scikit-learn 1.4.0 8 | rouge_score 0.1.2 9 | nltk 3.8.1 10 | accelerate 0.26.1 11 | ``` 12 | 13 | ## For query understanding tasks and document understanding tasks (qu-du-tasks) 14 | This evaluation script use pytorch DDP for text generation. 15 | 16 | 1. Download [test data](https://huggingface.co/datasets/yutaozhu94/INTERS/tree/main/test-qu-du-zero-shot) and save it to ``data/in-domain/zero_shot/``. The directory structure is like below: 17 | ``` 18 | qu-du-tasks 19 | ├── eval_sampling.py 20 | ├── inference_dataset.py 21 | ├── inference_qu_du.py 22 | ├── inference_tasks 23 | │ ├── conversational_qa.py 24 | │ ├── fact_verification.py 25 | │ └── ... 26 | └── data 27 | └── in-domain 28 | └── zero-shot 29 | ├── conversational_qa_coqa.zero_shot.test.jsonl 30 | ├── conversational_qa_quac.zero_shot.test.jsonl 31 | ├── fact_verification_climate_fever.zero_shot.test.jsonl 32 | ├── fact_verification_fever.zero_shot.test.jsonl 33 | ├── fact_verification_scifact.zero_shot.test.jsonl 34 | └── ... 35 | ``` 36 | 2. If you choose to place the test files in other directories, you can modify the path in each task file under ``inference_tasks`` directory (in ``get_path()`` function). 37 | 38 | 3. Run evaluation as 39 | ``` 40 | TOKENIZERS_PARALLELISM=True python3 inference_qu_du.py \ 41 | --model_name_or_path your/model/path \ 42 | --tokenizer_name your/tokenizer/path \ 43 | --setting in-domain \ 44 | --n_shots zero_shot 45 | ``` 46 | 47 | ## For query-document relationship understanding tasks (qdu-tasks) 48 | 1. Download [test data](https://huggingface.co/datasets/yutaozhu94/INTERS/tree/main/test-qdu) and save it to ``data/``. The directory structure is like below: 49 | ``` 50 | qdu-tasks 51 | ├── cqa.sh 52 | ├── eval_rank.py 53 | ├── postprocess_cqa.py 54 | ├── run_eval.sh 55 | └── data 56 | ├── cqadupstack 57 | │ ├── android 58 | │ │ └── test.pt.key.do-not-overwrite.json 59 | │ ├── english 60 | │ │ └── test.pt.key.do-not-overwrite.json 61 | │ └── ... 62 | ├── arguana.bm25.100.jsonl 63 | ├── climate_fever.bm25.100.jsonl 64 | └── ... 65 | ``` 66 | 1. For datasets other than cqadupstack, modify the paths in ``run_eval.sh``, then run the script 67 | ``` 68 | MODEL_PATH="your/model/path" 69 | TOKENIZER_PATH="your/tokenizer/path" 70 | RESULT_PATH="your/result/path" 71 | EVAL_DATA_PATH="data" 72 | 73 | ----------------------- 74 | bash run_eval.sh 75 | ``` 76 | 2. For cqadupstack dataset, modify the paths in ``cqa.sh``, then run the script 77 | ``` 78 | MODEL_PATH="your/model/path" 79 | TOKENIZER_PATH="your/tokenizer/path" 80 | RESULT_PATH="your/result/path" 81 | 82 | ----------------------- 83 | bash cqa.sh 84 | ``` 85 | 3. This script supports testing pointwise/pairwise/listwise methods for reranking. Modify the parameter of ``eval_rerank.py`` in ``run_eval.sh`` or ``cqa.sh`` 86 | ``` 87 | # pointwise: (default) 88 | --rerank_method pointwise 89 | 90 | # pairwise: 91 | --rerank_method pairwise 92 | 93 | # listwise: 94 | --rerank_method listwise \ 95 | --listwise_window 5 \ 96 | --listwise_stride 5 97 | ``` 98 | -------------------------------------------------------------------------------- /img/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/dataset.png -------------------------------------------------------------------------------- /img/in-domain-google.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/in-domain-google.png -------------------------------------------------------------------------------- /img/intro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/intro.jpg -------------------------------------------------------------------------------- /img/logo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/logo1.jpg -------------------------------------------------------------------------------- /img/process.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoD/INTERS/fe00b00d8714be6245a5f3770ffb331e0f8e4f66/img/process.jpg --------------------------------------------------------------------------------