├── .flake8 ├── .github └── workflows │ └── makefile.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── benchmarks ├── ALCE │ ├── ASQA │ │ ├── README.md │ │ ├── asqa_benchmark.py │ │ ├── generate.py │ │ ├── prompts │ │ │ └── asqa_prompt.json │ │ ├── results │ │ │ ├── asqa_dpr_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240427.json │ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_snippet_shot2_ndoc10_20240430.json │ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_snippet_shot2_ndoc5_20240430.json │ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_summary_shot2_ndoc10_20240430.json │ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_summary_shot2_ndoc5_20240430.json │ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240415.json │ │ │ └── asqa_oracle_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240428.json │ │ └── run.sh │ └── ELI5 │ │ ├── README.md │ │ ├── eli5_benchmark.py │ │ ├── generate.py │ │ ├── prompts │ │ └── eli5_prompt.json │ │ ├── results │ │ ├── eli5-bm25-Llama-2-7b-chat-hf-vanilla-shot2-ndoc5-20240310.json │ │ └── eli5_oracle_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240430.json │ │ └── run.sh ├── ASQA │ ├── README.md │ ├── asqa_benchmark.py │ ├── generate.py │ ├── prompts.py │ ├── run.sh │ └── run_generate.sh ├── HOTPOTQA │ ├── README.md │ ├── generate.py │ ├── hotpotqa_benchmark.py │ ├── prompts.py │ ├── run.sh │ └── run_generate.sh ├── WebGLM │ ├── generate.py │ ├── results │ │ └── webglm_Llama_2_7b_chat_hf_20240502.json │ ├── run.sh │ └── webglm_benchmark.py ├── __init__.py ├── auto │ ├── README.md │ ├── auto_benchmark.py │ ├── corpus │ │ ├── corpus.json │ │ └── few_shot_cases.json │ ├── output │ │ └── dataset.json │ ├── prompt.py │ └── run.sh ├── base.py └── utils.py ├── pyproject.toml ├── pytest.ini ├── rageval ├── __init__.py ├── evaluations.py ├── exceptions.py ├── metrics │ ├── __init__.py │ ├── answer_correctness │ │ ├── _answer_accuracy.py │ │ ├── _answer_bert_score.py │ │ ├── _answer_bleu.py │ │ ├── _answer_chrf.py │ │ ├── _answer_claim_recall.py │ │ ├── _answer_disambig_f1.py │ │ ├── _answer_edit_distance.py │ │ ├── _answer_exact_match.py │ │ ├── _answer_f1.py │ │ ├── _answer_lcs_ratio.py │ │ ├── _answer_relevancy.py │ │ ├── _answer_rouge_correctness.py │ │ └── _answer_ter.py │ ├── answer_groundedness │ │ ├── _answer_citation_precision.py │ │ ├── _answer_citation_recall.py │ │ ├── _claim_faithfulness.py │ │ └── _context_reject_rate.py │ ├── answer_informativeness │ │ ├── _answer_distinct12.py │ │ ├── _claim_num.py │ │ ├── _pairwise_accuracy.py │ │ ├── _repetitiveness.py │ │ └── _text_length.py │ ├── base.py │ ├── context_adequacy │ │ └── _context_recall.py │ └── context_relevance │ │ ├── _accuracy.py │ │ ├── _hit_rate.py │ │ ├── _mrr.py │ │ └── _ndcg.py ├── models │ ├── __init__.py │ ├── base.py │ ├── nli.py │ └── openai.py ├── tasks │ ├── __init__.py │ ├── _generate.py │ └── base.py ├── utils │ ├── RAGAS_prompt.py │ ├── __init__.py │ ├── check_utils.py │ ├── prompt.py │ └── utility.py ├── validation.py └── version.py ├── requirements.txt ├── setup.py ├── tests ├── demo.py ├── test_evaluation.py └── units │ ├── test_answer_accuracy.py │ ├── test_answer_bert_score.py │ ├── test_answer_bleu.py │ ├── test_answer_chrf.py │ ├── test_answer_citation_precision.py │ ├── test_answer_citation_recall.py │ ├── test_answer_claim_recall.py │ ├── test_answer_disambig_f1.py │ ├── test_answer_distinct.py │ ├── test_answer_edit_distance.py │ ├── test_answer_exect_match.py │ ├── test_answer_f1.py │ ├── test_answer_lcs_ratio.py │ ├── test_answer_rouge.py │ ├── test_answer_ter.py │ ├── test_context_recall.py │ ├── test_context_reject_rate.py │ ├── test_nli.py │ ├── test_openai_api.py │ └── test_text_length.py └── tutorials ├── README.md └── tutorial 1 ├── df_result_excel.xlsx ├── main.ipynb └── requirements.txt /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # D401 First line should be in imperative mood 4 | D401, 5 | 6 | # D202 No blank lines allowed after function docstring 7 | D202, 8 | 9 | # For doctests: 10 | # D207 Docstring is under-indented 11 | D207, 12 | # D301 Use r""" if any backslashes in a docstring 13 | D301, 14 | # F401 'blah blah' imported but unused 15 | F401, 16 | 17 | # D101 Missing docstring in public class 18 | D101, 19 | 20 | # D100 Missing docstring in public module 21 | D100, 22 | 23 | # E501 line too long (88 > 79 characters) 24 | E501, 25 | 26 | # D400 First line should end with a period 27 | D400, 28 | 29 | # D103 Missing docstring in public function 30 | D103, 31 | -------------------------------------------------------------------------------- /.github/workflows/makefile.yml: -------------------------------------------------------------------------------- 1 | name: Makefile CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - name: Setup Python version 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: 3.10.15 21 | 22 | - name: Install requirements 23 | run: make init 24 | 25 | - name: Run check 26 | run: make test 27 | 28 | - name: Upload coverage reports to Codecov 29 | uses: codecov/codecov-action@v3 30 | env: 31 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | *.swp 4 | *.bak 5 | *.weights 6 | *.DS_Store 7 | .vscode 8 | .coverage* 9 | RagEval.egg-info/* 10 | log/* 11 | build/* 12 | dist/* 13 | .idea/ 14 | .pytest_cache/ 15 | .cache 16 | htmlcov/* 17 | .rageval/ 18 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributing to RAGEval 2 | ---------- 3 | 4 | > Note: RAGEval is developed under Python 3.8.18 5 | 6 | Welcome! RAGEval is a community project that aims to evaluate different modules of RAG system, including query rewriting, document ranking, information compression, evidence verify, answer generating, and result validating. Your experience and what you can contribute are important to the project's success. 7 | 8 | Discussion 9 | ---------- 10 | 11 | If you've run into behavior in RAGEval you don't understand, or you're having trouble working out a good way to apply it to your code, or you've found a bug or would like a feature it doesn't have, we want to hear from you! 12 | 13 | Our main forum for discussion is the project's [GitHub issue tracker](https://github.com/gomate-community/rageval/issues). This is the right place to start a discussion of any of the above or most any other topic concerning the project. 14 | 15 | First Time Contributors 16 | ----------------------- 17 | 18 | RAGEval appreciates your contribution! If you are interested in helping improve RAGEval, there are several ways to get started: 19 | 20 | * Work on [new metrics and datasets](https://github.com/gomate-community/rageval/tree/main/rageval). 21 | * Try to answer questions on [the issue tracker](https://github.com/gomate-community/rageval/issues). 22 | 23 | Submitting Changes 24 | ------------------ 25 | 26 | Even more excellent than a good bug report is a fix for a bug, or the implementation of a much-needed new metrics or benchmarks. 27 | 28 | (*) We'd love to have your contributions. 29 | 30 | (*) If your new feature will be a lot of work, we recommend talking to us early -- see below. 31 | 32 | We use the usual GitHub pull-request flow, which may be familiar to you if you've contributed to other projects on GitHub -- see blow. 33 | 34 | Anyone interested in RAGEval may review your code. One of the RAGEval core developers will merge your pull request when they think it's ready. 35 | For every pull request, we aim to promptly either merge it or say why it's not yet ready; if you go a few days without a reply, please feel 36 | free to ping the thread by adding a new comment. 37 | 38 | For a list of RAGEval core developers, see [Readme](https://github.com/gomate-community/rageval/blob/main/README.md). 39 | 40 | Contributing Flow 41 | ------------------ 42 | 43 | 1. Fork the latest version of [RAGEval](https://github.com/gomate-community/rageval) into your repo. 44 | 2. Create an issue under [gomate-Community/rageval](https://github.com/gomate-community/rageval/issues), write description about the bug/enhancement. 45 | 3. Clone your forked RAGEval into your machine, add your changes together with associated tests. 46 | 4. Run `make test` with terminal, ensure all unit tests & integration tests passed on your computer. 47 | 5. Push to your forked repo, then send the pull request to the official repo. In pull request, you need to create a link to the issue you created using `#[issue_id]`, and describe what has been changed. 48 | 6. Wait [continuous integration](https://github.com/gomate-community/rageval/blob/main/.github/workflows/makefile.yml) passed. 49 | 7. Wait [Codecov](https://app.codecov.io/gh/gomate-community/rageval) generate the coverage report. 50 | 8. We'll assign reviewers to review your code. 51 | 52 | 53 | Your PR will be merged if: 54 | - Funcitonally benefit for the project. 55 | - Passed Countinuous Integration (all unit tests, integration tests and [PEP8](https://www.python.org/dev/peps/pep-0008/) check passed). 56 | - Test coverage didn't decreased, we use [pytest](https://docs.pytest.org/en/latest/). 57 | - With proper docstrings, see codebase as examples. 58 | - With type hints, see [typing](https://docs.python.org/3/library/typing.html). 59 | - All reviewers approved your changes. 60 | 61 | 62 | **Thanks and let's improve RAGEval together!** 63 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Usages: 2 | # 3 | # to install rageval dependencies: 4 | # $ make init 5 | # 6 | # to run all rageval tests, recommended for big PRs and new versions: 7 | # $ make test 8 | # 9 | # there are three kinds of tests: 10 | # 11 | # 1. "quick" tests 12 | # - run in seconds 13 | # - include all unit tests without marks and all doctests 14 | # - for rapid prototyping 15 | # - CI run this for all PRs 16 | # 17 | # 2. "slow" tests 18 | # - run in minutes 19 | # - include all unit tests marked "slow" 20 | # - CI run this for all PRs 21 | # 22 | # 3. "cron" tests 23 | # - run in minutes 24 | # - involves underministic behavoirs (e.g. network connection) 25 | # - include all unit tests marked "cron" 26 | # - CI run this on a daily basis 27 | # 28 | # to run quick tests, excluding time consuming tests and crons: 29 | # $ make quick 30 | # 31 | # to run slow tests, excluding normal tests and crons: 32 | # $ make slow 33 | # 34 | # to run crons: 35 | # $ make cron 36 | # 37 | # to run all tests: 38 | # $ make test 39 | # 40 | # to run CI push/PR tests: 41 | # $ make push 42 | # 43 | # to run docstring style check: 44 | # $ make flake 45 | 46 | init: 47 | pip3 install -r requirements.txt 48 | 49 | TEST_ARGS = -v --full-trace -l --doctest-modules --doctest-continue-on-failure --cov rageval/ --cov-report term-missing --cov-report html --cov-config .coveragerc rageval/ tests/ -W ignore::DeprecationWarning 50 | FLAKE_ARGS = ./rageval --exclude=__init__.py 51 | 52 | test: 53 | python3 -m pytest $(TEST_ARGS) 54 | python3 -m flake8 $(FLAKE_ARGS) 55 | 56 | push: 57 | python3 -m pytest -m 'not cron' $(TEST_ARGS) ${ARGS} 58 | python3 -m flake8 $(FLAKE_ARGS) 59 | 60 | quick: 61 | python3 -m pytest -m 'not slow and not cron' $(TEST_ARGS) ${ARGS} 62 | 63 | slow: 64 | python3 -m pytest -m 'slow and not cron' $(TEST_ARGS) ${ARGS} 65 | 66 | cron: 67 | python3 -m pytest -m 'cron' $(TEST_ARGS) ${ARGS} 68 | 69 | flake: 70 | python3 -m flake8 $(FLAKE_ARGS) ${ARGS} 71 | -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/asqa_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Dict, Tuple, Any 4 | 5 | from datasets import Dataset 6 | 7 | import rageval as rl 8 | from benchmarks import BaseBenchmark 9 | from rageval.metrics import AnswerEMCorrectness, AnswerCitationRecall, AnswerCitationPrecision 10 | 11 | 12 | class ASQABenchmark(BaseBenchmark): 13 | 14 | name = "asqa_benchmark" 15 | 16 | def __init__(self, cache_path) -> None: 17 | super().__init__() 18 | nli_model = rl.models.NLIModel( 19 | "text2text-generation", 20 | cache_path + "/models/t5_xxl_true_nli_mixture", 21 | ) 22 | self.metrics = [ 23 | AnswerEMCorrectness(), 24 | AnswerCitationRecall(nli_model=nli_model), 25 | AnswerCitationPrecision(nli_model=nli_model) 26 | ] 27 | 28 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]: 29 | self.dataset = self.dataset.rename_column("output", "answers") 30 | self.dataset = self.dataset.map(lambda data: { 31 | "gt_answers": [ 32 | pair["short_answers"] 33 | for pair in data["qa_pairs"] 34 | ] 35 | }) 36 | 37 | results = {} 38 | for metric in self.metrics: 39 | results[metric.name], self.dataset = metric.compute(self.dataset, self.batch_size) 40 | return results, self.dataset 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--cache_path", type=str, default=None) 46 | parser.add_argument("--remote_split", type=str, default=None) 47 | parser.add_argument("--local_file", type=str, default=None) 48 | args = parser.parse_args() 49 | date = time.strftime("%Y%m%d", time.localtime()) 50 | 51 | benchmark = ASQABenchmark(cache_path=args.cache_path) 52 | if args.local_file: 53 | results = benchmark.evaluate( 54 | path="json", 55 | data_files={ 56 | "test": args.cache_path+"/results/"+args.local_file 57 | }, 58 | split="test" 59 | ) 60 | benchmark.save_results(f"benchmarks/ALCE/ASQA/results/{args.local_file[:-5]}_{date}.json") 61 | else: 62 | results = benchmark.evaluate(path="golaxy/rag-bench", name="alce_asqa_gtr", split=args.remote_split) 63 | benchmark.save_results(f"benchmarks/ALCE/ASQA/results/{args.remote_split}_{date}.json") 64 | -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/results/asqa_dpr_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240427.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_exact_match": 0.2921589310829817, 3 | "answer_citation_recall": 0.49225060277275473, 4 | "answer_citation_precision": 0.8100701083492671 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_snippet_shot2_ndoc10_20240430.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_exact_match": 0.34504219409282705, 3 | "answer_citation_recall": 0.5594480062834494, 4 | "answer_citation_precision": 0.7248677248677249 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_snippet_shot2_ndoc5_20240430.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_exact_match": 0.3428973277074543, 3 | "answer_citation_recall": 0.566895218002813, 4 | "answer_citation_precision": 0.8382519863791147 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_summary_shot2_ndoc10_20240430.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_exact_match": 0.37202883263009845, 3 | "answer_citation_recall": 0.6047329457930724, 4 | "answer_citation_precision": 0.744484556758925 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_summary_shot2_ndoc5_20240430.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_exact_match": 0.3699542897327708, 3 | "answer_citation_recall": 0.6328197207152902, 4 | "answer_citation_precision": 0.8281718281718282 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240415.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_exact_match": 0.3334739803094233, 3 | "answer_citation_recall": 0.5590039180229054, 4 | "answer_citation_precision": 0.8004201680672269 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/results/asqa_oracle_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240428.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_exact_match": 0.4170534458509142, 3 | "answer_citation_recall": 0.5808795459111914, 4 | "answer_citation_precision": 0.788681204569055 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ASQA/run.sh: -------------------------------------------------------------------------------- 1 | script_dir=$(cd $(dirname $0);pwd) 2 | cache_dir=$(dirname $(dirname $(dirname $script_dir)))/.rageval 3 | wget -cP $cache_dir/datasets https://huggingface.co/datasets/princeton-nlp/ALCE-data/resolve/main/ALCE-data.tar 4 | tar -xvf $cache_dir/datasets/ALCE-data.tar -C $cache_dir/datasets 5 | python3 setup.py install 6 | 7 | #python3 $script_dir/generate.py\ 8 | # --cache_path $cache_dir\ 9 | # --model Llama-2-7b-chat-hf\ 10 | # --dataset gtr\ 11 | # --method vanilla\ 12 | # --ndoc 5\ 13 | # --shot 2 14 | 15 | python3 $script_dir/asqa_benchmark.py\ 16 | --cache_path $cache_dir\ 17 | --remote_split Llama_2_7b_chat_hf_vanilla_shot2_ndoc5 18 | -------------------------------------------------------------------------------- /benchmarks/ALCE/ELI5/eli5_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Dict, Tuple, Any 4 | 5 | from datasets import Dataset 6 | 7 | import rageval as rl 8 | from benchmarks import BaseBenchmark 9 | from rageval.metrics import AnswerNLICorrectness, AnswerCitationRecall, AnswerCitationPrecision 10 | 11 | 12 | class ELI5Benchmark(BaseBenchmark): 13 | 14 | name = "eli5_benchmark" 15 | 16 | def __init__(self, cache_path) -> None: 17 | super().__init__() 18 | nli_model = rl.models.NLIModel( 19 | "text2text-generation", 20 | cache_path + "/models/t5_xxl_true_nli_mixture", 21 | ) 22 | self.metrics = [ 23 | AnswerNLICorrectness(nli_model=nli_model, decompose_model="nltk"), 24 | AnswerCitationRecall(nli_model=nli_model), 25 | AnswerCitationPrecision(nli_model=nli_model) 26 | ] 27 | 28 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]: 29 | self.dataset = self.dataset.rename_column("output", "answers") 30 | self.dataset = self.dataset.rename_column("claims", "gt_answers") 31 | 32 | results = {} 33 | for metric in self.metrics: 34 | results[metric.name], self.dataset = metric.compute(self.dataset, self.batch_size) 35 | return results, self.dataset 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--cache_path", type=str, default=None) 41 | parser.add_argument("--remote_split", type=str, default=None) 42 | parser.add_argument("--local_file", type=str, default=None) 43 | args = parser.parse_args() 44 | date = time.strftime("%Y%m%d", time.localtime()) 45 | 46 | benchmark = ELI5Benchmark(cache_path=args.cache_path) 47 | if args.local_file: 48 | results = benchmark.evaluate( 49 | path="json", 50 | data_files={ 51 | "test": args.cache_path+"/results/"+args.local_file 52 | }, 53 | split="test" 54 | ) 55 | benchmark.save_results(f"benchmarks/ALCE/ELI5/results/{args.local_file[:-5]}_{date}.json") 56 | else: 57 | results = benchmark.evaluate(path="golaxy/rag-bench", name="alce_eli5_bm25", split=args.remote_split) 58 | benchmark.save_results(f"benchmarks/ALCE/ELI5/results/{args.remote_split}_{date}.json") -------------------------------------------------------------------------------- /benchmarks/ALCE/ELI5/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from tqdm import tqdm 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from rageval.models.openai import OpenAILLM 9 | 10 | 11 | def make_doc_prompt(doc, doc_id, doc_prompt, method): 12 | # For doc prompt: 13 | # - {ID}: doc id (starting from 1) 14 | # - {T}: title 15 | # - {P}: text 16 | # use_shorter: None, "summary", or "extraction" 17 | 18 | if method == "vanilla": 19 | text = doc['text'] 20 | elif method == "summary": 21 | text = doc["summary"] 22 | elif method == "snippet": 23 | text = doc["extraction"] 24 | else: 25 | raise ValueError("Don't support such method.") 26 | return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1)) 27 | 28 | 29 | def make_demo(item, prompt, ndoc, doc_prompt, instruction, method, test=False): 30 | # For demo prompt 31 | # - {INST}: the instruction 32 | # - {D}: the documents 33 | # - {Q}: the question 34 | # - {A}: the answers 35 | # ndoc: number of documents to put in context 36 | # use_shorter: None, "summary", or "extraction" 37 | 38 | prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question']) 39 | doc_texts = [] 40 | if "{D}" in prompt: 41 | if ndoc == 0: 42 | prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line 43 | else: 44 | doc_list = item["docs"][:ndoc] 45 | for doc_id, doc in enumerate(doc_list): 46 | doc_texts.append(make_doc_prompt(doc, doc_id, doc_prompt, method=method)) 47 | text = "".join(doc_texts) 48 | prompt = prompt.replace("{D}", text) 49 | 50 | if not test: 51 | answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"] 52 | prompt = prompt.replace("{A}", "").rstrip() + answer 53 | else: 54 | prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n 55 | 56 | return prompt, doc_texts 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--cache_path", type=str, default=None) 62 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo") 63 | parser.add_argument("--api_key", type=str, default=None) 64 | parser.add_argument("--max_length", type=int, default=4096) 65 | parser.add_argument("--temperature", type=float, default=0.5) 66 | parser.add_argument("--top_p", type=float, default=1.0) 67 | parser.add_argument("--dataset", type=str, default="bm25") 68 | parser.add_argument("--method", type=str, default="vanilla") 69 | parser.add_argument("--ndoc", type=int, default=5) 70 | parser.add_argument("--shot", type=int, default=2) 71 | args = parser.parse_args() 72 | 73 | print("-" * 10 + "Loading dataset" + "-" * 10) 74 | 75 | if args.dataset == "bm25": 76 | eval_data = json.load( 77 | open(args.cache_path + "/datasets/ALCE-data/eli5_eval_bm25_top100.json", "r") 78 | ) 79 | elif args.dataset == "oracle": 80 | eval_data = json.load( 81 | open(args.cache_path + "/datasets/ALCE-data/eli5_eval_bm25_top100_reranked_oracle.json", "r") 82 | ) 83 | else: 84 | raise ValueError("Don't support such dataset.") 85 | 86 | print("-" * 10 + "Finish loading dataset" + "-" * 10) 87 | 88 | print("-" * 10 + "Generating prompts" + "-" * 10) 89 | 90 | eval_prompt = json.load( 91 | open("benchmarks/ALCE/ELI5/prompts/eli5_prompt.json", "r") 92 | ) 93 | 94 | head_prompt = "" 95 | for demo_id in range(args.shot): 96 | demo_item = eval_prompt["demos"][demo_id] 97 | prompt, _ = make_demo( 98 | demo_item, 99 | prompt=eval_prompt["demo_prompt"], 100 | ndoc=args.ndoc, 101 | doc_prompt=eval_prompt["doc_prompt"], 102 | instruction=eval_prompt["instruction"], 103 | method=args.method 104 | ) 105 | head_prompt += prompt 106 | head_prompt += eval_prompt["demo_sep"] 107 | 108 | for idx, eval_item in enumerate(tqdm(eval_data)): 109 | prompt, doc_texts = make_demo( 110 | eval_item, 111 | prompt=eval_prompt["demo_prompt"], 112 | ndoc=args.ndoc, 113 | doc_prompt=eval_prompt["doc_prompt"], 114 | instruction=eval_prompt["instruction"], 115 | method=args.method, 116 | test=True 117 | ) 118 | eval_data[idx]['prompt'] = head_prompt + prompt 119 | eval_data[idx]['contexts'] = doc_texts 120 | eval_data[idx]['docs'] = eval_item["docs"][:args.ndoc] 121 | 122 | print("-" * 10 + "Finish generating prompts" + "-" * 10) 123 | 124 | print("-" * 10 + "Loading model" + "-" * 10) 125 | 126 | model_name = args.model.split("/")[-1] 127 | if "gpt" in model_name: 128 | os.environ["OPENAI_API_KEY"] = args.api_key 129 | model = OpenAILLM( 130 | model=model_name, 131 | _api_key_env_var="OPENAI_API_KEY", 132 | max_tokens=args.max_length, 133 | temperature=args.temperature, 134 | top_p=args.top_p 135 | ) 136 | else: 137 | tokenizer = AutoTokenizer.from_pretrained(args.cache_path+"/models/"+args.model, use_fast=False) 138 | model = AutoModelForCausalLM.from_pretrained( 139 | args.cache_path+"/models/"+args.model, 140 | device_map='auto' 141 | ) 142 | 143 | print("-" * 10 + "Finish loading model" + "-" * 10) 144 | 145 | print("-" * 10 + "Predict" + "-" * 10) 146 | 147 | for idx, item in enumerate(tqdm(eval_data)): 148 | prompt = item['prompt'] 149 | if "gpt" in model_name: 150 | output = model.generate( 151 | inputs=[prompt], 152 | system_role="You are a helpful assistant that answers the following questions with proper citations." 153 | ) 154 | item['output'] = output.generations[0][0].text 155 | else: 156 | inputs = tokenizer([prompt], return_tensors="pt").to(model.device) 157 | stop = ["\n", "Ċ", "ĊĊ", "<0x0A>"] # In Llama \n is <0x0A>; In OPT \n is Ċ 158 | stop_token_ids = list( 159 | set( 160 | [tokenizer._convert_token_to_id(stop_token) for stop_token in stop] 161 | + [model.config.eos_token_id] 162 | ) 163 | ) 164 | if "llama" in model_name.lower(): 165 | stop_token_ids.remove(tokenizer.unk_token_id) 166 | 167 | generation = model.generate( 168 | **inputs, 169 | max_length=args.max_length, 170 | temperature=args.temperature, 171 | top_p=args.top_p, 172 | eos_token_id=stop_token_ids, 173 | do_sample=True 174 | ) 175 | output = tokenizer.decode(generation[0][inputs['input_ids'].size(1):], skip_special_tokens=True) 176 | item['output'] = output 177 | 178 | print("-" * 10 + "Finish predicting" + "-" * 10) 179 | 180 | file_name = f"eli5-{args.dataset}-{model_name}-{args.method}-shot{args.shot}-ndoc{args.ndoc}" 181 | file_name = file_name.replace("-", "_") 182 | result_path = args.cache_path + "/results/" + file_name + ".json" 183 | json.dump(eval_data, open(result_path, "w"), indent=4) 184 | 185 | print(f"\nResult file saved as {result_path}") 186 | -------------------------------------------------------------------------------- /benchmarks/ALCE/ELI5/results/eli5-bm25-Llama-2-7b-chat-hf-vanilla-shot2-ndoc5-20240310.json: -------------------------------------------------------------------------------- 1 | { 2 | "nli_claim": 11.5, 3 | "citation_recall": 26.623809523809523, 4 | "citation_precision": 74.54545454545455 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ELI5/results/eli5_oracle_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240430.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_claim_recall": 0.17766666666666667, 3 | "answer_citation_recall": 0.34007738095238094, 4 | "answer_citation_precision": 0.7563654518222666 5 | } -------------------------------------------------------------------------------- /benchmarks/ALCE/ELI5/run.sh: -------------------------------------------------------------------------------- 1 | script_dir=$(cd $(dirname $0);pwd) 2 | cache_dir=$(dirname $(dirname $(dirname $script_dir)))/.rageval 3 | wget -cP $cache_dir/datasets https://huggingface.co/datasets/princeton-nlp/ALCE-data/resolve/main/ALCE-data.tar 4 | tar -xvf $cache_dir/datasets/ALCE-data.tar -C $cache_dir/datasets 5 | python3 setup.py install 6 | 7 | #python3 $script_dir/generate.py\ 8 | # --cache_path $cache_dir\ 9 | # --model Llama-2-7b-chat-hf\ 10 | # --dataset bm25\ 11 | # --method vanilla\ 12 | # --ndoc 5\ 13 | # --shot 2 14 | 15 | python3 $script_dir/eli5_benchmark.py\ 16 | --cache_path $cache_dir\ 17 | --remote_split Llama_2_7b_chat_hf_vanilla_shot2_ndoc5 18 | -------------------------------------------------------------------------------- /benchmarks/ASQA/asqa_benchmark.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Any, Optional 2 | from datasets import Dataset 3 | import numpy as np 4 | import os 5 | import argparse 6 | from benchmarks import BaseBenchmark 7 | from rageval.metrics import (AnswerRougeCorrectness, AnswerEMCorrectness, AnswerDisambigF1Correctness) 8 | 9 | 10 | class ASQABenchmark(BaseBenchmark): 11 | """Benchmark for ASQA dataset. 12 | 13 | The ASQA dataset is a question-answering dataset that contains factoid questions and long-form answers. The benchmark evaluates the correctness of the answers in the dataset. 14 | """ 15 | 16 | name = "asqa_benchmark" 17 | metrics = [AnswerRougeCorrectness(rouge_type="rougeL"), 18 | AnswerEMCorrectness(ignore_case=True), 19 | AnswerDisambigF1Correctness()] 20 | 21 | ground_truths = { 22 | "answer_disambig_f1": "long_answers", 23 | "answer_rouge_correctness": "long_answers", 24 | "answer_exact_match": "short_answers" 25 | } 26 | 27 | def __init__(self) -> None: 28 | """Initialization.""" 29 | super().__init__() 30 | 31 | def is_existed(self, column_name: str) -> bool: 32 | """Check if the column exists in the dataset.""" 33 | return column_name in self.dataset.column_names 34 | 35 | def _evaluate(self, ) -> Tuple[Dict[Any, Any], Dataset]: 36 | """Evaluate the dataset and return the dataset with scores. 37 | 38 | For the ASQA dataset, the `short_answers` and `long_answers` are stored in the "qa_pairs" and "annotations" columns, respectively. We need to extract them and add them to the dataset. 39 | 40 | We use the `short_answers` as the `gt_answers` to evaluate the string Exact Match correctness and the `long_answers` to evaluate the RougeL and DisambigF1 score. And then we calculate the `DR score` as the geometric mean of the RougeL and DisambigF1 scores. 41 | """ 42 | if not self.is_existed("short_answers"): 43 | self.dataset = self.dataset.map(lambda example: {"short_answers": [ann["short_answers"] for ann in example["qa_pairs"]]}) 44 | if not self.is_existed("long_answers"): 45 | self.dataset = self.dataset.map(lambda example: {"long_answers": [ann["long_answer"] for ann in example["annotations"]]}) 46 | 47 | results = {} 48 | for m in self.metrics: 49 | if m.name in self.ground_truths: 50 | print(f"Calculating {m.name}...") 51 | 52 | if self.is_existed(m.name): 53 | # Remove the metric column if it already exists 54 | self.dataset = self.dataset.remove_columns(m.name) 55 | if not self.is_existed(self.ground_truths[m.name]): 56 | # Check if the ground truth column exists 57 | raise ValueError(f"The column {self.ground_truths[m.name]} is not in the dataset. Please check the column names.") 58 | 59 | avg_scores, scores = m.compute( 60 | self.dataset["answers"], 61 | self.dataset[self.ground_truths[m.name]] 62 | ) 63 | results[m.name] = avg_scores 64 | self.dataset = self.dataset.add_column(m.name, scores) 65 | 66 | print(f"{m.name}: {avg_scores}") 67 | 68 | if self.is_existed("answer_rouge_correctness") and self.is_existed("answer_disambig_f1"): 69 | # Notice that DR score is an overall geometric mean of RougeL and DisambigF1 scores, which is calculated as sqrt(RougeL * DisambigF1) for whole dataset instead of average of each sample. 70 | print("Calculating DR score...") 71 | results["DR_score"] = np.sqrt(np.average(self.dataset["answer_disambig_f1"]) * np.average(self.dataset["answer_rouge_correctness"])) 72 | print(f"DR_score: {results['DR_score']}") 73 | 74 | return results, self.dataset 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument("--output_dir", type=str, default=".rageval/benchmark") 79 | parser.add_argument("--split", type=str, default="llama2_7b_chat") 80 | args = parser.parse_args() 81 | 82 | benchmark = ASQABenchmark() 83 | 84 | results = benchmark.evaluate(path="golaxy/rag-bench", name="asqa", split=args.split) 85 | print(f"Results:\n {results}") 86 | 87 | benchmark.save_results(os.path.join(args.output_dir,"results", f"{args.split}.jsonl")) 88 | benchmark.save_dataset(os.path.join(args.output_dir,"dataset", f"{args.split}.jsonl")) 89 | 90 | benchmark.set_metric([AnswerEMCorrectness(ignore_case=False)]) 91 | results = benchmark.evaluate() 92 | print(f"Results:\n {results}") 93 | -------------------------------------------------------------------------------- /benchmarks/ASQA/generate.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset, load_dataset 2 | import re 3 | from rageval.models import OpenAILLM 4 | import os 5 | import logging 6 | import argparse 7 | 8 | from prompts import (FEW_SHOT_EXAMPLES, PROMPT) 9 | 10 | def extract_key_information(pred: str) -> str: 11 | '''Extract key information from the response.''' 12 | pattern = r"(?:1\.|\(1\)).*?((?:1\.|\(1\)).*)" # find the second list starting with 1. or (1) 13 | pred = pred.strip().strip() 14 | matches = re.findall(pattern, pred, re.DOTALL) 15 | if matches: 16 | pred = matches[0] 17 | else: 18 | print(f"Cannot extract key information from the response: {pred}") 19 | return pred 20 | pred = re.sub(r'\(\d+\)\s', '', pred) # remove the index numbers 21 | return pred 22 | 23 | def generate_answers(engine: OpenAILLM, dataset: Dataset) -> Dataset: 24 | prompts = [ 25 | PROMPT.format(few_shot_examples=FEW_SHOT_EXAMPLES, 26 | question=data['ambiguous_question']) 27 | for data in dataset 28 | ] 29 | responses = engine.batch_generate(prompts) 30 | response_texts = [r.generations[0][0].text for r in responses] 31 | answers = [extract_key_information(response) for response in response_texts] 32 | dataset = dataset.add_column("responses", response_texts) 33 | return dataset.add_column("answers", answers) 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--max_num_examples", type=int, default=5) 38 | parser.add_argument("--max_new_tokens", type=int, default=256) 39 | parser.add_argument("--output_path", type=str, default="benchmarks/ASQA/output") 40 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo-instruct") 41 | parser.add_argument("--api_key", type=str, default=None) 42 | 43 | args = parser.parse_args() 44 | 45 | print("\nLoad ASQA dataset...") 46 | dataset = load_dataset("din0s/asqa") 47 | dataset = dataset['dev'].select(range(args.max_num_examples)) 48 | 49 | print("Init ASQA dataset...") 50 | os.environ['OPENAI_API_KEY'] = args.api_key 51 | engine = OpenAILLM(args.model, 52 | _api_key_env_var = 'OPENAI_API_KEY', 53 | max_tokens=args.max_new_tokens) 54 | 55 | print("Start generate answers...") 56 | dataset = generate_answers(engine, dataset) 57 | 58 | file_path = os.path.join(args.output_path, f"{args.model}.jsonl") 59 | dataset.to_json(file_path) 60 | print(f"\nFinish generate dataset. Dataset saved as {file_path}") 61 | 62 | engine.calculate_api_cost() 63 | -------------------------------------------------------------------------------- /benchmarks/ASQA/prompts.py: -------------------------------------------------------------------------------- 1 | FEW_SHOT_EXAMPLES = """Given an ambiguous question, figure out its interpretations and answer them one by one. 2 | Question: Who played bonnie in gone with the wind? 3 | Answer: This question is ambiguous in terms of which version or adaptation of Gone with the Wind is being referred to. In order to figure out its interpretations, we need to consider different versions or adaptations of Gone with the Wind. Gone with the Wind has two versions or adaptations: the 1939 film Gone with the Wind or the 2008 musical Gone with the Wind. Therefore, this question has 2 interpretations: (1) Who played Bonnie in the 1939 film Gone with the Wind? (2) Who played Bonnie in the 2008 musical Gone with the Wind? The answers to all interpretations are: (1) The 1939 film Gone with the Wind’s character Bonnie was played by Eleanore Cammack "Cammie" King. (2) The 2008 musical Gone with the Wind’s character Bonnie was played by Leilah de Meza. 4 | 5 | Given an ambiguous question, figure out its interpretations and answer them one by one. 6 | Question: What is the second largest city in the usa? 7 | Answer: This question is ambiguous in terms of the criteria being used to determine the second largest city in the USA. In order to figure out its interpretations, we need to consider different criteria to determine a city’s size. City size can be measured by two criteria: population or area. Therefore, this question has 2 interpretations: (1) What is the second largest city in the USA by population? (2) What is the second largest city in the USA by area? The answers to all interpretations are: (1) The second largest city in the USA by population is Los Angeles, California. (2) The second largest city in the USA by area is Juneau, Alaska. 8 | 9 | Given an ambiguous question, figure out its interpretations and answer them one by one. 10 | Question: When was bohemian rhapsody released as a single? 11 | Answer: This question is ambiguous in terms of which country’s release of the single is being referred to. In order to figure out its interpretations, we need to consider different countries where Bohemian Rhapsody is released. Bohemian Rhapsody was released in the United Kingdom and in the United States on different dates. Therefore, this question has 2 interpretations: (1) When was Bohemian Rhapsody released as a single in the United Kingdom? (2) When was Bohemian Rhapsody released as a single in the United States? The answers to all interpretations are: (1) Bohemian Rhapsody was released as a single in the United Kingdom on 31 October 1975. (2) Bohemian Rhapsody was released as a single in the United States on December 1975. 12 | 13 | Given an ambiguous question, figure out its interpretations and answer them one by one. 14 | Question: Where do the philadelphia eagles play their home games? 15 | Answer: This question is ambiguous in terms of which specific location or venue is being referred to. In order to figure out its interpretations, we need to consider the different possible locations or venues that could be considered the home field of the Philadelphia Eagles. These include the city, the sports complex, or the stadium. Therefore, this question has 3 interpretations: (1) What city do the Philadelphia Eagles play their home games? (2) In what sports complex do the Philadelphia Eagles play their home games? (3) What stadium do the Philadelphia Eagles play their home games? The answers to all interpretations are: (1) Philadelphia Eagles play their home games in the city Philadelphia. (2) Philadelphia Eagles play their home games in the South Philadelphia Sports Complex. (3) Philadelphia Eagles play their home games in the Lincoln Financial Field stadium. 16 | 17 | Given an ambiguous question, figure out its interpretations and answer them one by one. 18 | Question: When did xbox one come out in australia? 19 | Answer: This question is ambiguous in terms of which specific version of the Xbox One is being referred to. In order to figure out its interpretations, we need to consider the different versions of the Xbox One that have been released. Xbox One has two versions: the Xbox One video game console or the Xbox One X high-end model. Therefore, this question has 2 interpretations: (1) When did the Xbox One release in Australia? (2) When did the Xbox One X release in Australia? The answers to all interpretations are: (1) The Xbox One video game console was released in Australia on November 22, 2013. (2) The Xbox One X video game console was released in Australia on November 7, 2017. 20 | 21 | Given an ambiguous question, figure out its interpretations and answer them one by one. 22 | Question: When does the movie summer of 84 come out? 23 | Answer: This question is ambiguous in terms of which release of the movie is being referred to. In order to figure out its interpretations, we need to consider different releases of the movie Summer of ’84. The movie Summer of ’84 is first released at the Sundance Festival before it’s released throughout the US. Therefore, this question has 2 interpretations: (1) When did the movie Summer of ’84 first release at the Sundance Festival? (2) When did the movie Summer of ’84 first release throughout the US? The answers to all interpretations are: (1) Summer of ’84 was released at the Sundance Festival on January 22, 2018. (2) Summer of ’84 was released throughout the US on August 10, 2018. 24 | 25 | Given an ambiguous question, figure out its interpretations and answer them one by one. 26 | Question: What was roy orbison’s first number one hit? 27 | Answer: This question is ambiguous in terms of which specific chart or region is being referred to. In order to figure out its interpretations, we need to consider the different charts and regions where Roy Orbison’s music was popular. Roy Orbison is popular in both the US Hot 100 and Canada, and the UK and Ireland. Therefore, this question has 2 interpretations: (1) What was Roy Orbison’s first number one hit in the US Hot 100 and Canada? (2) What was Roy Orbison’s first number one hit in the UK and Ireland? The answers to all interpretations are: (1) Running Scared was the first number one hit for Roy Orbison in the US Hot 100 and Canada. (2) Only the Lonely (Know the Way I Feel) was the first number one hit for Roy Orbison in the UK and Ireland. 28 | 29 | Given an ambiguous question, figure out its interpretations and answer them one by one. 30 | Question: What is the criminal’s name in the breakfast club? 31 | Answer: This question is ambiguous in terms of which specific name is being referred to - the character’s name or the actor’s name. In order to figure out its interpretations, we need to consider both possibilities: the character’s name or the actor’s name. Therefore, this question has 2 interpretations: (1) What is the criminal’s character name in The Breakfast Club? (2) What is the the name of the actor who played the criminal in The Breakfast Club? The answers to all interpretations are: (1) John Bender was the name of the criminal’s character in The Breakfast Club. (2) Judd Nelson was the actor of the criminal in The Breakfast Club.""" 32 | 33 | PROMPT = """{few_shot_examples}\n\nGiven an ambiguous question, figure out its interpretations and answer them one by one.\nQuestion: {question}\nAnswer: """ -------------------------------------------------------------------------------- /benchmarks/ASQA/run.sh: -------------------------------------------------------------------------------- 1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0)))) 2 | cd $rageval_dir 3 | echo "Running ASQA Benchmark" 4 | python3 benchmarks/ASQA/asqa_benchmark.py --output_dir ".rageval/benchmark" --split "gpt_3.5_turbo_instruct" 5 | echo "ASQA Benchmark Complete" -------------------------------------------------------------------------------- /benchmarks/ASQA/run_generate.sh: -------------------------------------------------------------------------------- 1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0)))) 2 | cd $rageval_dir 3 | echo "Generating ASQA examples" 4 | python3 benchmarks/ASQA/generate.py \ 5 | --max_num_examples 500 \ 6 | --max_new_tokens 256 \ 7 | --output_path "benchmarks/ASQA/output" \ 8 | --model "gpt-3.5-turbo-instruct" \ 9 | --api_key "YOUR_API_KEY" -------------------------------------------------------------------------------- /benchmarks/HOTPOTQA/hotpotqa_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | from typing import Dict, Tuple, Any 5 | 6 | from datasets import Dataset 7 | 8 | from benchmarks import BaseBenchmark 9 | from rageval.metrics import (AnswerEMCorrectness, AnswerF1Correctness) 10 | 11 | 12 | class HOTPOTQABenchmark(BaseBenchmark): 13 | """Benchmark for HotPotQA dataset. 14 | 15 | HotPotQA is a new dataset with 113k Wikipedia-based question-answer pairs with four key features: (1) the questions require finding and reasoning over multiple supporting documents to answer; (2) the questions are diverse and not constrained to any pre-existing knowledge bases or knowledge schemas; (3) we provide sentence-level supporting facts required for reasoning, allowingQA systems to reason with strong supervision and explain the predictions; (4) we offer a new type of factoid comparison questions to test QA systems’ ability to extract relevant facts and perform necessary comparison. 16 | 17 | """ 18 | 19 | name = "hotpot_qa_benchmark" 20 | 21 | def __init__(self) -> None: 22 | """Initialization.""" 23 | super().__init__() 24 | 25 | def _recode_gt_supporting_facts(self, data: object) -> object: 26 | """To calculate f1 recode gt_sent_ids by linking title and index""" 27 | recode_answers = [] 28 | for title, sent_id in zip(data['supporting_facts']['title'], data['supporting_facts']['sent_id']): 29 | recode = title.replace(" ","")+ str(sent_id) 30 | recode_answers.append(recode) 31 | recode_answers = [' '.join(recode_answers)] 32 | data["gt_sent_ids"] = recode_answers 33 | return data 34 | 35 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]: 36 | """Evaluate the dataset and return the dataset with scores. 37 | 38 | For the HotPotQA dataset(Distractor Setting), we evaluate models by using the `short_answer` and `supporting_answer`. 39 | 40 | For the HotPotQA dataset(Fullwiki Setting), we evaluate models by using the `response`. 41 | 42 | In Distractor Setting,we use the `answer` as the `gt_answers` to evaluate the string Exact Match correctness and the `supporting_facts` to make "gt_sent_ids" to evaluate the F1. 43 | 44 | In Fullwiki Setting,we use the `answer` as the `gt_answers` to evaluate the string Exact Match correctness. 45 | """ 46 | 47 | self.metrics = [AnswerEMCorrectness(ignore_case=True), 48 | AnswerF1Correctness() 49 | ] 50 | if (("supporting_answer" in self.dataset.column_names) and "short_answer" in self.dataset.column_names): 51 | self.dataset = self.dataset.map(self._recode_gt_supporting_facts) 52 | self.dataset = self.dataset.map(lambda exmaple: {"answer": [[exmaple['answer']]]}) 53 | ground_truths = { 54 | "answer_f1": ("supporting_answer", "gt_sent_ids"), 55 | "answer_exact_match": ("short_answer", "answer") 56 | } 57 | else: 58 | self.dataset = self.dataset.map(lambda exmaple: {"answer": [[exmaple['answer']]]}) 59 | ground_truths = { 60 | "answer_exact_match": ("response", "answer") 61 | } 62 | 63 | results = {} 64 | 65 | for metric in self.metrics: 66 | if metric.name in ground_truths: 67 | print(f"Calculating {metric.name}...") 68 | 69 | if metric.name in self.dataset.column_names: 70 | self.dataset = self.dataset.remove_columns(metric.name) 71 | 72 | an, gtan = ground_truths[metric.name] 73 | self.dataset = self.dataset.rename_column(an, "answers") 74 | self.dataset = self.dataset.rename_column(gtan, "gt_answers") 75 | 76 | results[metric.name], self.dataset = metric.compute(self.dataset, self.batch_size) 77 | 78 | self.dataset = self.dataset.rename_column("answers", an) 79 | self.dataset = self.dataset.rename_column("gt_answers", gtan) 80 | self.dataset = self.dataset.map(lambda example: {"answer": example['answer'][0][0]}) 81 | return results, self.dataset 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | 87 | parser.add_argument("--output_dir", type=str, default="benchmarks/HOTPOTQA") 88 | parser.add_argument("--remote_split", type=str, default="gpt_3.5_turbo") 89 | parser.add_argument("--local_file", type=str, default=None) 90 | 91 | args = parser.parse_args() 92 | date = time.strftime("%Y%m%d", time.localtime()) 93 | 94 | benchmark = HOTPOTQABenchmark() 95 | if args.local_file: 96 | data_file = os.path.join(args.output_dir, 'output', args.local_file) 97 | results = benchmark.evaluate( 98 | path='json', 99 | data_files={"test": data_file}, 100 | split="test" 101 | ) 102 | print(f"Results:\n {results}") 103 | benchmark.save_results(os.path.join(args.output_dir, 'results', f"{args.local_file[:-5]}_{date}.jsonl")) 104 | benchmark.save_dataset(os.path.join(args.output_dir, 'output', f"{args.local_file[:-5]}_{date}.jsonl")) 105 | else: 106 | results = benchmark.evaluate(path='golaxy/rag-bench', name='hotpot_qa', split=args.remote_split) 107 | print(f"Results:\n {results}") 108 | benchmark.save_results(os.path.join(args.output_dir, 'results', f"{args.remote_split}_{date}.jsonl")) 109 | benchmark.save_dataset(os.path.join(args.output_dir, 'output', f"{args.remote_split}_{date}.jsonl")) 110 | -------------------------------------------------------------------------------- /benchmarks/HOTPOTQA/run.sh: -------------------------------------------------------------------------------- 1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0)))) 2 | cd $rageval_dir 3 | 4 | echo "Running HotPotQA Benchmark" 5 | 6 | python3 benchmarks/HOTPOTQA/hotpot_qa_benchmark.py \ 7 | --output_dir "benchmarks/HOTPOTQA" \ 8 | --remote_split "gpt_3.5_turbo" 9 | 10 | # Check return status code 11 | if [ $? -eq 0 ]; then 12 | echo "HotPotQA Benchmark Complete" 13 | else 14 | echo "Error: HotPotQA Benchmark failed to execute." 15 | fi 16 | 17 | -------------------------------------------------------------------------------- /benchmarks/HOTPOTQA/run_generate.sh: -------------------------------------------------------------------------------- 1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0)))) 2 | cache_dir="$rageval_dir/HOTPOTQA" 3 | cd $rageval_dir 4 | echo "Generating HOTPOTQA examples" 5 | python3 benchmarks/HOTPOTQA/generate.py \ 6 | --subset "distractor"\ 7 | --num_documents 10 \ 8 | --max_num_examples 500 \ 9 | --max_length 4096 \ 10 | --output_path "benchmarks/HOTPOT/output" \ 11 | --cache_path $cache_dir \ 12 | --model "gpt-3.5-turbo" \ 13 | --api_key "YOUR_API_KEY" -------------------------------------------------------------------------------- /benchmarks/WebGLM/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from tqdm import tqdm 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from rageval.models.openai import OpenAILLM 9 | 10 | 11 | PROMPT = "Answer the question based on the following references with citations. Use a mark for each helpful reference you cited, such as [1]. If there are multiple citations at one position, please use a format like [1][2][3]. If a reference is useless, do not cite it." 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--cache_path", type=str, default=None) 16 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo") 17 | parser.add_argument("--api_key", type=str, default=None) 18 | parser.add_argument("--max_length", type=int, default=4096) 19 | parser.add_argument("--temperature", type=float, default=0.5) 20 | parser.add_argument("--top_p", type=float, default=1.0) 21 | args = parser.parse_args() 22 | 23 | print("-" * 10 + "Loading dataset" + "-" * 10) 24 | 25 | eval_data = [] 26 | with open(args.cache_path + "/datasets/webglm-test.jsonl", "r") as read_file: 27 | for line in tqdm(read_file): 28 | eval_data.append(json.loads(line)) 29 | 30 | print("-" * 10 + "Finish loading dataset" + "-" * 10) 31 | 32 | print("-" * 10 + "Generating prompts" + "-" * 10) 33 | 34 | for idx, item in enumerate(tqdm(eval_data)): 35 | prompt = PROMPT + '\n' 36 | for ix, ref in enumerate(item["references"]): 37 | prompt += f'Reference [{ix + 1}]: {ref}\n' 38 | prompt += f'Question: {item["question"]}\nAnswer: ' 39 | eval_data[idx]["prompt"] = prompt 40 | 41 | print("-" * 10 + "Finish generating prompts" + "-" * 10) 42 | 43 | print("-" * 10 + "Loading model" + "-" * 10) 44 | 45 | model_name = args.model.split("/")[-1] 46 | if "gpt" in model_name: 47 | os.environ["OPENAI_API_KEY"] = args.api_key 48 | model = OpenAILLM( 49 | model=model_name, 50 | _api_key_env_var="OPENAI_API_KEY", 51 | max_tokens=args.max_length, 52 | temperature=args.temperature, 53 | top_p=args.top_p 54 | ) 55 | else: 56 | tokenizer = AutoTokenizer.from_pretrained(args.cache_path + "/models/" + args.model, use_fast=False) 57 | model = AutoModelForCausalLM.from_pretrained( 58 | args.cache_path + "/models/" + args.model, 59 | device_map='auto' 60 | ) 61 | 62 | print("-" * 10 + "Finish loading model" + "-" * 10) 63 | 64 | print("-" * 10 + "Predict" + "-" * 10) 65 | 66 | for idx, item in enumerate(tqdm(eval_data)): 67 | prompt = item['prompt'] 68 | if "gpt" in model_name: 69 | output = model.generate( 70 | inputs=[prompt], 71 | system_role="You are a helpful assistant that answers the following questions with proper citations." 72 | ) 73 | item['output'] = output.generations[0][0].text 74 | else: 75 | inputs = tokenizer([prompt], return_tensors="pt").to(model.device) 76 | stop = ["\n", "Ċ", "ĊĊ", "<0x0A>"] # In Llama \n is <0x0A>; In OPT \n is Ċ 77 | stop_token_ids = list( 78 | set( 79 | [tokenizer._convert_token_to_id(stop_token) for stop_token in stop] 80 | + [model.config.eos_token_id] 81 | ) 82 | ) 83 | if "llama" in model_name.lower(): 84 | stop_token_ids.remove(tokenizer.unk_token_id) 85 | 86 | generation = model.generate( 87 | **inputs, 88 | max_length=args.max_length, 89 | temperature=args.temperature, 90 | top_p=args.top_p, 91 | eos_token_id=stop_token_ids, 92 | do_sample=True 93 | ) 94 | output = tokenizer.decode(generation[0][inputs['input_ids'].size(1):], skip_special_tokens=True) 95 | item['output'] = output 96 | 97 | print("-" * 10 + "Finish predicting" + "-" * 10) 98 | 99 | file_name = f"webglm-{model_name}" 100 | file_name = file_name.replace("-", "_") 101 | result_path = args.cache_path + "/results/" + file_name + ".json" 102 | json.dump(eval_data, open(result_path, "w"), indent=4) 103 | 104 | print(f"\nResult file saved as {result_path}") 105 | -------------------------------------------------------------------------------- /benchmarks/WebGLM/results/webglm_Llama_2_7b_chat_hf_20240502.json: -------------------------------------------------------------------------------- 1 | { 2 | "answer_rouge_correctness": 0.2548292585536276, 3 | "answer_citation_recall": 0.10667063492063492, 4 | "answer_citation_precision": 0.9397590361445783 5 | } -------------------------------------------------------------------------------- /benchmarks/WebGLM/run.sh: -------------------------------------------------------------------------------- 1 | script_dir=$(cd $(dirname $0);pwd) 2 | cache_dir=$(dirname $(dirname $script_dir))/.rageval 3 | wget -c https://huggingface.co/datasets/THUDM/webglm-qa/resolve/main/data/test.jsonl -O $cache_dir/datasets/webglm-test.jsonl 4 | python3 setup.py install 5 | 6 | #python3 $script_dir/generate.py\ 7 | # --cache_path $cache_dir\ 8 | # --model Llama-2-7b-chat-hf 9 | 10 | python3 $script_dir/webglm_benchmark.py\ 11 | --cache_path $cache_dir\ 12 | --remote_split Llama_2_7b_chat_hf -------------------------------------------------------------------------------- /benchmarks/WebGLM/webglm_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Dict, Tuple, Any 4 | 5 | from datasets import Dataset 6 | 7 | import rageval as rl 8 | from benchmarks import BaseBenchmark 9 | from rageval.metrics import AnswerRougeCorrectness, AnswerCitationRecall, AnswerCitationPrecision 10 | 11 | 12 | class WebGLMBenchmark(BaseBenchmark): 13 | 14 | name = "webglm_benchmark" 15 | 16 | def __init__(self, cache_path) -> None: 17 | super().__init__() 18 | nli_model = rl.models.NLIModel( 19 | "text2text-generation", 20 | cache_path + "/models/t5_xxl_true_nli_mixture", 21 | ) 22 | self.metrics = [ 23 | AnswerRougeCorrectness(rouge_type="rougeL"), 24 | AnswerCitationRecall(nli_model=nli_model), 25 | AnswerCitationPrecision(nli_model=nli_model) 26 | ] 27 | 28 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]: 29 | self.dataset = self.dataset.rename_column("output", "answers") 30 | self.dataset = self.dataset.rename_column("answer", "gt_answers") 31 | self.dataset = self.dataset.rename_column("references", "contexts") 32 | 33 | results = {} 34 | 35 | for metric in self.metrics: 36 | if metric.name == "answer_rouge_correctness": 37 | self.dataset = self.dataset.map(lambda data: {"gt_answers": [data["gt_answers"]]}) 38 | 39 | results[metric.name], self.dataset = metric.compute(self.dataset, self.batch_size) 40 | 41 | if metric.name == "answer_rouge_correctness": 42 | self.dataset = self.dataset.map(lambda data: {"gt_answers": data["gt_answers"][0]}) 43 | 44 | return results, self.dataset 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--cache_path", type=str, default=None) 50 | parser.add_argument("--remote_split", type=str, default=None) 51 | parser.add_argument("--local_file", type=str, default=None) 52 | args = parser.parse_args() 53 | date = time.strftime("%Y%m%d", time.localtime()) 54 | 55 | benchmark = WebGLMBenchmark(cache_path=args.cache_path) 56 | if args.local_file: 57 | results = benchmark.evaluate( 58 | path="json", 59 | data_files={ 60 | "test": args.cache_path+"/results/"+args.local_file 61 | }, 62 | split="test" 63 | ) 64 | benchmark.save_results(f"benchmarks/WebGLM/results/{args.local_file[:-5]}_{date}.json") 65 | else: 66 | results = benchmark.evaluate(path="golaxy/rag-bench", name="webglm", split=args.remote_split) 67 | benchmark.save_results(f"benchmarks/WebGLM/results/{args.remote_split}_{date}.json") 68 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseBenchmark -------------------------------------------------------------------------------- /benchmarks/auto/README.md: -------------------------------------------------------------------------------- 1 | # AUTO BENCHMARK 2 | 3 | Auto benchmark aims to generate testsets with synthetic Q&A pairs from raw passages, using as few human annotations as possible. To achieve this, we use LLM as a annotator and dataset quality inspector. 4 | 5 | ## Usage 6 | 7 | 1. Prepare your corpus in a JSON file, following the format 8 | 9 | ``` 10 | corpus.json: ["Document 1", "Document 2", ...] 11 | few_shot_cases.json: [ 12 | {"document": "Sample document", 13 | "Query": "Sample question"} 14 | ] 15 | ``` 16 | 17 | 2. Place the corpus JSON file(s) in `corpus` directory. 18 | 3. Run `run.sh` to start dataset generation. The result will saved in `output` directory. 19 | 20 | ### Arguments: 21 | 22 | `--corpus_dir`: Directory containing the corpus and few-shot case JSON files. 23 | 24 | `--output_dir`: Directory where the generated dataset JSON will be saved. 25 | 26 | `--model`: The OpenAI GPT model to use, e.g., gpt-3.5-turbo-16k. 27 | 28 | `--api_key`: Your OpenAI API key. 29 | 30 | ## Citations 31 | 32 | In this case, `documents` refers to a list of news articles, and `few-shot cases` are derived from a random split of the Multi-RC dataset. And we refer to the prompt from ARES: An Automated Evaluation Framework for Retrieval-Augmented Generation Systems. 33 | 34 | ``` bibtex 35 | @misc{saadfalcon2023ares, 36 | title={ARES: An Automated Evaluation Framework for Retrieval-Augmented Generation Systems}, 37 | author={Jon Saad-Falcon and Omar Khattab and Christopher Potts and Matei Zaharia}, 38 | year={2023}, 39 | eprint={2311.09476}, 40 | archivePrefix={arXiv}, 41 | primaryClass={cs.CL} 42 | } 43 | 44 | @inproceedings{MultiRC2018, 45 | author = {Daniel Khashabi and Snigdha Chaturvedi and Michael Roth and Shyam Upadhyay and Dan Roth}, 46 | title = {Looking Beyond the Surface:A Challenge Set for Reading Comprehension over Multiple Sentences}, 47 | booktitle = {Proceedings of North American Chapter of the Association for Computational Linguistics (NAACL)}, 48 | year = {2018} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /benchmarks/auto/auto_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import pandas as pd 6 | from datasets import Dataset 7 | from typing import List, Any, Tuple, Dict 8 | 9 | from rageval.models import OpenAILLM 10 | from prompt import (SYNTHETIC_QUERY_FEW_SHOT, SYNTHETIC_QUERY_SYSTEM, SYNTHETIC_QUERY_USER, SYNTHETIC_ANSWER_SYSTEM, SYNTHETIC_ANSWER_USER) 11 | 12 | 13 | def load_corpus(corpus_dir): 14 | with open(f"{corpus_dir}/corpus.json", "r", encoding="utf-8") as f: 15 | docs = json.load(f) 16 | df = pd.DataFrame(docs) 17 | df.drop_duplicates(inplace=True) 18 | dataset = Dataset.from_dict({'document':df[0].apply(lambda x: x.strip())}) 19 | 20 | with open(f"{corpus_dir}/few_shot_cases.json", "r", encoding="utf-8") as f: 21 | cases = json.load(f) 22 | cases = random.sample(cases, 3) 23 | return dataset, cases 24 | 25 | def generate_responses(engine: OpenAILLM, user_prompts: List[List[str]], system_prompt: List[str]) -> List[str]: 26 | '''Generate responses from the OpenAILLM model.''' 27 | responses = engine.batch_generate(user_prompts, system_roles=system_prompt * len(user_prompts)) 28 | response_texts = [r.generations[0][0].text for r in responses] 29 | 30 | return response_texts 31 | 32 | def generate_questions(engine: OpenAILLM, dataset: Dataset, cases) -> Dataset: 33 | system_prompt = [SYNTHETIC_QUERY_SYSTEM] 34 | few_shot_cases = "" 35 | for i in range(len(cases)): 36 | few_shot_cases += SYNTHETIC_QUERY_FEW_SHOT.format( 37 | document=cases[i]["document"], question=cases[i]["Query"]) 38 | user_prompts = [[SYNTHETIC_QUERY_USER.format( 39 | few_shot_cases=few_shot_cases, document=d['document'])] for d in dataset] 40 | 41 | questions = generate_responses(engine, user_prompts, system_prompt) 42 | return dataset.add_column("question", questions) 43 | 44 | def generate_answers(engine: OpenAILLM, dataset: Dataset) -> Dataset: 45 | system_prompt = [SYNTHETIC_ANSWER_SYSTEM] 46 | user_prompts = [[SYNTHETIC_ANSWER_USER.format( 47 | question=d['question'], document=d['document']) + "\n"] for d in dataset] 48 | 49 | answers = generate_responses(engine, user_prompts, system_prompt) 50 | return dataset.add_column("answer", answers) 51 | 52 | def validate_question_with_answer(dataset: Dataset) -> Dataset: 53 | def check_generated_answer(answer: str): 54 | problematic_phrases = ["I don't know", "i don't know"] 55 | for phrase in problematic_phrases: 56 | if phrase in answer.lower(): 57 | return False 58 | return True 59 | # dataset.filter(lambda x: not check_generated_answer(x["answer"])).to_json(f"{args.output_dir}/filtered_dataset.json") 60 | return dataset.filter(lambda x: check_generated_answer(x["answer"])) 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--corpus_dir", type=str, default="benchmarks/auto/corpus") 65 | parser.add_argument("--output_dir", type=str, default="benchmarks/auto/output") 66 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo") 67 | parser.add_argument("--api_key", type=str, default=None) 68 | args = parser.parse_args() 69 | 70 | os.environ["OPENAI_API_KEY"] = args.api_key 71 | engine = OpenAILLM(args.model, "OPENAI_API_KEY") 72 | 73 | print(f"\nLoad corpus from {args.corpus_dir}") 74 | dataset, cases = load_corpus(args.corpus_dir) 75 | 76 | print("Start generate questions...") 77 | dataset = generate_questions(engine, dataset, cases) 78 | 79 | print("Start generate answers...") 80 | dataset = generate_answers(engine, dataset) 81 | 82 | print("Validate questions...") 83 | dataset = validate_question_with_answer(dataset) 84 | 85 | engine.calculate_api_cost() 86 | 87 | dataset.to_json(f"{args.output_dir}/dataset.json") 88 | print(f"\nFinish generate dataset. Dataset saved as {args.output_dir}/dataset.json") 89 | -------------------------------------------------------------------------------- /benchmarks/auto/prompt.py: -------------------------------------------------------------------------------- 1 | SYNTHETIC_QUERY_SYSTEM = '''You are an expert question-answering system. You must create a question for the provided document. The question must be answerable within the context of the document.''' # system prompt 2 | 3 | SYNTHETIC_QUERY_FEW_SHOT = '''Document: {document} 4 | Question: {question} 5 | 6 | ''' # few-shot 7 | 8 | SYNTHETIC_QUERY_USER = '''{few_shot_cases}Document: {document} 9 | Question: ''' # user prompt 10 | 11 | SYNTHETIC_ANSWER_SYSTEM='''You are a helpful assistant that are good at helping to answer a query based on the context step by step, the context is a document. If there is a good answer from the context, try to summarize the context as the answer. If the query doesn't form a complete question, or you don't know the answer, or there is no enough information to determine the answer, or the context is irrelevant to the question, just say I DON'T NO.''' # system prompt 12 | 13 | SYNTHETIC_ANSWER_USER = '''Here is the question {question} 14 | Here is the context: {document}''' # user prompt -------------------------------------------------------------------------------- /benchmarks/auto/run.sh: -------------------------------------------------------------------------------- 1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0)))) 2 | cd $rageval_dir 3 | echo "Running Auto Benchmark" 4 | python3 $rageval_dir/benchmarks/auto/auto_benchmark.py --corpus_dir "benchmarks/auto/corpus"\ 5 | --output_dir "benchmarks/auto/output"\ 6 | --model "gpt-3.5-turbo"\ 7 | --api_key "YOUR_API_KEY" 8 | echo "Auto Benchmark Complete" -------------------------------------------------------------------------------- /benchmarks/base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Dict, Any, Tuple, Optional 2 | from abc import abstractmethod, ABC 3 | from dataclasses import dataclass 4 | # import importlib 5 | from datasets import Dataset, load_dataset 6 | from rageval.metrics import Metric 7 | from .utils import save_json 8 | 9 | class BaseBenchmark(ABC): 10 | """Base class for benchmarks.""" 11 | 12 | metrics: List[Metric] = [] 13 | dataset: Dataset 14 | 15 | def __init__(self, batch_size: int = 1) -> None: 16 | """Initialization.""" 17 | self.batch_size = batch_size 18 | 19 | @property 20 | @abstractmethod 21 | def name(self) -> str: 22 | """The benchmark name.""" 23 | ... 24 | 25 | @property 26 | def metric_names(self) -> List[str]: 27 | """The metric names.""" 28 | return [m.name for m in self.metrics] 29 | 30 | def load_data(self, **kwargs) -> None: 31 | """Load the dataset with answers to evaluate.""" 32 | print("Load dataset...") 33 | self.dataset = load_dataset(**kwargs) 34 | print("Dataset loaded.") 35 | 36 | @abstractmethod 37 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]: 38 | """Evaluate the dataset and return the results and the detailed dataset with each sample scores.""" 39 | ... 40 | 41 | def evaluate(self, **kwargs) -> Dict[Any, Any]: 42 | """Load datasets and evaluate it, return a result dict.""" 43 | if not hasattr(self, "dataset"): 44 | self.load_data(**kwargs) 45 | print("Start evaluating...") 46 | self.results, self.dataset = self._evaluate() 47 | print("Evaluation finished.") 48 | return self.results 49 | 50 | def set_metric(self, metrics: List[Metric]) -> None: 51 | """Reset the metrics.""" 52 | if all(isinstance(m, Metric) for m in metrics): 53 | self.metrics = metrics 54 | else: 55 | raise ValueError("The metrics should be a list of Metric objects.") 56 | 57 | def save_dataset(self, file_path: str) -> None: 58 | """Save the result to files.""" 59 | if not hasattr(self, "dataset"): 60 | raise ValueError("Please load the dataset and evaluate it first.") 61 | self.dataset.to_json(file_path, orient="records") 62 | print(f"Dataset saved to {file_path}.") 63 | 64 | def save_results(self, file_path: str) -> None: 65 | """Save the result to files.""" 66 | if not hasattr(self, "results"): 67 | raise ValueError("Please run evaluation first.") 68 | save_json(self.results, file_path) 69 | print(f"Results saved to {file_path}.") 70 | 71 | # def get_metric(self, name: str, **kwargs) -> Union[Metric, MetricWithLLM]: 72 | # """Get the metric by name.""" 73 | # module = importlib.import_module(f"rageval.metrics") 74 | # return getattr(module, name)(**kwargs) 75 | -------------------------------------------------------------------------------- /benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | def save_json(data, file_path: str): 5 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 6 | with open(file_path, "w+") as f: 7 | json.dump(data, f, indent=4) 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "RagEval" 7 | version = "0.1.0" 8 | description = "Evaluation tools for Retrieval-augmented Generation (RAG) methods." 9 | keywords = ["RAG evaluation tools"] 10 | readme = {file = "README.md", content-type = "text/markdown"} 11 | license = {file = "LICENSE"} 12 | authors = [ 13 | { name="Wenshan Wang, Yixing Fan, etc.", email="wangwenshan@ict.ac.cn" }, 14 | ] 15 | requires-python = ">=3.10" 16 | classifiers = [ 17 | "Development Status :: 3 - Alpha", 18 | "Environment :: Console", 19 | "Operating System :: POSIX :: Linux", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | "License :: OSI Approved :: Apache Software License", 22 | "Programming Language :: Python :: 3" 23 | ] 24 | dependencies = [ 25 | "refchecker == 0.2.13", 26 | "numpy >= 1.26", 27 | "tqdm >= 4.66", 28 | "hyperopt >= 0.1.1", 29 | "h5py >= 2.8.0", 30 | "openai >= 1.10.0", 31 | "datasets >= 3.0.1", 32 | "langchain >= 0.3.1", 33 | "langchain-community >= 0.3.1", 34 | "transformers >= 4.37.2", 35 | "torch >= 2.2.0", 36 | "pandas >= 2.0.0", 37 | "nltk >= 3.9.1", 38 | "spacy >= 3.7.4", 39 | "en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl", 40 | "rouge_score >= 0.1.2", 41 | "jieba >= 0.42.1", 42 | "evaluate >= 0.4.3" 43 | ] 44 | 45 | [project.optional-dependencies] 46 | tests = [ 47 | "coverage >= 4.3.4", 48 | "codecov >= 2.0.15", 49 | "pytest >= 3.7.4", 50 | "pytest-cov >= 2.4.0", 51 | "flake8 == 7.0.0", 52 | "pydocstyle == 6.1", 53 | "flake8_docstrings >= 1.7.0" 54 | ] 55 | benchmarks = [ 56 | "accelerate == 0.27.2", 57 | "sentencepiece == 0.2.0", 58 | "protobuf == 4.25.3" 59 | ] 60 | 61 | [project.urls] 62 | Homepage = "https://github.com/gomate-community/rageval" 63 | Issues = "https://github.com/gomate-community/rageval/issues" 64 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | api: mark a test as a web api 4 | slow: mark a test as whole testing 5 | quick: mark a test as quick testing 6 | -------------------------------------------------------------------------------- /rageval/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .version import version as __version__ 3 | except ImportError: 4 | __version__ = "unknown version" 5 | 6 | from . import tasks 7 | from . import metrics 8 | from . import models 9 | from . import utils 10 | 11 | from .evaluations import evaluate 12 | 13 | # __all__ = ["evaluate", "__version__"] 14 | -------------------------------------------------------------------------------- /rageval/evaluations.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable, List 3 | 4 | from datasets import Dataset, concatenate_datasets 5 | 6 | from rageval.metrics import Metric 7 | 8 | 9 | def evaluate( 10 | testset: Dataset, 11 | metrics: List[Metric] | None = None, 12 | models: List[Callable] | None = None) -> (Dataset, Dataset): 13 | """Conduct the evaluation on testset.""" 14 | 15 | # run evaluation 16 | assert (len(metrics) == len(models)) 17 | [metrics[i].init_model(models[i]) for i in range(len(metrics))] 18 | avg_scores = [] 19 | instance_scores = [testset] 20 | for metric in metrics: 21 | print(f"evaluating with [{metric.name}]") 22 | avg_score, _testset = metric.compute(testset) 23 | avg_scores.append(Dataset.from_dict({metric.name: [avg_score]})) 24 | instance_scores.append(_testset.select_columns(metric.name)) 25 | 26 | return concatenate_datasets(avg_scores), concatenate_datasets(instance_scores) 27 | -------------------------------------------------------------------------------- /rageval/exceptions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/exceptions.py -------------------------------------------------------------------------------- /rageval/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Metric, MetricWithLLM, add_attribute 2 | 3 | # Metrics about the answer correctness 4 | from .answer_correctness._answer_accuracy import AnswerAccuracy 5 | from .answer_correctness._answer_bleu import AnswerBleuScore 6 | from .answer_correctness._answer_chrf import AnswerCHRFCorrectness 7 | from .answer_correctness._answer_exact_match import AnswerEMCorrectness 8 | from .answer_correctness._answer_f1 import AnswerF1Correctness 9 | from .answer_correctness._answer_rouge_correctness import AnswerRougeCorrectness 10 | from .answer_correctness._answer_bert_score import AnswerBERTScore 11 | from .answer_correctness._answer_edit_distance import AnswerEditDistance 12 | from .answer_correctness._answer_claim_recall import AnswerNLICorrectness 13 | from .answer_correctness._answer_disambig_f1 import AnswerDisambigF1Correctness 14 | from .answer_correctness._answer_lcs_ratio import AnswerLCSRatio 15 | from .answer_correctness._answer_ter import AnswerTERCorrectness 16 | ##from .answer_correctness._answer_relevancy import AnswerRelevancy 17 | 18 | # Metrics about the answer groundedness 19 | from .answer_groundedness._answer_citation_precision import AnswerCitationPrecision 20 | from .answer_groundedness._answer_citation_recall import AnswerCitationRecall 21 | from .answer_groundedness._context_reject_rate import ContextRejectRate 22 | ##from .answer_groundedness._claim_faithfulness import ClaimFaithfulness 23 | 24 | # Metrics about the answer informativeness 25 | ##from .answer_informative._claim_num import ClaimNum 26 | from .answer_informativeness._text_length import TextLength 27 | ##from .answer_informativeness._repetitiveness import Repetitiveness 28 | ##from .answer_informativeness._pairwise_accuracy import PairwiseAccuracy 29 | from .answer_informativeness._answer_distinct12 import AnswerDistinct 30 | 31 | # Metrics about the context relevancy 32 | 33 | # Metrics about the context aduquacy 34 | from .context_adequacy._context_recall import ContextRecall 35 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_accuracy.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import evaluate 4 | 5 | import datasets 6 | 7 | from rageval.metrics import Metric, add_attribute 8 | 9 | 10 | _DESCRIPTION = """\ 11 | The AnswerAccuracy is to measure the correctness of answers. 12 | 13 | This metric is applicable in scenarios where the LLM is required to output a unique short answer, such as options for \ 14 | multiple-choice questions or a single entity name. 15 | The renowned MMLU dataset utilizes this metric for evaluation. In the evaluation of the MMLU dataset, probabilities \ 16 | for each answer are first calculated, and the answer with the highest probability is selected as the predicted result. 17 | In our tool, we assume that the prediction result has already been obtained, and only perform the final score \ 18 | calculation. 19 | """ 20 | 21 | _KWARGS_DESCRIPTION = """\ 22 | Args: 23 | name : str 24 | batch_size : int, Batch size for openai completion. 25 | 26 | Optional Args: 27 | None 28 | 29 | Functions: 30 | _compute_one: Evaluating the correctness of answer. 31 | 32 | Examples: 33 | >>> from datasets import Dataset 34 | >>> import rageval as rl 35 | >>> sample = { 36 | ... "answers": [ 37 | ... "A", 38 | ... "B", 39 | ... "C" 40 | ... ], 41 | ... "gt_answers": [ 42 | ... "A", 43 | ... "C", 44 | ... "C" 45 | ... ] 46 | ... } 47 | >>> dataset = Dataset.from_dict(sample) 48 | >>> metric = rl.metrics.AnswerAccuracy() 49 | >>> metric.mtype 50 | 'AnswerCorrectness' 51 | >>> score, results = metric.compute(dataset["answers"], dataset["gt_answers"], 1) 52 | >>> score 53 | 0.6666666666666666 54 | >>> results[0] 55 | True 56 | """ 57 | 58 | _CITATION = """\ 59 | @misc{hendrycks2021measuring, 60 | title={Measuring Massive Multitask Language Understanding}, 61 | author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, 62 | year={2021}, 63 | eprint={2009.03300}, 64 | archivePrefix={arXiv}, 65 | primaryClass={cs.CY} 66 | } 67 | """ 68 | 69 | 70 | @dataclass 71 | @add_attribute('mtype', 'AnswerCorrectness') 72 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 73 | class AnswerAccuracy(Metric): 74 | """Estimates the correctness of answers.""" 75 | 76 | name = "answer_accuracy" 77 | 78 | ALIAS = ['answer_accuracy'] 79 | 80 | def __init__(self): 81 | """ 82 | Explicitly initialize AnswerAccuracy. 83 | 84 | Ensure all parent classes are initialized. 85 | """ 86 | super().__init__() 87 | self.info = evaluate.MetricInfo( 88 | description=_DESCRIPTION, 89 | inputs_description=_KWARGS_DESCRIPTION, 90 | citation=_CITATION, 91 | homepage="", 92 | features=datasets.Features( 93 | { 94 | "answers": datasets.Value("string"), 95 | "gt_answers": datasets.Value("string") 96 | } 97 | ), 98 | codebase_urls=["https://github.com/hendrycks/test"], 99 | reference_urls=["https://arxiv.org/abs/2009.03300"] 100 | ) 101 | 102 | def __repr__(self) -> str: 103 | """:return: Formatted string representation of the metric.""" 104 | return f"{self.ALIAS[0]}" 105 | 106 | def _compute_one( 107 | self, 108 | answer: str, 109 | gt_answer: str 110 | ) -> float: 111 | """Evaluating the correctness of answer.""" 112 | return answer == gt_answer 113 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_bert_score.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Tuple 3 | import evaluate 4 | 5 | import datasets 6 | from rageval.metrics import Metric, add_attribute 7 | from bert_score import BERTScorer 8 | import logging 9 | import transformers 10 | transformers.tokenization_utils.logger.setLevel(logging.ERROR) 11 | transformers.configuration_utils.logger.setLevel(logging.ERROR) 12 | transformers.modeling_utils.logger.setLevel(logging.ERROR) 13 | 14 | _DESCRIPTION = """\ 15 | BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks. 16 | 17 | For details, see the paper: https://openreview.net/forum?id=SkeHuCVFDr 18 | """ 19 | 20 | _KWARGS_DESCRIPTION = """\ 21 | Args: 22 | name : str 23 | lang : str, Language of the text. Default is "en". 24 | rescale_with_baseline : bool, Whether to rescale the score with pre-computed baseline. Not affect BERTScore's correlation with human judgment. Default is False. For more details, see https://github.com/Tiiiger/bert_score/blob/master/journal/rescale_baseline.md 25 | 26 | Optional Args: 27 | None 28 | 29 | Functions: 30 | _clean: clean special word in sentence. 31 | _compute_one: compute bleu score for single prediction with its references 32 | _compute_batch: compute bleu score for a batch of predictions with their references 33 | 34 | Examples: 35 | >>> from datasets import Dataset 36 | >>> import rageval as rl 37 | >>> sample = { 38 | ... "answers": [ 39 | ... "It is a guide to action which ensures that the military always obeys the commands of the party.", 40 | ... "It is to insure the troops forever hearing the activity guidebook that party direct.", 41 | ... ], 42 | ... "gt_answers": [ 43 | ... [ 44 | ... "It is a guide to action that ensures that the military will forever heed Party commands.", 45 | ... "It is the guiding principle which guarantees the military forces always being under the command of the Party.", 46 | ... "It is the practical guide for the army always to heed the directions of the party.", 47 | ... ], 48 | ... [ 49 | ... "It is a guide to action that ensures that the military will forever heed Party commands.", 50 | ... "It is the guiding principle which guarantees the military forces always being under the command of the Party.", 51 | ... "It is the practical guide for the army always to heed the directions of the party.", 52 | ... ] 53 | ... ], 54 | ... } 55 | >>> dataset = Dataset.from_dict(sample) 56 | >>> metric = rl.metrics.AnswerBERTScore(lang='en', rescale_with_baseline=True) 57 | >>> metric.mtype 58 | 'AnswerCorrectness' 59 | >>> score, results = metric.compute(dataset["answers"], dataset["gt_answers"], 1) 60 | >>> round(score, 2) 61 | 0.55 62 | >>> round(results[0], 1) 63 | 0.7 64 | """ 65 | 66 | 67 | _CITATION = """\ 68 | @inproceedings{bert-score, 69 | title={BERTScore: Evaluating Text Generation with BERT}, 70 | author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi}, 71 | booktitle={International Conference on Learning Representations}, 72 | year={2020}, 73 | url={https://openreview.net/forum?id=SkeHuCVFDr} 74 | } 75 | """ 76 | 77 | 78 | @dataclass 79 | @add_attribute('mtype', 'AnswerCorrectness') 80 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 81 | class AnswerBERTScore(Metric): 82 | """BERTScore depends on the model and language pair selected.""" 83 | 84 | name = "answer_bert_score" 85 | 86 | ALIAS = ['answer_bert_score'] 87 | 88 | def __init__(self, lang: str = "en", rescale_with_baseline=False): 89 | """Explicitly initialize the AnswerBERTScore to ensure all parent class initialized.""" 90 | super().__init__() 91 | self.scorer = BERTScorer(lang=lang, rescale_with_baseline=rescale_with_baseline) 92 | self.info = evaluate.MetricInfo( 93 | description=_DESCRIPTION, 94 | inputs_description=_KWARGS_DESCRIPTION, 95 | citation=_CITATION, 96 | homepage="", 97 | features=datasets.Features( 98 | { 99 | "answers": datasets.Value("string"), 100 | "gt_answers": datasets.Sequence(datasets.Value("string")) 101 | } 102 | ), 103 | codebase_urls=[ 104 | "https://github.com/Tiiiger/bert_score/tree/master", 105 | ], 106 | reference_urls=["https://openreview.net/forum?id=SkeHuCVFDr"] 107 | ) 108 | 109 | def __repr__(self) -> str: 110 | """:return: Formatted string representation of the metric.""" 111 | return f"{self.ALIAS[0]}" 112 | 113 | def _compute_one( 114 | self, 115 | pred_answers: str, 116 | ref_answers: List[str] 117 | ) -> float: 118 | """Compute the BERTscore for a pair of predictions and references.""" 119 | P, R, F1 = self.scorer.score([pred_answers] * len(ref_answers), ref_answers) 120 | return F1.max().tolist() 121 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_bleu.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass 3 | from typing import List, Tuple 4 | import evaluate 5 | import datasets 6 | from rageval.metrics import Metric, add_attribute 7 | from tqdm import tqdm 8 | 9 | 10 | _DESCRIPTION = """\ 11 | BLEU (Bilingual Evaluation Understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another. 12 | Scores are calculated by comparing individual translated segments, e.g., sentences, with a set of high-quality reference translations. 13 | Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. 14 | Neither intelligibility nor grammatical correctness are not taken into account. 15 | 16 | For details, see the paper: http://www.aclweb.org/anthology/P02-1040.pdf 17 | """ 18 | 19 | _KWARGS_DESCRIPTION = """\ 20 | Args: 21 | name : str 22 | batch_size : int, Batch size for openai completion. 23 | 24 | Optional Args: 25 | None 26 | 27 | Functions: 28 | _clean: clean special word in sentence. 29 | _compute_one: compute bleu score for single prediction with its references 30 | 31 | Examples: 32 | >>> from datasets import Dataset 33 | >>> import rageval as rl 34 | >>> sample = { 35 | ... "answers": [ 36 | ... "It is a guide to action which ensures that the military always obeys the commands of the party.", 37 | ... "It is to insure the troops forever hearing the activity guidebook that party direct.", 38 | ... ], 39 | ... "gt_answers": [ 40 | ... [ 41 | ... "It is a guide to action that ensures that the military will forever heed Party commands.", 42 | ... "It is the guiding principle which guarantees the military forces always being under the command of the Party.", 43 | ... "It is the practical guide for the army always to heed the directions of the party.", 44 | ... ], 45 | ... [ 46 | ... "It is a guide to action that ensures that the military will forever heed Party commands.", 47 | ... "It is the guiding principle which guarantees the military forces always being under the command of the Party.", 48 | ... "It is the practical guide for the army always to heed the directions of the party.", 49 | ... ] 50 | ... ], 51 | ... } 52 | >>> dataset = Dataset.from_dict(sample) 53 | >>> metric = rl.metrics.AnswerBleuScore() 54 | >>> metric.mtype 55 | 'AnswerCorrectness' 56 | >>> score, results = metric.compute(dataset["answers"], dataset["gt_answers"], 1) 57 | >>> score 58 | 0.3450835085970013 59 | >>> results[0] 60 | 0.5401725898595141 61 | """ 62 | 63 | 64 | _CITATION = """\ 65 | @misc{Kishore2002bleu, 66 | title={Bleu: a method for automatic evaluation of machine translation}, 67 | author={Kishore Papineni and Salim Roukos and Todd Ward and Wei-Jing Zhu}, 68 | year={2002}, 69 | page={311-318}, 70 | primaryClass={cs.CL} 71 | } 72 | """ 73 | 74 | 75 | @dataclass 76 | @add_attribute('mtype', 'AnswerCorrectness') 77 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 78 | class AnswerBleuScore(Metric): 79 | """Bleu score computing with good quality reference.""" 80 | 81 | """Note: this metric is just fit for English data by now(24/03/12)""" 82 | 83 | name = "answer_bleu" 84 | 85 | ALIAS = ['answer_bleu'] 86 | 87 | def __init__(self): 88 | """Explicitly initialize the AnswerBleuScore to ensure all parent class initialized.""" 89 | super().__init__() 90 | self.info = evaluate.MetricInfo( 91 | description=_DESCRIPTION, 92 | inputs_description=_KWARGS_DESCRIPTION, 93 | citation=_CITATION, 94 | homepage="", 95 | features=datasets.Features( 96 | { 97 | "answers": datasets.Value("string"), 98 | "gt_answers": datasets.Sequence(datasets.Value("string")) 99 | } 100 | ), 101 | codebase_urls=[ 102 | "https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py", 103 | "https://github.com/huggingface/datasets/blob/main/metrics/bleu/bleu.py" 104 | ], 105 | reference_urls=["https://www.aclweb.org/anthology/P02-1040.pdf"] 106 | ) 107 | 108 | def __repr__(self) -> str: 109 | """:return: Formatted string representation of the metric.""" 110 | return f"{self.ALIAS[0]}" # pragma: no cover 111 | 112 | def compute( 113 | self, 114 | pred_answers: List[str], 115 | ref_answers: List[List[str]], 116 | batch_size: int, 117 | ) -> Tuple[float, List[float]]: 118 | """Compute the bleu score on both corpus level and instance level.""" 119 | bleu = evaluate.load("bleu") 120 | # corpus level 121 | bleu_result = bleu.compute(predictions=pred_answers, references=ref_answers) 122 | score = bleu_result['bleu'] 123 | # instance level 124 | scores = [] 125 | for pred_answer, ref_answer in tqdm(zip(pred_answers, ref_answers), 126 | desc=f"Computing {self.name}", 127 | total=len(pred_answers)): 128 | scores.append(self._compute_one(pred_answer, ref_answer)) 129 | return score, scores 130 | 131 | def _compute_one( 132 | self, 133 | pred_answers: List[str], 134 | ref_answers: List[List[str]] 135 | ) -> List[float]: 136 | """Compute the bleu score on an instance level.""" 137 | 138 | bleu = evaluate.load("bleu") 139 | bleu_result = bleu.compute(predictions=[pred_answers], references=[ref_answers]) 140 | bleu_score = bleu_result['bleu'] 141 | return bleu_score 142 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_claim_recall.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Callable, Tuple 3 | import evaluate 4 | 5 | import datasets 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from rageval.metrics import Metric, add_attribute 10 | from rageval.utils.check_utils import text_to_sents 11 | 12 | _DESCRIPTION = """\ 13 | The AnswerNLICorrectness is to measure the correctness of long-form answers. In the original paper, the author first \ 14 | use Instruct-GPT(text-davinci-003) to generate three "sub-claims" (based on gold answers) and use a state-of-the-art \ 15 | natural-language inference (NLI) model TRUE(Honovich et al., 2022) to check whether the model output entails the \ 16 | sub-claims (claim recall). 17 | 18 | For details, see the paper: https://arxiv.org/abs/2305.14627. 19 | """ 20 | 21 | _KWARGS_DESCRIPTION = """\ 22 | Args: 23 | name : str 24 | batch_size : int, Batch size for openai completion. 25 | 26 | Optional Args: 27 | None 28 | 29 | Functions: 30 | _verify_by_stance: verify whether the stance of args:`claims` can be supported by args:`answer`. 31 | _compute_one: compute the score by measure whether the args:`claims` can be supported by args:`answers`. 32 | 33 | Examples: 34 | >>> from datasets import Dataset 35 | >>> import rageval as rl 36 | >>> sample = { 37 | ... "answers": [ 38 | ... "They went a while before introducing ads, so they could make money, as they needed to establish " 39 | ... "their brand and amass users. Once you have dedicated users, introducing ads won't deter most, but if " 40 | ... "you are still new, having ads will deter a lot. The same goes for Uber, it's not that they aren't " 41 | ... "making money, it's that they are reinvesting a ton of it to make their service better." 42 | ... ], 43 | ... "gt_answers": [ 44 | ... [ 45 | ... "Firms like Snapchat and Uber need to establish their brand and amass users before introducing " 46 | ... "ads.", 47 | ... "Introducing ads too early can deter potential users.", 48 | ... "Uber is reinvesting a lot of money to make their service better." 49 | ... ] 50 | ... ] 51 | ... } 52 | >>> dataset = Dataset.from_dict(sample) 53 | >>> nli_model = rl.models.NLIModel( 54 | ... 'text2text-generation', 55 | ... 'hf-internal-testing/tiny-random-T5ForConditionalGeneration' 56 | ... ) 57 | >>> metric = rl.metrics.AnswerNLICorrectness(nli_model=nli_model, decompose_model="nltk") 58 | >>> metric.mtype 59 | 'AnswerCorrectness' 60 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1) 61 | >>> assert score == 0 or score == 1 62 | """ 63 | 64 | _CITATION = """\ 65 | @misc{gao2023enabling, 66 | title={Enabling Large Language Models to Generate Text with Citations}, 67 | author={Tianyu Gao and Howard Yen and Jiatong Yu and Danqi Chen}, 68 | year={2023}, 69 | eprint={2305.14627}, 70 | archivePrefix={arXiv}, 71 | primaryClass={cs.CL} 72 | } 73 | """ 74 | 75 | 76 | @dataclass 77 | @add_attribute('mtype', 'AnswerCorrectness') 78 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 79 | class AnswerNLICorrectness(Metric): 80 | """Estimates the correctness of long-form answers based on the NLI model.""" 81 | 82 | name = "answer_claim_recall" 83 | 84 | ALIAS = ['answer_claim_recall'] 85 | 86 | def __init__(self, nli_model: Callable, decompose_model: str = "gpt-3.5-turbo"): 87 | """ 88 | Explicitly initialize AnswerNLICorrectness. 89 | 90 | Ensure all parent classes are initialized. 91 | Ensure nli_model and decompose_model is initialized. 92 | """ 93 | super().__init__() 94 | self.nli_model = nli_model 95 | self.decompose_model = decompose_model 96 | self.info = evaluate.MetricInfo( 97 | description=_DESCRIPTION, 98 | inputs_description=_KWARGS_DESCRIPTION, 99 | citation=_CITATION, 100 | homepage="", 101 | features=datasets.Features( 102 | { 103 | "answers": datasets.Value("string"), 104 | "gt_answers": datasets.Value("string") 105 | } 106 | ), 107 | codebase_urls=["https://github.com/princeton-nlp/ALCE"], 108 | reference_urls=["https://arxiv.org/abs/2305.14627"] 109 | ) 110 | 111 | def __repr__(self) -> str: 112 | """:return: Formatted string representation of the metric.""" 113 | return f"{self.ALIAS[0]}" # pragma: no cover 114 | 115 | def _compute_one( 116 | self, 117 | answer: str, 118 | claims: List[str] 119 | ) -> float: 120 | """ 121 | Evaluate the correctness of an answer. 122 | 123 | Firstly, split the gt_answer into a set of claims. 124 | Then, compute the faithfulness score of each claim. The faithfulness is a binary score. 125 | Finally, aggregate all faithfulness score of each claim. 126 | """ 127 | 128 | detail_results = [] 129 | scores = [] 130 | 131 | for i, claim in enumerate(claims): 132 | # obtain the faithfulness of each claim by language inference model. 133 | label = self.nli_model.generate_infer(premise=answer, hypothesis=claim) 134 | detail_results.append({ 135 | "answer": answer, 136 | "claim": claim, 137 | "reasoning": "", 138 | "error": "", 139 | "factuality": label, 140 | }) 141 | scores.append(label) 142 | # Note that the detail_results can be recorded by logger.info 143 | return np.average(scores) 144 | 145 | def _compute_batch( 146 | self, 147 | pred_answers: List[str], 148 | ref_answers: List[List[str]] 149 | ) -> List[float]: 150 | """ 151 | Evaluate the correctness of a batch of answers. 152 | 153 | Firstly, split the gt_answer into a set of claims. 154 | Then, compute the faithfulness score of each claim. The faithfulness is a binary score. 155 | Finally, aggregate all faithfulness score of each claim. 156 | """ 157 | 158 | if isinstance(ref_answers, list): 159 | if isinstance(ref_answers[0], list): 160 | # gt_answers has been decomposed into claims list 161 | claims = ref_answers 162 | elif isinstance(ref_answers[0], str): 163 | # use decompose_model to decompose the gt_answers into claims list 164 | claims = [text_to_sents(gt_answer, self.decompose_model) for gt_answer in ref_answers] 165 | else: 166 | raise ValueError("The type of gt_answers element should be list or string.") # pragma: no cover 167 | else: 168 | raise ValueError("The type of gt_answers should be list.") # pragma: no cover 169 | 170 | results = [] 171 | for i, answer in tqdm(enumerate(pred_answers)): 172 | r = self._compute_one(answer, claims[i]) 173 | results.append(r) 174 | return results 175 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_disambig_f1.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from collections import Counter 4 | from dataclasses import dataclass 5 | from typing import List 6 | import evaluate 7 | 8 | import datasets 9 | import numpy as np 10 | import spacy 11 | 12 | from rageval.metrics import Metric, add_attribute 13 | 14 | _DESCRIPTION = """\ 15 | The Disambig-F1 is a variant of the F1 score, estimates the similarity between the disambiguation of the answer and the ground truth answer. 16 | 17 | The original metric was presented in [ASQA paper](https://aclanthology.org/2022.emnlp-main.566/), and implemented through [this code](https://github.com/google-research/language/blob/master/language/asqa/scoring.py#L273). And we adopted an [alternative implementation](https://github.com/jzbjyb/FLARE/tree/main/src/datasets.py#L29) from the paper [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983). 18 | """ 19 | 20 | _KWARGS_DESCRIPTION = """\ 21 | Args: 22 | name : str 23 | model : str, model name of spacy model to ner. 24 | 25 | Optional Args: 26 | None 27 | 28 | Functions: 29 | _normalize_text: normalize the text by removing articles, white spaces, punctuations and lowercasing. 30 | _ner: extract named entities from the text. 31 | _validate_data: validate the dataset format. 32 | _f1_score: compute the f1 score between `pred` string and `ref` string. 33 | _compute_one: evaluate the disambig f1 score of between `answer` and `gt_answers`, return the highest score in all pairs. 34 | 35 | Examples: 36 | >>> from datasets import Dataset 37 | >>> import rageval as rl 38 | >>> sample = { 39 | ... "answers": [ 40 | ... "Democrat rick kriseman won the 2016 mayoral election, while re- publican former mayor rick baker did so in the 2017 mayoral election." 41 | ... ], 42 | ... "gt_answers": [ 43 | ... [ 44 | ... "Kriseman", 45 | ... "Rick Kriseman" 46 | ... ] 47 | ... ] 48 | ... } 49 | >>> dataset = Dataset.from_dict(sample) 50 | >>> metric = rl.metrics.AnswerDisambigF1Correctness(model="en_core_web_sm") 51 | >>> metric.mtype 52 | 'AnswerCorrectness' 53 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1) 54 | >>> assert 0 <= score <= 1 55 | """ 56 | 57 | _CITATION = """\ 58 | @inproceedings{stelmakh-etal-2022-asqa, 59 | title = "{ASQA}: Factoid Questions Meet Long-Form Answers", 60 | author = "Stelmakh, Ivan and 61 | Luan, Yi and 62 | Dhingra, Bhuwan and 63 | Chang, Ming-Wei", 64 | editor = "Goldberg, Yoav and 65 | Kozareva, Zornitsa and 66 | Zhang, Yue", 67 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 68 | month = dec, 69 | year = "2022", 70 | address = "Abu Dhabi, United Arab Emirates", 71 | publisher = "Association for Computational Linguistics", 72 | url = "https://aclanthology.org/2022.emnlp-main.566", 73 | doi = "10.18653/v1/2022.emnlp-main.566", 74 | pages = "8273--8288", 75 | } 76 | @misc{jiang2023active, 77 | title={Active Retrieval Augmented Generation}, 78 | author={Zhengbao Jiang and Frank F. Xu and Luyu Gao and Zhiqing Sun and Qian Liu and Jane Dwivedi-Yu and Yiming Yang and Jamie Callan and Graham Neubig}, 79 | year={2023}, 80 | eprint={2305.06983}, 81 | archivePrefix={arXiv}, 82 | primaryClass={cs.CL} 83 | } 84 | """ 85 | 86 | 87 | @dataclass 88 | @add_attribute('mtype', 'AnswerCorrectness') 89 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 90 | class AnswerDisambigF1Correctness(Metric): 91 | """Estimates the Disambig-F1 between answers and ground truth answers.""" 92 | 93 | name = "answer_disambig_f1" 94 | 95 | ALIAS = ['answer_disambig_f1'] 96 | 97 | def __init__(self, model: str = "en_core_web_sm"): 98 | """ 99 | Explicitly initialize AnswerDisambigF1Correctness. 100 | 101 | Ensure all parent classes are initialized. 102 | Ensure spacy ner model is initialized. 103 | """ 104 | super().__init__() 105 | self.model = model 106 | self.nlp = spacy.load(model) 107 | self.info = evaluate.MetricInfo( 108 | description=_DESCRIPTION, 109 | inputs_description=_KWARGS_DESCRIPTION, 110 | citation=_CITATION, 111 | homepage="", 112 | features=datasets.Features( 113 | { 114 | "answers": datasets.Value("string"), 115 | "gt_answers": datasets.Sequence(datasets.Value("string")) 116 | } 117 | ), 118 | codebase_urls=[ 119 | "https://github.com/google-research/language/blob/master/language/asqa", 120 | "https://github.com/jzbjyb/FLARE" 121 | ], 122 | reference_urls=[ 123 | "https://aclanthology.org/2022.emnlp-main.566", 124 | "https://arxiv.org/abs/2305.06983" 125 | ] 126 | ) 127 | 128 | def __repr__(self) -> str: 129 | """:return: Formatted string representation of the metric.""" 130 | return f"{self.ALIAS[0]}" # pragma: no cover 131 | 132 | def _normalize_text(self, s: str) -> str: 133 | def remove_articles(text): 134 | return re.sub(r'\b(a|an|the)\b', ' ', text) 135 | 136 | def white_space_fix(text): 137 | return ' '.join(text.split()) 138 | 139 | def remove_punc(text): 140 | exclude = set(string.punctuation) 141 | return ''.join(ch for ch in text if ch not in exclude) 142 | 143 | def lower(text): 144 | return text.lower() 145 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 146 | 147 | def _ner(self, s: str) -> List[str]: 148 | """Extract named entities from the text.""" 149 | doc = self.nlp(s) 150 | ents = doc.ents 151 | return [self._normalize_text(e.text) for e in ents] 152 | 153 | def _f1_score(self, pred: str, ref: str) -> float: 154 | """Compute the f1 score between pred and ref.""" 155 | pred_ents = self._ner(pred) 156 | ref_ents = self._ner(ref) 157 | 158 | pred_counter = Counter(pred_ents) 159 | ref_counter = Counter(ref_ents) 160 | 161 | tp = sum((pred_counter & ref_counter).values()) 162 | fp = sum((pred_counter - ref_counter).values()) 163 | fn = sum((ref_counter - pred_counter).values()) 164 | 165 | precision = (tp / (tp + fp)) if (tp + fp) > 0 else 1 166 | recall = (tp / (tp + fn)) if (tp + fn) > 0 else 1 167 | 168 | if precision + recall == 0: 169 | return 0 170 | return 2 * (precision * recall) / (precision + recall) 171 | 172 | def _compute_one( 173 | self, 174 | pred_answer: str, 175 | ref_answers: List[str] 176 | ) -> float: 177 | """Evaluate the disambig f1 score of an answer.""" 178 | return np.max([self._f1_score(pred_answer, ref_answer) for ref_answer in ref_answers]) 179 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_edit_distance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import datasets 4 | 5 | from rageval.metrics import Metric, add_attribute 6 | 7 | _DESCRIPTION = """\ 8 | The AnswerEditDistance is to measure the similarity between answer and gt_answer by calculating the edit distance. 9 | 10 | This is a very traditional method, but to this day, some work is still being carried out using it, such as \ 11 | https://ieeexplore.ieee.org/abstract/document/10172590. 12 | """ 13 | 14 | _KWARGS_DESCRIPTION = """\ 15 | Args: 16 | name : str 17 | batch_size : int, Batch size for openai completion. 18 | 19 | Optional Args: 20 | None 21 | 22 | Functions: 23 | _compute_one: evaluating the similarity between answer and gt_answer by calculating the edit distance. 24 | 25 | Examples: 26 | >>> from datasets import Dataset 27 | >>> import rageval as rl 28 | >>> sample = { 29 | ... "answers": [ 30 | ... "Language models trained on massive code corpora can generalize to tasks without the need " 31 | ... "for task-specific fine tuning." 32 | ... ], 33 | ... "gt_answers": [ 34 | ... "Large language models trained on massive code corpora can generalize to new tasks without the need " 35 | ... "for task-specific fine-tuning." 36 | ... ] 37 | ... } 38 | >>> dataset = Dataset.from_dict(sample) 39 | >>> metric = rl.metrics.AnswerEditDistance() 40 | >>> metric.mtype 41 | 'AnswerCorrectness' 42 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1) 43 | >>> assert score == 5 / 18 44 | """ 45 | 46 | _CITATION = """\ 47 | @INPROCEEDINGS{10172590, 48 | author={Nashid, Noor and Sintaha, Mifta and Mesbah, Ali}, 49 | booktitle={2023 IEEE/ACM 45th International Conference on Software Engineering (ICSE)}, 50 | title={Retrieval-Based Prompt Selection for Code-Related Few-Shot Learning}, 51 | year={2023}, 52 | volume={}, 53 | number={}, 54 | pages={2450-2462}, 55 | doi={10.1109/ICSE48619.2023.00205} 56 | } 57 | """ 58 | 59 | 60 | @dataclass 61 | @add_attribute('mtype', 'AnswerCorrectness') 62 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 63 | class AnswerEditDistance(Metric): 64 | """Estimates the similarity between answers and gt_answers.""" 65 | 66 | name = "answer_edit_distance" 67 | 68 | ALIAS = ['answer_edit_distance'] 69 | 70 | def __init__(self): 71 | """ 72 | Explicitly initialize AnswerEditDistance. 73 | 74 | Ensure all parent classes are initialized. 75 | """ 76 | super().__init__() 77 | 78 | def __repr__(self) -> str: 79 | """:return: Formatted string representation of the metric.""" 80 | return f"{self.ALIAS[0]}" 81 | 82 | def _info(self): 83 | return datasets.MetricInfo( 84 | description=_DESCRIPTION, 85 | inputs_description=_KWARGS_DESCRIPTION, 86 | citation=_CITATION, 87 | homepage="", 88 | features=datasets.Features( 89 | { 90 | "answers": datasets.Value("string"), 91 | "gt_answers": datasets.Value("string") 92 | } 93 | ), 94 | codebase_urls=[], 95 | reference_urls=["https://ieeexplore.ieee.org/abstract/document/10172590"] 96 | ) 97 | 98 | def _compute_one( 99 | self, 100 | pred_answer: str, 101 | ref_answer: str 102 | ) -> float: 103 | """Evaluating the similarity between answer and gt_answer by calculating the edit distance.""" 104 | pred_answer = pred_answer.split() 105 | ref_answer = ref_answer.split() 106 | m, n = len(pred_answer), len(ref_answer) 107 | 108 | if m == 0 or n == 0: 109 | return 0 110 | 111 | dp = [[0] * (n + 1) for _ in range(m + 1)] 112 | for i in range(m + 1): 113 | dp[i][0] = i 114 | for j in range(n + 1): 115 | dp[0][j] = j 116 | 117 | for i in range(1, m + 1): 118 | for j in range(1, n + 1): 119 | dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1) 120 | if pred_answer[i - 1] != ref_answer[j - 1]: 121 | dp[i][j] = min(dp[i][j], dp[i - 1][j - 1] + 1) 122 | else: 123 | dp[i][j] = min(dp[i][j], dp[i - 1][j - 1]) 124 | 125 | return dp[m][n] / m 126 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_exact_match.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import datasets 5 | import numpy as np 6 | 7 | from rageval.metrics import Metric, add_attribute 8 | 9 | 10 | _DESCRIPTION = """\ 11 | AnswerEMCorrectness evaluates answer correctness based on exact matching of annotated short answers. 12 | 13 | For details, see the paper: https://arxiv.org/abs/2204.06092. 14 | """ 15 | 16 | _KWARGS_DESCRIPTION = """\ 17 | Args: 18 | name : str 19 | batch_size : int, Batch size for openai completion. 20 | ignore_case : bool, whether to ignore case when comparing the answer and ground truth answers. 21 | 22 | Optional Args: 23 | None 24 | 25 | Functions: 26 | _compute_one: compute the score by measure whether the args:`answer` contains short answer in list:`gt_answers`. 27 | 28 | Examples: 29 | >>> from datasets import Dataset 30 | >>> import rageval as rl 31 | >>> sample = { 32 | ... "answers": [ 33 | ... "Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has " 34 | ... "the highest goals all-time in men's football and Christine Sinclair has the highest goals in women's " 35 | ... "world international football.", 36 | ... "A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, " 37 | ... "was the oldest person ever from the United States and the second-oldest fully documented person ever. " 38 | ... "Jeanne Calment was a French supercentenarian and the oldest human whose age is well-documented, with " 39 | ... "a lifespan of 122 years and 164 days, and was the oldest person in the world as of 1997. In 1985, " 40 | ... "the oldest living person was Mathew Beard and in 1986 it was Augusta Holtz, who lived 115 years and " 41 | ... "79 days, from 1871 to 1986." 42 | ... ], 43 | ... "gt_answers": [ 44 | ... [ 45 | ... ["Daei", "Ali Daei"], 46 | ... ["Bican", "Josef Bican"], 47 | ... ["Sinclair","Christine Sinclair"] 48 | ... ], 49 | ... [ 50 | ... ["Jeanne Calment"], 51 | ... ["Sarah Knauss"], 52 | ... ["Augusta-Holtz"], 53 | ... ] 54 | ... ], 55 | ... } 56 | >>> dataset = Dataset.from_dict(sample) 57 | >>> metric = rl.metrics.AnswerEMCorrectness() 58 | >>> metric.mtype 59 | 'AnswerCorrectness' 60 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1) 61 | >>> assert 0 <= score <= 1 62 | """ 63 | 64 | _CITATION = """\ 65 | @misc{stelmakh2023asqa, 66 | title={ASQA: Factoid Questions Meet Long-Form Answers}, 67 | author={Ivan Stelmakh and Yi Luan and Bhuwan Dhingra and Ming-Wei Chang}, 68 | year={2023}, 69 | eprint={2204.06092}, 70 | archivePrefix={arXiv}, 71 | primaryClass={cs.CL} 72 | } 73 | """ 74 | 75 | 76 | @dataclass 77 | @add_attribute('mtype', 'AnswerCorrectness') 78 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 79 | class AnswerEMCorrectness(Metric): 80 | """Estimates correctness using annotated short answers.""" 81 | 82 | name = "answer_exact_match" 83 | 84 | ALIAS = ['answer_exact_match'] 85 | 86 | def __init__(self, ignore_case: bool = False): 87 | """Explicitly initialize the AnswerEMCorrectness to ensure all parent class initialized.""" 88 | super().__init__() 89 | self.ignore_case = ignore_case 90 | 91 | def __repr__(self) -> str: 92 | """:return: Formatted string representation of the metric.""" 93 | return f"{self.ALIAS[0]}" 94 | 95 | def _info(self): 96 | return datasets.MetricInfo( 97 | description=_DESCRIPTION, 98 | inputs_description=_KWARGS_DESCRIPTION, 99 | citation=_CITATION, 100 | homepage="", 101 | features=datasets.Features( 102 | { 103 | "answers": datasets.Value("string"), 104 | "gt_answers": datasets.Sequence(datasets.Value("string")) 105 | } 106 | ), 107 | codebase_urls=[], 108 | reference_urls=["https://arxiv.org/abs/2204.06092"] 109 | ) 110 | 111 | def _compute_one(self, pred_answer: str, short_answers: List[List[str]]) -> float: 112 | """Compute the correctness of a single answer.""" 113 | acc = [] 114 | if self.ignore_case: 115 | pred_answer = pred_answer.lower() 116 | short_answers = [[a.lower() for a in candidate_short_answers] for candidate_short_answers in short_answers] 117 | for candidate_short_answers in short_answers: 118 | for candidate_short_answer in candidate_short_answers: 119 | if candidate_short_answer in pred_answer: 120 | acc.append(True) 121 | break 122 | else: 123 | acc.append(False) 124 | return np.average(acc) 125 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_lcs_ratio.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import datasets 4 | 5 | from rageval.metrics import Metric, add_attribute 6 | 7 | _DESCRIPTION = """\ 8 | The AnswerLCSRatio is to measure the similarity between answer and gt_answer by calculating the longest common subsequence. 9 | 10 | This is a very traditional method, but to this day, some work is still being carried out using it, such as \ 11 | https://ieeexplore.ieee.org/abstract/document/10172590. 12 | """ 13 | 14 | _KWARGS_DESCRIPTION = """\ 15 | Args: 16 | name : str 17 | batch_size : int, Batch size for openai completion. 18 | 19 | Optional Args: 20 | None 21 | 22 | Functions: 23 | _compute_one: evaluating the similarity between answer and gt_answer by calculating the longest common subsequence. 24 | 25 | Examples: 26 | >>> from datasets import Dataset 27 | >>> import rageval as rl 28 | >>> sample = { 29 | ... "answers": [ 30 | ... "Language models trained on massive code corpora can generalize to tasks without the need " 31 | ... "for task-specific fine-tuning." 32 | ... ], 33 | ... "gt_answers": [ 34 | ... "Large language models trained on massive code corpora can generalize to new tasks without the need " 35 | ... "for task-specific fine-tuning." 36 | ... ] 37 | ... } 38 | >>> dataset = Dataset.from_dict(sample) 39 | >>> metric = rl.metrics.AnswerLCSRatio() 40 | >>> metric.mtype 41 | 'AnswerCorrectness' 42 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1) 43 | >>> assert score == 16 / 17 44 | """ 45 | 46 | _CITATION = """\ 47 | @INPROCEEDINGS{10172590, 48 | author={Nashid, Noor and Sintaha, Mifta and Mesbah, Ali}, 49 | booktitle={2023 IEEE/ACM 45th International Conference on Software Engineering (ICSE)}, 50 | title={Retrieval-Based Prompt Selection for Code-Related Few-Shot Learning}, 51 | year={2023}, 52 | volume={}, 53 | number={}, 54 | pages={2450-2462}, 55 | doi={10.1109/ICSE48619.2023.00205} 56 | } 57 | """ 58 | 59 | 60 | @dataclass 61 | @add_attribute('mtype', 'AnswerCorrectness') 62 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 63 | class AnswerLCSRatio(Metric): 64 | """Estimates the similarity between answers and gt_answers.""" 65 | 66 | name = "answer_lcs_ratio" 67 | 68 | ALIAS = ['answer_lcs_ratio'] 69 | 70 | def __init__(self): 71 | """ 72 | Explicitly initialize AnswerLCSRatio. 73 | 74 | Ensure all parent classes are initialized. 75 | """ 76 | super().__init__() 77 | 78 | def __repr__(self) -> str: 79 | """:return: Formatted string representation of the metric.""" 80 | return f"{self.ALIAS[0]}" 81 | 82 | def _info(self): 83 | return datasets.MetricInfo( 84 | description=_DESCRIPTION, 85 | inputs_description=_KWARGS_DESCRIPTION, 86 | citation=_CITATION, 87 | homepage="", 88 | features=datasets.Features( 89 | { 90 | "answers": datasets.Value("string"), 91 | "gt_answers": datasets.Value("string") 92 | } 93 | ), 94 | codebase_urls=[], 95 | reference_urls=["https://ieeexplore.ieee.org/abstract/document/10172590"] 96 | ) 97 | 98 | def _compute_one( 99 | self, 100 | pred_answer: str, 101 | ref_answer: str 102 | ) -> float: 103 | """Evaluating the similarity between answer and gt_answer by calculating the longest common subsequence.""" 104 | pred_answer = pred_answer.split() 105 | ref_answer = ref_answer.split() 106 | m, n = len(pred_answer), len(ref_answer) 107 | 108 | if m == 0 or n == 0: 109 | return 0 110 | 111 | dp = [0] * (n + 1) 112 | for i in range(m): 113 | pre = 0 114 | for j in range(n): 115 | tmp = dp[j + 1] 116 | dp[j + 1] = pre + 1 if pred_answer[i] == ref_answer[j] else max(dp[j + 1], dp[j]) 117 | pre = tmp 118 | 119 | return dp[-1] / m 120 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_relevancy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_correctness/_answer_relevancy.py -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_rouge_correctness.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Callable, Optional 5 | 6 | import datasets 7 | from datasets import Dataset 8 | from rouge_score import rouge_scorer 9 | 10 | from rageval.metrics import Metric, add_attribute 11 | 12 | _DESCRIPTION = """Estimates ROUGE score by estimating answer and groundtruth answers. 13 | 14 | ROUGE is case insensitive, so the input text is converted to lower case before computing the score. This metrics is a wrapper around the https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py 15 | 16 | """ 17 | 18 | _KWARGS_DESCRIPTION = """\ 19 | Args: 20 | name : str 21 | rouge_type : str, the rouge type to calculate. Defaults to 'rouge1', 'rouge2', 'rougeL', 'rougeLsum' 22 | "rouge1": unigram (1-gram) based scoring 23 | "rouge2": bigram (2-gram) based scoring 24 | "rougeL": Longest common subsequence based scoring. 25 | "rougeLSum": splits text using "\n". 26 | 27 | Optional Args: 28 | tokenizer : Callable, a tokenizer can be passed to the scorer, replacing the default tokenizer which tokenizes on whitespace, especially for non-latin languages. For example, the `jieba.cut` can be used for Chinese. 29 | 30 | Functions: 31 | _compute_one: compute the score by measure whether the args:`answer` contains short answer in list:`gt_answers`. 32 | 33 | Examples: 34 | >>> from datasets import Dataset 35 | >>> import rageval as rl 36 | >>> sample = { 37 | ... "answers": [ 38 | ... "Some nanomaterials may give rise to various kinds of lung damage." 39 | ... ], 40 | ... "gt_answers":[ 41 | ... [ 42 | ... "Nanoparticles can penetrate the body, affecting the lungs, brain, and other organs,\ 43 | ... leading to possible respiratory, cardiovascular, and brain health problems.", 44 | ... "Due to their small size, nanoparticles can infiltrate the body and impact vital organs,\ 45 | ... posing risks to respiratory, heart, and neurological health." 46 | ... ] 47 | ... ] 48 | ... } 49 | >>> dataset = Dataset.from_dict(sample) 50 | >>> metric = rl.metrics.AnswerRougeCorrectness('rougeL') 51 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1) 52 | >>> assert 0 <= score <= 1 53 | """ 54 | 55 | _CITATION = """\ 56 | @inproceedings{lin-2004-rouge, 57 | title = "{ROUGE}: A Package for Automatic Evaluation of Summaries", 58 | author = "Lin, Chin-Yew", 59 | booktitle = "Text Summarization Branches Out", 60 | month = jul, 61 | year = "2004", 62 | address = "Barcelona, Spain", 63 | publisher = "Association for Computational Linguistics", 64 | url = "https://aclanthology.org/W04-1013", 65 | pages = "74--81", 66 | } 67 | @article{lewis2020retrieval, 68 | title={Retrieval-augmented generation for knowledge-intensive nlp tasks}, 69 | author={Lewis, Patrick and Perez, Ethan and Piktus, Aleksandra and Petroni, Fabio and Karpukhin, Vladimir and Goyal, Naman and K{\"u}ttler, Heinrich and Lewis, Mike and Yih, Wen-tau and Rockt{\"a}schel, Tim and others}, 70 | journal={Advances in Neural Information Processing Systems}, 71 | volume={33}, 72 | pages={9459--9474}, 73 | year={2020} 74 | } 75 | """ 76 | 77 | 78 | @dataclass 79 | @add_attribute('mtype', 'AnswerCorrectness') 80 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 81 | class AnswerRougeCorrectness(Metric): 82 | 83 | name = "answer_rouge_correctness" 84 | 85 | ALIAS = ['answer_rouge_correctness'] 86 | 87 | def __init__(self, rouge_type: str, tokenizer: Optional[Callable] = None): 88 | """Explicitly initialize the AnswerRougeCorrectness to ensure all parent class initialized as well as initialize the rouge type and tokenizer.""" 89 | self.rouge_type = rouge_type 90 | self.scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=True, tokenizer=tokenizer) 91 | super().__init__() 92 | 93 | def __repr__(self) -> str: 94 | """:return: Formated string representation of the metric.""" 95 | return f"{self.ALIAS[0]}" 96 | 97 | def _info(self): 98 | return datasets.MetricInfo( 99 | description=_DESCRIPTION, 100 | inputs_description=_KWARGS_DESCRIPTION, 101 | citation=_CITATION, 102 | homepage="", 103 | features=datasets.Features( 104 | { 105 | "answers": datasets.Value("string", id="sequence"), 106 | "gt_answers": datasets.Value("string", id="sequence"), 107 | } 108 | ), 109 | codebase_urls=["https://github.com/mim-solutions/rouge_score"], 110 | reference_urls=[ 111 | "https://aclanthology.org/W04-1013/", 112 | "https://arxiv.org/abs/2005.11401" 113 | ] 114 | ) 115 | 116 | def _compute_one(self, pred_answer: str, ref_answers: List[str]) -> float: 117 | """Evaluate the ROUGE between a single answer and groundtruth answers.""" 118 | score = self.scorer.score_multi(ref_answers, pred_answer) 119 | return score[self.rouge_type].fmeasure 120 | -------------------------------------------------------------------------------- /rageval/metrics/answer_correctness/_answer_ter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Tuple 3 | 4 | import datasets 5 | from sacrebleu.metrics import TER 6 | import numpy as np 7 | 8 | from rageval.metrics import Metric, add_attribute 9 | 10 | _DESCRIPTION = """\ 11 | TER (Translation Edit Rate, also called Translation Error Rate) is a metric to quantify the edit operations that a 12 | hypothesis requires to match a reference translation. The implementation is already present in sacrebleu 13 | (https://github.com/mjpost/sacreBLEU#ter), which in turn is inspired by the TERCOM implementation, which can be found 14 | here: https://github.com/jhclark/tercom. 15 | """ 16 | 17 | _KWARGS_DESCRIPTION = """\ 18 | Args: 19 | name : str 20 | normalized (boolean): If `True`, applies basic tokenization and normalization to sentences. Defaults to `False`. 21 | ignore_punct (boolean): If `True`, applies basic tokenization and normalization to sentences. Defaults to `False`. 22 | support_zh_ja_chars (boolean): If `True`, tokenization/normalization supports processing of Chinese characters, 23 | as well as Japanese Kanji, Hiragana, Katakana, and Phonetic Extensions of Katakana. 24 | Only applies if `normalized = True`. Defaults to `False`. 25 | case_sensitive (boolean): If `False`, makes all predictions and references lowercase to ignore differences in case. Defaults to `False`. 26 | 27 | Optional Args: 28 | None 29 | 30 | Functions: 31 | _validate_data: validate the dataset format. 32 | 33 | Examples: 34 | >>> from datasets import Dataset 35 | >>> import rageval as rl 36 | >>> sample = { 37 | ... "answers": [ 38 | ... "does this sentence match??", 39 | ... "what about this sentence?", 40 | ... "What did the TER metric user say to the developer?" 41 | ... ], 42 | ... "gt_answers": [ 43 | ... ["does this sentence match", "does this sentence match!?!"], 44 | ... ["wHaT aBoUt ThIs SeNtEnCe?", "wHaT aBoUt ThIs SeNtEnCe?"], 45 | ... ["Your jokes are...", "...TERrible"] 46 | ... ] 47 | ... } 48 | >>> dataset = Dataset.from_dict(sample) 49 | >>> metric = rl.metrics.AnswerTERCorrectness() 50 | >>> metric.mtype 51 | 'AnswerCorrectness' 52 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers']) 53 | >>> assert score == 110.00000000000001 54 | >>> assert results[0] == 25.0 55 | """ 56 | 57 | _CITATION = """\ 58 | @inproceedings{snover-etal-2006-study, 59 | title = "A Study of Translation Edit Rate with Targeted Human Annotation", 60 | author = "Snover, Matthew and 61 | Dorr, Bonnie and 62 | Schwartz, Rich and 63 | Micciulla, Linnea and 64 | Makhoul, John", 65 | booktitle = "Proceedings of the 7th Conference of the Association for Machine Translation in the Americas: Technical Papers", 66 | month = aug # " 8-12", 67 | year = "2006", 68 | address = "Cambridge, Massachusetts, USA", 69 | publisher = "Association for Machine Translation in the Americas", 70 | url = "https://aclanthology.org/2006.amta-papers.25", 71 | pages = "223--231", 72 | } 73 | @inproceedings{post-2018-call, 74 | title = "A Call for Clarity in Reporting {BLEU} Scores", 75 | author = "Post, Matt", 76 | booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers", 77 | month = oct, 78 | year = "2018", 79 | address = "Belgium, Brussels", 80 | publisher = "Association for Computational Linguistics", 81 | url = "https://www.aclweb.org/anthology/W18-6319", 82 | pages = "186--191", 83 | } 84 | """ 85 | 86 | 87 | @dataclass 88 | @add_attribute('mtype', 'AnswerCorrectness') 89 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 90 | class AnswerTERCorrectness(Metric): 91 | """Estimates the TER between answers and ground truth answers.""" 92 | 93 | name = "answer_ter" 94 | 95 | ALIAS = ['answer_ter'] 96 | 97 | def __init__( 98 | self, 99 | normalized: bool = False, 100 | ignore_punct: bool = False, 101 | support_zh_ja_chars: bool = False, 102 | case_sensitive: bool = False 103 | ): 104 | """ 105 | Explicitly initialize AnswerTERCorrectness. 106 | 107 | Ensure all parent classes are initialized. 108 | """ 109 | super().__init__() 110 | self.ter = TER( 111 | normalized=normalized, 112 | no_punct=ignore_punct, 113 | asian_support=support_zh_ja_chars, 114 | case_sensitive=case_sensitive 115 | ) 116 | 117 | def __repr__(self) -> str: 118 | """:return: Formatted string representation of the metric.""" 119 | return f"{self.ALIAS[0]}" 120 | 121 | def _info(self): 122 | return datasets.MetricInfo( 123 | description=_DESCRIPTION, 124 | inputs_description=_KWARGS_DESCRIPTION, 125 | citation=_CITATION, 126 | features=datasets.Features( 127 | { 128 | "answers": datasets.Value("string"), 129 | "gt_answers": datasets.Sequence(datasets.Value("string")) 130 | } 131 | ), 132 | codebase_urls=["https://github.com/mjpost/sacreBLEU#ter"], 133 | reference_urls=["https://aclanthology.org/2006.amta-papers.25", "https://www.aclweb.org/anthology/W18-6319"] 134 | ) 135 | 136 | def _validate_data( 137 | self, 138 | pred_answers: List[str], 139 | ref_answers: List[List[str]] 140 | ) -> None: 141 | """Validate the input predictions and references.""" 142 | super()._validate_data(pred_answers, ref_answers) 143 | if not all(isinstance(pred_answer, str) for pred_answer in pred_answers): 144 | raise ValueError("The type of pred_answers should be a list of strings.") 145 | if not all(isinstance(reference_list, list) and all(isinstance(reference, str) for reference in reference_list) for reference_list in ref_answers): 146 | raise ValueError("The type of ref_answers should be a list of lists of strings.") 147 | 148 | def _compute_one( 149 | self, 150 | pred_answer: str, 151 | ref_answers: List[str] 152 | ) -> float: 153 | """Compute the TER score of a single answer.""" 154 | return self.ter.sentence_score(pred_answer, ref_answers).score 155 | 156 | def compute( 157 | self, 158 | pred_answers: List[str], 159 | ref_answers: List[List[str]], 160 | ) -> Tuple[float, List[float]]: 161 | """Evaluate the dataset.""" 162 | self._validate_data(pred_answers, ref_answers) 163 | scores = self._compute_batch(pred_answers, ref_answers) 164 | ref_answers = np.array(ref_answers) 165 | ref_answers = ref_answers.T.tolist() 166 | return self.ter.corpus_score(pred_answers, ref_answers).score, scores 167 | -------------------------------------------------------------------------------- /rageval/metrics/answer_groundedness/_claim_faithfulness.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_groundedness/_claim_faithfulness.py -------------------------------------------------------------------------------- /rageval/metrics/answer_informativeness/_answer_distinct12.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Iterable, Tuple 4 | import datasets 5 | from nltk import ngrams 6 | from rageval.metrics import Metric, add_attribute 7 | 8 | _DESCRIPTION = """\ 9 | Distinct 1/2 measures the diversity of generated text by calculating the ratio of unique n-grams to the total number of n-grams. 10 | """ 11 | 12 | _KWARGS_DESCRIPTION = """\ 13 | Args: 14 | pred_answers (list of str): List of generated texts for which distinct metrics are computed. 15 | n_grams (int): The n-gram order for which distinct metrics are computed. 16 | 17 | Returns: 18 | dict: Dictionary containing Distinct-1 and Distinct-2 scores. 19 | 20 | Examples: 21 | >>> from datasets import Dataset 22 | >>> import rageval as rl 23 | >>> sample = { 24 | ... "answers": [ 25 | ... "This is the first sentence.", 26 | ... "This is the second sentence." 27 | ... ] 28 | ... } 29 | >>> dataset = Dataset.from_dict(sample) 30 | >>> metric = rl.metrics.AnswerDistinct(1) 31 | >>> metric.mtype 32 | 'AnswerInformativeness' 33 | >>> score, results = metric.compute(dataset['answers']) 34 | >>> score 35 | 0.6 36 | """ 37 | 38 | _CITATION = """\ 39 | @misc{selfmemory2023, 40 | title={Lift Yourself Up: Retrieval-augmented Text Generation with Self Memory}, 41 | author={Xin Cheng and Di Luo and Xiuying Chen and Lemao Liu and Dongyan Zhao and Rui Yan}, 42 | year={2023}, 43 | eprint={2305.02437}, 44 | archivePrefix={arXiv}, 45 | primaryClass={cs.CL} 46 | } 47 | """ 48 | 49 | 50 | def get_distinct_score(pred_answers: List[str], n_grams: int) -> dict: 51 | """Compute Distinct-1 and Distinct-2 metrics.""" 52 | c = Counter() 53 | for answer in pred_answers: 54 | tokens = answer.split() 55 | c.update(ngrams(tokens, n_grams)) 56 | 57 | return len(c) / sum(c.values()) 58 | 59 | 60 | @dataclass 61 | @add_attribute('mtype', 'AnswerInformativeness') 62 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 63 | class AnswerDistinct(Metric): 64 | """Distinct 1/2 metric for text generation.""" 65 | 66 | name = "answer_distinct" 67 | 68 | ALIAS = ['answer_distinct'] 69 | 70 | def __init__(self, n_grams: int = 1): 71 | """ 72 | Explicitly initialize Distinct. 73 | 74 | Ensure all parent classes are initialized. 75 | """ 76 | super().__init__() 77 | self.n_grams = n_grams 78 | 79 | def __repr__(self) -> str: 80 | """:return: Formatted string representation of the metric.""" 81 | return f"{self.ALIAS[0]}" 82 | 83 | def _info(self): 84 | return datasets.MetricInfo( 85 | description=_DESCRIPTION, 86 | inputs_description=_KWARGS_DESCRIPTION, 87 | citation=_CITATION, 88 | features=datasets.Features( 89 | { 90 | "pred_answers": datasets.Value("string"), 91 | } 92 | ), 93 | codebase_urls=["https://github.com/Hannibal046/SelfMemory/blob/main/src/utils/metrics_utils.py"], 94 | reference_urls=["https://arxiv.org/abs/2305.02437"] 95 | ) 96 | 97 | def _validate_data( 98 | self, 99 | pred_answers: Optional[Iterable] = None, 100 | ref_answers: Optional[Iterable] = None, 101 | ) -> bool: 102 | """Validate the input data.""" 103 | assert isinstance(pred_answers, str) or isinstance(pred_answers, list) # pragma: no cover 104 | 105 | def compute( 106 | self, 107 | pred_answers: Optional[Iterable] = None, 108 | ) -> Tuple[float, List[float]]: 109 | """ 110 | Evaluate the dataset. 111 | 112 | Return average scores of all inputs and a score list for each example. 113 | """ 114 | return get_distinct_score(pred_answers, self.n_grams), [get_distinct_score([pred_answer], self.n_grams) for pred_answer in pred_answers] 115 | -------------------------------------------------------------------------------- /rageval/metrics/answer_informativeness/_claim_num.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_informativeness/_claim_num.py -------------------------------------------------------------------------------- /rageval/metrics/answer_informativeness/_pairwise_accuracy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_informativeness/_pairwise_accuracy.py -------------------------------------------------------------------------------- /rageval/metrics/answer_informativeness/_repetitiveness.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_informativeness/_repetitiveness.py -------------------------------------------------------------------------------- /rageval/metrics/answer_informativeness/_text_length.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Iterable 3 | from transformers import AutoTokenizer 4 | import evaluate 5 | 6 | import datasets 7 | 8 | from rageval.metrics import Metric, add_attribute 9 | 10 | 11 | _DESCRIPTION = """\ 12 | Textlength is a metric used to evaluate the length of a model-generated response. 13 | 14 | It measures the number of tokens in the generated text by first converting the text into tokens and then counting the total number. This metric provides insight into the verbosity or conciseness of the model's output, offering a standardized way to compare text length across different responses. 15 | """ 16 | 17 | _KWARGS_DESCRIPTION = """\ 18 | Args: 19 | name : str 20 | 21 | Optional Args: 22 | None 23 | 24 | Functions: 25 | _compute_one: Evaluating the length of answer. 26 | 27 | Examples: 28 | >>> from datasets import Dataset 29 | >>> import rageval as rl 30 | >>> sample = { 31 | ... "answers": [ 32 | ... "A", 33 | ... "C", 34 | ... ] 35 | ... } 36 | >>> dataset = Dataset.from_dict(sample) 37 | >>> metric = TextLength(tokenize_model="Qwen/Qwen2-0.5B-Instruct") 38 | >>> metric.mtype 39 | 'answer_informativeness' 40 | """ 41 | 42 | 43 | @dataclass 44 | @add_attribute('mtype', 'answer_informativeness') 45 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 46 | class TextLength(Metric): 47 | """Estimates the text length of answers.""" 48 | 49 | name = "text_length" 50 | 51 | ALIAS = ['text_length'] 52 | 53 | def __init__(self, tokenize_model: str = "Qwen/Qwen2-0.5B-Instruct"): 54 | """ 55 | Explicitly initialize TextLength. 56 | 57 | Ensure all parent classes are initialized. 58 | """ 59 | self.tokenizer = AutoTokenizer.from_pretrained(tokenize_model) 60 | super().__init__() 61 | self.info = evaluate.MetricInfo( 62 | description=_DESCRIPTION, 63 | inputs_description=_KWARGS_DESCRIPTION, 64 | citation="", 65 | homepage="", 66 | features=datasets.Features( 67 | { 68 | "answers": datasets.Value("string"), 69 | } 70 | ), 71 | codebase_urls=[], 72 | reference_urls=[] 73 | ) 74 | 75 | def __repr__(self) -> str: 76 | """:return: Formatted string representation of the metric.""" 77 | return f"{self.ALIAS[0]}" # pragma: no cover 78 | 79 | def _compute_one( 80 | self, 81 | answer: str, 82 | *args: Optional[Iterable], 83 | ) -> float: 84 | """Evaluating the text length of answer.""" 85 | length = len(self.tokenizer(answer, return_tensors="pt")['input_ids'][0]) 86 | return length 87 | -------------------------------------------------------------------------------- /rageval/metrics/base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Callable, Optional, Iterable 2 | from abc import abstractmethod 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | from langchain.schema import LLMResult 7 | from tqdm import tqdm 8 | 9 | 10 | def add_attribute(attribute_name, attribute_value): 11 | """ 12 | This decorate is used to set attribute for Class. 13 | 14 | Currently, this decorate can be used to set attr:metric_type for each metric. 15 | There are four types, i.e., 'AnswerCorrectness', 'AnswerGroundedness', 'ContextRelevancy', 'ContextAdequacy', \ 16 | for all RAG metrics. 17 | """ 18 | def decorator(cls): 19 | setattr(cls, attribute_name, attribute_value) 20 | return cls 21 | return decorator 22 | 23 | 24 | @dataclass 25 | class Metric(): 26 | """Metric base class without LLM.""" 27 | 28 | def __init__( 29 | self, 30 | config_name: Optional[str] = None, 31 | experiment_id: Optional[str] = None 32 | ): 33 | """Initialization. 34 | 35 | Args: 36 | config_name: type(string), Optional. 37 | experiment_id: type(string), Optional. 38 | """ # pragma: no cover 39 | 40 | @property 41 | @abstractmethod 42 | def name(self) -> str: 43 | """The metric name.""" 44 | ... # pragma: no cover 45 | 46 | def _validate_data( 47 | self, 48 | pred_answers: Optional[Iterable] = None, 49 | ref_answers: Optional[Iterable] = None, 50 | *args: Optional[Iterable] 51 | ) -> None: 52 | """Validate the of the input dataset.""" 53 | if (pred_answers and ref_answers): 54 | if len(pred_answers) != len(ref_answers) or any(len(pred_answers) != len(arg) for arg in args): 55 | raise ValueError("The length of predictions and references should be the same.") # pragma: no cover 56 | 57 | def compute( 58 | self, 59 | pred_answers: Optional[Iterable] = None, 60 | ref_answers: Optional[Iterable] = None, 61 | batch_size: Optional[int] = None, 62 | *args: Optional[Iterable], 63 | ) -> Tuple[float, List[float]]: 64 | """ 65 | Evaluate the dataset. 66 | 67 | Return average scores of all inputs and a score list for each example. 68 | """ 69 | self._validate_data(pred_answers, ref_answers, *args) 70 | scores = self._compute_batch(pred_answers, ref_answers, *args) 71 | 72 | return np.average(scores), scores 73 | 74 | @abstractmethod 75 | def _compute_one( 76 | self, 77 | pred_answer: Optional[Iterable] = None, 78 | ref_answer: Optional[Iterable] = None, 79 | *args: Optional[Iterable] 80 | ) -> float: 81 | ... # pragma: no cover 82 | 83 | def _compute_batch( 84 | self, 85 | pred_answers: Optional[Iterable] = None, 86 | ref_answers: Optional[Iterable] = None, 87 | *args: Optional[Iterable] 88 | ) -> List[float]: 89 | """Compute the metric for a batch of predictions and references.""" 90 | scores = [] 91 | if (pred_answers and ref_answers): # if both columns exist 92 | for pred_answer, ref_answer in tqdm(zip(pred_answers, ref_answers), 93 | desc=f"Computing {self.name}", 94 | total=len(pred_answers)): 95 | scores.append(self._compute_one(pred_answer, ref_answer)) 96 | else: 97 | for pred_answer in tqdm(pred_answers, 98 | desc=f"Computing {self.name}", 99 | total=len(pred_answers)): 100 | scores.append(self._compute_one(pred_answer)) 101 | return scores 102 | 103 | 104 | @dataclass 105 | class MetricWithLLM(Metric): 106 | """Metrics based on LLM.""" 107 | 108 | def __init__(self, model: Callable): 109 | """Initialization.""" 110 | super().__init__() 111 | self.llm = model 112 | 113 | @abstractmethod 114 | def parse_llm_result(self, prompts: List[str], result: LLMResult): 115 | """Parse the LLM Result based on the Prompt.""" 116 | ... # pragma: no cover 117 | -------------------------------------------------------------------------------- /rageval/metrics/context_relevance/_accuracy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/context_relevance/_accuracy.py -------------------------------------------------------------------------------- /rageval/metrics/context_relevance/_hit_rate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/context_relevance/_hit_rate.py -------------------------------------------------------------------------------- /rageval/metrics/context_relevance/_mrr.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/context_relevance/_mrr.py -------------------------------------------------------------------------------- /rageval/metrics/context_relevance/_ndcg.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/context_relevance/_ndcg.py -------------------------------------------------------------------------------- /rageval/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai import OpenAILLM 2 | from .nli import NLIModel -------------------------------------------------------------------------------- /rageval/models/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing as t 4 | from abc import ABC 5 | 6 | from langchain.schema.output import LLMResult 7 | 8 | if t.TYPE_CHECKING: 9 | from langchain.callbacks.base import Callbacks 10 | from langchain.prompts import ChatPromptTemplate 11 | 12 | 13 | class BaseLLM(ABC): 14 | """ 15 | BaseLLM is the base class for all LLMs. 16 | 17 | It provides a consistent interface for other classes that interact with LLMs like Langchains, LlamaIndex, \ 18 | LiteLLM etc. Handles multiple_completions even if not supported by the LLM. 19 | 20 | It currently takes in ChatPromptTemplates and returns LLMResults which are Langchain primitives. 21 | """ 22 | 23 | # supports multiple completions for the given prompt 24 | n_completions_supported: bool = False 25 | 26 | @property 27 | def llm(self) -> t.Any: 28 | """LLM model.""" 29 | ... 30 | 31 | def validate_api_key(self): 32 | """Validates that the api key is set for the LLM.""" 33 | pass 34 | 35 | def generate( 36 | self, 37 | prompts: list[ChatPromptTemplate], 38 | n: int = 1, 39 | temperature: float = 1e-8, 40 | callbacks: t.Optional[Callbacks] = None, 41 | ) -> LLMResult: 42 | """Call the llm model to generate results.""" 43 | ... 44 | -------------------------------------------------------------------------------- /rageval/models/nli.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC 3 | 4 | import pytest 5 | from transformers import pipeline 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class NLIModel(ABC): 11 | """This is the Roberta-based NLI model.""" 12 | 13 | def __init__(self, task: str = "sentiment-analysis", model: str = "roberta-large-mnli") -> None: 14 | """Init the Roberta Model.""" 15 | self._model_name = model 16 | self._model = pipeline(task=task, model=model, device_map="auto") 17 | 18 | self._labelmap = { 19 | "NEUTRAL": 3, 20 | "CONTRADICTION": 2, 21 | "ENTAILMENT": 1 22 | } 23 | self._nli2stance = { 24 | "NEUTRAL": "irrelevant", 25 | "CONTRADICTION": "refute", 26 | "ENTAILMENT": "support" 27 | } 28 | self._stancemap = { 29 | 'irrelevant': 3, 30 | 'refute': 2, 31 | 'partially-support': 1, 32 | 'completely-support': 1 33 | } 34 | 35 | @property 36 | def model(self): 37 | """Construct the OpenAI LLM model.""" 38 | return self._model 39 | 40 | @pytest.mark.api 41 | def infer_prob(self, premise, hypothesis): 42 | """Predict one sample with NLI model.""" 43 | try: 44 | if len(premise) > 200: 45 | premise = premise[:200] 46 | if len(hypothesis) > 200: 47 | hypothesis = hypothesis[:200] 48 | input = "{}{}".format(premise, hypothesis) 49 | pred = self._model(input) 50 | # print(pred) 51 | except Exception as e: 52 | # token length > 514 53 | L = len(premise) 54 | premise = premise[:int(L / 2)] 55 | input = "{}{}".format(premise, hypothesis) 56 | pred = self._model(input) 57 | logger.info(f"An exception occurred during nli inference: {e}") 58 | return pred 59 | 60 | @pytest.mark.api 61 | def infer(self, premise, hypothesis): 62 | """Predict one sample with NLI model.""" 63 | pred = self.infer_prob(premise, hypothesis) 64 | # [{'label': 'CONTRADICTION', 'score': 0.9992701411247253}] 65 | if 'mnli' in self._model_name: 66 | return self._nli2stance[pred[0]['label']] 67 | else: 68 | nli2stance = { 69 | "LABEL_0": "irrelevant", 70 | "LABEL_1": "support" 71 | } 72 | return nli2stance[pred[0]['label']] 73 | 74 | @pytest.mark.api 75 | def generate_infer(self, premise, hypothesis): 76 | """Predict one sample with NLI model.""" 77 | input_text = "premise: {} hypothesis: {}".format(premise, hypothesis) 78 | pred = self._model(input_text, max_new_tokens=10) 79 | # [{'generated_text': 'support'}] 80 | if pred[0]["generated_text"] == "1": 81 | return 1 82 | return 0 83 | -------------------------------------------------------------------------------- /rageval/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseTask 2 | from ._generate import Generate 3 | -------------------------------------------------------------------------------- /rageval/tasks/_generate.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | from rageval.metrics import Metric 4 | from rageval.tasks import BaseTask 5 | 6 | 7 | class Generate(BaseTask): 8 | name = 'Generator' 9 | # Define required columns in testset for the evaluation of the task 10 | required_columns = ['questions', 'answers', 'gt_answers'] 11 | 12 | def __init__(self, metrics: Union[str, List[str], List[Metric]]): 13 | """Init task""" 14 | 15 | super().__init__(metrics) 16 | -------------------------------------------------------------------------------- /rageval/tasks/base.py: -------------------------------------------------------------------------------- 1 | """Base task.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from dataclasses import dataclass 5 | from typing import Union, Type, List 6 | 7 | from datasets import Dataset, concatenate_datasets 8 | 9 | from rageval.metrics import Metric 10 | 11 | 12 | @dataclass 13 | class BaseTask(ABC): 14 | """Base Task, shouldn't be used directly.""" 15 | 16 | def __init__(self, metrics: Union[str, List[str], List[Metric]]): 17 | """Base task constructor.""" 18 | self.detailed_result = [] 19 | self.result = {} 20 | self.metrics = self._init_metrics(metrics) 21 | 22 | @property 23 | @abstractmethod 24 | def name(self) -> str: 25 | """The task name.""" 26 | ... 27 | 28 | def _init_metrics(self, metrics): 29 | if not metrics: 30 | raise ValueError("metrics should not be empty") 31 | if isinstance(metrics, str): 32 | metrics = [metrics] 33 | return [self._parse_metric(m) for m in metrics] 34 | 35 | def _parse_metric(self, metric: Union[str, Type[Metric], Metric]): 36 | """ 37 | Parse input metric in any form into a :class:`Metric` instance. 38 | 39 | :param metric: Input metric in any form. 40 | :return: A :class:`Metric` instance 41 | 42 | """ 43 | 44 | if isinstance(metric, str): 45 | metric = metric.lower() # ignore case 46 | 47 | # TODO: parse metrics in str form 48 | """ 49 | for subclass in Metric.__subclasses__(): 50 | if metric in subclass.ALIAS: 51 | return subclass() 52 | """ 53 | 54 | elif isinstance(metric, Metric): 55 | return metric 56 | elif issubclass(metric, Metric): 57 | return metric() 58 | else: 59 | raise ValueError(metric) 60 | 61 | def evaluate(self, testset) -> Dataset: 62 | """Evaluation each metrics.""" 63 | 64 | self._validate_columns(testset) 65 | for m in self.metrics: 66 | res, de_res = m.compute(testset) 67 | self.result[m.name] = res 68 | self.detailed_result.append(de_res) 69 | return self.result 70 | 71 | def _validate_columns(self, testset: Dataset): 72 | """Make sure columns in testset is subset of required columns.""" 73 | 74 | if not set(self.required_columns).issubset(set(testset.column_names)): 75 | print("Testset should contain following columns: ", ', '.join(self.required_columns)) 76 | raise ValueError(testset) 77 | 78 | def obtain_detailed_result(self): 79 | """Obtain instance level result for the test case.""" 80 | 81 | if not self.detailed_result: 82 | raise NameError(self.detailed_result) 83 | colnames = self.required_columns + [m.name for m in self.metrics] 84 | self.detailed_result = concatenate_datasets(self.detailed_result).select_columns(colnames) 85 | return self.detailed_result 86 | -------------------------------------------------------------------------------- /rageval/utils/RAGAS_prompt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """All prompts used for RAGAS related metrics.""" 3 | 4 | # Most of the prompt in this file comes from the RAGAS project: 5 | # https://github.com/explodinggradients/ragas 6 | 7 | CONTEXT_RECALL_RA = """ 8 | Given a context, and an answer, analyze each sentence in the answer and classify if the sentence can be attributed to the given context or not. Use only "Yes" (1) or "No" (0) as a binary classification. Output json with reason. 9 | 10 | 11 | question: What can you tell me about albert Albert Einstein? 12 | context: Albert Einstein (14 March 1879 – 18 April 1955) was a German-born theoretical physicist,widely held to be one of the greatest and most influential scientists of all time. Best known for developing the theory of relativity, he also made important contributions to quantum mechanics, and was thus a central figure in the revolutionary reshaping of the scientific understanding of nature that modern physics accomplished in the first decades of the twentieth century. His mass–energy equivalence formula E = mc2, which arises from relativity theory, has been called "the world's most famous equation". He received the 1921 Nobel Prize in Physics "for his services to theoretical physics, and especially for his discovery of the law of the photoelectric effect", a pivotal step in the development of quantum theory. His work is also known for its influence on the philosophy of science. In a 1999 poll of 130 leading physicists worldwide by the British journal Physics World, Einstein was ranked the greatest physicist of all time. His intellectual achievements and originality have made Einstein synonymous with genius. 13 | answer: Albert Einstein born in 14 March 1879 was German-born theoretical physicist, widely held to be one of the greatest and most influential scientists of all time. He received the 1921 Nobel Prize in Physics "for his services to theoretical physics. He published 4 papers in 1905. Einstein moved to Switzerland in 1895 14 | classification: 15 | [ 16 | {{ "statement_1":"Albert Einstein, born on 14 March 1879, was a German-born theoretical physicist, widely held to be one of the greatest and most influential scientists of all time.", 17 | "reason": "The date of birth of Einstein is mentioned clearly in the context.", 18 | "Attributed": "1" 19 | }}, 20 | {{ 21 | "statement_2":"He received the 1921 Nobel Prize in Physics 'for his services to theoretical physics.", 22 | "reason": "The exact sentence is present in the given context.", 23 | "Attributed": "1" 24 | }}, 25 | {{ 26 | "statement_3": "He published 4 papers in 1905.", 27 | "reason": "There is no mention about papers he wrote in the given context.", 28 | "Attributed": "0" 29 | }}, 30 | {{ 31 | "statement_4":"Einstein moved to Switzerland in 1895.", 32 | "reason": "There is no supporting evidence for this in the given context.", 33 | "Attributed": "0" 34 | }} 35 | ] 36 | 37 | question: who won 2020 icc world cup? 38 | context: Who won the 2022 ICC Men's T20 World Cup? 39 | The 2022 ICC Men's T20 World Cup, held from October 16 to November 13, 2022, in Australia, was the eighth edition of the tournament. Originally scheduled for 2020, it was postponed due to the COVID-19 pandemic. England emerged victorious, defeating Pakistan by five wickets in the final to clinch their second ICC Men's T20 World Cup title. 40 | answer: England 41 | classification: 42 | [ 43 | {{ 44 | "statement_1":"England won the 2022 ICC Men's T20 World Cup.", 45 | "reason": "From context it is clear that England defeated Pakistan to win the World Cup.", 46 | "Attributed": "1" 47 | }} 48 | ] 49 | 50 | question:{question} 51 | context:{context} 52 | answer:{answer} 53 | classification: 54 | """ 55 | -------------------------------------------------------------------------------- /rageval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .check_utils import text_to_sents, remove_citations 2 | from .RAGAS_prompt import CONTEXT_RECALL_RA 3 | -------------------------------------------------------------------------------- /rageval/utils/check_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import List 4 | 5 | import nltk 6 | from nltk.downloader import Downloader 7 | 8 | from rageval.models import OpenAILLM 9 | from .prompt import DOC_TO_SENTENCES_PROMPT 10 | 11 | logger = logging.getLogger(__name__) 12 | if not Downloader().is_installed('punkt_tab'): 13 | nltk.download('punkt_tab') 14 | 15 | 16 | def text_to_sents(text: str, model_name="nltk") -> List[str]: 17 | """Convert the text into a set of sentences.""" 18 | sentences = [] 19 | if model_name == "nltk": 20 | sentences = nltk.sent_tokenize(text) 21 | sentences = [s.strip() for s in sentences if len(s.strip()) >= 3] 22 | 23 | elif model_name == "gpt-3.5-turbo": 24 | model = OpenAILLM("gpt-3.5-turbo", "OPENAI_API_KEY") # pragma: no cover 25 | prompt = DOC_TO_SENTENCES_PROMPT # pragma: no cover 26 | input_str = prompt.format(doc=text).strip() # pragma: no cover 27 | r = model.generate([input_str]) # pragma: no cover 28 | sentences = eval(r) # pragma: no cover 29 | else: 30 | logger.info("The parameter `model_name` should be in [`nltk`, `gpt-3.5-turbo`]. ") # pragma: no cover 31 | assert isinstance(sentences, list) 32 | 33 | return sentences 34 | 35 | 36 | def remove_citations(text: str) -> str: 37 | """Remove the citation in the text.""" 38 | return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", text)).replace(" |", "").replace("]", "") 39 | -------------------------------------------------------------------------------- /rageval/utils/utility.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | import typing as t 6 | import warnings 7 | from dataclasses import dataclass 8 | from functools import lru_cache 9 | 10 | from langchain.callbacks.manager import CallbackManager, trace_as_chain_group 11 | from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate 12 | 13 | if t.TYPE_CHECKING: 14 | from rageval.models import ragevalLLM 15 | 16 | DEBUG_ENV_VAR = "rageval_DEBUG" 17 | # constant to tell us that there is no key passed to the llm/embeddings 18 | NO_KEY = "no-key" 19 | 20 | 21 | @lru_cache(maxsize=1) 22 | def get_debug_mode() -> bool: 23 | """Get debug mode.""" 24 | if os.environ.get(DEBUG_ENV_VAR, str(False)).lower() == "true": 25 | return True 26 | else: 27 | return False 28 | 29 | 30 | def load_as_json(text): 31 | """Validate and return given text as json.""" 32 | 33 | try: 34 | return json.loads(text) 35 | except ValueError as e: 36 | warnings.warn(f"Invalid json: {e}") 37 | 38 | return {} 39 | 40 | 41 | JSON_PROMPT = HumanMessagePromptTemplate.from_template( 42 | """ 43 | 44 | Rewrite the input into valid json 45 | 46 | 47 | Input: 48 | {{ 49 | "name": "John Doe", 50 | "age": 30, 51 | "isStudent": false 52 | "address": {{ 53 | "street": "123 Main St", 54 | "city": "Anytown", 55 | "state": "CA", 56 | }} 57 | "hobbies": ["reading", "swimming", "cycling"] 58 | }} 59 | Output: 60 | {{ 61 | "name": "John Doe", 62 | "age": 30, 63 | "isStudent": false, 64 | "address": {{ 65 | "street": "123 Main St", 66 | "city": "Anytown", 67 | "state": "CA" 68 | }}, 69 | "hobbies": ["reading", "swimming", "cycling"] 70 | }} 71 | 72 | 73 | Input: 74 | {{ 75 | "statement": "The Earth is also known as "Terra" " 76 | }} 77 | Output: 78 | {{ 79 | "statement": "The Earth is also known as 'Terra'" 80 | }} 81 | 82 | Input: 83 | {input} 84 | 85 | Output: 86 | """ 87 | ) 88 | 89 | 90 | @dataclass 91 | class JsonLoader: 92 | """This class is for .... (wenshan fix).""" 93 | 94 | max_retries: int = 2 95 | 96 | def safe_load(self, text: str, llm: ragevalLLM): 97 | """Load json in safety mode.""" 98 | retry = 0 99 | while retry <= self.max_retries: 100 | try: 101 | start, end = self._find_outermost_json(text) 102 | return json.loads(text[start:end]) 103 | except ValueError: 104 | text = self._fix_to_json(text, llm) 105 | retry += 1 106 | 107 | return {} 108 | 109 | def _fix_to_json( 110 | self, 111 | text, 112 | llm, 113 | callbacks: t.Optional[CallbackManager] = None, 114 | callback_group_name: str = "batch", 115 | ): 116 | # TODO (executor) 117 | with trace_as_chain_group( 118 | callback_group_name, callback_manager=callbacks 119 | ) as batch_group: 120 | human_prompt = ChatPromptTemplate.from_messages( 121 | [JSON_PROMPT.format(input=text)] 122 | ) 123 | results = llm.generate( 124 | [human_prompt], 125 | n=1, 126 | callbacks=batch_group, 127 | ) 128 | return results.generations[0][0].text 129 | 130 | def _find_outermost_json(self, text): 131 | stack = [] 132 | start_index = -1 133 | 134 | for i, char in enumerate(text): 135 | if char in "{[": 136 | if len(stack) == 0: 137 | start_index = i 138 | stack.append(char) 139 | 140 | elif char in "}]": 141 | if len(stack) > 0: 142 | last = stack.pop() 143 | if (char == "}" and last != "{") or (char == "]" and last != "["): 144 | # Mismatched closing brace/bracket, invalid JSON 145 | break 146 | 147 | if len(stack) == 0 and start_index != -1: 148 | # Found a valid outermost JSON 149 | return ( 150 | start_index, 151 | i + 1, 152 | ) # Add 1 to include the closing brace/bracket in the range 153 | 154 | return -1, -1 # No valid JSON found 155 | 156 | 157 | json_loader = JsonLoader() 158 | -------------------------------------------------------------------------------- /rageval/validation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/validation.py -------------------------------------------------------------------------------- /rageval/version.py: -------------------------------------------------------------------------------- 1 | """Rageval version file.""" 2 | 3 | __version__ = '0.0.1' 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | refchecker == 0.2.13 2 | numpy >= 1.26 3 | tqdm >= 4.66 4 | hyperopt >= 0.1.1 5 | h5py >= 2.8.0 6 | coverage >= 4.3.4 7 | codecov >= 2.0.15 8 | pytest >= 3.7.4 9 | pytest-cov >= 2.4.0 10 | flake8 >= 7.0.0 11 | flake8_docstrings >= 1.7.0 12 | pydocstyle >= 6.1 13 | openai >= 1.10.0 14 | datasets >= 3.0.1 15 | langchain >= 0.3.1 16 | langchain-community >= 0.3.1 17 | transformers >= 4.37.2 18 | torch >= 2.2.0 19 | pandas >= 2.0.0 20 | nltk >= 3.9.1 21 | spacy >= 3.7.4 22 | en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl 23 | rouge_score >= 0.1.2 24 | accelerate >= 0.27.2 25 | sentencepiece >= 0.2.0 26 | protobuf >= 4.25.3 27 | sacrebleu >= 2.3.3 28 | bert_score >= 0.3.13 29 | jieba >= 0.42.1 30 | evaluate >= 0.4.3 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from setuptools import setup, find_packages 5 | 6 | 7 | here = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | # Avoids IDE errors, but actual version is read from version.py 10 | __version__ = None 11 | exec(open('rageval/version.py').read()) 12 | 13 | short_description = 'Evaluation tools for Retrieval-augmented Generation (RAG) methods.' 14 | 15 | # Get the long description from the README file 16 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 17 | long_description = f.read() 18 | 19 | install_requires = [ 20 | 'numpy >= 1.14', 21 | 'tqdm >= 4.23.4', 22 | 'hyperopt >= 0.1.1', 23 | 'h5py >= 2.8.0', 24 | 'openai == 1.10.0', 25 | 'datasets == 2.16.1', 26 | 'langchain == 0.1.4', 27 | 'transformers == 4.37.2', 28 | 'torch == 2.2.0', 29 | 'pandas == 2.0.0', 30 | 'nltk == 3.8.1', 31 | 'spacy == 3.7.4', 32 | 'en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl', 33 | 'rouge_score == 0.1.2', 34 | 'sacrebleu == 2.3.3' 35 | ] 36 | 37 | extras_requires = { 38 | 'tests': [ 39 | 'coverage >= 4.3.4', 40 | 'codecov >= 2.0.15', 41 | 'pytest >= 3.7.4', 42 | 'pytest-cov >= 2.4.0', 43 | 'flake8 == 7.0.0', 44 | 'pydocstyle == 6.1', 45 | 'flake8_docstrings >= 1.7.0' 46 | ], 47 | 'benchmarks': [ 48 | 'accelerate == 0.27.2', 49 | 'sentencepiece == 0.2.0', 50 | 'protobuf == 4.25.3' 51 | ] 52 | } 53 | 54 | 55 | setup( 56 | name="RagEval", 57 | version=__version__, 58 | author="Wenshan Wang, Yixing Fan, etc.", 59 | author_email="wangwenshan@ict.ac.cn", 60 | description=short_description, 61 | license="Apache 2.0", 62 | keywords="RAG evaluation tools", 63 | url="https://github.com/gomate-community/rageval", 64 | packages=find_packages(), 65 | long_description=long_description, 66 | long_description_content_type='text/markdown', 67 | classifiers=[ 68 | "Development Status :: 3 - Alpha", 69 | 'Environment :: Console', 70 | 'Operating System :: POSIX :: Linux', 71 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 72 | "License :: OSI Approved :: Apache Software License", 73 | 'Programming Language :: Python :: 3.6' 74 | ], 75 | install_requires=install_requires, 76 | extras_require=extras_requires 77 | ) 78 | -------------------------------------------------------------------------------- /tests/demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # import sys 4 | # sys.path.insert(0, '../src') 5 | 6 | # from rageval.evaluations import evaluate 7 | """ 8 | from .src.tasks import ( 9 | retriever, 10 | ) 11 | from .src.metrics import ( 12 | context_recall, 13 | ) 14 | """ 15 | 16 | 17 | # from rageval.metrics import ContextRecall 18 | from datasets import Dataset 19 | # import os 20 | 21 | 22 | # 准备您的huggingface数据集,格式如下 23 | # Dataset({ 24 | # features: ['question', 'contexts', 'answer', 'ground_truths'], 25 | # num_rows: 25 26 | # }) 27 | 28 | # 模拟生成待测评数据 29 | questions = ["恐龙是怎么被命名的?"] 30 | ground_truths = [ 31 | [ 32 | "1841年,英国科学家理查德·欧文在研究几块样子像蜥蜴骨头化石时,认为它们是某种史前动物留下来的,并命名为恐龙,意思是“恐怖的蜥蜴”。" 33 | ] 34 | ] 35 | answers = ["人从恐龙进化而来"] 36 | contexts = [ 37 | [ 38 | "[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。 " 39 | "[12]“我们的结果显示恐龙所具有的生长速率和新陈代谢速率,既不是冷血生物体也不是温血生物体所具有的特征。它们既不像哺乳动物或者鸟类,也不像爬行动物或者鱼类,而是介于现代冷血动物和温血动物之间。" 40 | "简言之,它们的生理机能在现代社会并不常见。”美国亚利桑那大学进化生物学家和生态学家布莱恩·恩奎斯特说。墨西哥生物学家表示,正是这种中等程度的新陈代谢使得恐龙可以长得比任何哺乳动物都要大。" 41 | "温血动物需要大量进食,因此它们频繁猎捕和咀嚼植物。“很难想象霸王龙大小的狮子能够吃饱以 存活下来。","[12]哺乳动物起源于爬行动物,它们的前身是“似哺乳类的爬行动物”,即兽孔目,早期则是“似爬行类的哺乳动物”," 42 | "即哺乳型动物。 [12]中生代的爬行动物,大部分在中生代的末期灭绝了;一部分适应了变化的环境被保留下来,即现存的爬行动物(如龟鳖类、蛇类、鳄类等);还有一部分沿着不同的进化方向,进化成了现今的鸟类和哺乳类。 " 43 | "[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。" 44 | ] 45 | ] 46 | 47 | # To dict 48 | data = { 49 | "question": questions, 50 | "answer": answers, 51 | "contexts": contexts, 52 | "ground_truths": ground_truths 53 | } 54 | 55 | # Convert dict to dataset 56 | dataset = Dataset.from_dict(data) 57 | 58 | # dataset: Dataset 59 | 60 | # results = evaluate(dataset, task='retriever', metrics=['context_recall']) 61 | # results = evaluate(dataset, metrics=[ContextRecall()]) 62 | # print(results) 63 | -------------------------------------------------------------------------------- /tests/test_evaluation.py: -------------------------------------------------------------------------------- 1 | """Test the evaluation function.""" 2 | 3 | import sys 4 | 5 | import pytest 6 | from datasets import load_dataset, Dataset 7 | from langchain.llms.fake import FakeListLLM 8 | 9 | import rageval as rl 10 | from rageval.models import NLIModel 11 | 12 | sys.path.insert(0, '../src') 13 | 14 | 15 | @pytest.mark.skip 16 | def test_evaluation(): 17 | """ 18 | This is test unit for testing the load_dataset function. 19 | """ 20 | 21 | # 1) init test task: task type, metrics 22 | 23 | # 2) load dataset, and extract testset 24 | # train_data = rageval.datasets.load_data('', task='') 25 | # assert len(train_data) == 300 26 | 27 | # 3) run evaluation 28 | # result = evaluate(testset, info) 29 | 30 | ds = load_dataset("explodinggradients/fiqa", "ragas_eval")["baseline"] 31 | ds = ds.rename_column("question", "questions") 32 | ds = ds.rename_column("answer", "answers") 33 | ds = ds.rename_column("ground_truths", "gt_answers") 34 | 35 | # crop answers longer than 300 words, since tiny nli model has maximum sequence length of 500 36 | def truncate_answer(example): 37 | max_length = 100 38 | answers = [] 39 | gt_answers = [] 40 | """ 41 | for a in example["answers"]: 42 | answers.append([c[:max_length] if len(c) > max_length else c for c in a]) 43 | example["answers"] = answers 44 | """ 45 | for ga in example["gt_answers"]: 46 | gt_answers.append([q[:max_length] if len(q) > max_length else q for q in ga]) 47 | example["gt_answers"] = gt_answers 48 | return example 49 | ds = ds.map(truncate_answer, batched=True) 50 | 51 | # define model for each metric 52 | cr_model = FakeListLLM( 53 | responses=[ 54 | '[\n {\n "statement_1":"恐龙的命名始于1841年,由英国科学家理查德·欧文命名。",\n "reason": "The answer provides ' 55 | 'the exact year and the scientist who named the dinosaurs.",\n "Attributed": "1"\n },\n {\n' 56 | ' "statement_2":"欧文在研究几块样子像蜥蜴骨头化石时,认为它们是某种史前动物留下来的,并命名为恐龙。",\n "reason": "The answer ' 57 | 'accurately describes the process of how dinosaurs were named.",\n "Attributed": "1"\n }\n]' 58 | ] 59 | ) 60 | ag_model = NLIModel( 61 | 'text2text-generation', 62 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration' 63 | ) 64 | 65 | # define metrics 66 | metrics = [ 67 | rl.metrics.ContextRecall(cr_model), 68 | rl.metrics.AnswerNLICorrectness(nli_model=ag_model, decompose_model="nltk") 69 | ] 70 | 71 | # define task 72 | task = rl.tasks.Generate(metrics=metrics) 73 | 74 | # run evaluate 75 | result = task.evaluate(ds) 76 | assert isinstance(result, dict) 77 | detailed_result = task.obtain_detailed_result() 78 | assert isinstance(detailed_result, Dataset) 79 | -------------------------------------------------------------------------------- /tests/units/test_answer_accuracy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerAccuracy 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "A", 12 | "B", 13 | "C" 14 | ], 15 | "gt_answers": [ 16 | "A", 17 | "C", 18 | "C" 19 | ] 20 | } 21 | return test_case 22 | 23 | 24 | @pytest.fixture(scope='module') 25 | def testset(sample): 26 | ds = Dataset.from_dict(sample) 27 | return ds 28 | 29 | 30 | @pytest.mark.slow 31 | def test_case_on_answer_accuracy(testset): 32 | metric = AnswerAccuracy() 33 | assert metric.name == "answer_accuracy" 34 | assert metric.mtype == 'AnswerCorrectness' 35 | assert repr(metric) == "answer_accuracy" 36 | score, results = metric.compute(testset["answers"], testset["gt_answers"], 1) 37 | assert score == 2 / 3 38 | assert results[0] is True 39 | -------------------------------------------------------------------------------- /tests/units/test_answer_bert_score.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerBERTScore 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "It is a guide to action which ensures that the military always obeys the commands of the party.", 12 | "It is to insure the troops forever hearing the activity guidebook that party direct." 13 | ], 14 | "gt_answers": [ 15 | [ 16 | "It is a guide to action that ensures that the military will forever heed Party commands.", 17 | "It is the guiding principle which guarantees the military forces always being under the command of the Party.", 18 | "It is the practical guide for the army always to heed the directions of the party." 19 | ], 20 | [ 21 | "It is a guide to action that ensures that the military will forever heed Party commands.", 22 | "It is the guiding principle which guarantees the military forces always being under the command of the Party.", 23 | "It is the practical guide for the army always to heed the directions of the party." 24 | ] 25 | ] 26 | } 27 | return test_case 28 | 29 | 30 | @pytest.fixture(scope='module') 31 | def testset(sample): 32 | ds = Dataset.from_dict(sample) 33 | return ds 34 | 35 | 36 | @pytest.mark.slow 37 | def test_case_on_answer_bert_score(testset): 38 | metric = AnswerBERTScore(lang='en', rescale_with_baseline=True) 39 | assert metric.name == "answer_bert_score" 40 | assert metric.mtype == 'AnswerCorrectness' 41 | assert repr(metric) == "answer_bert_score" 42 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 43 | assert round(score, 2) == 0.55 44 | assert round(results[0], 1) == 0.7 45 | -------------------------------------------------------------------------------- /tests/units/test_answer_bleu.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerBleuScore 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "It is a guide to action which ensures that the military always obeys the commands of the party.", 12 | "It is to insure the troops forever hearing the activity guidebook that party direct." 13 | ], 14 | "gt_answers": [ 15 | [ 16 | "It is a guide to action that ensures that the military will forever heed Party commands.", 17 | "It is the guiding principle which guarantees the military forces always being under the command of the Party.", 18 | "It is the practical guide for the army always to heed the directions of the party." 19 | ], 20 | [ 21 | "It is a guide to action that ensures that the military will forever heed Party commands.", 22 | "It is the guiding principle which guarantees the military forces always being under the command of the Party.", 23 | "It is the practical guide for the army always to heed the directions of the party." 24 | ] 25 | ] 26 | } 27 | return test_case 28 | 29 | 30 | @pytest.fixture(scope='module') 31 | def testset(sample): 32 | ds = Dataset.from_dict(sample) 33 | return ds 34 | 35 | 36 | @pytest.mark.slow 37 | def test_case_on_answer_bleu(testset): 38 | metric = AnswerBleuScore() 39 | assert metric.name == "answer_bleu" 40 | assert metric.mtype == 'AnswerCorrectness' 41 | assert repr(metric) == "answer_bleu" 42 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 43 | assert score == 0.3450835085970013 44 | assert results[0] == 0.5401725898595141 45 | -------------------------------------------------------------------------------- /tests/units/test_answer_chrf.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerCHRFCorrectness 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "The relationship between cats and dogs is not exactly friendly.", 12 | "a good bookshop is just a genteel black hole that knows how to read." 13 | ], 14 | "gt_answers": [ 15 | ["The relationship between dogs and cats is not exactly friendly.", ], 16 | ["A good bookshop is just a genteel Black Hole that knows how to read."] 17 | ] 18 | } 19 | return test_case 20 | 21 | 22 | @pytest.fixture(scope='module') 23 | def testset(sample): 24 | ds = Dataset.from_dict(sample) 25 | return ds 26 | 27 | 28 | @pytest.mark.slow 29 | def test_case_on_answer_chrf(testset): 30 | metric = AnswerCHRFCorrectness() 31 | assert metric.name == "answer_chrf" 32 | assert metric.mtype == 'AnswerCorrectness' 33 | assert repr(metric) == "answer_chrf" 34 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 35 | assert score == 84.64214891738334 36 | assert results[0] == 84.41131092011067 37 | -------------------------------------------------------------------------------- /tests/units/test_answer_citation_precision.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.models import NLIModel 5 | from rageval.metrics import AnswerCitationPrecision 6 | 7 | 8 | @pytest.fixture(scope='module') 9 | def sample(): 10 | test_case = { 11 | "answers": [ 12 | "Several places on Earth claim to be the most rainy, such as Lloró, Colombia, which reported an average " 13 | "annual rainfall of 12,717 mm between 1952 and 1989, and López de Micay, Colombia, which reported an " 14 | "annual 12,892 mm between 1960 and 2012 [3]. However, the official record is held by Mawsynram, India " 15 | "with an average annual rainfall of 11,872 mm [3], although nearby town Sohra, India, also known as " 16 | "Cherrapunji, holds the record for most rain in a calendar month for July 1861 and most rain in a year " 17 | "from August 1860 to July 1861 [1]." 18 | ], 19 | "contexts": [ 20 | [ 21 | "Cherrapunji Cherrapunji (; with the native name Sohra being more commonly used, and can also be " 22 | "spelled Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in " 23 | "the Indian state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal " 24 | "chieftainship constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been " 25 | "credited as being the wettest place on Earth, but for now nearby Mawsynram currently holds that " 26 | "distinction. Cherrapunji still holds the all-time record for the most rainfall in a calendar month " 27 | "for July 1861 and most rain in a year from August 1860 to July 1861, however: it received in", 28 | "Radio relay station known as Akashvani Cherrapunji. It broadcasts on FM frequencies. Cherrapunji " 29 | "Cherrapunji (; with the native name Sohra being more commonly used, and can also be spelled " 30 | "Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in the Indian " 31 | "state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal chieftainship " 32 | "constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been credited as " 33 | "being the wettest place on Earth, but for now nearby Mawsynram currently holds that distinction. " 34 | "Cherrapunji still holds the all-time record for the most rainfall", 35 | "Mawsynram Mawsynram () is a village in the East Khasi Hills district of Meghalaya state in " 36 | "north-eastern India, 65 kilometres from Shillong. Mawsynram receives one of the highest rainfalls " 37 | "in India. It is reportedly the wettest place on Earth, with an average annual rainfall of 11,872 mm, " 38 | "but that claim is disputed by Lloró, Colombia, which reported an average yearly rainfall of 12,717 mm " 39 | "between 1952 and 1989 and López de Micay, also in Colombia, which reported an annual 12,892 mm per " 40 | "year between 1960 and 2012. According to the \"Guinness Book of World Records\", Mawsynram received " 41 | "of rainfall in 1985. Mawsynram is located at 25° 18′" 42 | ] 43 | ] 44 | } 45 | return test_case 46 | 47 | 48 | @pytest.fixture(scope='module') 49 | def testset(sample): 50 | ds = Dataset.from_dict(sample) 51 | return ds 52 | 53 | 54 | @pytest.mark.slow 55 | def test_answer_citation_recall(testset): 56 | nli_model = NLIModel( 57 | 'text2text-generation', 58 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration' 59 | ) 60 | metric = AnswerCitationPrecision(nli_model=nli_model) 61 | assert metric.name == "answer_citation_precision" 62 | assert metric.mtype == 'AnswerGroundedness' 63 | assert repr(metric) == "answer_citation_precision" 64 | score, results = metric.compute(testset['answers'], testset['contexts'], 1) 65 | assert 0 <= score <= 1 66 | -------------------------------------------------------------------------------- /tests/units/test_answer_citation_recall.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.models import NLIModel 5 | from rageval.metrics import AnswerCitationRecall 6 | 7 | 8 | @pytest.fixture(scope='module') 9 | def sample(): 10 | test_case = { 11 | "answers": [ 12 | "Several places on Earth claim to be the most rainy, such as Lloró, Colombia, which reported an average " 13 | "annual rainfall of 12,717 mm between 1952 and 1989, and López de Micay, Colombia, which reported an " 14 | "annual 12,892 mm between 1960 and 2012 [3]. However, the official record is held by Mawsynram, India " 15 | "with an average annual rainfall of 11,872 mm [3], although nearby town Sohra, India, also known as " 16 | "Cherrapunji, holds the record for most rain in a calendar month for July 1861 and most rain in a year " 17 | "from August 1860 to July 1861 [1]." 18 | ], 19 | "contexts": [ 20 | [ 21 | "Cherrapunji Cherrapunji (; with the native name Sohra being more commonly used, and can also be " 22 | "spelled Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in " 23 | "the Indian state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal " 24 | "chieftainship constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been " 25 | "credited as being the wettest place on Earth, but for now nearby Mawsynram currently holds that " 26 | "distinction. Cherrapunji still holds the all-time record for the most rainfall in a calendar month " 27 | "for July 1861 and most rain in a year from August 1860 to July 1861, however: it received in", 28 | "Radio relay station known as Akashvani Cherrapunji. It broadcasts on FM frequencies. Cherrapunji " 29 | "Cherrapunji (; with the native name Sohra being more commonly used, and can also be spelled " 30 | "Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in the Indian " 31 | "state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal chieftainship " 32 | "constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been credited as " 33 | "being the wettest place on Earth, but for now nearby Mawsynram currently holds that distinction. " 34 | "Cherrapunji still holds the all-time record for the most rainfall", 35 | "Mawsynram Mawsynram () is a village in the East Khasi Hills district of Meghalaya state in " 36 | "north-eastern India, 65 kilometres from Shillong. Mawsynram receives one of the highest rainfalls " 37 | "in India. It is reportedly the wettest place on Earth, with an average annual rainfall of 11,872 mm, " 38 | "but that claim is disputed by Lloró, Colombia, which reported an average yearly rainfall of 12,717 mm " 39 | "between 1952 and 1989 and López de Micay, also in Colombia, which reported an annual 12,892 mm per " 40 | "year between 1960 and 2012. According to the \"Guinness Book of World Records\", Mawsynram received " 41 | "of rainfall in 1985. Mawsynram is located at 25° 18′" 42 | ] 43 | ] 44 | } 45 | return test_case 46 | 47 | 48 | @pytest.fixture(scope='module') 49 | def testset(sample): 50 | ds = Dataset.from_dict(sample) 51 | return ds 52 | 53 | 54 | @pytest.mark.slow 55 | def test_answer_citation_recall(testset): 56 | nli_model = NLIModel( 57 | 'text2text-generation', 58 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration' 59 | ) 60 | metric = AnswerCitationRecall(nli_model=nli_model) 61 | assert metric.name == "answer_citation_recall" 62 | assert metric.mtype == 'AnswerGroundedness' 63 | score, results = metric.compute(testset['answers'], testset['contexts'], 1) 64 | assert 0 <= score <= 1 65 | -------------------------------------------------------------------------------- /tests/units/test_answer_claim_recall.py: -------------------------------------------------------------------------------- 1 | """Test the AnswerClaimRecall Metric.""" 2 | 3 | import pytest 4 | from datasets import Dataset 5 | 6 | from rageval.models import NLIModel 7 | from rageval.metrics import AnswerNLICorrectness 8 | 9 | 10 | @pytest.fixture(scope='module') 11 | def sample(): 12 | test_case = { 13 | "answers": [ 14 | "Yes. Did you watch The Social Network? They went a while before introducing ads, so they could make " 15 | "money, as they needed to establish their brand and amass users. Once you have dedicated users, " 16 | "introducing ads won't deter most, but if you are still new, having ads will deter a lot. The same goes " 17 | "for Uber, it's not that they aren't making money, it's that they are reinvesting a ton of it to make " 18 | "their service better." 19 | ], 20 | "gt_answers": [ 21 | [ 22 | "Firms like Snapchat and Uber need to establish their brand and amass users before introducing ads.", 23 | "Introducing ads too early can deter potential users.", 24 | "Uber is reinvesting a lot of money to make their service better." 25 | ] 26 | ] 27 | } 28 | return test_case 29 | 30 | 31 | @pytest.fixture(scope='module') 32 | def sample_with_decompose(): 33 | test_case = { 34 | "answers": [ 35 | "Yes. Did you watch The Social Network? They went a while before introducing ads, so they could make \ 36 | money, as they needed to establish their brand and amass users. Once you have dedicated users, \ 37 | introducing ads won't deter most, but if you are still new, having ads will deter a lot. The same goes \ 38 | for Uber, it's not that they aren't making money, it's that they are reinvesting a ton of it to make \ 39 | their service better." 40 | ], 41 | "gt_answers": [ 42 | "Firms like Snapchat and Uber need to establish their brand and amass users before introducing ads. \ 43 | Introducing ads too early can deter potential users. Uber is reinvesting a lot of money to make their \ 44 | service better." 45 | ] 46 | } 47 | return test_case 48 | 49 | 50 | @pytest.fixture(scope='module') 51 | def testset(sample): 52 | ds = Dataset.from_dict(sample) 53 | return ds 54 | 55 | 56 | @pytest.fixture(scope='module') 57 | def testset_with_decompose(sample_with_decompose): 58 | ds = Dataset.from_dict(sample_with_decompose) 59 | return ds 60 | 61 | 62 | @pytest.mark.slow 63 | def test_case_on_answer_claim_recall_metric(testset): 64 | nli_model = NLIModel( 65 | 'text2text-generation', 66 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration' 67 | ) 68 | metric = AnswerNLICorrectness(nli_model=nli_model) 69 | assert metric.name == "answer_claim_recall" 70 | assert metric.mtype == 'AnswerCorrectness' 71 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 72 | assert score == 0 or score == 1 73 | 74 | 75 | @pytest.mark.slow 76 | def test_case_on_answer_claim_recall_metric_with_decompose(testset_with_decompose): 77 | nli_model = NLIModel( 78 | 'text2text-generation', 79 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration' 80 | ) 81 | metric = AnswerNLICorrectness(nli_model=nli_model, decompose_model="nltk") 82 | assert metric.name == "answer_claim_recall" 83 | assert metric.mtype == 'AnswerCorrectness' 84 | score, results = metric.compute(testset_with_decompose['answers'], testset_with_decompose['gt_answers'], 1) 85 | assert score == 0 or score == 1 86 | -------------------------------------------------------------------------------- /tests/units/test_answer_disambig_f1.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerDisambigF1Correctness 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "Ali Daei holds the record for the most international goals in world football according to FIFA. Josef Bican holds the record for the most total goals in world football according to UEFA.", 12 | ], 13 | "gt_answers": [ 14 | ["Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world international football.", 15 | "The players with the highest all-time goals and highest men's and women's international football goals differ. The player with the highest all-time men's football goals is Josef Bican, who in 2020 was recognized by FIFA, the international governing body of football, as the record scorer with an estimated 805 goals. Christine Sinclair has the highest goals in women's international football with 187 and is the all-time leader for international goals scored for men or women. Cristiano Ronaldo and Ali Daei are currently tied for leading goalscorer in the history of men's international football with 109."], 16 | ] 17 | } 18 | return test_case 19 | 20 | 21 | @pytest.fixture(scope='module') 22 | def testset(sample): 23 | ds = Dataset.from_dict(sample) 24 | return ds 25 | 26 | 27 | @pytest.mark.slow 28 | def test_case_on_answer_disambig_f1(testset): 29 | metric = AnswerDisambigF1Correctness() 30 | assert metric.name == "answer_disambig_f1" 31 | assert metric.mtype == 'AnswerCorrectness' 32 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 33 | assert 0 <= score <= 1 34 | -------------------------------------------------------------------------------- /tests/units/test_answer_distinct.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerDistinct 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the \ 12 | highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world \ 13 | international football.", 14 | "A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, was \ 15 | the oldest person ever from the United States and the second-oldest fully documented person ever. Jeanne \ 16 | Calment was a French supercentenarian and the oldest human whose age is well-documented, with a lifespan \ 17 | of 122 years and 164 days, and was the oldest person in the world as of 1997." 18 | ] 19 | } 20 | return test_case 21 | 22 | 23 | @pytest.fixture(scope='module') 24 | def testset(sample): 25 | ds = Dataset.from_dict(sample) 26 | return ds 27 | 28 | 29 | @pytest.mark.slow 30 | def test_case_on_answer_distinct(testset): 31 | metric = AnswerDistinct(n_grams=1) 32 | assert metric.name == "answer_distinct" 33 | repr(metric) == 'answer_distinct' 34 | assert metric.mtype == 'AnswerInformativeness' 35 | score, results = metric.compute(pred_answers=testset['answers']) 36 | assert 0 <= score <= 1 37 | 38 | metric = AnswerDistinct(n_grams=2) 39 | assert metric.name == "answer_distinct" 40 | repr(metric) == 'answer_distinct' 41 | assert metric.mtype == 'AnswerInformativeness' 42 | score, results = metric.compute(pred_answers=testset['answers']) 43 | assert 0 <= score <= 1 -------------------------------------------------------------------------------- /tests/units/test_answer_edit_distance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerEditDistance 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "Language models trained on massive code corpora can generalize to tasks without the need " 12 | "for task-specific fine tuning." 13 | ], 14 | "gt_answers": [ 15 | "Large language models trained on massive code corpora can generalize to new tasks without the need " 16 | "for task-specific fine-tuning." 17 | ] 18 | } 19 | return test_case 20 | 21 | 22 | @pytest.fixture(scope='module') 23 | def testset(sample): 24 | ds = Dataset.from_dict(sample) 25 | return ds 26 | 27 | 28 | @pytest.mark.slow 29 | def test_case_on_answer_edit_distance(testset): 30 | metric = AnswerEditDistance() 31 | assert metric.name == "answer_edit_distance" 32 | assert metric.mtype == 'AnswerCorrectness' 33 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 34 | assert score == 5 / 18 35 | -------------------------------------------------------------------------------- /tests/units/test_answer_exect_match.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerEMCorrectness 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the " 12 | "highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world " 13 | "international football.", 14 | "A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, was " 15 | "the oldest person ever from the United States and the second-oldest fully documented person ever. Jeanne " 16 | "Calment was a French supercentenarian and the oldest human whose age is well-documented, with a lifespan " 17 | "of 122 years and 164 days, and was the oldest person in the world as of 1997. In 1985, the oldest living " 18 | "person was Mathew Beard and in 1986 it was Augusta Holtz, who lived 115 years and 79 days, from 1871 to " 19 | "1986." 20 | ], 21 | "gt_answers": [ 22 | [ 23 | ["Daei", "Ali Daei"], 24 | ["Bican", "Josef Bican"], 25 | ["Sinclair", "Christine Sinclair"] 26 | ], 27 | [ 28 | ["Jeanne Calment"], 29 | ["Sarah Knauss"], 30 | ["Augusta-Holtz"], 31 | ] 32 | ] 33 | } 34 | return test_case 35 | 36 | 37 | @pytest.fixture(scope='module') 38 | def testset(sample): 39 | ds = Dataset.from_dict(sample) 40 | return ds 41 | 42 | 43 | @pytest.mark.slow 44 | def test_case_on_answer_exact_match(testset): 45 | metric = AnswerEMCorrectness() 46 | assert metric.name == "answer_exact_match" 47 | assert metric.mtype == 'AnswerCorrectness' 48 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 49 | assert 0 <= score <= 1 50 | 51 | metric = AnswerEMCorrectness(ignore_case=True) 52 | assert metric.name == "answer_exact_match" 53 | assert metric.mtype == 'AnswerCorrectness' 54 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 55 | assert 0 <= score <= 1 56 | -------------------------------------------------------------------------------- /tests/units/test_answer_f1.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerF1Correctness 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the \ 12 | highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world \ 13 | international football.", 14 | "A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, was \ 15 | the oldest person ever from the United States and the second-oldest fully documented person ever. Jeanne \ 16 | Calment was a French supercentenarian and the oldest human whose age is well-documented, with a lifespan \ 17 | of 122 years and 164 days, and was the oldest person in the world as of 1997." 18 | ], 19 | "gt_answers": [ 20 | ["Daei", "Ali Daei"], 21 | ["Jeanne Calment"] 22 | ], 23 | "answers_zh": [ 24 | "魏晋", 25 | "北齐只设于清都郡。", 26 | ], 27 | "gt_answers_zh": [ 28 | ["魏晋", "魏晋时期"], 29 | ["北齐只设于清都郡。", "清都郡"] 30 | ], 31 | "answers_num":[[1,2,3], [4,5,6]], 32 | "gt_answers_num":[[2,3,4,5,6], [1,2,3,4,5]] 33 | } 34 | return test_case 35 | 36 | @pytest.fixture(scope='module') 37 | def testset(sample): 38 | ds = Dataset.from_dict(sample) 39 | return ds 40 | 41 | @pytest.mark.slow 42 | def test_case_on_answer_f1(testset): 43 | metric = AnswerF1Correctness(normalize=True, language='en') 44 | assert metric.name == "answer_f1" 45 | assert metric.mtype == 'AnswerCorrectness' 46 | score, results = metric.compute(testset['answers'], testset['gt_answers']) 47 | assert 0 <= score <= 1 48 | 49 | metric = AnswerF1Correctness(normalize=True, language='zh') 50 | assert metric.name == "answer_f1" 51 | assert metric.mtype == 'AnswerCorrectness' 52 | score_zh, results_zh = metric.compute(testset['answers_zh'], testset['gt_answers_zh']) 53 | assert 0 <= score_zh <= 1 54 | 55 | metric = AnswerF1Correctness(normalize=False) 56 | assert metric.name == "answer_f1" 57 | assert metric.mtype == 'AnswerCorrectness' 58 | score, results = metric.compute(testset['answers_num'], testset['gt_answers_num']) 59 | assert 0 <= score <= 1 60 | -------------------------------------------------------------------------------- /tests/units/test_answer_lcs_ratio.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerLCSRatio 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "Language models trained on massive code corpora can generalize to tasks without the need " 12 | "for task-specific fine-tuning." 13 | ], 14 | "gt_answers": [ 15 | "Large language models trained on massive code corpora can generalize to new tasks without the need " 16 | "for task-specific fine-tuning." 17 | ] 18 | } 19 | return test_case 20 | 21 | 22 | @pytest.fixture(scope='module') 23 | def testset(sample): 24 | ds = Dataset.from_dict(sample) 25 | return ds 26 | 27 | 28 | @pytest.mark.slow 29 | def test_case_on_answer_lcs_ratio(testset): 30 | metric = AnswerLCSRatio() 31 | assert metric.name == "answer_lcs_ratio" 32 | assert metric.mtype == 'AnswerCorrectness' 33 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 34 | assert score == 16 / 17 35 | -------------------------------------------------------------------------------- /tests/units/test_answer_rouge.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | from typing import List 4 | from rageval.metrics import AnswerRougeCorrectness 5 | 6 | 7 | class CharTokenizer: 8 | """Tokenize text into characters.""" 9 | def tokenize(self, text: str) -> List[str]: 10 | # Tokenize by characters to avoid a dependency on word segmentation methods. 11 | return [c for c in text] 12 | 13 | 14 | @pytest.fixture(scope='module') 15 | def sample(): 16 | test_case = { 17 | "answers": [ 18 | "###刚刚发声,A股这种情况十分罕见!大聪明逆市抄底330亿,一篇研报引爆全球,市场逻辑生变?", 19 | "The quick brown fox jumps over the lazy dog." 20 | ], 21 | "gt_answers": [ 22 | [ 23 | "刚刚过去的这个月,美股总市值暴跌了将近6万亿美元(折合人民币超过40万亿),这背后的原因可能不仅仅是加息这么简单。最近瑞士信贷知名分析师Zoltan Polzsar撰写了一篇极其重要的文章,详细分析了现有世界秩序的崩坏本质以及美国和西方将要采取的应对策略。在该文中,Zoltan Polzsar直指美国通胀的本质和其长期性。同期,A股市场亦出现了大幅杀跌的情况。" 24 | ], 25 | [ 26 | "The quick brown fox jumps over the lazy dog.", 27 | "The brown fox jumps over the lazy dog." 28 | ] 29 | ] 30 | } 31 | return test_case 32 | 33 | 34 | @pytest.fixture(scope='module') 35 | def testset(sample): 36 | ds = Dataset.from_dict(sample) 37 | return ds 38 | 39 | 40 | def test_case_on_answer_exact_match(testset): 41 | 42 | # Test with Chinese tokenizer 43 | chinese_tokenizer = CharTokenizer() 44 | metric = AnswerRougeCorrectness('rouge1', chinese_tokenizer) 45 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 46 | assert metric.mtype == 'AnswerCorrectness' 47 | assert 0 <= score <= 1 48 | 49 | # Test with English tokenizer 50 | metric = AnswerRougeCorrectness('rouge1') 51 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1) 52 | assert metric.mtype == 'AnswerCorrectness' 53 | assert 0 <= score <= 1 54 | -------------------------------------------------------------------------------- /tests/units/test_answer_ter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import AnswerTERCorrectness 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def sample(): 9 | test_case = { 10 | "answers": [ 11 | "does this sentence match??", 12 | "what about this sentence?", 13 | "What did the TER metric user say to the developer?" 14 | ], 15 | "gt_answers": [ 16 | ["does this sentence match", "does this sentence match!?!"], 17 | ["wHaT aBoUt ThIs SeNtEnCe?", "wHaT aBoUt ThIs SeNtEnCe?"], 18 | ["Your jokes are...", "...TERrible"] 19 | ] 20 | } 21 | return test_case 22 | 23 | 24 | @pytest.fixture(scope='module') 25 | def testset(sample): 26 | ds = Dataset.from_dict(sample) 27 | return ds 28 | 29 | 30 | @pytest.mark.slow 31 | def test_case_on_answer_ter(testset): 32 | metric = AnswerTERCorrectness() 33 | assert metric.name == "answer_ter" 34 | assert metric.mtype == 'AnswerCorrectness' 35 | score, results = metric.compute(testset['answers'], testset['gt_answers']) 36 | assert score == 110.00000000000001 37 | assert results[0] == 25.0 38 | 39 | -------------------------------------------------------------------------------- /tests/units/test_context_recall.py: -------------------------------------------------------------------------------- 1 | """Test the ContextReCall Metric.""" 2 | 3 | # -*- coding: utf-8 -*- 4 | 5 | import pytest 6 | from datasets import Dataset 7 | from langchain.llms.fake import FakeListLLM 8 | 9 | from rageval.models.openai import OpenAILLM 10 | from rageval.metrics import ContextRecall 11 | 12 | 13 | @pytest.fixture(scope='module') 14 | def sample(): 15 | test_case = { 16 | "questions": ["恐龙是怎么被命名的?"], 17 | "gt_answers": [ 18 | [ 19 | "1841年,英国科学家理查德·欧文在研究几块样子像蜥蜴骨头化石时,认为它们是某种史前动物留下来的,并命名为恐龙,意思是“恐怖的蜥蜴”。" 20 | ] 21 | ], 22 | "contexts": [ 23 | [ 24 | "[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。 " 25 | "[12]“我们的结果显示恐龙所具有的生长速率和新陈代谢速率,既不是冷血生物体也不是温血生物体所具有的特征。它们既不像哺乳动物或者鸟类,也不像爬行动物或者鱼类,而是介于现代冷血动物和温血动物之间。" 26 | "简言之,它们的生理机能在现代社会并不常见。”美国亚利桑那大学进化生物学家和生态学家布莱恩·恩奎斯特说。墨西哥生物学家表示,正是这种中等程度的新陈代谢使得恐龙可以长得比任何哺乳动物都要大。" 27 | "温血动物需要大量进食,因此它们频繁猎捕和咀嚼植物。“很难想象霸王龙大小的狮子能够吃饱以 存活下来。", 28 | "[12]哺乳动物起源于爬行动物,它们的前身是“似哺乳类的爬行动物”,即兽孔目,早期则是“似爬行类的哺乳动物”,即哺乳型动物。 [12]中生代的爬行动物,大部分在中生代的末期灭绝了;一部分适应了变化的环境" 29 | "被保留下来,即现存的爬行动物(如龟鳖类、蛇类、鳄类等);还有一部分沿着不同的进化方向,进化成了现今的鸟类和哺乳类。 [12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动" 30 | "物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。" 31 | ] 32 | ] 33 | } 34 | return test_case 35 | 36 | 37 | @pytest.fixture(scope='module') 38 | def testset(sample): 39 | ds = Dataset.from_dict(sample) 40 | return ds 41 | 42 | 43 | @pytest.mark.skip 44 | def test_batch_on_context_recall_metric(testset): 45 | model = OpenAILLM('gpt-3.5-turbo-16k', 'OPENAI_API_KEY') 46 | metric = ContextRecall(model) 47 | score, results = metric.compute(testset['questions'], testset['gt_answers'], testset['contexts'], 1) 48 | assert score == 0 or score == 1 49 | 50 | 51 | @pytest.mark.slow 52 | def test_batch_on_context_recall_metric_fakellm1(testset): 53 | model = FakeListLLM( 54 | responses=[ 55 | '[\n {\n "statement_1":"恐龙的命名始于1841年,由英国科学家理查德·欧文命名。",\n "reason": "The answer provides ' 56 | 'the exact year and the scientist who named the dinosaurs.",\n "Attributed": "1"\n },\n {\n' 57 | ' "statement_2":"欧文在研究几块样子像蜥蜴骨头化石时,认为它们是某种史前动物留下来的,并命名为恐龙。",\n "reason": "The answer ' 58 | 'accurately describes the process of how dinosaurs were named.",\n "Attributed": "1"\n }\n]' 59 | ] 60 | ) 61 | metric = ContextRecall(model) 62 | score, results = metric.compute(testset['questions'], testset['gt_answers'], testset['contexts'], 1) 63 | assert metric.mtype == 'ContextRelevancy' 64 | assert 0 <= score <= 1 65 | 66 | 67 | @pytest.mark.slow 68 | def test_batch_on_context_recall_metric_fakellm2(testset): 69 | model = FakeListLLM(responses=['wrong response format']) 70 | metric = ContextRecall(model) 71 | score, results = metric.compute(testset['questions'], testset['gt_answers'], testset['contexts'], 1) 72 | assert metric.mtype == 'ContextRelevancy' 73 | assert 0 <= score <= 1 74 | -------------------------------------------------------------------------------- /tests/units/test_context_reject_rate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | from langchain.llms.fake import FakeListLLM 4 | 5 | from rageval.metrics import ContextRejectRate 6 | 7 | 8 | @pytest.fixture(scope='module') 9 | def sample(): 10 | test_case = { 11 | "questions": [ 12 | "Why did Bushnell set himself on fire?", 13 | "Did Bushnell have a wife?" 14 | ], 15 | "contexts": [ 16 | [ 17 | [ 18 | "An active-duty member of the U.S. Air Force has died after he set himself ablaze outside the " 19 | "Israeli Embassy in Washington, D.C., while declaring that he “will no longer be complicit in " 20 | "genocide.”" 21 | ], 22 | [ 23 | "The 25-year-old airman, Aaron Bushnell, of San Antonio, Texas, died from his injuries, the " 24 | "Metropolitan Police Department said Monday." 25 | ], 26 | [ 27 | "Bushnell had walked up to the embassy shortly before 1 p.m. Sunday and began livestreaming on " 28 | "the video streaming platform Twitch, a person familiar with the matter told The Associated " 29 | "Press. Law enforcement officials believe he set his phone down and then doused himself in " 30 | "accelerant and ignited the flames. At one point, he said he “will no longer be complicit in " 31 | "genocide,” the person said. The video was later removed from the platform, but law enforcement " 32 | "officials have obtained and reviewed a copy." 33 | ] 34 | ], 35 | [ 36 | [ 37 | "An active-duty member of the U.S. Air Force has died after he set himself ablaze outside the " 38 | "Israeli Embassy in Washington, D.C., while declaring that he “will no longer be complicit in " 39 | "genocide.”" 40 | ], 41 | [ 42 | "The 25-year-old airman, Aaron Bushnell, of San Antonio, Texas, died from his injuries, the " 43 | "Metropolitan Police Department said Monday." 44 | ], 45 | [ 46 | "Bushnell had walked up to the embassy shortly before 1 p.m. Sunday and began livestreaming on " 47 | "the video streaming platform Twitch, a person familiar with the matter told The Associated " 48 | "Press. Law enforcement officials believe he set his phone down and then doused himself in " 49 | "accelerant and ignited the flames. At one point, he said he “will no longer be complicit in " 50 | "genocide,” the person said. The video was later removed from the platform, but law enforcement " 51 | "officials have obtained and reviewed a copy." 52 | ] 53 | ] 54 | ] 55 | } 56 | return test_case 57 | 58 | 59 | @pytest.fixture(scope='module') 60 | def testset(sample): 61 | ds = Dataset.from_dict(sample) 62 | return ds 63 | 64 | 65 | @pytest.mark.slow 66 | def test_case_on_context_reject_rate(testset): 67 | model = FakeListLLM( 68 | responses=[ 69 | "Answer: wrong response format", 70 | "Answer: sorry, cannot answer the question" 71 | ] 72 | ) 73 | metric = ContextRejectRate(model) 74 | assert metric.name == "context_reject_rate" 75 | assert metric.mtype == 'AnswerGroundedness' 76 | score, results = metric.compute(testset['questions'], testset['contexts'], 1) 77 | assert score == 0.5 78 | -------------------------------------------------------------------------------- /tests/units/test_nli.py: -------------------------------------------------------------------------------- 1 | # import sys 2 | # sys.path.insert(0, '../src') 3 | 4 | import pytest 5 | 6 | from rageval.models import NLIModel 7 | 8 | 9 | @pytest.fixture(scope='module') 10 | def test_case(): 11 | sample = { 12 | "claim": "In 1980, the oldest justice on the United States Supreme Court was Justice William O. Douglas.", 13 | "evidence": "August 3, 1994 \u2013 June 30, 2022 (27 years, 10 months, 27 days) photo source: Wikimedia " 14 | "Commons After the passing of Ruth Bader Ginsberg in 2020, Stephen Breyer was the oldest " 15 | "sitting member of the Supreme Court until his retirement in 2022. Stepping down at the age " 16 | "of 83, Breyer is now one of the oldest Supreme Court justices ever. Breyer was nominated by " 17 | "Bill Clinton and served on the Court for more than 27 years. During his tenure, Breyer fell " 18 | "in line with the liberal wing of the court. Before he was appointed to the Supreme Court, " 19 | "Breyer served as a judge on the U.S. Court of Appeals for the First Circuit; he was the " 20 | "Chief Judge for the last four years of his appointment.", 21 | "stance": "irrelevant" 22 | } 23 | return sample 24 | 25 | 26 | @pytest.mark.slow 27 | def test_nli(test_case): 28 | 29 | # model = NLIModel('sentiment-analysis', 'roberta-large-mnli') 30 | model = NLIModel( 31 | 'text-classification', 32 | 'hf-internal-testing/tiny-random-RobertaPreLayerNormForSequenceClassification' 33 | ) 34 | 35 | # test request 36 | result = model.infer_prob(test_case['evidence'], test_case['claim']) 37 | # print(result) 38 | assert result[0]['label'] in ['LABEL_0', 'LABEL_1'] 39 | assert 'score' in result[0] 40 | 41 | # case = test_case() 42 | # test_nli(case) 43 | -------------------------------------------------------------------------------- /tests/units/test_openai_api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import pytest 5 | from unittest.mock import patch, MagicMock 6 | import openai 7 | import httpx 8 | 9 | from rageval.models import OpenAILLM 10 | 11 | from langchain.schema import Generation, LLMResult 12 | 13 | 14 | @pytest.fixture(scope='module') 15 | def test_case(): 16 | questions = ["截止2001年岛上的人口有多少?", "墨西哥土拨鼠栖息的地区海拔约是多少米?"] 17 | ground_truths = [ 18 | [ 19 | "总人口9495(2001年)。", 20 | "墨西哥土拨鼠栖息在海拔1600-2200米的平原上。" 21 | ] 22 | ] 23 | contexts = [ 24 | [ 25 | "米尼科伊岛(Minicoy)位于印度拉克沙群岛中央直辖区最南端,是Lakshadweep县的一个城镇。它与拉克沙群岛隔九度海峡相望,与马尔代夫伊哈万迪富卢环礁隔八度海峡相望。总人口9495(2001年)。米尼科伊岛位于盐水胡东南部的一个有人居住的岛屿,全岛几乎被椰子树覆盖,唯一的地标是一座灯塔。Viringili位于米尼科伊岛西南侧,是一个长度不到两百米的小岛,曾被用作麻风病患者的驱逐地。该地2001年总人口9495人,其中男性4616人,女性4879人;0—6岁人口1129人,其中男571人,女558人;识字率81.95%,其中男性为83.51%,女性为80.47%。", 26 | "墨西哥土拨鼠(\"Cynomys mexicanus\"),又名墨西哥草原松鼠或墨西哥草原犬鼠,是原住于墨西哥的一种啮齿目。牠们是日间活动的。由于牠们被看为害虫,估其数量下降至濒危水平。墨西哥土拨鼠栖息在海拔1600-2200米的平原上。牠们分布在墨西哥的圣路易斯波托西州北部及科阿韦拉州。牠们主要吃草,并透过食物来吸收水份。牠们有时也会吃昆虫。牠们的天敌有郊狼、短尾猫、鹰、獾及鼬。墨西哥土拨鼠是会冬眠的,其繁殖季节也较短,一般只在1月至4月。妊娠期为1个月,雌鼠每年只会产一胎,一胎平均有四子。幼鼠出生时眼睛闭合,会先以尾巴作为辅助,直至出生后40日才能看见。于5月至6月会断奶,到了1岁就会离开巢穴。冬天前幼鼠就会离开母鼠。幼鼠之间会互咬、嘶叫及扭住来玩耍。牠们1岁后就达至性成熟,寿命约3-5年。成年重约1公斤及长14-17吋,雄性较雌性大只。牠们呈黄色,耳朵较深色,腹部较浅色。墨西哥土拨鼠的语言最为复杂,能奔跑达每小时55公里。所以当受到威胁时,牠们会大叫作为警报,并且高速逃走。墨西哥土拨鼠的巢穴是挖掘出来的。巢穴的入口像漏斗,通道长达100呎,两侧有空间储存食物及休息。巢穴可以多达几百只墨西哥土拨鼠,但一般少于50只,群族有一只雄性的领袖。牠们有时会与斑点黄鼠及穴鸮分享他们的洞穴。于1956年,墨西哥土拨鼠曾在科阿韦拉州、新莱昂州及圣路易斯波托西州出没。到了1980年代,牠们从新莱昂州消失,其分布地少于800平方米。由于牠们被认为是害虫,故经常被毒杀,到了1994年到达濒危的状况。" 27 | ] 28 | ] 29 | return {'questions': questions, 30 | 'ground_truths': ground_truths, 31 | 'contexts': contexts} 32 | 33 | @pytest.fixture 34 | def openai_llm(): 35 | return OpenAILLM(model="gpt-3.5-turbo", _api_key_env_var="OPENAI_API_KEY") 36 | 37 | def test_init(openai_llm): 38 | assert openai_llm.model == "gpt-3.5-turbo" 39 | assert openai_llm.base_url == "https://api.openai.com/v1" 40 | assert openai_llm.num_retries == 3 41 | assert openai_llm.timeout == 60 42 | assert openai_llm.api_key == os.getenv("OPENAI_API_KEY", "NO_KEY") 43 | 44 | def test_build_request(openai_llm): 45 | request = openai_llm.build_request() 46 | assert request["model"] == "gpt-3.5-turbo" 47 | assert request["max_tokens"] is None 48 | assert request["n"] is None 49 | assert request["temperature"] is None 50 | assert request["top_p"] is None 51 | assert request["logprobs"] is None 52 | 53 | def test_is_chat_model_engine(openai_llm): 54 | assert openai_llm._is_chat_model_engine is True 55 | 56 | @patch("rageval.models.openai.openai.OpenAI") 57 | def test_llm(mock_openai, openai_llm): 58 | llm = openai_llm.llm 59 | mock_openai.assert_called_once_with( 60 | api_key=openai_llm.api_key, 61 | base_url=openai_llm.base_url, 62 | max_retries=openai_llm.num_retries, 63 | timeout=openai_llm.timeout 64 | ) 65 | 66 | @patch("rageval.models.openai.openai.OpenAI") 67 | def test_get_chat_model_response(mock_openai, openai_llm): 68 | mock_response = MagicMock() 69 | mock_openai().chat.completions.create.return_value = mock_response 70 | prompt = [{"role": "user", "content": "Hello"}] 71 | response = openai_llm._get_chat_model_response(prompt) 72 | assert response == mock_response 73 | mock_openai().chat.completions.create.assert_called_once() 74 | 75 | @patch("rageval.models.openai.openai.OpenAI") 76 | def test_get_instruct_model_response(mock_openai, openai_llm): 77 | mock_response = MagicMock() 78 | mock_openai().completions.create.return_value = mock_response 79 | prompt = "Hello" 80 | response = openai_llm._get_instruct_model_response(prompt) 81 | assert response == mock_response 82 | mock_openai().completions.create.assert_called_once() 83 | 84 | @patch("rageval.models.openai.OpenAILLM._get_chat_model_response") 85 | @patch("rageval.models.openai.OpenAILLM._get_instruct_model_response") 86 | def test_generate(mock_instruct_response, mock_chat_response, openai_llm): 87 | mock_chat_response.return_value = {"choices": [{"message": {"content": "Hi"}}]} 88 | mock_instruct_response.return_value = {"choices": [{"text": "Hi"}]} 89 | 90 | prompt_chat = [{"role": "user", "content": "Hello"}] 91 | result_chat = openai_llm.generate(prompt_chat) 92 | assert isinstance(result_chat, LLMResult) 93 | assert result_chat.generations[0][0].text == "Hi" 94 | 95 | prompt_instruct = "Hello" 96 | openai_llm.model = "gpt-3.5-turbo-instruct" 97 | result_instruct = openai_llm.generate(prompt_instruct) 98 | assert isinstance(result_instruct, LLMResult) 99 | assert result_instruct.generations[0][0].text == "Hi" 100 | 101 | def test_create_llm_result(openai_llm): 102 | response = { 103 | "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, 104 | "choices": [{"message": {"content": "Hi"}, "finish_reason": "stop", "logprobs": None}] 105 | } 106 | result = openai_llm.create_llm_result(response) 107 | assert isinstance(result, LLMResult) 108 | assert result.llm_output["token_usage"]["prompt_tokens"] == 10 109 | assert result.llm_output["token_usage"]["completion_tokens"] == 20 110 | assert result.llm_output["token_usage"]["total_tokens"] == 30 111 | assert result.generations[0][0].text == "Hi" 112 | 113 | @patch.object(OpenAILLM, 'generate') 114 | def test_batch_generate(mock_generate, openai_llm): 115 | # Mock the generate method to return a simple LLMResult 116 | mock_generate.return_value = LLMResult(generations=[[Generation(text="Hi")]]) 117 | 118 | # Define prompts for testing 119 | prompts = [ 120 | [{"role": "user", "content": "Hello"}], 121 | [{"role": "user", "content": "How are you?"}], 122 | ] 123 | 124 | # Call batch_generate 125 | results = openai_llm.batch_generate(prompts, max_workers=2) 126 | 127 | # Verify generate was called with each prompt 128 | assert mock_generate.call_count == len(prompts) 129 | mock_generate.assert_any_call(prompts[0]) 130 | mock_generate.assert_any_call(prompts[1]) 131 | 132 | # Check results 133 | assert len(results) == len(prompts) 134 | for result in results: 135 | assert isinstance(result, LLMResult) 136 | assert result.generations[0][0].text == "Hi" 137 | 138 | @patch.object(OpenAILLM, 'generate') 139 | def test_batch_generate_order(mock_generate, openai_llm): 140 | # Mock the generate method to return different results based on input 141 | def side_effect(prompt): 142 | if prompt == [{"role": "user", "content": "Hello"}]: 143 | return LLMResult(generations=[[Generation(text="Hi")]]) 144 | elif prompt == [{"role": "user", "content": "How are you?"}]: 145 | return LLMResult(generations=[[Generation(text="I am fine")]]) 146 | 147 | mock_generate.side_effect = side_effect 148 | 149 | # Define prompts for testing 150 | prompts = [ 151 | [{"role": "user", "content": "Hello"}], 152 | [{"role": "user", "content": "How are you?"}], 153 | ] 154 | 155 | # Call batch_generate 156 | results = openai_llm.batch_generate(prompts, max_workers=2) 157 | 158 | # Check results are in the correct order 159 | assert results[0].generations[0][0].text == "Hi" 160 | assert results[1].generations[0][0].text == "I am fine" 161 | -------------------------------------------------------------------------------- /tests/units/test_text_length.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from rageval.metrics import TextLength 5 | import rageval as rl 6 | 7 | 8 | @pytest.fixture(scope='module') 9 | def sample(): 10 | test_case = { 11 | #"questions": [ 12 | # "习近平主席在何时何地会见了埃塞俄比亚总理海尔马里亚姆?", 13 | # "埃塞俄比亚希望与中国在哪些领域加强合作?" 14 | #], 15 | "answers": [ 16 | "习近平主席在2017年5月12日于北京人民大会堂会见了埃塞俄比亚总理海尔马里亚姆。", 17 | "埃塞俄比亚希望与中国在以下领域加强合作:\n1. **共建“一带一路”框架下合作**:埃塞俄比亚表示希望能够积极参与“一带一路”倡议,深化与中国在基础设施建设、产能合作、互联互通等领域的合作。\n2. **提高工业化水平和出口创汇能力**:埃塞俄比亚期待中国在推动其工业化进程以及提升出口创汇能力方面提供帮助和合作。\n3. **安全、有序、有效推进经贸合作**:希望与中方在贸易和投资合作领域取得进展,实现稳定、有序和高效的合作。" 18 | ] 19 | } 20 | return test_case 21 | 22 | 23 | @pytest.fixture(scope='module') 24 | def testset(sample): 25 | ds = Dataset.from_dict(sample) 26 | return ds 27 | 28 | 29 | @pytest.mark.slow 30 | def test_case_on_text_length(testset): 31 | metric = TextLength(tokenize_model="Qwen/Qwen2-0.5B-Instruct") 32 | assert metric.name == "text_length" 33 | score, results = metric.compute(testset["answers"]) 34 | print(score, results) 35 | assert score == 75.0 36 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | Welcome to the tutorials directory for the RAGEval project! This directory contains step-by-step guides and resources to help you get started with RAG evaluation methods, best practices, and practical implementations. 4 | 5 | 6 | ## Table of Contents 7 | 8 | 1. [Overview](#overview) 9 | 2. [Getting Started](#getting-started) 10 | 3. [Tutorials](#tutorials) 11 | - [Basic RAG Evaluation](#basic-rag-evaluation) 12 | - [Advanced Techniques](#advanced-techniques) 13 | - [Custom Dataset Evaluation](#custom-dataset-evaluation) 14 | 4. [Contributing](#contributing) 15 | 5. [License](#license) 16 | 17 | ## Overview 18 | 19 | The RAG evaluation project aims to provide a comprehensive framework for evaluating the performance of retrieval-augmented models. This directory is dedicated to tutorials that will guide you through various aspects of RAG evaluation, helping you understand both the theory and practical applications. 20 | 21 | ## Getting Started 22 | 23 | To begin, ensure you have the necessary dependencies installed. Follow the [installation instructions](../INSTALL.md) to set up your environment. 24 | 25 | ## Tutorials 26 | 27 | ### Basic RAG Evaluation 28 | 29 | - **Description**: Learn how to perform a simple evaluation of RAG models using standard metrics. 30 | - **Link**: [Basic RAG Evaluation Tutorial](basic_rag_evaluation.md) 31 | 32 | ### Advanced Techniques 33 | 34 | - **Description**: Explore advanced evaluation techniques that enhance your understanding of model performance. 35 | - **Link**: [Advanced Techniques Tutorial](advanced_techniques.md) 36 | 37 | ### Custom Dataset Evaluation 38 | 39 | - **Description**: A guide on how to evaluate RAG models on your custom datasets, including preprocessing steps and metric selection. 40 | - **Link**: [Custom Dataset Evaluation Tutorial](custom_dataset_evaluation.md) 41 | 42 | ## Contributing 43 | 44 | We welcome contributions! If you have ideas for additional tutorials or improvements to existing ones, please submit a pull request or open an issue. For more information, check our [Contributing Guidelines](../CONTRIBUTING.md). 45 | 46 | ## License 47 | 48 | This project is licensed under the Apache License. See the [LICENSE](../LICENSE) file for details. 49 | 50 | --- 51 | 52 | If you have any questions or need further assistance, feel free to open an issue in the main repository. Happy evaluating! -------------------------------------------------------------------------------- /tutorials/tutorial 1/df_result_excel.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/tutorials/tutorial 1/df_result_excel.xlsx -------------------------------------------------------------------------------- /tutorials/tutorial 1/requirements.txt: -------------------------------------------------------------------------------- 1 | openpyxl 2 | pandas 3 | transformers 4 | matplotlib 5 | datasets 6 | ragchecker 7 | vllm --------------------------------------------------------------------------------