├── .flake8
├── .github
└── workflows
│ └── makefile.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── benchmarks
├── ALCE
│ ├── ASQA
│ │ ├── README.md
│ │ ├── asqa_benchmark.py
│ │ ├── generate.py
│ │ ├── prompts
│ │ │ └── asqa_prompt.json
│ │ ├── results
│ │ │ ├── asqa_dpr_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240427.json
│ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_snippet_shot2_ndoc10_20240430.json
│ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_snippet_shot2_ndoc5_20240430.json
│ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_summary_shot2_ndoc10_20240430.json
│ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_summary_shot2_ndoc5_20240430.json
│ │ │ ├── asqa_gtr_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240415.json
│ │ │ └── asqa_oracle_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240428.json
│ │ └── run.sh
│ └── ELI5
│ │ ├── README.md
│ │ ├── eli5_benchmark.py
│ │ ├── generate.py
│ │ ├── prompts
│ │ └── eli5_prompt.json
│ │ ├── results
│ │ ├── eli5-bm25-Llama-2-7b-chat-hf-vanilla-shot2-ndoc5-20240310.json
│ │ └── eli5_oracle_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240430.json
│ │ └── run.sh
├── ASQA
│ ├── README.md
│ ├── asqa_benchmark.py
│ ├── generate.py
│ ├── prompts.py
│ ├── run.sh
│ └── run_generate.sh
├── HOTPOTQA
│ ├── README.md
│ ├── generate.py
│ ├── hotpotqa_benchmark.py
│ ├── prompts.py
│ ├── run.sh
│ └── run_generate.sh
├── WebGLM
│ ├── generate.py
│ ├── results
│ │ └── webglm_Llama_2_7b_chat_hf_20240502.json
│ ├── run.sh
│ └── webglm_benchmark.py
├── __init__.py
├── auto
│ ├── README.md
│ ├── auto_benchmark.py
│ ├── corpus
│ │ ├── corpus.json
│ │ └── few_shot_cases.json
│ ├── output
│ │ └── dataset.json
│ ├── prompt.py
│ └── run.sh
├── base.py
└── utils.py
├── pyproject.toml
├── pytest.ini
├── rageval
├── __init__.py
├── evaluations.py
├── exceptions.py
├── metrics
│ ├── __init__.py
│ ├── answer_correctness
│ │ ├── _answer_accuracy.py
│ │ ├── _answer_bert_score.py
│ │ ├── _answer_bleu.py
│ │ ├── _answer_chrf.py
│ │ ├── _answer_claim_recall.py
│ │ ├── _answer_disambig_f1.py
│ │ ├── _answer_edit_distance.py
│ │ ├── _answer_exact_match.py
│ │ ├── _answer_f1.py
│ │ ├── _answer_lcs_ratio.py
│ │ ├── _answer_relevancy.py
│ │ ├── _answer_rouge_correctness.py
│ │ └── _answer_ter.py
│ ├── answer_groundedness
│ │ ├── _answer_citation_precision.py
│ │ ├── _answer_citation_recall.py
│ │ ├── _claim_faithfulness.py
│ │ └── _context_reject_rate.py
│ ├── answer_informativeness
│ │ ├── _answer_distinct12.py
│ │ ├── _claim_num.py
│ │ ├── _pairwise_accuracy.py
│ │ ├── _repetitiveness.py
│ │ └── _text_length.py
│ ├── base.py
│ ├── context_adequacy
│ │ └── _context_recall.py
│ └── context_relevance
│ │ ├── _accuracy.py
│ │ ├── _hit_rate.py
│ │ ├── _mrr.py
│ │ └── _ndcg.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── nli.py
│ └── openai.py
├── tasks
│ ├── __init__.py
│ ├── _generate.py
│ └── base.py
├── utils
│ ├── RAGAS_prompt.py
│ ├── __init__.py
│ ├── check_utils.py
│ ├── prompt.py
│ └── utility.py
├── validation.py
└── version.py
├── requirements.txt
├── setup.py
├── tests
├── demo.py
├── test_evaluation.py
└── units
│ ├── test_answer_accuracy.py
│ ├── test_answer_bert_score.py
│ ├── test_answer_bleu.py
│ ├── test_answer_chrf.py
│ ├── test_answer_citation_precision.py
│ ├── test_answer_citation_recall.py
│ ├── test_answer_claim_recall.py
│ ├── test_answer_disambig_f1.py
│ ├── test_answer_distinct.py
│ ├── test_answer_edit_distance.py
│ ├── test_answer_exect_match.py
│ ├── test_answer_f1.py
│ ├── test_answer_lcs_ratio.py
│ ├── test_answer_rouge.py
│ ├── test_answer_ter.py
│ ├── test_context_recall.py
│ ├── test_context_reject_rate.py
│ ├── test_nli.py
│ ├── test_openai_api.py
│ └── test_text_length.py
└── tutorials
├── README.md
└── tutorial 1
├── df_result_excel.xlsx
├── main.ipynb
└── requirements.txt
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | # D401 First line should be in imperative mood
4 | D401,
5 |
6 | # D202 No blank lines allowed after function docstring
7 | D202,
8 |
9 | # For doctests:
10 | # D207 Docstring is under-indented
11 | D207,
12 | # D301 Use r""" if any backslashes in a docstring
13 | D301,
14 | # F401 'blah blah' imported but unused
15 | F401,
16 |
17 | # D101 Missing docstring in public class
18 | D101,
19 |
20 | # D100 Missing docstring in public module
21 | D100,
22 |
23 | # E501 line too long (88 > 79 characters)
24 | E501,
25 |
26 | # D400 First line should end with a period
27 | D400,
28 |
29 | # D103 Missing docstring in public function
30 | D103,
31 |
--------------------------------------------------------------------------------
/.github/workflows/makefile.yml:
--------------------------------------------------------------------------------
1 | name: Makefile CI
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | jobs:
10 | build:
11 |
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v3
16 |
17 | - name: Setup Python version
18 | uses: actions/setup-python@v1
19 | with:
20 | python-version: 3.10.15
21 |
22 | - name: Install requirements
23 | run: make init
24 |
25 | - name: Run check
26 | run: make test
27 |
28 | - name: Upload coverage reports to Codecov
29 | uses: codecov/codecov-action@v3
30 | env:
31 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.log
3 | *.swp
4 | *.bak
5 | *.weights
6 | *.DS_Store
7 | .vscode
8 | .coverage*
9 | RagEval.egg-info/*
10 | log/*
11 | build/*
12 | dist/*
13 | .idea/
14 | .pytest_cache/
15 | .cache
16 | htmlcov/*
17 | .rageval/
18 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Contributing to RAGEval
2 | ----------
3 |
4 | > Note: RAGEval is developed under Python 3.8.18
5 |
6 | Welcome! RAGEval is a community project that aims to evaluate different modules of RAG system, including query rewriting, document ranking, information compression, evidence verify, answer generating, and result validating. Your experience and what you can contribute are important to the project's success.
7 |
8 | Discussion
9 | ----------
10 |
11 | If you've run into behavior in RAGEval you don't understand, or you're having trouble working out a good way to apply it to your code, or you've found a bug or would like a feature it doesn't have, we want to hear from you!
12 |
13 | Our main forum for discussion is the project's [GitHub issue tracker](https://github.com/gomate-community/rageval/issues). This is the right place to start a discussion of any of the above or most any other topic concerning the project.
14 |
15 | First Time Contributors
16 | -----------------------
17 |
18 | RAGEval appreciates your contribution! If you are interested in helping improve RAGEval, there are several ways to get started:
19 |
20 | * Work on [new metrics and datasets](https://github.com/gomate-community/rageval/tree/main/rageval).
21 | * Try to answer questions on [the issue tracker](https://github.com/gomate-community/rageval/issues).
22 |
23 | Submitting Changes
24 | ------------------
25 |
26 | Even more excellent than a good bug report is a fix for a bug, or the implementation of a much-needed new metrics or benchmarks.
27 |
28 | (*) We'd love to have your contributions.
29 |
30 | (*) If your new feature will be a lot of work, we recommend talking to us early -- see below.
31 |
32 | We use the usual GitHub pull-request flow, which may be familiar to you if you've contributed to other projects on GitHub -- see blow.
33 |
34 | Anyone interested in RAGEval may review your code. One of the RAGEval core developers will merge your pull request when they think it's ready.
35 | For every pull request, we aim to promptly either merge it or say why it's not yet ready; if you go a few days without a reply, please feel
36 | free to ping the thread by adding a new comment.
37 |
38 | For a list of RAGEval core developers, see [Readme](https://github.com/gomate-community/rageval/blob/main/README.md).
39 |
40 | Contributing Flow
41 | ------------------
42 |
43 | 1. Fork the latest version of [RAGEval](https://github.com/gomate-community/rageval) into your repo.
44 | 2. Create an issue under [gomate-Community/rageval](https://github.com/gomate-community/rageval/issues), write description about the bug/enhancement.
45 | 3. Clone your forked RAGEval into your machine, add your changes together with associated tests.
46 | 4. Run `make test` with terminal, ensure all unit tests & integration tests passed on your computer.
47 | 5. Push to your forked repo, then send the pull request to the official repo. In pull request, you need to create a link to the issue you created using `#[issue_id]`, and describe what has been changed.
48 | 6. Wait [continuous integration](https://github.com/gomate-community/rageval/blob/main/.github/workflows/makefile.yml) passed.
49 | 7. Wait [Codecov](https://app.codecov.io/gh/gomate-community/rageval) generate the coverage report.
50 | 8. We'll assign reviewers to review your code.
51 |
52 |
53 | Your PR will be merged if:
54 | - Funcitonally benefit for the project.
55 | - Passed Countinuous Integration (all unit tests, integration tests and [PEP8](https://www.python.org/dev/peps/pep-0008/) check passed).
56 | - Test coverage didn't decreased, we use [pytest](https://docs.pytest.org/en/latest/).
57 | - With proper docstrings, see codebase as examples.
58 | - With type hints, see [typing](https://docs.python.org/3/library/typing.html).
59 | - All reviewers approved your changes.
60 |
61 |
62 | **Thanks and let's improve RAGEval together!**
63 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Usages:
2 | #
3 | # to install rageval dependencies:
4 | # $ make init
5 | #
6 | # to run all rageval tests, recommended for big PRs and new versions:
7 | # $ make test
8 | #
9 | # there are three kinds of tests:
10 | #
11 | # 1. "quick" tests
12 | # - run in seconds
13 | # - include all unit tests without marks and all doctests
14 | # - for rapid prototyping
15 | # - CI run this for all PRs
16 | #
17 | # 2. "slow" tests
18 | # - run in minutes
19 | # - include all unit tests marked "slow"
20 | # - CI run this for all PRs
21 | #
22 | # 3. "cron" tests
23 | # - run in minutes
24 | # - involves underministic behavoirs (e.g. network connection)
25 | # - include all unit tests marked "cron"
26 | # - CI run this on a daily basis
27 | #
28 | # to run quick tests, excluding time consuming tests and crons:
29 | # $ make quick
30 | #
31 | # to run slow tests, excluding normal tests and crons:
32 | # $ make slow
33 | #
34 | # to run crons:
35 | # $ make cron
36 | #
37 | # to run all tests:
38 | # $ make test
39 | #
40 | # to run CI push/PR tests:
41 | # $ make push
42 | #
43 | # to run docstring style check:
44 | # $ make flake
45 |
46 | init:
47 | pip3 install -r requirements.txt
48 |
49 | TEST_ARGS = -v --full-trace -l --doctest-modules --doctest-continue-on-failure --cov rageval/ --cov-report term-missing --cov-report html --cov-config .coveragerc rageval/ tests/ -W ignore::DeprecationWarning
50 | FLAKE_ARGS = ./rageval --exclude=__init__.py
51 |
52 | test:
53 | python3 -m pytest $(TEST_ARGS)
54 | python3 -m flake8 $(FLAKE_ARGS)
55 |
56 | push:
57 | python3 -m pytest -m 'not cron' $(TEST_ARGS) ${ARGS}
58 | python3 -m flake8 $(FLAKE_ARGS)
59 |
60 | quick:
61 | python3 -m pytest -m 'not slow and not cron' $(TEST_ARGS) ${ARGS}
62 |
63 | slow:
64 | python3 -m pytest -m 'slow and not cron' $(TEST_ARGS) ${ARGS}
65 |
66 | cron:
67 | python3 -m pytest -m 'cron' $(TEST_ARGS) ${ARGS}
68 |
69 | flake:
70 | python3 -m flake8 $(FLAKE_ARGS) ${ARGS}
71 |
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/asqa_benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from typing import Dict, Tuple, Any
4 |
5 | from datasets import Dataset
6 |
7 | import rageval as rl
8 | from benchmarks import BaseBenchmark
9 | from rageval.metrics import AnswerEMCorrectness, AnswerCitationRecall, AnswerCitationPrecision
10 |
11 |
12 | class ASQABenchmark(BaseBenchmark):
13 |
14 | name = "asqa_benchmark"
15 |
16 | def __init__(self, cache_path) -> None:
17 | super().__init__()
18 | nli_model = rl.models.NLIModel(
19 | "text2text-generation",
20 | cache_path + "/models/t5_xxl_true_nli_mixture",
21 | )
22 | self.metrics = [
23 | AnswerEMCorrectness(),
24 | AnswerCitationRecall(nli_model=nli_model),
25 | AnswerCitationPrecision(nli_model=nli_model)
26 | ]
27 |
28 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]:
29 | self.dataset = self.dataset.rename_column("output", "answers")
30 | self.dataset = self.dataset.map(lambda data: {
31 | "gt_answers": [
32 | pair["short_answers"]
33 | for pair in data["qa_pairs"]
34 | ]
35 | })
36 |
37 | results = {}
38 | for metric in self.metrics:
39 | results[metric.name], self.dataset = metric.compute(self.dataset, self.batch_size)
40 | return results, self.dataset
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument("--cache_path", type=str, default=None)
46 | parser.add_argument("--remote_split", type=str, default=None)
47 | parser.add_argument("--local_file", type=str, default=None)
48 | args = parser.parse_args()
49 | date = time.strftime("%Y%m%d", time.localtime())
50 |
51 | benchmark = ASQABenchmark(cache_path=args.cache_path)
52 | if args.local_file:
53 | results = benchmark.evaluate(
54 | path="json",
55 | data_files={
56 | "test": args.cache_path+"/results/"+args.local_file
57 | },
58 | split="test"
59 | )
60 | benchmark.save_results(f"benchmarks/ALCE/ASQA/results/{args.local_file[:-5]}_{date}.json")
61 | else:
62 | results = benchmark.evaluate(path="golaxy/rag-bench", name="alce_asqa_gtr", split=args.remote_split)
63 | benchmark.save_results(f"benchmarks/ALCE/ASQA/results/{args.remote_split}_{date}.json")
64 |
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/results/asqa_dpr_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240427.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_exact_match": 0.2921589310829817,
3 | "answer_citation_recall": 0.49225060277275473,
4 | "answer_citation_precision": 0.8100701083492671
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_snippet_shot2_ndoc10_20240430.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_exact_match": 0.34504219409282705,
3 | "answer_citation_recall": 0.5594480062834494,
4 | "answer_citation_precision": 0.7248677248677249
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_snippet_shot2_ndoc5_20240430.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_exact_match": 0.3428973277074543,
3 | "answer_citation_recall": 0.566895218002813,
4 | "answer_citation_precision": 0.8382519863791147
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_summary_shot2_ndoc10_20240430.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_exact_match": 0.37202883263009845,
3 | "answer_citation_recall": 0.6047329457930724,
4 | "answer_citation_precision": 0.744484556758925
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_summary_shot2_ndoc5_20240430.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_exact_match": 0.3699542897327708,
3 | "answer_citation_recall": 0.6328197207152902,
4 | "answer_citation_precision": 0.8281718281718282
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/results/asqa_gtr_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240415.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_exact_match": 0.3334739803094233,
3 | "answer_citation_recall": 0.5590039180229054,
4 | "answer_citation_precision": 0.8004201680672269
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/results/asqa_oracle_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240428.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_exact_match": 0.4170534458509142,
3 | "answer_citation_recall": 0.5808795459111914,
4 | "answer_citation_precision": 0.788681204569055
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ASQA/run.sh:
--------------------------------------------------------------------------------
1 | script_dir=$(cd $(dirname $0);pwd)
2 | cache_dir=$(dirname $(dirname $(dirname $script_dir)))/.rageval
3 | wget -cP $cache_dir/datasets https://huggingface.co/datasets/princeton-nlp/ALCE-data/resolve/main/ALCE-data.tar
4 | tar -xvf $cache_dir/datasets/ALCE-data.tar -C $cache_dir/datasets
5 | python3 setup.py install
6 |
7 | #python3 $script_dir/generate.py\
8 | # --cache_path $cache_dir\
9 | # --model Llama-2-7b-chat-hf\
10 | # --dataset gtr\
11 | # --method vanilla\
12 | # --ndoc 5\
13 | # --shot 2
14 |
15 | python3 $script_dir/asqa_benchmark.py\
16 | --cache_path $cache_dir\
17 | --remote_split Llama_2_7b_chat_hf_vanilla_shot2_ndoc5
18 |
--------------------------------------------------------------------------------
/benchmarks/ALCE/ELI5/eli5_benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from typing import Dict, Tuple, Any
4 |
5 | from datasets import Dataset
6 |
7 | import rageval as rl
8 | from benchmarks import BaseBenchmark
9 | from rageval.metrics import AnswerNLICorrectness, AnswerCitationRecall, AnswerCitationPrecision
10 |
11 |
12 | class ELI5Benchmark(BaseBenchmark):
13 |
14 | name = "eli5_benchmark"
15 |
16 | def __init__(self, cache_path) -> None:
17 | super().__init__()
18 | nli_model = rl.models.NLIModel(
19 | "text2text-generation",
20 | cache_path + "/models/t5_xxl_true_nli_mixture",
21 | )
22 | self.metrics = [
23 | AnswerNLICorrectness(nli_model=nli_model, decompose_model="nltk"),
24 | AnswerCitationRecall(nli_model=nli_model),
25 | AnswerCitationPrecision(nli_model=nli_model)
26 | ]
27 |
28 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]:
29 | self.dataset = self.dataset.rename_column("output", "answers")
30 | self.dataset = self.dataset.rename_column("claims", "gt_answers")
31 |
32 | results = {}
33 | for metric in self.metrics:
34 | results[metric.name], self.dataset = metric.compute(self.dataset, self.batch_size)
35 | return results, self.dataset
36 |
37 |
38 | if __name__ == "__main__":
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument("--cache_path", type=str, default=None)
41 | parser.add_argument("--remote_split", type=str, default=None)
42 | parser.add_argument("--local_file", type=str, default=None)
43 | args = parser.parse_args()
44 | date = time.strftime("%Y%m%d", time.localtime())
45 |
46 | benchmark = ELI5Benchmark(cache_path=args.cache_path)
47 | if args.local_file:
48 | results = benchmark.evaluate(
49 | path="json",
50 | data_files={
51 | "test": args.cache_path+"/results/"+args.local_file
52 | },
53 | split="test"
54 | )
55 | benchmark.save_results(f"benchmarks/ALCE/ELI5/results/{args.local_file[:-5]}_{date}.json")
56 | else:
57 | results = benchmark.evaluate(path="golaxy/rag-bench", name="alce_eli5_bm25", split=args.remote_split)
58 | benchmark.save_results(f"benchmarks/ALCE/ELI5/results/{args.remote_split}_{date}.json")
--------------------------------------------------------------------------------
/benchmarks/ALCE/ELI5/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | from tqdm import tqdm
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 |
8 | from rageval.models.openai import OpenAILLM
9 |
10 |
11 | def make_doc_prompt(doc, doc_id, doc_prompt, method):
12 | # For doc prompt:
13 | # - {ID}: doc id (starting from 1)
14 | # - {T}: title
15 | # - {P}: text
16 | # use_shorter: None, "summary", or "extraction"
17 |
18 | if method == "vanilla":
19 | text = doc['text']
20 | elif method == "summary":
21 | text = doc["summary"]
22 | elif method == "snippet":
23 | text = doc["extraction"]
24 | else:
25 | raise ValueError("Don't support such method.")
26 | return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1))
27 |
28 |
29 | def make_demo(item, prompt, ndoc, doc_prompt, instruction, method, test=False):
30 | # For demo prompt
31 | # - {INST}: the instruction
32 | # - {D}: the documents
33 | # - {Q}: the question
34 | # - {A}: the answers
35 | # ndoc: number of documents to put in context
36 | # use_shorter: None, "summary", or "extraction"
37 |
38 | prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question'])
39 | doc_texts = []
40 | if "{D}" in prompt:
41 | if ndoc == 0:
42 | prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line
43 | else:
44 | doc_list = item["docs"][:ndoc]
45 | for doc_id, doc in enumerate(doc_list):
46 | doc_texts.append(make_doc_prompt(doc, doc_id, doc_prompt, method=method))
47 | text = "".join(doc_texts)
48 | prompt = prompt.replace("{D}", text)
49 |
50 | if not test:
51 | answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"]
52 | prompt = prompt.replace("{A}", "").rstrip() + answer
53 | else:
54 | prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n
55 |
56 | return prompt, doc_texts
57 |
58 |
59 | if __name__ == "__main__":
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("--cache_path", type=str, default=None)
62 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo")
63 | parser.add_argument("--api_key", type=str, default=None)
64 | parser.add_argument("--max_length", type=int, default=4096)
65 | parser.add_argument("--temperature", type=float, default=0.5)
66 | parser.add_argument("--top_p", type=float, default=1.0)
67 | parser.add_argument("--dataset", type=str, default="bm25")
68 | parser.add_argument("--method", type=str, default="vanilla")
69 | parser.add_argument("--ndoc", type=int, default=5)
70 | parser.add_argument("--shot", type=int, default=2)
71 | args = parser.parse_args()
72 |
73 | print("-" * 10 + "Loading dataset" + "-" * 10)
74 |
75 | if args.dataset == "bm25":
76 | eval_data = json.load(
77 | open(args.cache_path + "/datasets/ALCE-data/eli5_eval_bm25_top100.json", "r")
78 | )
79 | elif args.dataset == "oracle":
80 | eval_data = json.load(
81 | open(args.cache_path + "/datasets/ALCE-data/eli5_eval_bm25_top100_reranked_oracle.json", "r")
82 | )
83 | else:
84 | raise ValueError("Don't support such dataset.")
85 |
86 | print("-" * 10 + "Finish loading dataset" + "-" * 10)
87 |
88 | print("-" * 10 + "Generating prompts" + "-" * 10)
89 |
90 | eval_prompt = json.load(
91 | open("benchmarks/ALCE/ELI5/prompts/eli5_prompt.json", "r")
92 | )
93 |
94 | head_prompt = ""
95 | for demo_id in range(args.shot):
96 | demo_item = eval_prompt["demos"][demo_id]
97 | prompt, _ = make_demo(
98 | demo_item,
99 | prompt=eval_prompt["demo_prompt"],
100 | ndoc=args.ndoc,
101 | doc_prompt=eval_prompt["doc_prompt"],
102 | instruction=eval_prompt["instruction"],
103 | method=args.method
104 | )
105 | head_prompt += prompt
106 | head_prompt += eval_prompt["demo_sep"]
107 |
108 | for idx, eval_item in enumerate(tqdm(eval_data)):
109 | prompt, doc_texts = make_demo(
110 | eval_item,
111 | prompt=eval_prompt["demo_prompt"],
112 | ndoc=args.ndoc,
113 | doc_prompt=eval_prompt["doc_prompt"],
114 | instruction=eval_prompt["instruction"],
115 | method=args.method,
116 | test=True
117 | )
118 | eval_data[idx]['prompt'] = head_prompt + prompt
119 | eval_data[idx]['contexts'] = doc_texts
120 | eval_data[idx]['docs'] = eval_item["docs"][:args.ndoc]
121 |
122 | print("-" * 10 + "Finish generating prompts" + "-" * 10)
123 |
124 | print("-" * 10 + "Loading model" + "-" * 10)
125 |
126 | model_name = args.model.split("/")[-1]
127 | if "gpt" in model_name:
128 | os.environ["OPENAI_API_KEY"] = args.api_key
129 | model = OpenAILLM(
130 | model=model_name,
131 | _api_key_env_var="OPENAI_API_KEY",
132 | max_tokens=args.max_length,
133 | temperature=args.temperature,
134 | top_p=args.top_p
135 | )
136 | else:
137 | tokenizer = AutoTokenizer.from_pretrained(args.cache_path+"/models/"+args.model, use_fast=False)
138 | model = AutoModelForCausalLM.from_pretrained(
139 | args.cache_path+"/models/"+args.model,
140 | device_map='auto'
141 | )
142 |
143 | print("-" * 10 + "Finish loading model" + "-" * 10)
144 |
145 | print("-" * 10 + "Predict" + "-" * 10)
146 |
147 | for idx, item in enumerate(tqdm(eval_data)):
148 | prompt = item['prompt']
149 | if "gpt" in model_name:
150 | output = model.generate(
151 | inputs=[prompt],
152 | system_role="You are a helpful assistant that answers the following questions with proper citations."
153 | )
154 | item['output'] = output.generations[0][0].text
155 | else:
156 | inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
157 | stop = ["\n", "Ċ", "ĊĊ", "<0x0A>"] # In Llama \n is <0x0A>; In OPT \n is Ċ
158 | stop_token_ids = list(
159 | set(
160 | [tokenizer._convert_token_to_id(stop_token) for stop_token in stop]
161 | + [model.config.eos_token_id]
162 | )
163 | )
164 | if "llama" in model_name.lower():
165 | stop_token_ids.remove(tokenizer.unk_token_id)
166 |
167 | generation = model.generate(
168 | **inputs,
169 | max_length=args.max_length,
170 | temperature=args.temperature,
171 | top_p=args.top_p,
172 | eos_token_id=stop_token_ids,
173 | do_sample=True
174 | )
175 | output = tokenizer.decode(generation[0][inputs['input_ids'].size(1):], skip_special_tokens=True)
176 | item['output'] = output
177 |
178 | print("-" * 10 + "Finish predicting" + "-" * 10)
179 |
180 | file_name = f"eli5-{args.dataset}-{model_name}-{args.method}-shot{args.shot}-ndoc{args.ndoc}"
181 | file_name = file_name.replace("-", "_")
182 | result_path = args.cache_path + "/results/" + file_name + ".json"
183 | json.dump(eval_data, open(result_path, "w"), indent=4)
184 |
185 | print(f"\nResult file saved as {result_path}")
186 |
--------------------------------------------------------------------------------
/benchmarks/ALCE/ELI5/results/eli5-bm25-Llama-2-7b-chat-hf-vanilla-shot2-ndoc5-20240310.json:
--------------------------------------------------------------------------------
1 | {
2 | "nli_claim": 11.5,
3 | "citation_recall": 26.623809523809523,
4 | "citation_precision": 74.54545454545455
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ELI5/results/eli5_oracle_Llama_2_7b_chat_hf_vanilla_shot2_ndoc5_20240430.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_claim_recall": 0.17766666666666667,
3 | "answer_citation_recall": 0.34007738095238094,
4 | "answer_citation_precision": 0.7563654518222666
5 | }
--------------------------------------------------------------------------------
/benchmarks/ALCE/ELI5/run.sh:
--------------------------------------------------------------------------------
1 | script_dir=$(cd $(dirname $0);pwd)
2 | cache_dir=$(dirname $(dirname $(dirname $script_dir)))/.rageval
3 | wget -cP $cache_dir/datasets https://huggingface.co/datasets/princeton-nlp/ALCE-data/resolve/main/ALCE-data.tar
4 | tar -xvf $cache_dir/datasets/ALCE-data.tar -C $cache_dir/datasets
5 | python3 setup.py install
6 |
7 | #python3 $script_dir/generate.py\
8 | # --cache_path $cache_dir\
9 | # --model Llama-2-7b-chat-hf\
10 | # --dataset bm25\
11 | # --method vanilla\
12 | # --ndoc 5\
13 | # --shot 2
14 |
15 | python3 $script_dir/eli5_benchmark.py\
16 | --cache_path $cache_dir\
17 | --remote_split Llama_2_7b_chat_hf_vanilla_shot2_ndoc5
18 |
--------------------------------------------------------------------------------
/benchmarks/ASQA/asqa_benchmark.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Any, Optional
2 | from datasets import Dataset
3 | import numpy as np
4 | import os
5 | import argparse
6 | from benchmarks import BaseBenchmark
7 | from rageval.metrics import (AnswerRougeCorrectness, AnswerEMCorrectness, AnswerDisambigF1Correctness)
8 |
9 |
10 | class ASQABenchmark(BaseBenchmark):
11 | """Benchmark for ASQA dataset.
12 |
13 | The ASQA dataset is a question-answering dataset that contains factoid questions and long-form answers. The benchmark evaluates the correctness of the answers in the dataset.
14 | """
15 |
16 | name = "asqa_benchmark"
17 | metrics = [AnswerRougeCorrectness(rouge_type="rougeL"),
18 | AnswerEMCorrectness(ignore_case=True),
19 | AnswerDisambigF1Correctness()]
20 |
21 | ground_truths = {
22 | "answer_disambig_f1": "long_answers",
23 | "answer_rouge_correctness": "long_answers",
24 | "answer_exact_match": "short_answers"
25 | }
26 |
27 | def __init__(self) -> None:
28 | """Initialization."""
29 | super().__init__()
30 |
31 | def is_existed(self, column_name: str) -> bool:
32 | """Check if the column exists in the dataset."""
33 | return column_name in self.dataset.column_names
34 |
35 | def _evaluate(self, ) -> Tuple[Dict[Any, Any], Dataset]:
36 | """Evaluate the dataset and return the dataset with scores.
37 |
38 | For the ASQA dataset, the `short_answers` and `long_answers` are stored in the "qa_pairs" and "annotations" columns, respectively. We need to extract them and add them to the dataset.
39 |
40 | We use the `short_answers` as the `gt_answers` to evaluate the string Exact Match correctness and the `long_answers` to evaluate the RougeL and DisambigF1 score. And then we calculate the `DR score` as the geometric mean of the RougeL and DisambigF1 scores.
41 | """
42 | if not self.is_existed("short_answers"):
43 | self.dataset = self.dataset.map(lambda example: {"short_answers": [ann["short_answers"] for ann in example["qa_pairs"]]})
44 | if not self.is_existed("long_answers"):
45 | self.dataset = self.dataset.map(lambda example: {"long_answers": [ann["long_answer"] for ann in example["annotations"]]})
46 |
47 | results = {}
48 | for m in self.metrics:
49 | if m.name in self.ground_truths:
50 | print(f"Calculating {m.name}...")
51 |
52 | if self.is_existed(m.name):
53 | # Remove the metric column if it already exists
54 | self.dataset = self.dataset.remove_columns(m.name)
55 | if not self.is_existed(self.ground_truths[m.name]):
56 | # Check if the ground truth column exists
57 | raise ValueError(f"The column {self.ground_truths[m.name]} is not in the dataset. Please check the column names.")
58 |
59 | avg_scores, scores = m.compute(
60 | self.dataset["answers"],
61 | self.dataset[self.ground_truths[m.name]]
62 | )
63 | results[m.name] = avg_scores
64 | self.dataset = self.dataset.add_column(m.name, scores)
65 |
66 | print(f"{m.name}: {avg_scores}")
67 |
68 | if self.is_existed("answer_rouge_correctness") and self.is_existed("answer_disambig_f1"):
69 | # Notice that DR score is an overall geometric mean of RougeL and DisambigF1 scores, which is calculated as sqrt(RougeL * DisambigF1) for whole dataset instead of average of each sample.
70 | print("Calculating DR score...")
71 | results["DR_score"] = np.sqrt(np.average(self.dataset["answer_disambig_f1"]) * np.average(self.dataset["answer_rouge_correctness"]))
72 | print(f"DR_score: {results['DR_score']}")
73 |
74 | return results, self.dataset
75 |
76 | if __name__ == "__main__":
77 | parser = argparse.ArgumentParser()
78 | parser.add_argument("--output_dir", type=str, default=".rageval/benchmark")
79 | parser.add_argument("--split", type=str, default="llama2_7b_chat")
80 | args = parser.parse_args()
81 |
82 | benchmark = ASQABenchmark()
83 |
84 | results = benchmark.evaluate(path="golaxy/rag-bench", name="asqa", split=args.split)
85 | print(f"Results:\n {results}")
86 |
87 | benchmark.save_results(os.path.join(args.output_dir,"results", f"{args.split}.jsonl"))
88 | benchmark.save_dataset(os.path.join(args.output_dir,"dataset", f"{args.split}.jsonl"))
89 |
90 | benchmark.set_metric([AnswerEMCorrectness(ignore_case=False)])
91 | results = benchmark.evaluate()
92 | print(f"Results:\n {results}")
93 |
--------------------------------------------------------------------------------
/benchmarks/ASQA/generate.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset, load_dataset
2 | import re
3 | from rageval.models import OpenAILLM
4 | import os
5 | import logging
6 | import argparse
7 |
8 | from prompts import (FEW_SHOT_EXAMPLES, PROMPT)
9 |
10 | def extract_key_information(pred: str) -> str:
11 | '''Extract key information from the response.'''
12 | pattern = r"(?:1\.|\(1\)).*?((?:1\.|\(1\)).*)" # find the second list starting with 1. or (1)
13 | pred = pred.strip().strip()
14 | matches = re.findall(pattern, pred, re.DOTALL)
15 | if matches:
16 | pred = matches[0]
17 | else:
18 | print(f"Cannot extract key information from the response: {pred}")
19 | return pred
20 | pred = re.sub(r'\(\d+\)\s', '', pred) # remove the index numbers
21 | return pred
22 |
23 | def generate_answers(engine: OpenAILLM, dataset: Dataset) -> Dataset:
24 | prompts = [
25 | PROMPT.format(few_shot_examples=FEW_SHOT_EXAMPLES,
26 | question=data['ambiguous_question'])
27 | for data in dataset
28 | ]
29 | responses = engine.batch_generate(prompts)
30 | response_texts = [r.generations[0][0].text for r in responses]
31 | answers = [extract_key_information(response) for response in response_texts]
32 | dataset = dataset.add_column("responses", response_texts)
33 | return dataset.add_column("answers", answers)
34 |
35 | if __name__ == "__main__":
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--max_num_examples", type=int, default=5)
38 | parser.add_argument("--max_new_tokens", type=int, default=256)
39 | parser.add_argument("--output_path", type=str, default="benchmarks/ASQA/output")
40 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo-instruct")
41 | parser.add_argument("--api_key", type=str, default=None)
42 |
43 | args = parser.parse_args()
44 |
45 | print("\nLoad ASQA dataset...")
46 | dataset = load_dataset("din0s/asqa")
47 | dataset = dataset['dev'].select(range(args.max_num_examples))
48 |
49 | print("Init ASQA dataset...")
50 | os.environ['OPENAI_API_KEY'] = args.api_key
51 | engine = OpenAILLM(args.model,
52 | _api_key_env_var = 'OPENAI_API_KEY',
53 | max_tokens=args.max_new_tokens)
54 |
55 | print("Start generate answers...")
56 | dataset = generate_answers(engine, dataset)
57 |
58 | file_path = os.path.join(args.output_path, f"{args.model}.jsonl")
59 | dataset.to_json(file_path)
60 | print(f"\nFinish generate dataset. Dataset saved as {file_path}")
61 |
62 | engine.calculate_api_cost()
63 |
--------------------------------------------------------------------------------
/benchmarks/ASQA/prompts.py:
--------------------------------------------------------------------------------
1 | FEW_SHOT_EXAMPLES = """Given an ambiguous question, figure out its interpretations and answer them one by one.
2 | Question: Who played bonnie in gone with the wind?
3 | Answer: This question is ambiguous in terms of which version or adaptation of Gone with the Wind is being referred to. In order to figure out its interpretations, we need to consider different versions or adaptations of Gone with the Wind. Gone with the Wind has two versions or adaptations: the 1939 film Gone with the Wind or the 2008 musical Gone with the Wind. Therefore, this question has 2 interpretations: (1) Who played Bonnie in the 1939 film Gone with the Wind? (2) Who played Bonnie in the 2008 musical Gone with the Wind? The answers to all interpretations are: (1) The 1939 film Gone with the Wind’s character Bonnie was played by Eleanore Cammack "Cammie" King. (2) The 2008 musical Gone with the Wind’s character Bonnie was played by Leilah de Meza.
4 |
5 | Given an ambiguous question, figure out its interpretations and answer them one by one.
6 | Question: What is the second largest city in the usa?
7 | Answer: This question is ambiguous in terms of the criteria being used to determine the second largest city in the USA. In order to figure out its interpretations, we need to consider different criteria to determine a city’s size. City size can be measured by two criteria: population or area. Therefore, this question has 2 interpretations: (1) What is the second largest city in the USA by population? (2) What is the second largest city in the USA by area? The answers to all interpretations are: (1) The second largest city in the USA by population is Los Angeles, California. (2) The second largest city in the USA by area is Juneau, Alaska.
8 |
9 | Given an ambiguous question, figure out its interpretations and answer them one by one.
10 | Question: When was bohemian rhapsody released as a single?
11 | Answer: This question is ambiguous in terms of which country’s release of the single is being referred to. In order to figure out its interpretations, we need to consider different countries where Bohemian Rhapsody is released. Bohemian Rhapsody was released in the United Kingdom and in the United States on different dates. Therefore, this question has 2 interpretations: (1) When was Bohemian Rhapsody released as a single in the United Kingdom? (2) When was Bohemian Rhapsody released as a single in the United States? The answers to all interpretations are: (1) Bohemian Rhapsody was released as a single in the United Kingdom on 31 October 1975. (2) Bohemian Rhapsody was released as a single in the United States on December 1975.
12 |
13 | Given an ambiguous question, figure out its interpretations and answer them one by one.
14 | Question: Where do the philadelphia eagles play their home games?
15 | Answer: This question is ambiguous in terms of which specific location or venue is being referred to. In order to figure out its interpretations, we need to consider the different possible locations or venues that could be considered the home field of the Philadelphia Eagles. These include the city, the sports complex, or the stadium. Therefore, this question has 3 interpretations: (1) What city do the Philadelphia Eagles play their home games? (2) In what sports complex do the Philadelphia Eagles play their home games? (3) What stadium do the Philadelphia Eagles play their home games? The answers to all interpretations are: (1) Philadelphia Eagles play their home games in the city Philadelphia. (2) Philadelphia Eagles play their home games in the South Philadelphia Sports Complex. (3) Philadelphia Eagles play their home games in the Lincoln Financial Field stadium.
16 |
17 | Given an ambiguous question, figure out its interpretations and answer them one by one.
18 | Question: When did xbox one come out in australia?
19 | Answer: This question is ambiguous in terms of which specific version of the Xbox One is being referred to. In order to figure out its interpretations, we need to consider the different versions of the Xbox One that have been released. Xbox One has two versions: the Xbox One video game console or the Xbox One X high-end model. Therefore, this question has 2 interpretations: (1) When did the Xbox One release in Australia? (2) When did the Xbox One X release in Australia? The answers to all interpretations are: (1) The Xbox One video game console was released in Australia on November 22, 2013. (2) The Xbox One X video game console was released in Australia on November 7, 2017.
20 |
21 | Given an ambiguous question, figure out its interpretations and answer them one by one.
22 | Question: When does the movie summer of 84 come out?
23 | Answer: This question is ambiguous in terms of which release of the movie is being referred to. In order to figure out its interpretations, we need to consider different releases of the movie Summer of ’84. The movie Summer of ’84 is first released at the Sundance Festival before it’s released throughout the US. Therefore, this question has 2 interpretations: (1) When did the movie Summer of ’84 first release at the Sundance Festival? (2) When did the movie Summer of ’84 first release throughout the US? The answers to all interpretations are: (1) Summer of ’84 was released at the Sundance Festival on January 22, 2018. (2) Summer of ’84 was released throughout the US on August 10, 2018.
24 |
25 | Given an ambiguous question, figure out its interpretations and answer them one by one.
26 | Question: What was roy orbison’s first number one hit?
27 | Answer: This question is ambiguous in terms of which specific chart or region is being referred to. In order to figure out its interpretations, we need to consider the different charts and regions where Roy Orbison’s music was popular. Roy Orbison is popular in both the US Hot 100 and Canada, and the UK and Ireland. Therefore, this question has 2 interpretations: (1) What was Roy Orbison’s first number one hit in the US Hot 100 and Canada? (2) What was Roy Orbison’s first number one hit in the UK and Ireland? The answers to all interpretations are: (1) Running Scared was the first number one hit for Roy Orbison in the US Hot 100 and Canada. (2) Only the Lonely (Know the Way I Feel) was the first number one hit for Roy Orbison in the UK and Ireland.
28 |
29 | Given an ambiguous question, figure out its interpretations and answer them one by one.
30 | Question: What is the criminal’s name in the breakfast club?
31 | Answer: This question is ambiguous in terms of which specific name is being referred to - the character’s name or the actor’s name. In order to figure out its interpretations, we need to consider both possibilities: the character’s name or the actor’s name. Therefore, this question has 2 interpretations: (1) What is the criminal’s character name in The Breakfast Club? (2) What is the the name of the actor who played the criminal in The Breakfast Club? The answers to all interpretations are: (1) John Bender was the name of the criminal’s character in The Breakfast Club. (2) Judd Nelson was the actor of the criminal in The Breakfast Club."""
32 |
33 | PROMPT = """{few_shot_examples}\n\nGiven an ambiguous question, figure out its interpretations and answer them one by one.\nQuestion: {question}\nAnswer: """
--------------------------------------------------------------------------------
/benchmarks/ASQA/run.sh:
--------------------------------------------------------------------------------
1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0))))
2 | cd $rageval_dir
3 | echo "Running ASQA Benchmark"
4 | python3 benchmarks/ASQA/asqa_benchmark.py --output_dir ".rageval/benchmark" --split "gpt_3.5_turbo_instruct"
5 | echo "ASQA Benchmark Complete"
--------------------------------------------------------------------------------
/benchmarks/ASQA/run_generate.sh:
--------------------------------------------------------------------------------
1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0))))
2 | cd $rageval_dir
3 | echo "Generating ASQA examples"
4 | python3 benchmarks/ASQA/generate.py \
5 | --max_num_examples 500 \
6 | --max_new_tokens 256 \
7 | --output_path "benchmarks/ASQA/output" \
8 | --model "gpt-3.5-turbo-instruct" \
9 | --api_key "YOUR_API_KEY"
--------------------------------------------------------------------------------
/benchmarks/HOTPOTQA/hotpotqa_benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | from typing import Dict, Tuple, Any
5 |
6 | from datasets import Dataset
7 |
8 | from benchmarks import BaseBenchmark
9 | from rageval.metrics import (AnswerEMCorrectness, AnswerF1Correctness)
10 |
11 |
12 | class HOTPOTQABenchmark(BaseBenchmark):
13 | """Benchmark for HotPotQA dataset.
14 |
15 | HotPotQA is a new dataset with 113k Wikipedia-based question-answer pairs with four key features: (1) the questions require finding and reasoning over multiple supporting documents to answer; (2) the questions are diverse and not constrained to any pre-existing knowledge bases or knowledge schemas; (3) we provide sentence-level supporting facts required for reasoning, allowingQA systems to reason with strong supervision and explain the predictions; (4) we offer a new type of factoid comparison questions to test QA systems’ ability to extract relevant facts and perform necessary comparison.
16 |
17 | """
18 |
19 | name = "hotpot_qa_benchmark"
20 |
21 | def __init__(self) -> None:
22 | """Initialization."""
23 | super().__init__()
24 |
25 | def _recode_gt_supporting_facts(self, data: object) -> object:
26 | """To calculate f1 recode gt_sent_ids by linking title and index"""
27 | recode_answers = []
28 | for title, sent_id in zip(data['supporting_facts']['title'], data['supporting_facts']['sent_id']):
29 | recode = title.replace(" ","")+ str(sent_id)
30 | recode_answers.append(recode)
31 | recode_answers = [' '.join(recode_answers)]
32 | data["gt_sent_ids"] = recode_answers
33 | return data
34 |
35 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]:
36 | """Evaluate the dataset and return the dataset with scores.
37 |
38 | For the HotPotQA dataset(Distractor Setting), we evaluate models by using the `short_answer` and `supporting_answer`.
39 |
40 | For the HotPotQA dataset(Fullwiki Setting), we evaluate models by using the `response`.
41 |
42 | In Distractor Setting,we use the `answer` as the `gt_answers` to evaluate the string Exact Match correctness and the `supporting_facts` to make "gt_sent_ids" to evaluate the F1.
43 |
44 | In Fullwiki Setting,we use the `answer` as the `gt_answers` to evaluate the string Exact Match correctness.
45 | """
46 |
47 | self.metrics = [AnswerEMCorrectness(ignore_case=True),
48 | AnswerF1Correctness()
49 | ]
50 | if (("supporting_answer" in self.dataset.column_names) and "short_answer" in self.dataset.column_names):
51 | self.dataset = self.dataset.map(self._recode_gt_supporting_facts)
52 | self.dataset = self.dataset.map(lambda exmaple: {"answer": [[exmaple['answer']]]})
53 | ground_truths = {
54 | "answer_f1": ("supporting_answer", "gt_sent_ids"),
55 | "answer_exact_match": ("short_answer", "answer")
56 | }
57 | else:
58 | self.dataset = self.dataset.map(lambda exmaple: {"answer": [[exmaple['answer']]]})
59 | ground_truths = {
60 | "answer_exact_match": ("response", "answer")
61 | }
62 |
63 | results = {}
64 |
65 | for metric in self.metrics:
66 | if metric.name in ground_truths:
67 | print(f"Calculating {metric.name}...")
68 |
69 | if metric.name in self.dataset.column_names:
70 | self.dataset = self.dataset.remove_columns(metric.name)
71 |
72 | an, gtan = ground_truths[metric.name]
73 | self.dataset = self.dataset.rename_column(an, "answers")
74 | self.dataset = self.dataset.rename_column(gtan, "gt_answers")
75 |
76 | results[metric.name], self.dataset = metric.compute(self.dataset, self.batch_size)
77 |
78 | self.dataset = self.dataset.rename_column("answers", an)
79 | self.dataset = self.dataset.rename_column("gt_answers", gtan)
80 | self.dataset = self.dataset.map(lambda example: {"answer": example['answer'][0][0]})
81 | return results, self.dataset
82 |
83 |
84 | if __name__ == "__main__":
85 | parser = argparse.ArgumentParser()
86 |
87 | parser.add_argument("--output_dir", type=str, default="benchmarks/HOTPOTQA")
88 | parser.add_argument("--remote_split", type=str, default="gpt_3.5_turbo")
89 | parser.add_argument("--local_file", type=str, default=None)
90 |
91 | args = parser.parse_args()
92 | date = time.strftime("%Y%m%d", time.localtime())
93 |
94 | benchmark = HOTPOTQABenchmark()
95 | if args.local_file:
96 | data_file = os.path.join(args.output_dir, 'output', args.local_file)
97 | results = benchmark.evaluate(
98 | path='json',
99 | data_files={"test": data_file},
100 | split="test"
101 | )
102 | print(f"Results:\n {results}")
103 | benchmark.save_results(os.path.join(args.output_dir, 'results', f"{args.local_file[:-5]}_{date}.jsonl"))
104 | benchmark.save_dataset(os.path.join(args.output_dir, 'output', f"{args.local_file[:-5]}_{date}.jsonl"))
105 | else:
106 | results = benchmark.evaluate(path='golaxy/rag-bench', name='hotpot_qa', split=args.remote_split)
107 | print(f"Results:\n {results}")
108 | benchmark.save_results(os.path.join(args.output_dir, 'results', f"{args.remote_split}_{date}.jsonl"))
109 | benchmark.save_dataset(os.path.join(args.output_dir, 'output', f"{args.remote_split}_{date}.jsonl"))
110 |
--------------------------------------------------------------------------------
/benchmarks/HOTPOTQA/run.sh:
--------------------------------------------------------------------------------
1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0))))
2 | cd $rageval_dir
3 |
4 | echo "Running HotPotQA Benchmark"
5 |
6 | python3 benchmarks/HOTPOTQA/hotpot_qa_benchmark.py \
7 | --output_dir "benchmarks/HOTPOTQA" \
8 | --remote_split "gpt_3.5_turbo"
9 |
10 | # Check return status code
11 | if [ $? -eq 0 ]; then
12 | echo "HotPotQA Benchmark Complete"
13 | else
14 | echo "Error: HotPotQA Benchmark failed to execute."
15 | fi
16 |
17 |
--------------------------------------------------------------------------------
/benchmarks/HOTPOTQA/run_generate.sh:
--------------------------------------------------------------------------------
1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0))))
2 | cache_dir="$rageval_dir/HOTPOTQA"
3 | cd $rageval_dir
4 | echo "Generating HOTPOTQA examples"
5 | python3 benchmarks/HOTPOTQA/generate.py \
6 | --subset "distractor"\
7 | --num_documents 10 \
8 | --max_num_examples 500 \
9 | --max_length 4096 \
10 | --output_path "benchmarks/HOTPOT/output" \
11 | --cache_path $cache_dir \
12 | --model "gpt-3.5-turbo" \
13 | --api_key "YOUR_API_KEY"
--------------------------------------------------------------------------------
/benchmarks/WebGLM/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | from tqdm import tqdm
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 |
8 | from rageval.models.openai import OpenAILLM
9 |
10 |
11 | PROMPT = "Answer the question based on the following references with citations. Use a mark for each helpful reference you cited, such as [1]. If there are multiple citations at one position, please use a format like [1][2][3]. If a reference is useless, do not cite it."
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--cache_path", type=str, default=None)
16 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo")
17 | parser.add_argument("--api_key", type=str, default=None)
18 | parser.add_argument("--max_length", type=int, default=4096)
19 | parser.add_argument("--temperature", type=float, default=0.5)
20 | parser.add_argument("--top_p", type=float, default=1.0)
21 | args = parser.parse_args()
22 |
23 | print("-" * 10 + "Loading dataset" + "-" * 10)
24 |
25 | eval_data = []
26 | with open(args.cache_path + "/datasets/webglm-test.jsonl", "r") as read_file:
27 | for line in tqdm(read_file):
28 | eval_data.append(json.loads(line))
29 |
30 | print("-" * 10 + "Finish loading dataset" + "-" * 10)
31 |
32 | print("-" * 10 + "Generating prompts" + "-" * 10)
33 |
34 | for idx, item in enumerate(tqdm(eval_data)):
35 | prompt = PROMPT + '\n'
36 | for ix, ref in enumerate(item["references"]):
37 | prompt += f'Reference [{ix + 1}]: {ref}\n'
38 | prompt += f'Question: {item["question"]}\nAnswer: '
39 | eval_data[idx]["prompt"] = prompt
40 |
41 | print("-" * 10 + "Finish generating prompts" + "-" * 10)
42 |
43 | print("-" * 10 + "Loading model" + "-" * 10)
44 |
45 | model_name = args.model.split("/")[-1]
46 | if "gpt" in model_name:
47 | os.environ["OPENAI_API_KEY"] = args.api_key
48 | model = OpenAILLM(
49 | model=model_name,
50 | _api_key_env_var="OPENAI_API_KEY",
51 | max_tokens=args.max_length,
52 | temperature=args.temperature,
53 | top_p=args.top_p
54 | )
55 | else:
56 | tokenizer = AutoTokenizer.from_pretrained(args.cache_path + "/models/" + args.model, use_fast=False)
57 | model = AutoModelForCausalLM.from_pretrained(
58 | args.cache_path + "/models/" + args.model,
59 | device_map='auto'
60 | )
61 |
62 | print("-" * 10 + "Finish loading model" + "-" * 10)
63 |
64 | print("-" * 10 + "Predict" + "-" * 10)
65 |
66 | for idx, item in enumerate(tqdm(eval_data)):
67 | prompt = item['prompt']
68 | if "gpt" in model_name:
69 | output = model.generate(
70 | inputs=[prompt],
71 | system_role="You are a helpful assistant that answers the following questions with proper citations."
72 | )
73 | item['output'] = output.generations[0][0].text
74 | else:
75 | inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
76 | stop = ["\n", "Ċ", "ĊĊ", "<0x0A>"] # In Llama \n is <0x0A>; In OPT \n is Ċ
77 | stop_token_ids = list(
78 | set(
79 | [tokenizer._convert_token_to_id(stop_token) for stop_token in stop]
80 | + [model.config.eos_token_id]
81 | )
82 | )
83 | if "llama" in model_name.lower():
84 | stop_token_ids.remove(tokenizer.unk_token_id)
85 |
86 | generation = model.generate(
87 | **inputs,
88 | max_length=args.max_length,
89 | temperature=args.temperature,
90 | top_p=args.top_p,
91 | eos_token_id=stop_token_ids,
92 | do_sample=True
93 | )
94 | output = tokenizer.decode(generation[0][inputs['input_ids'].size(1):], skip_special_tokens=True)
95 | item['output'] = output
96 |
97 | print("-" * 10 + "Finish predicting" + "-" * 10)
98 |
99 | file_name = f"webglm-{model_name}"
100 | file_name = file_name.replace("-", "_")
101 | result_path = args.cache_path + "/results/" + file_name + ".json"
102 | json.dump(eval_data, open(result_path, "w"), indent=4)
103 |
104 | print(f"\nResult file saved as {result_path}")
105 |
--------------------------------------------------------------------------------
/benchmarks/WebGLM/results/webglm_Llama_2_7b_chat_hf_20240502.json:
--------------------------------------------------------------------------------
1 | {
2 | "answer_rouge_correctness": 0.2548292585536276,
3 | "answer_citation_recall": 0.10667063492063492,
4 | "answer_citation_precision": 0.9397590361445783
5 | }
--------------------------------------------------------------------------------
/benchmarks/WebGLM/run.sh:
--------------------------------------------------------------------------------
1 | script_dir=$(cd $(dirname $0);pwd)
2 | cache_dir=$(dirname $(dirname $script_dir))/.rageval
3 | wget -c https://huggingface.co/datasets/THUDM/webglm-qa/resolve/main/data/test.jsonl -O $cache_dir/datasets/webglm-test.jsonl
4 | python3 setup.py install
5 |
6 | #python3 $script_dir/generate.py\
7 | # --cache_path $cache_dir\
8 | # --model Llama-2-7b-chat-hf
9 |
10 | python3 $script_dir/webglm_benchmark.py\
11 | --cache_path $cache_dir\
12 | --remote_split Llama_2_7b_chat_hf
--------------------------------------------------------------------------------
/benchmarks/WebGLM/webglm_benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from typing import Dict, Tuple, Any
4 |
5 | from datasets import Dataset
6 |
7 | import rageval as rl
8 | from benchmarks import BaseBenchmark
9 | from rageval.metrics import AnswerRougeCorrectness, AnswerCitationRecall, AnswerCitationPrecision
10 |
11 |
12 | class WebGLMBenchmark(BaseBenchmark):
13 |
14 | name = "webglm_benchmark"
15 |
16 | def __init__(self, cache_path) -> None:
17 | super().__init__()
18 | nli_model = rl.models.NLIModel(
19 | "text2text-generation",
20 | cache_path + "/models/t5_xxl_true_nli_mixture",
21 | )
22 | self.metrics = [
23 | AnswerRougeCorrectness(rouge_type="rougeL"),
24 | AnswerCitationRecall(nli_model=nli_model),
25 | AnswerCitationPrecision(nli_model=nli_model)
26 | ]
27 |
28 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]:
29 | self.dataset = self.dataset.rename_column("output", "answers")
30 | self.dataset = self.dataset.rename_column("answer", "gt_answers")
31 | self.dataset = self.dataset.rename_column("references", "contexts")
32 |
33 | results = {}
34 |
35 | for metric in self.metrics:
36 | if metric.name == "answer_rouge_correctness":
37 | self.dataset = self.dataset.map(lambda data: {"gt_answers": [data["gt_answers"]]})
38 |
39 | results[metric.name], self.dataset = metric.compute(self.dataset, self.batch_size)
40 |
41 | if metric.name == "answer_rouge_correctness":
42 | self.dataset = self.dataset.map(lambda data: {"gt_answers": data["gt_answers"][0]})
43 |
44 | return results, self.dataset
45 |
46 |
47 | if __name__ == "__main__":
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument("--cache_path", type=str, default=None)
50 | parser.add_argument("--remote_split", type=str, default=None)
51 | parser.add_argument("--local_file", type=str, default=None)
52 | args = parser.parse_args()
53 | date = time.strftime("%Y%m%d", time.localtime())
54 |
55 | benchmark = WebGLMBenchmark(cache_path=args.cache_path)
56 | if args.local_file:
57 | results = benchmark.evaluate(
58 | path="json",
59 | data_files={
60 | "test": args.cache_path+"/results/"+args.local_file
61 | },
62 | split="test"
63 | )
64 | benchmark.save_results(f"benchmarks/WebGLM/results/{args.local_file[:-5]}_{date}.json")
65 | else:
66 | results = benchmark.evaluate(path="golaxy/rag-bench", name="webglm", split=args.remote_split)
67 | benchmark.save_results(f"benchmarks/WebGLM/results/{args.remote_split}_{date}.json")
68 |
--------------------------------------------------------------------------------
/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseBenchmark
--------------------------------------------------------------------------------
/benchmarks/auto/README.md:
--------------------------------------------------------------------------------
1 | # AUTO BENCHMARK
2 |
3 | Auto benchmark aims to generate testsets with synthetic Q&A pairs from raw passages, using as few human annotations as possible. To achieve this, we use LLM as a annotator and dataset quality inspector.
4 |
5 | ## Usage
6 |
7 | 1. Prepare your corpus in a JSON file, following the format
8 |
9 | ```
10 | corpus.json: ["Document 1", "Document 2", ...]
11 | few_shot_cases.json: [
12 | {"document": "Sample document",
13 | "Query": "Sample question"}
14 | ]
15 | ```
16 |
17 | 2. Place the corpus JSON file(s) in `corpus` directory.
18 | 3. Run `run.sh` to start dataset generation. The result will saved in `output` directory.
19 |
20 | ### Arguments:
21 |
22 | `--corpus_dir`: Directory containing the corpus and few-shot case JSON files.
23 |
24 | `--output_dir`: Directory where the generated dataset JSON will be saved.
25 |
26 | `--model`: The OpenAI GPT model to use, e.g., gpt-3.5-turbo-16k.
27 |
28 | `--api_key`: Your OpenAI API key.
29 |
30 | ## Citations
31 |
32 | In this case, `documents` refers to a list of news articles, and `few-shot cases` are derived from a random split of the Multi-RC dataset. And we refer to the prompt from ARES: An Automated Evaluation Framework for Retrieval-Augmented Generation Systems.
33 |
34 | ``` bibtex
35 | @misc{saadfalcon2023ares,
36 | title={ARES: An Automated Evaluation Framework for Retrieval-Augmented Generation Systems},
37 | author={Jon Saad-Falcon and Omar Khattab and Christopher Potts and Matei Zaharia},
38 | year={2023},
39 | eprint={2311.09476},
40 | archivePrefix={arXiv},
41 | primaryClass={cs.CL}
42 | }
43 |
44 | @inproceedings{MultiRC2018,
45 | author = {Daniel Khashabi and Snigdha Chaturvedi and Michael Roth and Shyam Upadhyay and Dan Roth},
46 | title = {Looking Beyond the Surface:A Challenge Set for Reading Comprehension over Multiple Sentences},
47 | booktitle = {Proceedings of North American Chapter of the Association for Computational Linguistics (NAACL)},
48 | year = {2018}
49 | }
50 | ```
51 |
--------------------------------------------------------------------------------
/benchmarks/auto/auto_benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 | import pandas as pd
6 | from datasets import Dataset
7 | from typing import List, Any, Tuple, Dict
8 |
9 | from rageval.models import OpenAILLM
10 | from prompt import (SYNTHETIC_QUERY_FEW_SHOT, SYNTHETIC_QUERY_SYSTEM, SYNTHETIC_QUERY_USER, SYNTHETIC_ANSWER_SYSTEM, SYNTHETIC_ANSWER_USER)
11 |
12 |
13 | def load_corpus(corpus_dir):
14 | with open(f"{corpus_dir}/corpus.json", "r", encoding="utf-8") as f:
15 | docs = json.load(f)
16 | df = pd.DataFrame(docs)
17 | df.drop_duplicates(inplace=True)
18 | dataset = Dataset.from_dict({'document':df[0].apply(lambda x: x.strip())})
19 |
20 | with open(f"{corpus_dir}/few_shot_cases.json", "r", encoding="utf-8") as f:
21 | cases = json.load(f)
22 | cases = random.sample(cases, 3)
23 | return dataset, cases
24 |
25 | def generate_responses(engine: OpenAILLM, user_prompts: List[List[str]], system_prompt: List[str]) -> List[str]:
26 | '''Generate responses from the OpenAILLM model.'''
27 | responses = engine.batch_generate(user_prompts, system_roles=system_prompt * len(user_prompts))
28 | response_texts = [r.generations[0][0].text for r in responses]
29 |
30 | return response_texts
31 |
32 | def generate_questions(engine: OpenAILLM, dataset: Dataset, cases) -> Dataset:
33 | system_prompt = [SYNTHETIC_QUERY_SYSTEM]
34 | few_shot_cases = ""
35 | for i in range(len(cases)):
36 | few_shot_cases += SYNTHETIC_QUERY_FEW_SHOT.format(
37 | document=cases[i]["document"], question=cases[i]["Query"])
38 | user_prompts = [[SYNTHETIC_QUERY_USER.format(
39 | few_shot_cases=few_shot_cases, document=d['document'])] for d in dataset]
40 |
41 | questions = generate_responses(engine, user_prompts, system_prompt)
42 | return dataset.add_column("question", questions)
43 |
44 | def generate_answers(engine: OpenAILLM, dataset: Dataset) -> Dataset:
45 | system_prompt = [SYNTHETIC_ANSWER_SYSTEM]
46 | user_prompts = [[SYNTHETIC_ANSWER_USER.format(
47 | question=d['question'], document=d['document']) + "\n"] for d in dataset]
48 |
49 | answers = generate_responses(engine, user_prompts, system_prompt)
50 | return dataset.add_column("answer", answers)
51 |
52 | def validate_question_with_answer(dataset: Dataset) -> Dataset:
53 | def check_generated_answer(answer: str):
54 | problematic_phrases = ["I don't know", "i don't know"]
55 | for phrase in problematic_phrases:
56 | if phrase in answer.lower():
57 | return False
58 | return True
59 | # dataset.filter(lambda x: not check_generated_answer(x["answer"])).to_json(f"{args.output_dir}/filtered_dataset.json")
60 | return dataset.filter(lambda x: check_generated_answer(x["answer"]))
61 |
62 | if __name__ == "__main__":
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument("--corpus_dir", type=str, default="benchmarks/auto/corpus")
65 | parser.add_argument("--output_dir", type=str, default="benchmarks/auto/output")
66 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo")
67 | parser.add_argument("--api_key", type=str, default=None)
68 | args = parser.parse_args()
69 |
70 | os.environ["OPENAI_API_KEY"] = args.api_key
71 | engine = OpenAILLM(args.model, "OPENAI_API_KEY")
72 |
73 | print(f"\nLoad corpus from {args.corpus_dir}")
74 | dataset, cases = load_corpus(args.corpus_dir)
75 |
76 | print("Start generate questions...")
77 | dataset = generate_questions(engine, dataset, cases)
78 |
79 | print("Start generate answers...")
80 | dataset = generate_answers(engine, dataset)
81 |
82 | print("Validate questions...")
83 | dataset = validate_question_with_answer(dataset)
84 |
85 | engine.calculate_api_cost()
86 |
87 | dataset.to_json(f"{args.output_dir}/dataset.json")
88 | print(f"\nFinish generate dataset. Dataset saved as {args.output_dir}/dataset.json")
89 |
--------------------------------------------------------------------------------
/benchmarks/auto/prompt.py:
--------------------------------------------------------------------------------
1 | SYNTHETIC_QUERY_SYSTEM = '''You are an expert question-answering system. You must create a question for the provided document. The question must be answerable within the context of the document.''' # system prompt
2 |
3 | SYNTHETIC_QUERY_FEW_SHOT = '''Document: {document}
4 | Question: {question}
5 |
6 | ''' # few-shot
7 |
8 | SYNTHETIC_QUERY_USER = '''{few_shot_cases}Document: {document}
9 | Question: ''' # user prompt
10 |
11 | SYNTHETIC_ANSWER_SYSTEM='''You are a helpful assistant that are good at helping to answer a query based on the context step by step, the context is a document. If there is a good answer from the context, try to summarize the context as the answer. If the query doesn't form a complete question, or you don't know the answer, or there is no enough information to determine the answer, or the context is irrelevant to the question, just say I DON'T NO.''' # system prompt
12 |
13 | SYNTHETIC_ANSWER_USER = '''Here is the question {question}
14 | Here is the context: {document}''' # user prompt
--------------------------------------------------------------------------------
/benchmarks/auto/run.sh:
--------------------------------------------------------------------------------
1 | rageval_dir=$(dirname $(dirname $(dirname $(realpath $0))))
2 | cd $rageval_dir
3 | echo "Running Auto Benchmark"
4 | python3 $rageval_dir/benchmarks/auto/auto_benchmark.py --corpus_dir "benchmarks/auto/corpus"\
5 | --output_dir "benchmarks/auto/output"\
6 | --model "gpt-3.5-turbo"\
7 | --api_key "YOUR_API_KEY"
8 | echo "Auto Benchmark Complete"
--------------------------------------------------------------------------------
/benchmarks/base.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Dict, Any, Tuple, Optional
2 | from abc import abstractmethod, ABC
3 | from dataclasses import dataclass
4 | # import importlib
5 | from datasets import Dataset, load_dataset
6 | from rageval.metrics import Metric
7 | from .utils import save_json
8 |
9 | class BaseBenchmark(ABC):
10 | """Base class for benchmarks."""
11 |
12 | metrics: List[Metric] = []
13 | dataset: Dataset
14 |
15 | def __init__(self, batch_size: int = 1) -> None:
16 | """Initialization."""
17 | self.batch_size = batch_size
18 |
19 | @property
20 | @abstractmethod
21 | def name(self) -> str:
22 | """The benchmark name."""
23 | ...
24 |
25 | @property
26 | def metric_names(self) -> List[str]:
27 | """The metric names."""
28 | return [m.name for m in self.metrics]
29 |
30 | def load_data(self, **kwargs) -> None:
31 | """Load the dataset with answers to evaluate."""
32 | print("Load dataset...")
33 | self.dataset = load_dataset(**kwargs)
34 | print("Dataset loaded.")
35 |
36 | @abstractmethod
37 | def _evaluate(self) -> Tuple[Dict[Any, Any], Dataset]:
38 | """Evaluate the dataset and return the results and the detailed dataset with each sample scores."""
39 | ...
40 |
41 | def evaluate(self, **kwargs) -> Dict[Any, Any]:
42 | """Load datasets and evaluate it, return a result dict."""
43 | if not hasattr(self, "dataset"):
44 | self.load_data(**kwargs)
45 | print("Start evaluating...")
46 | self.results, self.dataset = self._evaluate()
47 | print("Evaluation finished.")
48 | return self.results
49 |
50 | def set_metric(self, metrics: List[Metric]) -> None:
51 | """Reset the metrics."""
52 | if all(isinstance(m, Metric) for m in metrics):
53 | self.metrics = metrics
54 | else:
55 | raise ValueError("The metrics should be a list of Metric objects.")
56 |
57 | def save_dataset(self, file_path: str) -> None:
58 | """Save the result to files."""
59 | if not hasattr(self, "dataset"):
60 | raise ValueError("Please load the dataset and evaluate it first.")
61 | self.dataset.to_json(file_path, orient="records")
62 | print(f"Dataset saved to {file_path}.")
63 |
64 | def save_results(self, file_path: str) -> None:
65 | """Save the result to files."""
66 | if not hasattr(self, "results"):
67 | raise ValueError("Please run evaluation first.")
68 | save_json(self.results, file_path)
69 | print(f"Results saved to {file_path}.")
70 |
71 | # def get_metric(self, name: str, **kwargs) -> Union[Metric, MetricWithLLM]:
72 | # """Get the metric by name."""
73 | # module = importlib.import_module(f"rageval.metrics")
74 | # return getattr(module, name)(**kwargs)
75 |
--------------------------------------------------------------------------------
/benchmarks/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | def save_json(data, file_path: str):
5 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
6 | with open(file_path, "w+") as f:
7 | json.dump(data, f, indent=4)
8 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 42", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "RagEval"
7 | version = "0.1.0"
8 | description = "Evaluation tools for Retrieval-augmented Generation (RAG) methods."
9 | keywords = ["RAG evaluation tools"]
10 | readme = {file = "README.md", content-type = "text/markdown"}
11 | license = {file = "LICENSE"}
12 | authors = [
13 | { name="Wenshan Wang, Yixing Fan, etc.", email="wangwenshan@ict.ac.cn" },
14 | ]
15 | requires-python = ">=3.10"
16 | classifiers = [
17 | "Development Status :: 3 - Alpha",
18 | "Environment :: Console",
19 | "Operating System :: POSIX :: Linux",
20 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
21 | "License :: OSI Approved :: Apache Software License",
22 | "Programming Language :: Python :: 3"
23 | ]
24 | dependencies = [
25 | "refchecker == 0.2.13",
26 | "numpy >= 1.26",
27 | "tqdm >= 4.66",
28 | "hyperopt >= 0.1.1",
29 | "h5py >= 2.8.0",
30 | "openai >= 1.10.0",
31 | "datasets >= 3.0.1",
32 | "langchain >= 0.3.1",
33 | "langchain-community >= 0.3.1",
34 | "transformers >= 4.37.2",
35 | "torch >= 2.2.0",
36 | "pandas >= 2.0.0",
37 | "nltk >= 3.9.1",
38 | "spacy >= 3.7.4",
39 | "en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
40 | "rouge_score >= 0.1.2",
41 | "jieba >= 0.42.1",
42 | "evaluate >= 0.4.3"
43 | ]
44 |
45 | [project.optional-dependencies]
46 | tests = [
47 | "coverage >= 4.3.4",
48 | "codecov >= 2.0.15",
49 | "pytest >= 3.7.4",
50 | "pytest-cov >= 2.4.0",
51 | "flake8 == 7.0.0",
52 | "pydocstyle == 6.1",
53 | "flake8_docstrings >= 1.7.0"
54 | ]
55 | benchmarks = [
56 | "accelerate == 0.27.2",
57 | "sentencepiece == 0.2.0",
58 | "protobuf == 4.25.3"
59 | ]
60 |
61 | [project.urls]
62 | Homepage = "https://github.com/gomate-community/rageval"
63 | Issues = "https://github.com/gomate-community/rageval/issues"
64 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | markers =
3 | api: mark a test as a web api
4 | slow: mark a test as whole testing
5 | quick: mark a test as quick testing
6 |
--------------------------------------------------------------------------------
/rageval/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .version import version as __version__
3 | except ImportError:
4 | __version__ = "unknown version"
5 |
6 | from . import tasks
7 | from . import metrics
8 | from . import models
9 | from . import utils
10 |
11 | from .evaluations import evaluate
12 |
13 | # __all__ = ["evaluate", "__version__"]
14 |
--------------------------------------------------------------------------------
/rageval/evaluations.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable, List
3 |
4 | from datasets import Dataset, concatenate_datasets
5 |
6 | from rageval.metrics import Metric
7 |
8 |
9 | def evaluate(
10 | testset: Dataset,
11 | metrics: List[Metric] | None = None,
12 | models: List[Callable] | None = None) -> (Dataset, Dataset):
13 | """Conduct the evaluation on testset."""
14 |
15 | # run evaluation
16 | assert (len(metrics) == len(models))
17 | [metrics[i].init_model(models[i]) for i in range(len(metrics))]
18 | avg_scores = []
19 | instance_scores = [testset]
20 | for metric in metrics:
21 | print(f"evaluating with [{metric.name}]")
22 | avg_score, _testset = metric.compute(testset)
23 | avg_scores.append(Dataset.from_dict({metric.name: [avg_score]}))
24 | instance_scores.append(_testset.select_columns(metric.name))
25 |
26 | return concatenate_datasets(avg_scores), concatenate_datasets(instance_scores)
27 |
--------------------------------------------------------------------------------
/rageval/exceptions.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/exceptions.py
--------------------------------------------------------------------------------
/rageval/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Metric, MetricWithLLM, add_attribute
2 |
3 | # Metrics about the answer correctness
4 | from .answer_correctness._answer_accuracy import AnswerAccuracy
5 | from .answer_correctness._answer_bleu import AnswerBleuScore
6 | from .answer_correctness._answer_chrf import AnswerCHRFCorrectness
7 | from .answer_correctness._answer_exact_match import AnswerEMCorrectness
8 | from .answer_correctness._answer_f1 import AnswerF1Correctness
9 | from .answer_correctness._answer_rouge_correctness import AnswerRougeCorrectness
10 | from .answer_correctness._answer_bert_score import AnswerBERTScore
11 | from .answer_correctness._answer_edit_distance import AnswerEditDistance
12 | from .answer_correctness._answer_claim_recall import AnswerNLICorrectness
13 | from .answer_correctness._answer_disambig_f1 import AnswerDisambigF1Correctness
14 | from .answer_correctness._answer_lcs_ratio import AnswerLCSRatio
15 | from .answer_correctness._answer_ter import AnswerTERCorrectness
16 | ##from .answer_correctness._answer_relevancy import AnswerRelevancy
17 |
18 | # Metrics about the answer groundedness
19 | from .answer_groundedness._answer_citation_precision import AnswerCitationPrecision
20 | from .answer_groundedness._answer_citation_recall import AnswerCitationRecall
21 | from .answer_groundedness._context_reject_rate import ContextRejectRate
22 | ##from .answer_groundedness._claim_faithfulness import ClaimFaithfulness
23 |
24 | # Metrics about the answer informativeness
25 | ##from .answer_informative._claim_num import ClaimNum
26 | from .answer_informativeness._text_length import TextLength
27 | ##from .answer_informativeness._repetitiveness import Repetitiveness
28 | ##from .answer_informativeness._pairwise_accuracy import PairwiseAccuracy
29 | from .answer_informativeness._answer_distinct12 import AnswerDistinct
30 |
31 | # Metrics about the context relevancy
32 |
33 | # Metrics about the context aduquacy
34 | from .context_adequacy._context_recall import ContextRecall
35 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_accuracy.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 | import evaluate
4 |
5 | import datasets
6 |
7 | from rageval.metrics import Metric, add_attribute
8 |
9 |
10 | _DESCRIPTION = """\
11 | The AnswerAccuracy is to measure the correctness of answers.
12 |
13 | This metric is applicable in scenarios where the LLM is required to output a unique short answer, such as options for \
14 | multiple-choice questions or a single entity name.
15 | The renowned MMLU dataset utilizes this metric for evaluation. In the evaluation of the MMLU dataset, probabilities \
16 | for each answer are first calculated, and the answer with the highest probability is selected as the predicted result.
17 | In our tool, we assume that the prediction result has already been obtained, and only perform the final score \
18 | calculation.
19 | """
20 |
21 | _KWARGS_DESCRIPTION = """\
22 | Args:
23 | name : str
24 | batch_size : int, Batch size for openai completion.
25 |
26 | Optional Args:
27 | None
28 |
29 | Functions:
30 | _compute_one: Evaluating the correctness of answer.
31 |
32 | Examples:
33 | >>> from datasets import Dataset
34 | >>> import rageval as rl
35 | >>> sample = {
36 | ... "answers": [
37 | ... "A",
38 | ... "B",
39 | ... "C"
40 | ... ],
41 | ... "gt_answers": [
42 | ... "A",
43 | ... "C",
44 | ... "C"
45 | ... ]
46 | ... }
47 | >>> dataset = Dataset.from_dict(sample)
48 | >>> metric = rl.metrics.AnswerAccuracy()
49 | >>> metric.mtype
50 | 'AnswerCorrectness'
51 | >>> score, results = metric.compute(dataset["answers"], dataset["gt_answers"], 1)
52 | >>> score
53 | 0.6666666666666666
54 | >>> results[0]
55 | True
56 | """
57 |
58 | _CITATION = """\
59 | @misc{hendrycks2021measuring,
60 | title={Measuring Massive Multitask Language Understanding},
61 | author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
62 | year={2021},
63 | eprint={2009.03300},
64 | archivePrefix={arXiv},
65 | primaryClass={cs.CY}
66 | }
67 | """
68 |
69 |
70 | @dataclass
71 | @add_attribute('mtype', 'AnswerCorrectness')
72 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
73 | class AnswerAccuracy(Metric):
74 | """Estimates the correctness of answers."""
75 |
76 | name = "answer_accuracy"
77 |
78 | ALIAS = ['answer_accuracy']
79 |
80 | def __init__(self):
81 | """
82 | Explicitly initialize AnswerAccuracy.
83 |
84 | Ensure all parent classes are initialized.
85 | """
86 | super().__init__()
87 | self.info = evaluate.MetricInfo(
88 | description=_DESCRIPTION,
89 | inputs_description=_KWARGS_DESCRIPTION,
90 | citation=_CITATION,
91 | homepage="",
92 | features=datasets.Features(
93 | {
94 | "answers": datasets.Value("string"),
95 | "gt_answers": datasets.Value("string")
96 | }
97 | ),
98 | codebase_urls=["https://github.com/hendrycks/test"],
99 | reference_urls=["https://arxiv.org/abs/2009.03300"]
100 | )
101 |
102 | def __repr__(self) -> str:
103 | """:return: Formatted string representation of the metric."""
104 | return f"{self.ALIAS[0]}"
105 |
106 | def _compute_one(
107 | self,
108 | answer: str,
109 | gt_answer: str
110 | ) -> float:
111 | """Evaluating the correctness of answer."""
112 | return answer == gt_answer
113 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_bert_score.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Tuple
3 | import evaluate
4 |
5 | import datasets
6 | from rageval.metrics import Metric, add_attribute
7 | from bert_score import BERTScorer
8 | import logging
9 | import transformers
10 | transformers.tokenization_utils.logger.setLevel(logging.ERROR)
11 | transformers.configuration_utils.logger.setLevel(logging.ERROR)
12 | transformers.modeling_utils.logger.setLevel(logging.ERROR)
13 |
14 | _DESCRIPTION = """\
15 | BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks.
16 |
17 | For details, see the paper: https://openreview.net/forum?id=SkeHuCVFDr
18 | """
19 |
20 | _KWARGS_DESCRIPTION = """\
21 | Args:
22 | name : str
23 | lang : str, Language of the text. Default is "en".
24 | rescale_with_baseline : bool, Whether to rescale the score with pre-computed baseline. Not affect BERTScore's correlation with human judgment. Default is False. For more details, see https://github.com/Tiiiger/bert_score/blob/master/journal/rescale_baseline.md
25 |
26 | Optional Args:
27 | None
28 |
29 | Functions:
30 | _clean: clean special word in sentence.
31 | _compute_one: compute bleu score for single prediction with its references
32 | _compute_batch: compute bleu score for a batch of predictions with their references
33 |
34 | Examples:
35 | >>> from datasets import Dataset
36 | >>> import rageval as rl
37 | >>> sample = {
38 | ... "answers": [
39 | ... "It is a guide to action which ensures that the military always obeys the commands of the party.",
40 | ... "It is to insure the troops forever hearing the activity guidebook that party direct.",
41 | ... ],
42 | ... "gt_answers": [
43 | ... [
44 | ... "It is a guide to action that ensures that the military will forever heed Party commands.",
45 | ... "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
46 | ... "It is the practical guide for the army always to heed the directions of the party.",
47 | ... ],
48 | ... [
49 | ... "It is a guide to action that ensures that the military will forever heed Party commands.",
50 | ... "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
51 | ... "It is the practical guide for the army always to heed the directions of the party.",
52 | ... ]
53 | ... ],
54 | ... }
55 | >>> dataset = Dataset.from_dict(sample)
56 | >>> metric = rl.metrics.AnswerBERTScore(lang='en', rescale_with_baseline=True)
57 | >>> metric.mtype
58 | 'AnswerCorrectness'
59 | >>> score, results = metric.compute(dataset["answers"], dataset["gt_answers"], 1)
60 | >>> round(score, 2)
61 | 0.55
62 | >>> round(results[0], 1)
63 | 0.7
64 | """
65 |
66 |
67 | _CITATION = """\
68 | @inproceedings{bert-score,
69 | title={BERTScore: Evaluating Text Generation with BERT},
70 | author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi},
71 | booktitle={International Conference on Learning Representations},
72 | year={2020},
73 | url={https://openreview.net/forum?id=SkeHuCVFDr}
74 | }
75 | """
76 |
77 |
78 | @dataclass
79 | @add_attribute('mtype', 'AnswerCorrectness')
80 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
81 | class AnswerBERTScore(Metric):
82 | """BERTScore depends on the model and language pair selected."""
83 |
84 | name = "answer_bert_score"
85 |
86 | ALIAS = ['answer_bert_score']
87 |
88 | def __init__(self, lang: str = "en", rescale_with_baseline=False):
89 | """Explicitly initialize the AnswerBERTScore to ensure all parent class initialized."""
90 | super().__init__()
91 | self.scorer = BERTScorer(lang=lang, rescale_with_baseline=rescale_with_baseline)
92 | self.info = evaluate.MetricInfo(
93 | description=_DESCRIPTION,
94 | inputs_description=_KWARGS_DESCRIPTION,
95 | citation=_CITATION,
96 | homepage="",
97 | features=datasets.Features(
98 | {
99 | "answers": datasets.Value("string"),
100 | "gt_answers": datasets.Sequence(datasets.Value("string"))
101 | }
102 | ),
103 | codebase_urls=[
104 | "https://github.com/Tiiiger/bert_score/tree/master",
105 | ],
106 | reference_urls=["https://openreview.net/forum?id=SkeHuCVFDr"]
107 | )
108 |
109 | def __repr__(self) -> str:
110 | """:return: Formatted string representation of the metric."""
111 | return f"{self.ALIAS[0]}"
112 |
113 | def _compute_one(
114 | self,
115 | pred_answers: str,
116 | ref_answers: List[str]
117 | ) -> float:
118 | """Compute the BERTscore for a pair of predictions and references."""
119 | P, R, F1 = self.scorer.score([pred_answers] * len(ref_answers), ref_answers)
120 | return F1.max().tolist()
121 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_bleu.py:
--------------------------------------------------------------------------------
1 | import re
2 | from dataclasses import dataclass
3 | from typing import List, Tuple
4 | import evaluate
5 | import datasets
6 | from rageval.metrics import Metric, add_attribute
7 | from tqdm import tqdm
8 |
9 |
10 | _DESCRIPTION = """\
11 | BLEU (Bilingual Evaluation Understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
12 | Scores are calculated by comparing individual translated segments, e.g., sentences, with a set of high-quality reference translations.
13 | Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality.
14 | Neither intelligibility nor grammatical correctness are not taken into account.
15 |
16 | For details, see the paper: http://www.aclweb.org/anthology/P02-1040.pdf
17 | """
18 |
19 | _KWARGS_DESCRIPTION = """\
20 | Args:
21 | name : str
22 | batch_size : int, Batch size for openai completion.
23 |
24 | Optional Args:
25 | None
26 |
27 | Functions:
28 | _clean: clean special word in sentence.
29 | _compute_one: compute bleu score for single prediction with its references
30 |
31 | Examples:
32 | >>> from datasets import Dataset
33 | >>> import rageval as rl
34 | >>> sample = {
35 | ... "answers": [
36 | ... "It is a guide to action which ensures that the military always obeys the commands of the party.",
37 | ... "It is to insure the troops forever hearing the activity guidebook that party direct.",
38 | ... ],
39 | ... "gt_answers": [
40 | ... [
41 | ... "It is a guide to action that ensures that the military will forever heed Party commands.",
42 | ... "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
43 | ... "It is the practical guide for the army always to heed the directions of the party.",
44 | ... ],
45 | ... [
46 | ... "It is a guide to action that ensures that the military will forever heed Party commands.",
47 | ... "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
48 | ... "It is the practical guide for the army always to heed the directions of the party.",
49 | ... ]
50 | ... ],
51 | ... }
52 | >>> dataset = Dataset.from_dict(sample)
53 | >>> metric = rl.metrics.AnswerBleuScore()
54 | >>> metric.mtype
55 | 'AnswerCorrectness'
56 | >>> score, results = metric.compute(dataset["answers"], dataset["gt_answers"], 1)
57 | >>> score
58 | 0.3450835085970013
59 | >>> results[0]
60 | 0.5401725898595141
61 | """
62 |
63 |
64 | _CITATION = """\
65 | @misc{Kishore2002bleu,
66 | title={Bleu: a method for automatic evaluation of machine translation},
67 | author={Kishore Papineni and Salim Roukos and Todd Ward and Wei-Jing Zhu},
68 | year={2002},
69 | page={311-318},
70 | primaryClass={cs.CL}
71 | }
72 | """
73 |
74 |
75 | @dataclass
76 | @add_attribute('mtype', 'AnswerCorrectness')
77 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
78 | class AnswerBleuScore(Metric):
79 | """Bleu score computing with good quality reference."""
80 |
81 | """Note: this metric is just fit for English data by now(24/03/12)"""
82 |
83 | name = "answer_bleu"
84 |
85 | ALIAS = ['answer_bleu']
86 |
87 | def __init__(self):
88 | """Explicitly initialize the AnswerBleuScore to ensure all parent class initialized."""
89 | super().__init__()
90 | self.info = evaluate.MetricInfo(
91 | description=_DESCRIPTION,
92 | inputs_description=_KWARGS_DESCRIPTION,
93 | citation=_CITATION,
94 | homepage="",
95 | features=datasets.Features(
96 | {
97 | "answers": datasets.Value("string"),
98 | "gt_answers": datasets.Sequence(datasets.Value("string"))
99 | }
100 | ),
101 | codebase_urls=[
102 | "https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py",
103 | "https://github.com/huggingface/datasets/blob/main/metrics/bleu/bleu.py"
104 | ],
105 | reference_urls=["https://www.aclweb.org/anthology/P02-1040.pdf"]
106 | )
107 |
108 | def __repr__(self) -> str:
109 | """:return: Formatted string representation of the metric."""
110 | return f"{self.ALIAS[0]}" # pragma: no cover
111 |
112 | def compute(
113 | self,
114 | pred_answers: List[str],
115 | ref_answers: List[List[str]],
116 | batch_size: int,
117 | ) -> Tuple[float, List[float]]:
118 | """Compute the bleu score on both corpus level and instance level."""
119 | bleu = evaluate.load("bleu")
120 | # corpus level
121 | bleu_result = bleu.compute(predictions=pred_answers, references=ref_answers)
122 | score = bleu_result['bleu']
123 | # instance level
124 | scores = []
125 | for pred_answer, ref_answer in tqdm(zip(pred_answers, ref_answers),
126 | desc=f"Computing {self.name}",
127 | total=len(pred_answers)):
128 | scores.append(self._compute_one(pred_answer, ref_answer))
129 | return score, scores
130 |
131 | def _compute_one(
132 | self,
133 | pred_answers: List[str],
134 | ref_answers: List[List[str]]
135 | ) -> List[float]:
136 | """Compute the bleu score on an instance level."""
137 |
138 | bleu = evaluate.load("bleu")
139 | bleu_result = bleu.compute(predictions=[pred_answers], references=[ref_answers])
140 | bleu_score = bleu_result['bleu']
141 | return bleu_score
142 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_claim_recall.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Callable, Tuple
3 | import evaluate
4 |
5 | import datasets
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from rageval.metrics import Metric, add_attribute
10 | from rageval.utils.check_utils import text_to_sents
11 |
12 | _DESCRIPTION = """\
13 | The AnswerNLICorrectness is to measure the correctness of long-form answers. In the original paper, the author first \
14 | use Instruct-GPT(text-davinci-003) to generate three "sub-claims" (based on gold answers) and use a state-of-the-art \
15 | natural-language inference (NLI) model TRUE(Honovich et al., 2022) to check whether the model output entails the \
16 | sub-claims (claim recall).
17 |
18 | For details, see the paper: https://arxiv.org/abs/2305.14627.
19 | """
20 |
21 | _KWARGS_DESCRIPTION = """\
22 | Args:
23 | name : str
24 | batch_size : int, Batch size for openai completion.
25 |
26 | Optional Args:
27 | None
28 |
29 | Functions:
30 | _verify_by_stance: verify whether the stance of args:`claims` can be supported by args:`answer`.
31 | _compute_one: compute the score by measure whether the args:`claims` can be supported by args:`answers`.
32 |
33 | Examples:
34 | >>> from datasets import Dataset
35 | >>> import rageval as rl
36 | >>> sample = {
37 | ... "answers": [
38 | ... "They went a while before introducing ads, so they could make money, as they needed to establish "
39 | ... "their brand and amass users. Once you have dedicated users, introducing ads won't deter most, but if "
40 | ... "you are still new, having ads will deter a lot. The same goes for Uber, it's not that they aren't "
41 | ... "making money, it's that they are reinvesting a ton of it to make their service better."
42 | ... ],
43 | ... "gt_answers": [
44 | ... [
45 | ... "Firms like Snapchat and Uber need to establish their brand and amass users before introducing "
46 | ... "ads.",
47 | ... "Introducing ads too early can deter potential users.",
48 | ... "Uber is reinvesting a lot of money to make their service better."
49 | ... ]
50 | ... ]
51 | ... }
52 | >>> dataset = Dataset.from_dict(sample)
53 | >>> nli_model = rl.models.NLIModel(
54 | ... 'text2text-generation',
55 | ... 'hf-internal-testing/tiny-random-T5ForConditionalGeneration'
56 | ... )
57 | >>> metric = rl.metrics.AnswerNLICorrectness(nli_model=nli_model, decompose_model="nltk")
58 | >>> metric.mtype
59 | 'AnswerCorrectness'
60 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1)
61 | >>> assert score == 0 or score == 1
62 | """
63 |
64 | _CITATION = """\
65 | @misc{gao2023enabling,
66 | title={Enabling Large Language Models to Generate Text with Citations},
67 | author={Tianyu Gao and Howard Yen and Jiatong Yu and Danqi Chen},
68 | year={2023},
69 | eprint={2305.14627},
70 | archivePrefix={arXiv},
71 | primaryClass={cs.CL}
72 | }
73 | """
74 |
75 |
76 | @dataclass
77 | @add_attribute('mtype', 'AnswerCorrectness')
78 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
79 | class AnswerNLICorrectness(Metric):
80 | """Estimates the correctness of long-form answers based on the NLI model."""
81 |
82 | name = "answer_claim_recall"
83 |
84 | ALIAS = ['answer_claim_recall']
85 |
86 | def __init__(self, nli_model: Callable, decompose_model: str = "gpt-3.5-turbo"):
87 | """
88 | Explicitly initialize AnswerNLICorrectness.
89 |
90 | Ensure all parent classes are initialized.
91 | Ensure nli_model and decompose_model is initialized.
92 | """
93 | super().__init__()
94 | self.nli_model = nli_model
95 | self.decompose_model = decompose_model
96 | self.info = evaluate.MetricInfo(
97 | description=_DESCRIPTION,
98 | inputs_description=_KWARGS_DESCRIPTION,
99 | citation=_CITATION,
100 | homepage="",
101 | features=datasets.Features(
102 | {
103 | "answers": datasets.Value("string"),
104 | "gt_answers": datasets.Value("string")
105 | }
106 | ),
107 | codebase_urls=["https://github.com/princeton-nlp/ALCE"],
108 | reference_urls=["https://arxiv.org/abs/2305.14627"]
109 | )
110 |
111 | def __repr__(self) -> str:
112 | """:return: Formatted string representation of the metric."""
113 | return f"{self.ALIAS[0]}" # pragma: no cover
114 |
115 | def _compute_one(
116 | self,
117 | answer: str,
118 | claims: List[str]
119 | ) -> float:
120 | """
121 | Evaluate the correctness of an answer.
122 |
123 | Firstly, split the gt_answer into a set of claims.
124 | Then, compute the faithfulness score of each claim. The faithfulness is a binary score.
125 | Finally, aggregate all faithfulness score of each claim.
126 | """
127 |
128 | detail_results = []
129 | scores = []
130 |
131 | for i, claim in enumerate(claims):
132 | # obtain the faithfulness of each claim by language inference model.
133 | label = self.nli_model.generate_infer(premise=answer, hypothesis=claim)
134 | detail_results.append({
135 | "answer": answer,
136 | "claim": claim,
137 | "reasoning": "",
138 | "error": "",
139 | "factuality": label,
140 | })
141 | scores.append(label)
142 | # Note that the detail_results can be recorded by logger.info
143 | return np.average(scores)
144 |
145 | def _compute_batch(
146 | self,
147 | pred_answers: List[str],
148 | ref_answers: List[List[str]]
149 | ) -> List[float]:
150 | """
151 | Evaluate the correctness of a batch of answers.
152 |
153 | Firstly, split the gt_answer into a set of claims.
154 | Then, compute the faithfulness score of each claim. The faithfulness is a binary score.
155 | Finally, aggregate all faithfulness score of each claim.
156 | """
157 |
158 | if isinstance(ref_answers, list):
159 | if isinstance(ref_answers[0], list):
160 | # gt_answers has been decomposed into claims list
161 | claims = ref_answers
162 | elif isinstance(ref_answers[0], str):
163 | # use decompose_model to decompose the gt_answers into claims list
164 | claims = [text_to_sents(gt_answer, self.decompose_model) for gt_answer in ref_answers]
165 | else:
166 | raise ValueError("The type of gt_answers element should be list or string.") # pragma: no cover
167 | else:
168 | raise ValueError("The type of gt_answers should be list.") # pragma: no cover
169 |
170 | results = []
171 | for i, answer in tqdm(enumerate(pred_answers)):
172 | r = self._compute_one(answer, claims[i])
173 | results.append(r)
174 | return results
175 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_disambig_f1.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 | from collections import Counter
4 | from dataclasses import dataclass
5 | from typing import List
6 | import evaluate
7 |
8 | import datasets
9 | import numpy as np
10 | import spacy
11 |
12 | from rageval.metrics import Metric, add_attribute
13 |
14 | _DESCRIPTION = """\
15 | The Disambig-F1 is a variant of the F1 score, estimates the similarity between the disambiguation of the answer and the ground truth answer.
16 |
17 | The original metric was presented in [ASQA paper](https://aclanthology.org/2022.emnlp-main.566/), and implemented through [this code](https://github.com/google-research/language/blob/master/language/asqa/scoring.py#L273). And we adopted an [alternative implementation](https://github.com/jzbjyb/FLARE/tree/main/src/datasets.py#L29) from the paper [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983).
18 | """
19 |
20 | _KWARGS_DESCRIPTION = """\
21 | Args:
22 | name : str
23 | model : str, model name of spacy model to ner.
24 |
25 | Optional Args:
26 | None
27 |
28 | Functions:
29 | _normalize_text: normalize the text by removing articles, white spaces, punctuations and lowercasing.
30 | _ner: extract named entities from the text.
31 | _validate_data: validate the dataset format.
32 | _f1_score: compute the f1 score between `pred` string and `ref` string.
33 | _compute_one: evaluate the disambig f1 score of between `answer` and `gt_answers`, return the highest score in all pairs.
34 |
35 | Examples:
36 | >>> from datasets import Dataset
37 | >>> import rageval as rl
38 | >>> sample = {
39 | ... "answers": [
40 | ... "Democrat rick kriseman won the 2016 mayoral election, while re- publican former mayor rick baker did so in the 2017 mayoral election."
41 | ... ],
42 | ... "gt_answers": [
43 | ... [
44 | ... "Kriseman",
45 | ... "Rick Kriseman"
46 | ... ]
47 | ... ]
48 | ... }
49 | >>> dataset = Dataset.from_dict(sample)
50 | >>> metric = rl.metrics.AnswerDisambigF1Correctness(model="en_core_web_sm")
51 | >>> metric.mtype
52 | 'AnswerCorrectness'
53 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1)
54 | >>> assert 0 <= score <= 1
55 | """
56 |
57 | _CITATION = """\
58 | @inproceedings{stelmakh-etal-2022-asqa,
59 | title = "{ASQA}: Factoid Questions Meet Long-Form Answers",
60 | author = "Stelmakh, Ivan and
61 | Luan, Yi and
62 | Dhingra, Bhuwan and
63 | Chang, Ming-Wei",
64 | editor = "Goldberg, Yoav and
65 | Kozareva, Zornitsa and
66 | Zhang, Yue",
67 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing",
68 | month = dec,
69 | year = "2022",
70 | address = "Abu Dhabi, United Arab Emirates",
71 | publisher = "Association for Computational Linguistics",
72 | url = "https://aclanthology.org/2022.emnlp-main.566",
73 | doi = "10.18653/v1/2022.emnlp-main.566",
74 | pages = "8273--8288",
75 | }
76 | @misc{jiang2023active,
77 | title={Active Retrieval Augmented Generation},
78 | author={Zhengbao Jiang and Frank F. Xu and Luyu Gao and Zhiqing Sun and Qian Liu and Jane Dwivedi-Yu and Yiming Yang and Jamie Callan and Graham Neubig},
79 | year={2023},
80 | eprint={2305.06983},
81 | archivePrefix={arXiv},
82 | primaryClass={cs.CL}
83 | }
84 | """
85 |
86 |
87 | @dataclass
88 | @add_attribute('mtype', 'AnswerCorrectness')
89 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
90 | class AnswerDisambigF1Correctness(Metric):
91 | """Estimates the Disambig-F1 between answers and ground truth answers."""
92 |
93 | name = "answer_disambig_f1"
94 |
95 | ALIAS = ['answer_disambig_f1']
96 |
97 | def __init__(self, model: str = "en_core_web_sm"):
98 | """
99 | Explicitly initialize AnswerDisambigF1Correctness.
100 |
101 | Ensure all parent classes are initialized.
102 | Ensure spacy ner model is initialized.
103 | """
104 | super().__init__()
105 | self.model = model
106 | self.nlp = spacy.load(model)
107 | self.info = evaluate.MetricInfo(
108 | description=_DESCRIPTION,
109 | inputs_description=_KWARGS_DESCRIPTION,
110 | citation=_CITATION,
111 | homepage="",
112 | features=datasets.Features(
113 | {
114 | "answers": datasets.Value("string"),
115 | "gt_answers": datasets.Sequence(datasets.Value("string"))
116 | }
117 | ),
118 | codebase_urls=[
119 | "https://github.com/google-research/language/blob/master/language/asqa",
120 | "https://github.com/jzbjyb/FLARE"
121 | ],
122 | reference_urls=[
123 | "https://aclanthology.org/2022.emnlp-main.566",
124 | "https://arxiv.org/abs/2305.06983"
125 | ]
126 | )
127 |
128 | def __repr__(self) -> str:
129 | """:return: Formatted string representation of the metric."""
130 | return f"{self.ALIAS[0]}" # pragma: no cover
131 |
132 | def _normalize_text(self, s: str) -> str:
133 | def remove_articles(text):
134 | return re.sub(r'\b(a|an|the)\b', ' ', text)
135 |
136 | def white_space_fix(text):
137 | return ' '.join(text.split())
138 |
139 | def remove_punc(text):
140 | exclude = set(string.punctuation)
141 | return ''.join(ch for ch in text if ch not in exclude)
142 |
143 | def lower(text):
144 | return text.lower()
145 | return white_space_fix(remove_articles(remove_punc(lower(s))))
146 |
147 | def _ner(self, s: str) -> List[str]:
148 | """Extract named entities from the text."""
149 | doc = self.nlp(s)
150 | ents = doc.ents
151 | return [self._normalize_text(e.text) for e in ents]
152 |
153 | def _f1_score(self, pred: str, ref: str) -> float:
154 | """Compute the f1 score between pred and ref."""
155 | pred_ents = self._ner(pred)
156 | ref_ents = self._ner(ref)
157 |
158 | pred_counter = Counter(pred_ents)
159 | ref_counter = Counter(ref_ents)
160 |
161 | tp = sum((pred_counter & ref_counter).values())
162 | fp = sum((pred_counter - ref_counter).values())
163 | fn = sum((ref_counter - pred_counter).values())
164 |
165 | precision = (tp / (tp + fp)) if (tp + fp) > 0 else 1
166 | recall = (tp / (tp + fn)) if (tp + fn) > 0 else 1
167 |
168 | if precision + recall == 0:
169 | return 0
170 | return 2 * (precision * recall) / (precision + recall)
171 |
172 | def _compute_one(
173 | self,
174 | pred_answer: str,
175 | ref_answers: List[str]
176 | ) -> float:
177 | """Evaluate the disambig f1 score of an answer."""
178 | return np.max([self._f1_score(pred_answer, ref_answer) for ref_answer in ref_answers])
179 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_edit_distance.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 | import datasets
4 |
5 | from rageval.metrics import Metric, add_attribute
6 |
7 | _DESCRIPTION = """\
8 | The AnswerEditDistance is to measure the similarity between answer and gt_answer by calculating the edit distance.
9 |
10 | This is a very traditional method, but to this day, some work is still being carried out using it, such as \
11 | https://ieeexplore.ieee.org/abstract/document/10172590.
12 | """
13 |
14 | _KWARGS_DESCRIPTION = """\
15 | Args:
16 | name : str
17 | batch_size : int, Batch size for openai completion.
18 |
19 | Optional Args:
20 | None
21 |
22 | Functions:
23 | _compute_one: evaluating the similarity between answer and gt_answer by calculating the edit distance.
24 |
25 | Examples:
26 | >>> from datasets import Dataset
27 | >>> import rageval as rl
28 | >>> sample = {
29 | ... "answers": [
30 | ... "Language models trained on massive code corpora can generalize to tasks without the need "
31 | ... "for task-specific fine tuning."
32 | ... ],
33 | ... "gt_answers": [
34 | ... "Large language models trained on massive code corpora can generalize to new tasks without the need "
35 | ... "for task-specific fine-tuning."
36 | ... ]
37 | ... }
38 | >>> dataset = Dataset.from_dict(sample)
39 | >>> metric = rl.metrics.AnswerEditDistance()
40 | >>> metric.mtype
41 | 'AnswerCorrectness'
42 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1)
43 | >>> assert score == 5 / 18
44 | """
45 |
46 | _CITATION = """\
47 | @INPROCEEDINGS{10172590,
48 | author={Nashid, Noor and Sintaha, Mifta and Mesbah, Ali},
49 | booktitle={2023 IEEE/ACM 45th International Conference on Software Engineering (ICSE)},
50 | title={Retrieval-Based Prompt Selection for Code-Related Few-Shot Learning},
51 | year={2023},
52 | volume={},
53 | number={},
54 | pages={2450-2462},
55 | doi={10.1109/ICSE48619.2023.00205}
56 | }
57 | """
58 |
59 |
60 | @dataclass
61 | @add_attribute('mtype', 'AnswerCorrectness')
62 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63 | class AnswerEditDistance(Metric):
64 | """Estimates the similarity between answers and gt_answers."""
65 |
66 | name = "answer_edit_distance"
67 |
68 | ALIAS = ['answer_edit_distance']
69 |
70 | def __init__(self):
71 | """
72 | Explicitly initialize AnswerEditDistance.
73 |
74 | Ensure all parent classes are initialized.
75 | """
76 | super().__init__()
77 |
78 | def __repr__(self) -> str:
79 | """:return: Formatted string representation of the metric."""
80 | return f"{self.ALIAS[0]}"
81 |
82 | def _info(self):
83 | return datasets.MetricInfo(
84 | description=_DESCRIPTION,
85 | inputs_description=_KWARGS_DESCRIPTION,
86 | citation=_CITATION,
87 | homepage="",
88 | features=datasets.Features(
89 | {
90 | "answers": datasets.Value("string"),
91 | "gt_answers": datasets.Value("string")
92 | }
93 | ),
94 | codebase_urls=[],
95 | reference_urls=["https://ieeexplore.ieee.org/abstract/document/10172590"]
96 | )
97 |
98 | def _compute_one(
99 | self,
100 | pred_answer: str,
101 | ref_answer: str
102 | ) -> float:
103 | """Evaluating the similarity between answer and gt_answer by calculating the edit distance."""
104 | pred_answer = pred_answer.split()
105 | ref_answer = ref_answer.split()
106 | m, n = len(pred_answer), len(ref_answer)
107 |
108 | if m == 0 or n == 0:
109 | return 0
110 |
111 | dp = [[0] * (n + 1) for _ in range(m + 1)]
112 | for i in range(m + 1):
113 | dp[i][0] = i
114 | for j in range(n + 1):
115 | dp[0][j] = j
116 |
117 | for i in range(1, m + 1):
118 | for j in range(1, n + 1):
119 | dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1)
120 | if pred_answer[i - 1] != ref_answer[j - 1]:
121 | dp[i][j] = min(dp[i][j], dp[i - 1][j - 1] + 1)
122 | else:
123 | dp[i][j] = min(dp[i][j], dp[i - 1][j - 1])
124 |
125 | return dp[m][n] / m
126 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_exact_match.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 | import datasets
5 | import numpy as np
6 |
7 | from rageval.metrics import Metric, add_attribute
8 |
9 |
10 | _DESCRIPTION = """\
11 | AnswerEMCorrectness evaluates answer correctness based on exact matching of annotated short answers.
12 |
13 | For details, see the paper: https://arxiv.org/abs/2204.06092.
14 | """
15 |
16 | _KWARGS_DESCRIPTION = """\
17 | Args:
18 | name : str
19 | batch_size : int, Batch size for openai completion.
20 | ignore_case : bool, whether to ignore case when comparing the answer and ground truth answers.
21 |
22 | Optional Args:
23 | None
24 |
25 | Functions:
26 | _compute_one: compute the score by measure whether the args:`answer` contains short answer in list:`gt_answers`.
27 |
28 | Examples:
29 | >>> from datasets import Dataset
30 | >>> import rageval as rl
31 | >>> sample = {
32 | ... "answers": [
33 | ... "Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has "
34 | ... "the highest goals all-time in men's football and Christine Sinclair has the highest goals in women's "
35 | ... "world international football.",
36 | ... "A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, "
37 | ... "was the oldest person ever from the United States and the second-oldest fully documented person ever. "
38 | ... "Jeanne Calment was a French supercentenarian and the oldest human whose age is well-documented, with "
39 | ... "a lifespan of 122 years and 164 days, and was the oldest person in the world as of 1997. In 1985, "
40 | ... "the oldest living person was Mathew Beard and in 1986 it was Augusta Holtz, who lived 115 years and "
41 | ... "79 days, from 1871 to 1986."
42 | ... ],
43 | ... "gt_answers": [
44 | ... [
45 | ... ["Daei", "Ali Daei"],
46 | ... ["Bican", "Josef Bican"],
47 | ... ["Sinclair","Christine Sinclair"]
48 | ... ],
49 | ... [
50 | ... ["Jeanne Calment"],
51 | ... ["Sarah Knauss"],
52 | ... ["Augusta-Holtz"],
53 | ... ]
54 | ... ],
55 | ... }
56 | >>> dataset = Dataset.from_dict(sample)
57 | >>> metric = rl.metrics.AnswerEMCorrectness()
58 | >>> metric.mtype
59 | 'AnswerCorrectness'
60 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1)
61 | >>> assert 0 <= score <= 1
62 | """
63 |
64 | _CITATION = """\
65 | @misc{stelmakh2023asqa,
66 | title={ASQA: Factoid Questions Meet Long-Form Answers},
67 | author={Ivan Stelmakh and Yi Luan and Bhuwan Dhingra and Ming-Wei Chang},
68 | year={2023},
69 | eprint={2204.06092},
70 | archivePrefix={arXiv},
71 | primaryClass={cs.CL}
72 | }
73 | """
74 |
75 |
76 | @dataclass
77 | @add_attribute('mtype', 'AnswerCorrectness')
78 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
79 | class AnswerEMCorrectness(Metric):
80 | """Estimates correctness using annotated short answers."""
81 |
82 | name = "answer_exact_match"
83 |
84 | ALIAS = ['answer_exact_match']
85 |
86 | def __init__(self, ignore_case: bool = False):
87 | """Explicitly initialize the AnswerEMCorrectness to ensure all parent class initialized."""
88 | super().__init__()
89 | self.ignore_case = ignore_case
90 |
91 | def __repr__(self) -> str:
92 | """:return: Formatted string representation of the metric."""
93 | return f"{self.ALIAS[0]}"
94 |
95 | def _info(self):
96 | return datasets.MetricInfo(
97 | description=_DESCRIPTION,
98 | inputs_description=_KWARGS_DESCRIPTION,
99 | citation=_CITATION,
100 | homepage="",
101 | features=datasets.Features(
102 | {
103 | "answers": datasets.Value("string"),
104 | "gt_answers": datasets.Sequence(datasets.Value("string"))
105 | }
106 | ),
107 | codebase_urls=[],
108 | reference_urls=["https://arxiv.org/abs/2204.06092"]
109 | )
110 |
111 | def _compute_one(self, pred_answer: str, short_answers: List[List[str]]) -> float:
112 | """Compute the correctness of a single answer."""
113 | acc = []
114 | if self.ignore_case:
115 | pred_answer = pred_answer.lower()
116 | short_answers = [[a.lower() for a in candidate_short_answers] for candidate_short_answers in short_answers]
117 | for candidate_short_answers in short_answers:
118 | for candidate_short_answer in candidate_short_answers:
119 | if candidate_short_answer in pred_answer:
120 | acc.append(True)
121 | break
122 | else:
123 | acc.append(False)
124 | return np.average(acc)
125 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_lcs_ratio.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 | import datasets
4 |
5 | from rageval.metrics import Metric, add_attribute
6 |
7 | _DESCRIPTION = """\
8 | The AnswerLCSRatio is to measure the similarity between answer and gt_answer by calculating the longest common subsequence.
9 |
10 | This is a very traditional method, but to this day, some work is still being carried out using it, such as \
11 | https://ieeexplore.ieee.org/abstract/document/10172590.
12 | """
13 |
14 | _KWARGS_DESCRIPTION = """\
15 | Args:
16 | name : str
17 | batch_size : int, Batch size for openai completion.
18 |
19 | Optional Args:
20 | None
21 |
22 | Functions:
23 | _compute_one: evaluating the similarity between answer and gt_answer by calculating the longest common subsequence.
24 |
25 | Examples:
26 | >>> from datasets import Dataset
27 | >>> import rageval as rl
28 | >>> sample = {
29 | ... "answers": [
30 | ... "Language models trained on massive code corpora can generalize to tasks without the need "
31 | ... "for task-specific fine-tuning."
32 | ... ],
33 | ... "gt_answers": [
34 | ... "Large language models trained on massive code corpora can generalize to new tasks without the need "
35 | ... "for task-specific fine-tuning."
36 | ... ]
37 | ... }
38 | >>> dataset = Dataset.from_dict(sample)
39 | >>> metric = rl.metrics.AnswerLCSRatio()
40 | >>> metric.mtype
41 | 'AnswerCorrectness'
42 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1)
43 | >>> assert score == 16 / 17
44 | """
45 |
46 | _CITATION = """\
47 | @INPROCEEDINGS{10172590,
48 | author={Nashid, Noor and Sintaha, Mifta and Mesbah, Ali},
49 | booktitle={2023 IEEE/ACM 45th International Conference on Software Engineering (ICSE)},
50 | title={Retrieval-Based Prompt Selection for Code-Related Few-Shot Learning},
51 | year={2023},
52 | volume={},
53 | number={},
54 | pages={2450-2462},
55 | doi={10.1109/ICSE48619.2023.00205}
56 | }
57 | """
58 |
59 |
60 | @dataclass
61 | @add_attribute('mtype', 'AnswerCorrectness')
62 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63 | class AnswerLCSRatio(Metric):
64 | """Estimates the similarity between answers and gt_answers."""
65 |
66 | name = "answer_lcs_ratio"
67 |
68 | ALIAS = ['answer_lcs_ratio']
69 |
70 | def __init__(self):
71 | """
72 | Explicitly initialize AnswerLCSRatio.
73 |
74 | Ensure all parent classes are initialized.
75 | """
76 | super().__init__()
77 |
78 | def __repr__(self) -> str:
79 | """:return: Formatted string representation of the metric."""
80 | return f"{self.ALIAS[0]}"
81 |
82 | def _info(self):
83 | return datasets.MetricInfo(
84 | description=_DESCRIPTION,
85 | inputs_description=_KWARGS_DESCRIPTION,
86 | citation=_CITATION,
87 | homepage="",
88 | features=datasets.Features(
89 | {
90 | "answers": datasets.Value("string"),
91 | "gt_answers": datasets.Value("string")
92 | }
93 | ),
94 | codebase_urls=[],
95 | reference_urls=["https://ieeexplore.ieee.org/abstract/document/10172590"]
96 | )
97 |
98 | def _compute_one(
99 | self,
100 | pred_answer: str,
101 | ref_answer: str
102 | ) -> float:
103 | """Evaluating the similarity between answer and gt_answer by calculating the longest common subsequence."""
104 | pred_answer = pred_answer.split()
105 | ref_answer = ref_answer.split()
106 | m, n = len(pred_answer), len(ref_answer)
107 |
108 | if m == 0 or n == 0:
109 | return 0
110 |
111 | dp = [0] * (n + 1)
112 | for i in range(m):
113 | pre = 0
114 | for j in range(n):
115 | tmp = dp[j + 1]
116 | dp[j + 1] = pre + 1 if pred_answer[i] == ref_answer[j] else max(dp[j + 1], dp[j])
117 | pre = tmp
118 |
119 | return dp[-1] / m
120 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_relevancy.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_correctness/_answer_relevancy.py
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_rouge_correctness.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Callable, Optional
5 |
6 | import datasets
7 | from datasets import Dataset
8 | from rouge_score import rouge_scorer
9 |
10 | from rageval.metrics import Metric, add_attribute
11 |
12 | _DESCRIPTION = """Estimates ROUGE score by estimating answer and groundtruth answers.
13 |
14 | ROUGE is case insensitive, so the input text is converted to lower case before computing the score. This metrics is a wrapper around the https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py
15 |
16 | """
17 |
18 | _KWARGS_DESCRIPTION = """\
19 | Args:
20 | name : str
21 | rouge_type : str, the rouge type to calculate. Defaults to 'rouge1', 'rouge2', 'rougeL', 'rougeLsum'
22 | "rouge1": unigram (1-gram) based scoring
23 | "rouge2": bigram (2-gram) based scoring
24 | "rougeL": Longest common subsequence based scoring.
25 | "rougeLSum": splits text using "\n".
26 |
27 | Optional Args:
28 | tokenizer : Callable, a tokenizer can be passed to the scorer, replacing the default tokenizer which tokenizes on whitespace, especially for non-latin languages. For example, the `jieba.cut` can be used for Chinese.
29 |
30 | Functions:
31 | _compute_one: compute the score by measure whether the args:`answer` contains short answer in list:`gt_answers`.
32 |
33 | Examples:
34 | >>> from datasets import Dataset
35 | >>> import rageval as rl
36 | >>> sample = {
37 | ... "answers": [
38 | ... "Some nanomaterials may give rise to various kinds of lung damage."
39 | ... ],
40 | ... "gt_answers":[
41 | ... [
42 | ... "Nanoparticles can penetrate the body, affecting the lungs, brain, and other organs,\
43 | ... leading to possible respiratory, cardiovascular, and brain health problems.",
44 | ... "Due to their small size, nanoparticles can infiltrate the body and impact vital organs,\
45 | ... posing risks to respiratory, heart, and neurological health."
46 | ... ]
47 | ... ]
48 | ... }
49 | >>> dataset = Dataset.from_dict(sample)
50 | >>> metric = rl.metrics.AnswerRougeCorrectness('rougeL')
51 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'], 1)
52 | >>> assert 0 <= score <= 1
53 | """
54 |
55 | _CITATION = """\
56 | @inproceedings{lin-2004-rouge,
57 | title = "{ROUGE}: A Package for Automatic Evaluation of Summaries",
58 | author = "Lin, Chin-Yew",
59 | booktitle = "Text Summarization Branches Out",
60 | month = jul,
61 | year = "2004",
62 | address = "Barcelona, Spain",
63 | publisher = "Association for Computational Linguistics",
64 | url = "https://aclanthology.org/W04-1013",
65 | pages = "74--81",
66 | }
67 | @article{lewis2020retrieval,
68 | title={Retrieval-augmented generation for knowledge-intensive nlp tasks},
69 | author={Lewis, Patrick and Perez, Ethan and Piktus, Aleksandra and Petroni, Fabio and Karpukhin, Vladimir and Goyal, Naman and K{\"u}ttler, Heinrich and Lewis, Mike and Yih, Wen-tau and Rockt{\"a}schel, Tim and others},
70 | journal={Advances in Neural Information Processing Systems},
71 | volume={33},
72 | pages={9459--9474},
73 | year={2020}
74 | }
75 | """
76 |
77 |
78 | @dataclass
79 | @add_attribute('mtype', 'AnswerCorrectness')
80 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
81 | class AnswerRougeCorrectness(Metric):
82 |
83 | name = "answer_rouge_correctness"
84 |
85 | ALIAS = ['answer_rouge_correctness']
86 |
87 | def __init__(self, rouge_type: str, tokenizer: Optional[Callable] = None):
88 | """Explicitly initialize the AnswerRougeCorrectness to ensure all parent class initialized as well as initialize the rouge type and tokenizer."""
89 | self.rouge_type = rouge_type
90 | self.scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=True, tokenizer=tokenizer)
91 | super().__init__()
92 |
93 | def __repr__(self) -> str:
94 | """:return: Formated string representation of the metric."""
95 | return f"{self.ALIAS[0]}"
96 |
97 | def _info(self):
98 | return datasets.MetricInfo(
99 | description=_DESCRIPTION,
100 | inputs_description=_KWARGS_DESCRIPTION,
101 | citation=_CITATION,
102 | homepage="",
103 | features=datasets.Features(
104 | {
105 | "answers": datasets.Value("string", id="sequence"),
106 | "gt_answers": datasets.Value("string", id="sequence"),
107 | }
108 | ),
109 | codebase_urls=["https://github.com/mim-solutions/rouge_score"],
110 | reference_urls=[
111 | "https://aclanthology.org/W04-1013/",
112 | "https://arxiv.org/abs/2005.11401"
113 | ]
114 | )
115 |
116 | def _compute_one(self, pred_answer: str, ref_answers: List[str]) -> float:
117 | """Evaluate the ROUGE between a single answer and groundtruth answers."""
118 | score = self.scorer.score_multi(ref_answers, pred_answer)
119 | return score[self.rouge_type].fmeasure
120 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_correctness/_answer_ter.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Tuple
3 |
4 | import datasets
5 | from sacrebleu.metrics import TER
6 | import numpy as np
7 |
8 | from rageval.metrics import Metric, add_attribute
9 |
10 | _DESCRIPTION = """\
11 | TER (Translation Edit Rate, also called Translation Error Rate) is a metric to quantify the edit operations that a
12 | hypothesis requires to match a reference translation. The implementation is already present in sacrebleu
13 | (https://github.com/mjpost/sacreBLEU#ter), which in turn is inspired by the TERCOM implementation, which can be found
14 | here: https://github.com/jhclark/tercom.
15 | """
16 |
17 | _KWARGS_DESCRIPTION = """\
18 | Args:
19 | name : str
20 | normalized (boolean): If `True`, applies basic tokenization and normalization to sentences. Defaults to `False`.
21 | ignore_punct (boolean): If `True`, applies basic tokenization and normalization to sentences. Defaults to `False`.
22 | support_zh_ja_chars (boolean): If `True`, tokenization/normalization supports processing of Chinese characters,
23 | as well as Japanese Kanji, Hiragana, Katakana, and Phonetic Extensions of Katakana.
24 | Only applies if `normalized = True`. Defaults to `False`.
25 | case_sensitive (boolean): If `False`, makes all predictions and references lowercase to ignore differences in case. Defaults to `False`.
26 |
27 | Optional Args:
28 | None
29 |
30 | Functions:
31 | _validate_data: validate the dataset format.
32 |
33 | Examples:
34 | >>> from datasets import Dataset
35 | >>> import rageval as rl
36 | >>> sample = {
37 | ... "answers": [
38 | ... "does this sentence match??",
39 | ... "what about this sentence?",
40 | ... "What did the TER metric user say to the developer?"
41 | ... ],
42 | ... "gt_answers": [
43 | ... ["does this sentence match", "does this sentence match!?!"],
44 | ... ["wHaT aBoUt ThIs SeNtEnCe?", "wHaT aBoUt ThIs SeNtEnCe?"],
45 | ... ["Your jokes are...", "...TERrible"]
46 | ... ]
47 | ... }
48 | >>> dataset = Dataset.from_dict(sample)
49 | >>> metric = rl.metrics.AnswerTERCorrectness()
50 | >>> metric.mtype
51 | 'AnswerCorrectness'
52 | >>> score, results = metric.compute(dataset['answers'], dataset['gt_answers'])
53 | >>> assert score == 110.00000000000001
54 | >>> assert results[0] == 25.0
55 | """
56 |
57 | _CITATION = """\
58 | @inproceedings{snover-etal-2006-study,
59 | title = "A Study of Translation Edit Rate with Targeted Human Annotation",
60 | author = "Snover, Matthew and
61 | Dorr, Bonnie and
62 | Schwartz, Rich and
63 | Micciulla, Linnea and
64 | Makhoul, John",
65 | booktitle = "Proceedings of the 7th Conference of the Association for Machine Translation in the Americas: Technical Papers",
66 | month = aug # " 8-12",
67 | year = "2006",
68 | address = "Cambridge, Massachusetts, USA",
69 | publisher = "Association for Machine Translation in the Americas",
70 | url = "https://aclanthology.org/2006.amta-papers.25",
71 | pages = "223--231",
72 | }
73 | @inproceedings{post-2018-call,
74 | title = "A Call for Clarity in Reporting {BLEU} Scores",
75 | author = "Post, Matt",
76 | booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
77 | month = oct,
78 | year = "2018",
79 | address = "Belgium, Brussels",
80 | publisher = "Association for Computational Linguistics",
81 | url = "https://www.aclweb.org/anthology/W18-6319",
82 | pages = "186--191",
83 | }
84 | """
85 |
86 |
87 | @dataclass
88 | @add_attribute('mtype', 'AnswerCorrectness')
89 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
90 | class AnswerTERCorrectness(Metric):
91 | """Estimates the TER between answers and ground truth answers."""
92 |
93 | name = "answer_ter"
94 |
95 | ALIAS = ['answer_ter']
96 |
97 | def __init__(
98 | self,
99 | normalized: bool = False,
100 | ignore_punct: bool = False,
101 | support_zh_ja_chars: bool = False,
102 | case_sensitive: bool = False
103 | ):
104 | """
105 | Explicitly initialize AnswerTERCorrectness.
106 |
107 | Ensure all parent classes are initialized.
108 | """
109 | super().__init__()
110 | self.ter = TER(
111 | normalized=normalized,
112 | no_punct=ignore_punct,
113 | asian_support=support_zh_ja_chars,
114 | case_sensitive=case_sensitive
115 | )
116 |
117 | def __repr__(self) -> str:
118 | """:return: Formatted string representation of the metric."""
119 | return f"{self.ALIAS[0]}"
120 |
121 | def _info(self):
122 | return datasets.MetricInfo(
123 | description=_DESCRIPTION,
124 | inputs_description=_KWARGS_DESCRIPTION,
125 | citation=_CITATION,
126 | features=datasets.Features(
127 | {
128 | "answers": datasets.Value("string"),
129 | "gt_answers": datasets.Sequence(datasets.Value("string"))
130 | }
131 | ),
132 | codebase_urls=["https://github.com/mjpost/sacreBLEU#ter"],
133 | reference_urls=["https://aclanthology.org/2006.amta-papers.25", "https://www.aclweb.org/anthology/W18-6319"]
134 | )
135 |
136 | def _validate_data(
137 | self,
138 | pred_answers: List[str],
139 | ref_answers: List[List[str]]
140 | ) -> None:
141 | """Validate the input predictions and references."""
142 | super()._validate_data(pred_answers, ref_answers)
143 | if not all(isinstance(pred_answer, str) for pred_answer in pred_answers):
144 | raise ValueError("The type of pred_answers should be a list of strings.")
145 | if not all(isinstance(reference_list, list) and all(isinstance(reference, str) for reference in reference_list) for reference_list in ref_answers):
146 | raise ValueError("The type of ref_answers should be a list of lists of strings.")
147 |
148 | def _compute_one(
149 | self,
150 | pred_answer: str,
151 | ref_answers: List[str]
152 | ) -> float:
153 | """Compute the TER score of a single answer."""
154 | return self.ter.sentence_score(pred_answer, ref_answers).score
155 |
156 | def compute(
157 | self,
158 | pred_answers: List[str],
159 | ref_answers: List[List[str]],
160 | ) -> Tuple[float, List[float]]:
161 | """Evaluate the dataset."""
162 | self._validate_data(pred_answers, ref_answers)
163 | scores = self._compute_batch(pred_answers, ref_answers)
164 | ref_answers = np.array(ref_answers)
165 | ref_answers = ref_answers.T.tolist()
166 | return self.ter.corpus_score(pred_answers, ref_answers).score, scores
167 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_groundedness/_claim_faithfulness.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_groundedness/_claim_faithfulness.py
--------------------------------------------------------------------------------
/rageval/metrics/answer_informativeness/_answer_distinct12.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from dataclasses import dataclass
3 | from typing import List, Optional, Iterable, Tuple
4 | import datasets
5 | from nltk import ngrams
6 | from rageval.metrics import Metric, add_attribute
7 |
8 | _DESCRIPTION = """\
9 | Distinct 1/2 measures the diversity of generated text by calculating the ratio of unique n-grams to the total number of n-grams.
10 | """
11 |
12 | _KWARGS_DESCRIPTION = """\
13 | Args:
14 | pred_answers (list of str): List of generated texts for which distinct metrics are computed.
15 | n_grams (int): The n-gram order for which distinct metrics are computed.
16 |
17 | Returns:
18 | dict: Dictionary containing Distinct-1 and Distinct-2 scores.
19 |
20 | Examples:
21 | >>> from datasets import Dataset
22 | >>> import rageval as rl
23 | >>> sample = {
24 | ... "answers": [
25 | ... "This is the first sentence.",
26 | ... "This is the second sentence."
27 | ... ]
28 | ... }
29 | >>> dataset = Dataset.from_dict(sample)
30 | >>> metric = rl.metrics.AnswerDistinct(1)
31 | >>> metric.mtype
32 | 'AnswerInformativeness'
33 | >>> score, results = metric.compute(dataset['answers'])
34 | >>> score
35 | 0.6
36 | """
37 |
38 | _CITATION = """\
39 | @misc{selfmemory2023,
40 | title={Lift Yourself Up: Retrieval-augmented Text Generation with Self Memory},
41 | author={Xin Cheng and Di Luo and Xiuying Chen and Lemao Liu and Dongyan Zhao and Rui Yan},
42 | year={2023},
43 | eprint={2305.02437},
44 | archivePrefix={arXiv},
45 | primaryClass={cs.CL}
46 | }
47 | """
48 |
49 |
50 | def get_distinct_score(pred_answers: List[str], n_grams: int) -> dict:
51 | """Compute Distinct-1 and Distinct-2 metrics."""
52 | c = Counter()
53 | for answer in pred_answers:
54 | tokens = answer.split()
55 | c.update(ngrams(tokens, n_grams))
56 |
57 | return len(c) / sum(c.values())
58 |
59 |
60 | @dataclass
61 | @add_attribute('mtype', 'AnswerInformativeness')
62 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63 | class AnswerDistinct(Metric):
64 | """Distinct 1/2 metric for text generation."""
65 |
66 | name = "answer_distinct"
67 |
68 | ALIAS = ['answer_distinct']
69 |
70 | def __init__(self, n_grams: int = 1):
71 | """
72 | Explicitly initialize Distinct.
73 |
74 | Ensure all parent classes are initialized.
75 | """
76 | super().__init__()
77 | self.n_grams = n_grams
78 |
79 | def __repr__(self) -> str:
80 | """:return: Formatted string representation of the metric."""
81 | return f"{self.ALIAS[0]}"
82 |
83 | def _info(self):
84 | return datasets.MetricInfo(
85 | description=_DESCRIPTION,
86 | inputs_description=_KWARGS_DESCRIPTION,
87 | citation=_CITATION,
88 | features=datasets.Features(
89 | {
90 | "pred_answers": datasets.Value("string"),
91 | }
92 | ),
93 | codebase_urls=["https://github.com/Hannibal046/SelfMemory/blob/main/src/utils/metrics_utils.py"],
94 | reference_urls=["https://arxiv.org/abs/2305.02437"]
95 | )
96 |
97 | def _validate_data(
98 | self,
99 | pred_answers: Optional[Iterable] = None,
100 | ref_answers: Optional[Iterable] = None,
101 | ) -> bool:
102 | """Validate the input data."""
103 | assert isinstance(pred_answers, str) or isinstance(pred_answers, list) # pragma: no cover
104 |
105 | def compute(
106 | self,
107 | pred_answers: Optional[Iterable] = None,
108 | ) -> Tuple[float, List[float]]:
109 | """
110 | Evaluate the dataset.
111 |
112 | Return average scores of all inputs and a score list for each example.
113 | """
114 | return get_distinct_score(pred_answers, self.n_grams), [get_distinct_score([pred_answer], self.n_grams) for pred_answer in pred_answers]
115 |
--------------------------------------------------------------------------------
/rageval/metrics/answer_informativeness/_claim_num.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_informativeness/_claim_num.py
--------------------------------------------------------------------------------
/rageval/metrics/answer_informativeness/_pairwise_accuracy.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_informativeness/_pairwise_accuracy.py
--------------------------------------------------------------------------------
/rageval/metrics/answer_informativeness/_repetitiveness.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/answer_informativeness/_repetitiveness.py
--------------------------------------------------------------------------------
/rageval/metrics/answer_informativeness/_text_length.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Iterable
3 | from transformers import AutoTokenizer
4 | import evaluate
5 |
6 | import datasets
7 |
8 | from rageval.metrics import Metric, add_attribute
9 |
10 |
11 | _DESCRIPTION = """\
12 | Textlength is a metric used to evaluate the length of a model-generated response.
13 |
14 | It measures the number of tokens in the generated text by first converting the text into tokens and then counting the total number. This metric provides insight into the verbosity or conciseness of the model's output, offering a standardized way to compare text length across different responses.
15 | """
16 |
17 | _KWARGS_DESCRIPTION = """\
18 | Args:
19 | name : str
20 |
21 | Optional Args:
22 | None
23 |
24 | Functions:
25 | _compute_one: Evaluating the length of answer.
26 |
27 | Examples:
28 | >>> from datasets import Dataset
29 | >>> import rageval as rl
30 | >>> sample = {
31 | ... "answers": [
32 | ... "A",
33 | ... "C",
34 | ... ]
35 | ... }
36 | >>> dataset = Dataset.from_dict(sample)
37 | >>> metric = TextLength(tokenize_model="Qwen/Qwen2-0.5B-Instruct")
38 | >>> metric.mtype
39 | 'answer_informativeness'
40 | """
41 |
42 |
43 | @dataclass
44 | @add_attribute('mtype', 'answer_informativeness')
45 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
46 | class TextLength(Metric):
47 | """Estimates the text length of answers."""
48 |
49 | name = "text_length"
50 |
51 | ALIAS = ['text_length']
52 |
53 | def __init__(self, tokenize_model: str = "Qwen/Qwen2-0.5B-Instruct"):
54 | """
55 | Explicitly initialize TextLength.
56 |
57 | Ensure all parent classes are initialized.
58 | """
59 | self.tokenizer = AutoTokenizer.from_pretrained(tokenize_model)
60 | super().__init__()
61 | self.info = evaluate.MetricInfo(
62 | description=_DESCRIPTION,
63 | inputs_description=_KWARGS_DESCRIPTION,
64 | citation="",
65 | homepage="",
66 | features=datasets.Features(
67 | {
68 | "answers": datasets.Value("string"),
69 | }
70 | ),
71 | codebase_urls=[],
72 | reference_urls=[]
73 | )
74 |
75 | def __repr__(self) -> str:
76 | """:return: Formatted string representation of the metric."""
77 | return f"{self.ALIAS[0]}" # pragma: no cover
78 |
79 | def _compute_one(
80 | self,
81 | answer: str,
82 | *args: Optional[Iterable],
83 | ) -> float:
84 | """Evaluating the text length of answer."""
85 | length = len(self.tokenizer(answer, return_tensors="pt")['input_ids'][0])
86 | return length
87 |
--------------------------------------------------------------------------------
/rageval/metrics/base.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Callable, Optional, Iterable
2 | from abc import abstractmethod
3 | from dataclasses import dataclass
4 |
5 | import numpy as np
6 | from langchain.schema import LLMResult
7 | from tqdm import tqdm
8 |
9 |
10 | def add_attribute(attribute_name, attribute_value):
11 | """
12 | This decorate is used to set attribute for Class.
13 |
14 | Currently, this decorate can be used to set attr:metric_type for each metric.
15 | There are four types, i.e., 'AnswerCorrectness', 'AnswerGroundedness', 'ContextRelevancy', 'ContextAdequacy', \
16 | for all RAG metrics.
17 | """
18 | def decorator(cls):
19 | setattr(cls, attribute_name, attribute_value)
20 | return cls
21 | return decorator
22 |
23 |
24 | @dataclass
25 | class Metric():
26 | """Metric base class without LLM."""
27 |
28 | def __init__(
29 | self,
30 | config_name: Optional[str] = None,
31 | experiment_id: Optional[str] = None
32 | ):
33 | """Initialization.
34 |
35 | Args:
36 | config_name: type(string), Optional.
37 | experiment_id: type(string), Optional.
38 | """ # pragma: no cover
39 |
40 | @property
41 | @abstractmethod
42 | def name(self) -> str:
43 | """The metric name."""
44 | ... # pragma: no cover
45 |
46 | def _validate_data(
47 | self,
48 | pred_answers: Optional[Iterable] = None,
49 | ref_answers: Optional[Iterable] = None,
50 | *args: Optional[Iterable]
51 | ) -> None:
52 | """Validate the of the input dataset."""
53 | if (pred_answers and ref_answers):
54 | if len(pred_answers) != len(ref_answers) or any(len(pred_answers) != len(arg) for arg in args):
55 | raise ValueError("The length of predictions and references should be the same.") # pragma: no cover
56 |
57 | def compute(
58 | self,
59 | pred_answers: Optional[Iterable] = None,
60 | ref_answers: Optional[Iterable] = None,
61 | batch_size: Optional[int] = None,
62 | *args: Optional[Iterable],
63 | ) -> Tuple[float, List[float]]:
64 | """
65 | Evaluate the dataset.
66 |
67 | Return average scores of all inputs and a score list for each example.
68 | """
69 | self._validate_data(pred_answers, ref_answers, *args)
70 | scores = self._compute_batch(pred_answers, ref_answers, *args)
71 |
72 | return np.average(scores), scores
73 |
74 | @abstractmethod
75 | def _compute_one(
76 | self,
77 | pred_answer: Optional[Iterable] = None,
78 | ref_answer: Optional[Iterable] = None,
79 | *args: Optional[Iterable]
80 | ) -> float:
81 | ... # pragma: no cover
82 |
83 | def _compute_batch(
84 | self,
85 | pred_answers: Optional[Iterable] = None,
86 | ref_answers: Optional[Iterable] = None,
87 | *args: Optional[Iterable]
88 | ) -> List[float]:
89 | """Compute the metric for a batch of predictions and references."""
90 | scores = []
91 | if (pred_answers and ref_answers): # if both columns exist
92 | for pred_answer, ref_answer in tqdm(zip(pred_answers, ref_answers),
93 | desc=f"Computing {self.name}",
94 | total=len(pred_answers)):
95 | scores.append(self._compute_one(pred_answer, ref_answer))
96 | else:
97 | for pred_answer in tqdm(pred_answers,
98 | desc=f"Computing {self.name}",
99 | total=len(pred_answers)):
100 | scores.append(self._compute_one(pred_answer))
101 | return scores
102 |
103 |
104 | @dataclass
105 | class MetricWithLLM(Metric):
106 | """Metrics based on LLM."""
107 |
108 | def __init__(self, model: Callable):
109 | """Initialization."""
110 | super().__init__()
111 | self.llm = model
112 |
113 | @abstractmethod
114 | def parse_llm_result(self, prompts: List[str], result: LLMResult):
115 | """Parse the LLM Result based on the Prompt."""
116 | ... # pragma: no cover
117 |
--------------------------------------------------------------------------------
/rageval/metrics/context_relevance/_accuracy.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/context_relevance/_accuracy.py
--------------------------------------------------------------------------------
/rageval/metrics/context_relevance/_hit_rate.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/context_relevance/_hit_rate.py
--------------------------------------------------------------------------------
/rageval/metrics/context_relevance/_mrr.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/context_relevance/_mrr.py
--------------------------------------------------------------------------------
/rageval/metrics/context_relevance/_ndcg.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/metrics/context_relevance/_ndcg.py
--------------------------------------------------------------------------------
/rageval/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .openai import OpenAILLM
2 | from .nli import NLIModel
--------------------------------------------------------------------------------
/rageval/models/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import typing as t
4 | from abc import ABC
5 |
6 | from langchain.schema.output import LLMResult
7 |
8 | if t.TYPE_CHECKING:
9 | from langchain.callbacks.base import Callbacks
10 | from langchain.prompts import ChatPromptTemplate
11 |
12 |
13 | class BaseLLM(ABC):
14 | """
15 | BaseLLM is the base class for all LLMs.
16 |
17 | It provides a consistent interface for other classes that interact with LLMs like Langchains, LlamaIndex, \
18 | LiteLLM etc. Handles multiple_completions even if not supported by the LLM.
19 |
20 | It currently takes in ChatPromptTemplates and returns LLMResults which are Langchain primitives.
21 | """
22 |
23 | # supports multiple completions for the given prompt
24 | n_completions_supported: bool = False
25 |
26 | @property
27 | def llm(self) -> t.Any:
28 | """LLM model."""
29 | ...
30 |
31 | def validate_api_key(self):
32 | """Validates that the api key is set for the LLM."""
33 | pass
34 |
35 | def generate(
36 | self,
37 | prompts: list[ChatPromptTemplate],
38 | n: int = 1,
39 | temperature: float = 1e-8,
40 | callbacks: t.Optional[Callbacks] = None,
41 | ) -> LLMResult:
42 | """Call the llm model to generate results."""
43 | ...
44 |
--------------------------------------------------------------------------------
/rageval/models/nli.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from abc import ABC
3 |
4 | import pytest
5 | from transformers import pipeline
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class NLIModel(ABC):
11 | """This is the Roberta-based NLI model."""
12 |
13 | def __init__(self, task: str = "sentiment-analysis", model: str = "roberta-large-mnli") -> None:
14 | """Init the Roberta Model."""
15 | self._model_name = model
16 | self._model = pipeline(task=task, model=model, device_map="auto")
17 |
18 | self._labelmap = {
19 | "NEUTRAL": 3,
20 | "CONTRADICTION": 2,
21 | "ENTAILMENT": 1
22 | }
23 | self._nli2stance = {
24 | "NEUTRAL": "irrelevant",
25 | "CONTRADICTION": "refute",
26 | "ENTAILMENT": "support"
27 | }
28 | self._stancemap = {
29 | 'irrelevant': 3,
30 | 'refute': 2,
31 | 'partially-support': 1,
32 | 'completely-support': 1
33 | }
34 |
35 | @property
36 | def model(self):
37 | """Construct the OpenAI LLM model."""
38 | return self._model
39 |
40 | @pytest.mark.api
41 | def infer_prob(self, premise, hypothesis):
42 | """Predict one sample with NLI model."""
43 | try:
44 | if len(premise) > 200:
45 | premise = premise[:200]
46 | if len(hypothesis) > 200:
47 | hypothesis = hypothesis[:200]
48 | input = "{}{}".format(premise, hypothesis)
49 | pred = self._model(input)
50 | # print(pred)
51 | except Exception as e:
52 | # token length > 514
53 | L = len(premise)
54 | premise = premise[:int(L / 2)]
55 | input = "{}{}".format(premise, hypothesis)
56 | pred = self._model(input)
57 | logger.info(f"An exception occurred during nli inference: {e}")
58 | return pred
59 |
60 | @pytest.mark.api
61 | def infer(self, premise, hypothesis):
62 | """Predict one sample with NLI model."""
63 | pred = self.infer_prob(premise, hypothesis)
64 | # [{'label': 'CONTRADICTION', 'score': 0.9992701411247253}]
65 | if 'mnli' in self._model_name:
66 | return self._nli2stance[pred[0]['label']]
67 | else:
68 | nli2stance = {
69 | "LABEL_0": "irrelevant",
70 | "LABEL_1": "support"
71 | }
72 | return nli2stance[pred[0]['label']]
73 |
74 | @pytest.mark.api
75 | def generate_infer(self, premise, hypothesis):
76 | """Predict one sample with NLI model."""
77 | input_text = "premise: {} hypothesis: {}".format(premise, hypothesis)
78 | pred = self._model(input_text, max_new_tokens=10)
79 | # [{'generated_text': 'support'}]
80 | if pred[0]["generated_text"] == "1":
81 | return 1
82 | return 0
83 |
--------------------------------------------------------------------------------
/rageval/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseTask
2 | from ._generate import Generate
3 |
--------------------------------------------------------------------------------
/rageval/tasks/_generate.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 |
3 | from rageval.metrics import Metric
4 | from rageval.tasks import BaseTask
5 |
6 |
7 | class Generate(BaseTask):
8 | name = 'Generator'
9 | # Define required columns in testset for the evaluation of the task
10 | required_columns = ['questions', 'answers', 'gt_answers']
11 |
12 | def __init__(self, metrics: Union[str, List[str], List[Metric]]):
13 | """Init task"""
14 |
15 | super().__init__(metrics)
16 |
--------------------------------------------------------------------------------
/rageval/tasks/base.py:
--------------------------------------------------------------------------------
1 | """Base task."""
2 |
3 | from abc import ABC, abstractmethod
4 | from dataclasses import dataclass
5 | from typing import Union, Type, List
6 |
7 | from datasets import Dataset, concatenate_datasets
8 |
9 | from rageval.metrics import Metric
10 |
11 |
12 | @dataclass
13 | class BaseTask(ABC):
14 | """Base Task, shouldn't be used directly."""
15 |
16 | def __init__(self, metrics: Union[str, List[str], List[Metric]]):
17 | """Base task constructor."""
18 | self.detailed_result = []
19 | self.result = {}
20 | self.metrics = self._init_metrics(metrics)
21 |
22 | @property
23 | @abstractmethod
24 | def name(self) -> str:
25 | """The task name."""
26 | ...
27 |
28 | def _init_metrics(self, metrics):
29 | if not metrics:
30 | raise ValueError("metrics should not be empty")
31 | if isinstance(metrics, str):
32 | metrics = [metrics]
33 | return [self._parse_metric(m) for m in metrics]
34 |
35 | def _parse_metric(self, metric: Union[str, Type[Metric], Metric]):
36 | """
37 | Parse input metric in any form into a :class:`Metric` instance.
38 |
39 | :param metric: Input metric in any form.
40 | :return: A :class:`Metric` instance
41 |
42 | """
43 |
44 | if isinstance(metric, str):
45 | metric = metric.lower() # ignore case
46 |
47 | # TODO: parse metrics in str form
48 | """
49 | for subclass in Metric.__subclasses__():
50 | if metric in subclass.ALIAS:
51 | return subclass()
52 | """
53 |
54 | elif isinstance(metric, Metric):
55 | return metric
56 | elif issubclass(metric, Metric):
57 | return metric()
58 | else:
59 | raise ValueError(metric)
60 |
61 | def evaluate(self, testset) -> Dataset:
62 | """Evaluation each metrics."""
63 |
64 | self._validate_columns(testset)
65 | for m in self.metrics:
66 | res, de_res = m.compute(testset)
67 | self.result[m.name] = res
68 | self.detailed_result.append(de_res)
69 | return self.result
70 |
71 | def _validate_columns(self, testset: Dataset):
72 | """Make sure columns in testset is subset of required columns."""
73 |
74 | if not set(self.required_columns).issubset(set(testset.column_names)):
75 | print("Testset should contain following columns: ", ', '.join(self.required_columns))
76 | raise ValueError(testset)
77 |
78 | def obtain_detailed_result(self):
79 | """Obtain instance level result for the test case."""
80 |
81 | if not self.detailed_result:
82 | raise NameError(self.detailed_result)
83 | colnames = self.required_columns + [m.name for m in self.metrics]
84 | self.detailed_result = concatenate_datasets(self.detailed_result).select_columns(colnames)
85 | return self.detailed_result
86 |
--------------------------------------------------------------------------------
/rageval/utils/RAGAS_prompt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """All prompts used for RAGAS related metrics."""
3 |
4 | # Most of the prompt in this file comes from the RAGAS project:
5 | # https://github.com/explodinggradients/ragas
6 |
7 | CONTEXT_RECALL_RA = """
8 | Given a context, and an answer, analyze each sentence in the answer and classify if the sentence can be attributed to the given context or not. Use only "Yes" (1) or "No" (0) as a binary classification. Output json with reason.
9 |
10 |
11 | question: What can you tell me about albert Albert Einstein?
12 | context: Albert Einstein (14 March 1879 – 18 April 1955) was a German-born theoretical physicist,widely held to be one of the greatest and most influential scientists of all time. Best known for developing the theory of relativity, he also made important contributions to quantum mechanics, and was thus a central figure in the revolutionary reshaping of the scientific understanding of nature that modern physics accomplished in the first decades of the twentieth century. His mass–energy equivalence formula E = mc2, which arises from relativity theory, has been called "the world's most famous equation". He received the 1921 Nobel Prize in Physics "for his services to theoretical physics, and especially for his discovery of the law of the photoelectric effect", a pivotal step in the development of quantum theory. His work is also known for its influence on the philosophy of science. In a 1999 poll of 130 leading physicists worldwide by the British journal Physics World, Einstein was ranked the greatest physicist of all time. His intellectual achievements and originality have made Einstein synonymous with genius.
13 | answer: Albert Einstein born in 14 March 1879 was German-born theoretical physicist, widely held to be one of the greatest and most influential scientists of all time. He received the 1921 Nobel Prize in Physics "for his services to theoretical physics. He published 4 papers in 1905. Einstein moved to Switzerland in 1895
14 | classification:
15 | [
16 | {{ "statement_1":"Albert Einstein, born on 14 March 1879, was a German-born theoretical physicist, widely held to be one of the greatest and most influential scientists of all time.",
17 | "reason": "The date of birth of Einstein is mentioned clearly in the context.",
18 | "Attributed": "1"
19 | }},
20 | {{
21 | "statement_2":"He received the 1921 Nobel Prize in Physics 'for his services to theoretical physics.",
22 | "reason": "The exact sentence is present in the given context.",
23 | "Attributed": "1"
24 | }},
25 | {{
26 | "statement_3": "He published 4 papers in 1905.",
27 | "reason": "There is no mention about papers he wrote in the given context.",
28 | "Attributed": "0"
29 | }},
30 | {{
31 | "statement_4":"Einstein moved to Switzerland in 1895.",
32 | "reason": "There is no supporting evidence for this in the given context.",
33 | "Attributed": "0"
34 | }}
35 | ]
36 |
37 | question: who won 2020 icc world cup?
38 | context: Who won the 2022 ICC Men's T20 World Cup?
39 | The 2022 ICC Men's T20 World Cup, held from October 16 to November 13, 2022, in Australia, was the eighth edition of the tournament. Originally scheduled for 2020, it was postponed due to the COVID-19 pandemic. England emerged victorious, defeating Pakistan by five wickets in the final to clinch their second ICC Men's T20 World Cup title.
40 | answer: England
41 | classification:
42 | [
43 | {{
44 | "statement_1":"England won the 2022 ICC Men's T20 World Cup.",
45 | "reason": "From context it is clear that England defeated Pakistan to win the World Cup.",
46 | "Attributed": "1"
47 | }}
48 | ]
49 |
50 | question:{question}
51 | context:{context}
52 | answer:{answer}
53 | classification:
54 | """
55 |
--------------------------------------------------------------------------------
/rageval/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .check_utils import text_to_sents, remove_citations
2 | from .RAGAS_prompt import CONTEXT_RECALL_RA
3 |
--------------------------------------------------------------------------------
/rageval/utils/check_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from typing import List
4 |
5 | import nltk
6 | from nltk.downloader import Downloader
7 |
8 | from rageval.models import OpenAILLM
9 | from .prompt import DOC_TO_SENTENCES_PROMPT
10 |
11 | logger = logging.getLogger(__name__)
12 | if not Downloader().is_installed('punkt_tab'):
13 | nltk.download('punkt_tab')
14 |
15 |
16 | def text_to_sents(text: str, model_name="nltk") -> List[str]:
17 | """Convert the text into a set of sentences."""
18 | sentences = []
19 | if model_name == "nltk":
20 | sentences = nltk.sent_tokenize(text)
21 | sentences = [s.strip() for s in sentences if len(s.strip()) >= 3]
22 |
23 | elif model_name == "gpt-3.5-turbo":
24 | model = OpenAILLM("gpt-3.5-turbo", "OPENAI_API_KEY") # pragma: no cover
25 | prompt = DOC_TO_SENTENCES_PROMPT # pragma: no cover
26 | input_str = prompt.format(doc=text).strip() # pragma: no cover
27 | r = model.generate([input_str]) # pragma: no cover
28 | sentences = eval(r) # pragma: no cover
29 | else:
30 | logger.info("The parameter `model_name` should be in [`nltk`, `gpt-3.5-turbo`]. ") # pragma: no cover
31 | assert isinstance(sentences, list)
32 |
33 | return sentences
34 |
35 |
36 | def remove_citations(text: str) -> str:
37 | """Remove the citation in the text."""
38 | return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", text)).replace(" |", "").replace("]", "")
39 |
--------------------------------------------------------------------------------
/rageval/utils/utility.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import os
5 | import typing as t
6 | import warnings
7 | from dataclasses import dataclass
8 | from functools import lru_cache
9 |
10 | from langchain.callbacks.manager import CallbackManager, trace_as_chain_group
11 | from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
12 |
13 | if t.TYPE_CHECKING:
14 | from rageval.models import ragevalLLM
15 |
16 | DEBUG_ENV_VAR = "rageval_DEBUG"
17 | # constant to tell us that there is no key passed to the llm/embeddings
18 | NO_KEY = "no-key"
19 |
20 |
21 | @lru_cache(maxsize=1)
22 | def get_debug_mode() -> bool:
23 | """Get debug mode."""
24 | if os.environ.get(DEBUG_ENV_VAR, str(False)).lower() == "true":
25 | return True
26 | else:
27 | return False
28 |
29 |
30 | def load_as_json(text):
31 | """Validate and return given text as json."""
32 |
33 | try:
34 | return json.loads(text)
35 | except ValueError as e:
36 | warnings.warn(f"Invalid json: {e}")
37 |
38 | return {}
39 |
40 |
41 | JSON_PROMPT = HumanMessagePromptTemplate.from_template(
42 | """
43 |
44 | Rewrite the input into valid json
45 |
46 |
47 | Input:
48 | {{
49 | "name": "John Doe",
50 | "age": 30,
51 | "isStudent": false
52 | "address": {{
53 | "street": "123 Main St",
54 | "city": "Anytown",
55 | "state": "CA",
56 | }}
57 | "hobbies": ["reading", "swimming", "cycling"]
58 | }}
59 | Output:
60 | {{
61 | "name": "John Doe",
62 | "age": 30,
63 | "isStudent": false,
64 | "address": {{
65 | "street": "123 Main St",
66 | "city": "Anytown",
67 | "state": "CA"
68 | }},
69 | "hobbies": ["reading", "swimming", "cycling"]
70 | }}
71 |
72 |
73 | Input:
74 | {{
75 | "statement": "The Earth is also known as "Terra" "
76 | }}
77 | Output:
78 | {{
79 | "statement": "The Earth is also known as 'Terra'"
80 | }}
81 |
82 | Input:
83 | {input}
84 |
85 | Output:
86 | """
87 | )
88 |
89 |
90 | @dataclass
91 | class JsonLoader:
92 | """This class is for .... (wenshan fix)."""
93 |
94 | max_retries: int = 2
95 |
96 | def safe_load(self, text: str, llm: ragevalLLM):
97 | """Load json in safety mode."""
98 | retry = 0
99 | while retry <= self.max_retries:
100 | try:
101 | start, end = self._find_outermost_json(text)
102 | return json.loads(text[start:end])
103 | except ValueError:
104 | text = self._fix_to_json(text, llm)
105 | retry += 1
106 |
107 | return {}
108 |
109 | def _fix_to_json(
110 | self,
111 | text,
112 | llm,
113 | callbacks: t.Optional[CallbackManager] = None,
114 | callback_group_name: str = "batch",
115 | ):
116 | # TODO (executor)
117 | with trace_as_chain_group(
118 | callback_group_name, callback_manager=callbacks
119 | ) as batch_group:
120 | human_prompt = ChatPromptTemplate.from_messages(
121 | [JSON_PROMPT.format(input=text)]
122 | )
123 | results = llm.generate(
124 | [human_prompt],
125 | n=1,
126 | callbacks=batch_group,
127 | )
128 | return results.generations[0][0].text
129 |
130 | def _find_outermost_json(self, text):
131 | stack = []
132 | start_index = -1
133 |
134 | for i, char in enumerate(text):
135 | if char in "{[":
136 | if len(stack) == 0:
137 | start_index = i
138 | stack.append(char)
139 |
140 | elif char in "}]":
141 | if len(stack) > 0:
142 | last = stack.pop()
143 | if (char == "}" and last != "{") or (char == "]" and last != "["):
144 | # Mismatched closing brace/bracket, invalid JSON
145 | break
146 |
147 | if len(stack) == 0 and start_index != -1:
148 | # Found a valid outermost JSON
149 | return (
150 | start_index,
151 | i + 1,
152 | ) # Add 1 to include the closing brace/bracket in the range
153 |
154 | return -1, -1 # No valid JSON found
155 |
156 |
157 | json_loader = JsonLoader()
158 |
--------------------------------------------------------------------------------
/rageval/validation.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/rageval/validation.py
--------------------------------------------------------------------------------
/rageval/version.py:
--------------------------------------------------------------------------------
1 | """Rageval version file."""
2 |
3 | __version__ = '0.0.1'
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | refchecker == 0.2.13
2 | numpy >= 1.26
3 | tqdm >= 4.66
4 | hyperopt >= 0.1.1
5 | h5py >= 2.8.0
6 | coverage >= 4.3.4
7 | codecov >= 2.0.15
8 | pytest >= 3.7.4
9 | pytest-cov >= 2.4.0
10 | flake8 >= 7.0.0
11 | flake8_docstrings >= 1.7.0
12 | pydocstyle >= 6.1
13 | openai >= 1.10.0
14 | datasets >= 3.0.1
15 | langchain >= 0.3.1
16 | langchain-community >= 0.3.1
17 | transformers >= 4.37.2
18 | torch >= 2.2.0
19 | pandas >= 2.0.0
20 | nltk >= 3.9.1
21 | spacy >= 3.7.4
22 | en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
23 | rouge_score >= 0.1.2
24 | accelerate >= 0.27.2
25 | sentencepiece >= 0.2.0
26 | protobuf >= 4.25.3
27 | sacrebleu >= 2.3.3
28 | bert_score >= 0.3.13
29 | jieba >= 0.42.1
30 | evaluate >= 0.4.3
31 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 |
4 | from setuptools import setup, find_packages
5 |
6 |
7 | here = os.path.abspath(os.path.dirname(__file__))
8 |
9 | # Avoids IDE errors, but actual version is read from version.py
10 | __version__ = None
11 | exec(open('rageval/version.py').read())
12 |
13 | short_description = 'Evaluation tools for Retrieval-augmented Generation (RAG) methods.'
14 |
15 | # Get the long description from the README file
16 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
17 | long_description = f.read()
18 |
19 | install_requires = [
20 | 'numpy >= 1.14',
21 | 'tqdm >= 4.23.4',
22 | 'hyperopt >= 0.1.1',
23 | 'h5py >= 2.8.0',
24 | 'openai == 1.10.0',
25 | 'datasets == 2.16.1',
26 | 'langchain == 0.1.4',
27 | 'transformers == 4.37.2',
28 | 'torch == 2.2.0',
29 | 'pandas == 2.0.0',
30 | 'nltk == 3.8.1',
31 | 'spacy == 3.7.4',
32 | 'en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl',
33 | 'rouge_score == 0.1.2',
34 | 'sacrebleu == 2.3.3'
35 | ]
36 |
37 | extras_requires = {
38 | 'tests': [
39 | 'coverage >= 4.3.4',
40 | 'codecov >= 2.0.15',
41 | 'pytest >= 3.7.4',
42 | 'pytest-cov >= 2.4.0',
43 | 'flake8 == 7.0.0',
44 | 'pydocstyle == 6.1',
45 | 'flake8_docstrings >= 1.7.0'
46 | ],
47 | 'benchmarks': [
48 | 'accelerate == 0.27.2',
49 | 'sentencepiece == 0.2.0',
50 | 'protobuf == 4.25.3'
51 | ]
52 | }
53 |
54 |
55 | setup(
56 | name="RagEval",
57 | version=__version__,
58 | author="Wenshan Wang, Yixing Fan, etc.",
59 | author_email="wangwenshan@ict.ac.cn",
60 | description=short_description,
61 | license="Apache 2.0",
62 | keywords="RAG evaluation tools",
63 | url="https://github.com/gomate-community/rageval",
64 | packages=find_packages(),
65 | long_description=long_description,
66 | long_description_content_type='text/markdown',
67 | classifiers=[
68 | "Development Status :: 3 - Alpha",
69 | 'Environment :: Console',
70 | 'Operating System :: POSIX :: Linux',
71 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
72 | "License :: OSI Approved :: Apache Software License",
73 | 'Programming Language :: Python :: 3.6'
74 | ],
75 | install_requires=install_requires,
76 | extras_require=extras_requires
77 | )
78 |
--------------------------------------------------------------------------------
/tests/demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # import sys
4 | # sys.path.insert(0, '../src')
5 |
6 | # from rageval.evaluations import evaluate
7 | """
8 | from .src.tasks import (
9 | retriever,
10 | )
11 | from .src.metrics import (
12 | context_recall,
13 | )
14 | """
15 |
16 |
17 | # from rageval.metrics import ContextRecall
18 | from datasets import Dataset
19 | # import os
20 |
21 |
22 | # 准备您的huggingface数据集,格式如下
23 | # Dataset({
24 | # features: ['question', 'contexts', 'answer', 'ground_truths'],
25 | # num_rows: 25
26 | # })
27 |
28 | # 模拟生成待测评数据
29 | questions = ["恐龙是怎么被命名的?"]
30 | ground_truths = [
31 | [
32 | "1841年,英国科学家理查德·欧文在研究几块样子像蜥蜴骨头化石时,认为它们是某种史前动物留下来的,并命名为恐龙,意思是“恐怖的蜥蜴”。"
33 | ]
34 | ]
35 | answers = ["人从恐龙进化而来"]
36 | contexts = [
37 | [
38 | "[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。 "
39 | "[12]“我们的结果显示恐龙所具有的生长速率和新陈代谢速率,既不是冷血生物体也不是温血生物体所具有的特征。它们既不像哺乳动物或者鸟类,也不像爬行动物或者鱼类,而是介于现代冷血动物和温血动物之间。"
40 | "简言之,它们的生理机能在现代社会并不常见。”美国亚利桑那大学进化生物学家和生态学家布莱恩·恩奎斯特说。墨西哥生物学家表示,正是这种中等程度的新陈代谢使得恐龙可以长得比任何哺乳动物都要大。"
41 | "温血动物需要大量进食,因此它们频繁猎捕和咀嚼植物。“很难想象霸王龙大小的狮子能够吃饱以 存活下来。","[12]哺乳动物起源于爬行动物,它们的前身是“似哺乳类的爬行动物”,即兽孔目,早期则是“似爬行类的哺乳动物”,"
42 | "即哺乳型动物。 [12]中生代的爬行动物,大部分在中生代的末期灭绝了;一部分适应了变化的环境被保留下来,即现存的爬行动物(如龟鳖类、蛇类、鳄类等);还有一部分沿着不同的进化方向,进化成了现今的鸟类和哺乳类。 "
43 | "[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。"
44 | ]
45 | ]
46 |
47 | # To dict
48 | data = {
49 | "question": questions,
50 | "answer": answers,
51 | "contexts": contexts,
52 | "ground_truths": ground_truths
53 | }
54 |
55 | # Convert dict to dataset
56 | dataset = Dataset.from_dict(data)
57 |
58 | # dataset: Dataset
59 |
60 | # results = evaluate(dataset, task='retriever', metrics=['context_recall'])
61 | # results = evaluate(dataset, metrics=[ContextRecall()])
62 | # print(results)
63 |
--------------------------------------------------------------------------------
/tests/test_evaluation.py:
--------------------------------------------------------------------------------
1 | """Test the evaluation function."""
2 |
3 | import sys
4 |
5 | import pytest
6 | from datasets import load_dataset, Dataset
7 | from langchain.llms.fake import FakeListLLM
8 |
9 | import rageval as rl
10 | from rageval.models import NLIModel
11 |
12 | sys.path.insert(0, '../src')
13 |
14 |
15 | @pytest.mark.skip
16 | def test_evaluation():
17 | """
18 | This is test unit for testing the load_dataset function.
19 | """
20 |
21 | # 1) init test task: task type, metrics
22 |
23 | # 2) load dataset, and extract testset
24 | # train_data = rageval.datasets.load_data('', task='')
25 | # assert len(train_data) == 300
26 |
27 | # 3) run evaluation
28 | # result = evaluate(testset, info)
29 |
30 | ds = load_dataset("explodinggradients/fiqa", "ragas_eval")["baseline"]
31 | ds = ds.rename_column("question", "questions")
32 | ds = ds.rename_column("answer", "answers")
33 | ds = ds.rename_column("ground_truths", "gt_answers")
34 |
35 | # crop answers longer than 300 words, since tiny nli model has maximum sequence length of 500
36 | def truncate_answer(example):
37 | max_length = 100
38 | answers = []
39 | gt_answers = []
40 | """
41 | for a in example["answers"]:
42 | answers.append([c[:max_length] if len(c) > max_length else c for c in a])
43 | example["answers"] = answers
44 | """
45 | for ga in example["gt_answers"]:
46 | gt_answers.append([q[:max_length] if len(q) > max_length else q for q in ga])
47 | example["gt_answers"] = gt_answers
48 | return example
49 | ds = ds.map(truncate_answer, batched=True)
50 |
51 | # define model for each metric
52 | cr_model = FakeListLLM(
53 | responses=[
54 | '[\n {\n "statement_1":"恐龙的命名始于1841年,由英国科学家理查德·欧文命名。",\n "reason": "The answer provides '
55 | 'the exact year and the scientist who named the dinosaurs.",\n "Attributed": "1"\n },\n {\n'
56 | ' "statement_2":"欧文在研究几块样子像蜥蜴骨头化石时,认为它们是某种史前动物留下来的,并命名为恐龙。",\n "reason": "The answer '
57 | 'accurately describes the process of how dinosaurs were named.",\n "Attributed": "1"\n }\n]'
58 | ]
59 | )
60 | ag_model = NLIModel(
61 | 'text2text-generation',
62 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration'
63 | )
64 |
65 | # define metrics
66 | metrics = [
67 | rl.metrics.ContextRecall(cr_model),
68 | rl.metrics.AnswerNLICorrectness(nli_model=ag_model, decompose_model="nltk")
69 | ]
70 |
71 | # define task
72 | task = rl.tasks.Generate(metrics=metrics)
73 |
74 | # run evaluate
75 | result = task.evaluate(ds)
76 | assert isinstance(result, dict)
77 | detailed_result = task.obtain_detailed_result()
78 | assert isinstance(detailed_result, Dataset)
79 |
--------------------------------------------------------------------------------
/tests/units/test_answer_accuracy.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerAccuracy
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "A",
12 | "B",
13 | "C"
14 | ],
15 | "gt_answers": [
16 | "A",
17 | "C",
18 | "C"
19 | ]
20 | }
21 | return test_case
22 |
23 |
24 | @pytest.fixture(scope='module')
25 | def testset(sample):
26 | ds = Dataset.from_dict(sample)
27 | return ds
28 |
29 |
30 | @pytest.mark.slow
31 | def test_case_on_answer_accuracy(testset):
32 | metric = AnswerAccuracy()
33 | assert metric.name == "answer_accuracy"
34 | assert metric.mtype == 'AnswerCorrectness'
35 | assert repr(metric) == "answer_accuracy"
36 | score, results = metric.compute(testset["answers"], testset["gt_answers"], 1)
37 | assert score == 2 / 3
38 | assert results[0] is True
39 |
--------------------------------------------------------------------------------
/tests/units/test_answer_bert_score.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerBERTScore
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "It is a guide to action which ensures that the military always obeys the commands of the party.",
12 | "It is to insure the troops forever hearing the activity guidebook that party direct."
13 | ],
14 | "gt_answers": [
15 | [
16 | "It is a guide to action that ensures that the military will forever heed Party commands.",
17 | "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
18 | "It is the practical guide for the army always to heed the directions of the party."
19 | ],
20 | [
21 | "It is a guide to action that ensures that the military will forever heed Party commands.",
22 | "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
23 | "It is the practical guide for the army always to heed the directions of the party."
24 | ]
25 | ]
26 | }
27 | return test_case
28 |
29 |
30 | @pytest.fixture(scope='module')
31 | def testset(sample):
32 | ds = Dataset.from_dict(sample)
33 | return ds
34 |
35 |
36 | @pytest.mark.slow
37 | def test_case_on_answer_bert_score(testset):
38 | metric = AnswerBERTScore(lang='en', rescale_with_baseline=True)
39 | assert metric.name == "answer_bert_score"
40 | assert metric.mtype == 'AnswerCorrectness'
41 | assert repr(metric) == "answer_bert_score"
42 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
43 | assert round(score, 2) == 0.55
44 | assert round(results[0], 1) == 0.7
45 |
--------------------------------------------------------------------------------
/tests/units/test_answer_bleu.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerBleuScore
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "It is a guide to action which ensures that the military always obeys the commands of the party.",
12 | "It is to insure the troops forever hearing the activity guidebook that party direct."
13 | ],
14 | "gt_answers": [
15 | [
16 | "It is a guide to action that ensures that the military will forever heed Party commands.",
17 | "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
18 | "It is the practical guide for the army always to heed the directions of the party."
19 | ],
20 | [
21 | "It is a guide to action that ensures that the military will forever heed Party commands.",
22 | "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
23 | "It is the practical guide for the army always to heed the directions of the party."
24 | ]
25 | ]
26 | }
27 | return test_case
28 |
29 |
30 | @pytest.fixture(scope='module')
31 | def testset(sample):
32 | ds = Dataset.from_dict(sample)
33 | return ds
34 |
35 |
36 | @pytest.mark.slow
37 | def test_case_on_answer_bleu(testset):
38 | metric = AnswerBleuScore()
39 | assert metric.name == "answer_bleu"
40 | assert metric.mtype == 'AnswerCorrectness'
41 | assert repr(metric) == "answer_bleu"
42 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
43 | assert score == 0.3450835085970013
44 | assert results[0] == 0.5401725898595141
45 |
--------------------------------------------------------------------------------
/tests/units/test_answer_chrf.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerCHRFCorrectness
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "The relationship between cats and dogs is not exactly friendly.",
12 | "a good bookshop is just a genteel black hole that knows how to read."
13 | ],
14 | "gt_answers": [
15 | ["The relationship between dogs and cats is not exactly friendly.", ],
16 | ["A good bookshop is just a genteel Black Hole that knows how to read."]
17 | ]
18 | }
19 | return test_case
20 |
21 |
22 | @pytest.fixture(scope='module')
23 | def testset(sample):
24 | ds = Dataset.from_dict(sample)
25 | return ds
26 |
27 |
28 | @pytest.mark.slow
29 | def test_case_on_answer_chrf(testset):
30 | metric = AnswerCHRFCorrectness()
31 | assert metric.name == "answer_chrf"
32 | assert metric.mtype == 'AnswerCorrectness'
33 | assert repr(metric) == "answer_chrf"
34 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
35 | assert score == 84.64214891738334
36 | assert results[0] == 84.41131092011067
37 |
--------------------------------------------------------------------------------
/tests/units/test_answer_citation_precision.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.models import NLIModel
5 | from rageval.metrics import AnswerCitationPrecision
6 |
7 |
8 | @pytest.fixture(scope='module')
9 | def sample():
10 | test_case = {
11 | "answers": [
12 | "Several places on Earth claim to be the most rainy, such as Lloró, Colombia, which reported an average "
13 | "annual rainfall of 12,717 mm between 1952 and 1989, and López de Micay, Colombia, which reported an "
14 | "annual 12,892 mm between 1960 and 2012 [3]. However, the official record is held by Mawsynram, India "
15 | "with an average annual rainfall of 11,872 mm [3], although nearby town Sohra, India, also known as "
16 | "Cherrapunji, holds the record for most rain in a calendar month for July 1861 and most rain in a year "
17 | "from August 1860 to July 1861 [1]."
18 | ],
19 | "contexts": [
20 | [
21 | "Cherrapunji Cherrapunji (; with the native name Sohra being more commonly used, and can also be "
22 | "spelled Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in "
23 | "the Indian state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal "
24 | "chieftainship constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been "
25 | "credited as being the wettest place on Earth, but for now nearby Mawsynram currently holds that "
26 | "distinction. Cherrapunji still holds the all-time record for the most rainfall in a calendar month "
27 | "for July 1861 and most rain in a year from August 1860 to July 1861, however: it received in",
28 | "Radio relay station known as Akashvani Cherrapunji. It broadcasts on FM frequencies. Cherrapunji "
29 | "Cherrapunji (; with the native name Sohra being more commonly used, and can also be spelled "
30 | "Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in the Indian "
31 | "state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal chieftainship "
32 | "constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been credited as "
33 | "being the wettest place on Earth, but for now nearby Mawsynram currently holds that distinction. "
34 | "Cherrapunji still holds the all-time record for the most rainfall",
35 | "Mawsynram Mawsynram () is a village in the East Khasi Hills district of Meghalaya state in "
36 | "north-eastern India, 65 kilometres from Shillong. Mawsynram receives one of the highest rainfalls "
37 | "in India. It is reportedly the wettest place on Earth, with an average annual rainfall of 11,872 mm, "
38 | "but that claim is disputed by Lloró, Colombia, which reported an average yearly rainfall of 12,717 mm "
39 | "between 1952 and 1989 and López de Micay, also in Colombia, which reported an annual 12,892 mm per "
40 | "year between 1960 and 2012. According to the \"Guinness Book of World Records\", Mawsynram received "
41 | "of rainfall in 1985. Mawsynram is located at 25° 18′"
42 | ]
43 | ]
44 | }
45 | return test_case
46 |
47 |
48 | @pytest.fixture(scope='module')
49 | def testset(sample):
50 | ds = Dataset.from_dict(sample)
51 | return ds
52 |
53 |
54 | @pytest.mark.slow
55 | def test_answer_citation_recall(testset):
56 | nli_model = NLIModel(
57 | 'text2text-generation',
58 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration'
59 | )
60 | metric = AnswerCitationPrecision(nli_model=nli_model)
61 | assert metric.name == "answer_citation_precision"
62 | assert metric.mtype == 'AnswerGroundedness'
63 | assert repr(metric) == "answer_citation_precision"
64 | score, results = metric.compute(testset['answers'], testset['contexts'], 1)
65 | assert 0 <= score <= 1
66 |
--------------------------------------------------------------------------------
/tests/units/test_answer_citation_recall.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.models import NLIModel
5 | from rageval.metrics import AnswerCitationRecall
6 |
7 |
8 | @pytest.fixture(scope='module')
9 | def sample():
10 | test_case = {
11 | "answers": [
12 | "Several places on Earth claim to be the most rainy, such as Lloró, Colombia, which reported an average "
13 | "annual rainfall of 12,717 mm between 1952 and 1989, and López de Micay, Colombia, which reported an "
14 | "annual 12,892 mm between 1960 and 2012 [3]. However, the official record is held by Mawsynram, India "
15 | "with an average annual rainfall of 11,872 mm [3], although nearby town Sohra, India, also known as "
16 | "Cherrapunji, holds the record for most rain in a calendar month for July 1861 and most rain in a year "
17 | "from August 1860 to July 1861 [1]."
18 | ],
19 | "contexts": [
20 | [
21 | "Cherrapunji Cherrapunji (; with the native name Sohra being more commonly used, and can also be "
22 | "spelled Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in "
23 | "the Indian state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal "
24 | "chieftainship constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been "
25 | "credited as being the wettest place on Earth, but for now nearby Mawsynram currently holds that "
26 | "distinction. Cherrapunji still holds the all-time record for the most rainfall in a calendar month "
27 | "for July 1861 and most rain in a year from August 1860 to July 1861, however: it received in",
28 | "Radio relay station known as Akashvani Cherrapunji. It broadcasts on FM frequencies. Cherrapunji "
29 | "Cherrapunji (; with the native name Sohra being more commonly used, and can also be spelled "
30 | "Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in the Indian "
31 | "state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal chieftainship "
32 | "constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been credited as "
33 | "being the wettest place on Earth, but for now nearby Mawsynram currently holds that distinction. "
34 | "Cherrapunji still holds the all-time record for the most rainfall",
35 | "Mawsynram Mawsynram () is a village in the East Khasi Hills district of Meghalaya state in "
36 | "north-eastern India, 65 kilometres from Shillong. Mawsynram receives one of the highest rainfalls "
37 | "in India. It is reportedly the wettest place on Earth, with an average annual rainfall of 11,872 mm, "
38 | "but that claim is disputed by Lloró, Colombia, which reported an average yearly rainfall of 12,717 mm "
39 | "between 1952 and 1989 and López de Micay, also in Colombia, which reported an annual 12,892 mm per "
40 | "year between 1960 and 2012. According to the \"Guinness Book of World Records\", Mawsynram received "
41 | "of rainfall in 1985. Mawsynram is located at 25° 18′"
42 | ]
43 | ]
44 | }
45 | return test_case
46 |
47 |
48 | @pytest.fixture(scope='module')
49 | def testset(sample):
50 | ds = Dataset.from_dict(sample)
51 | return ds
52 |
53 |
54 | @pytest.mark.slow
55 | def test_answer_citation_recall(testset):
56 | nli_model = NLIModel(
57 | 'text2text-generation',
58 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration'
59 | )
60 | metric = AnswerCitationRecall(nli_model=nli_model)
61 | assert metric.name == "answer_citation_recall"
62 | assert metric.mtype == 'AnswerGroundedness'
63 | score, results = metric.compute(testset['answers'], testset['contexts'], 1)
64 | assert 0 <= score <= 1
65 |
--------------------------------------------------------------------------------
/tests/units/test_answer_claim_recall.py:
--------------------------------------------------------------------------------
1 | """Test the AnswerClaimRecall Metric."""
2 |
3 | import pytest
4 | from datasets import Dataset
5 |
6 | from rageval.models import NLIModel
7 | from rageval.metrics import AnswerNLICorrectness
8 |
9 |
10 | @pytest.fixture(scope='module')
11 | def sample():
12 | test_case = {
13 | "answers": [
14 | "Yes. Did you watch The Social Network? They went a while before introducing ads, so they could make "
15 | "money, as they needed to establish their brand and amass users. Once you have dedicated users, "
16 | "introducing ads won't deter most, but if you are still new, having ads will deter a lot. The same goes "
17 | "for Uber, it's not that they aren't making money, it's that they are reinvesting a ton of it to make "
18 | "their service better."
19 | ],
20 | "gt_answers": [
21 | [
22 | "Firms like Snapchat and Uber need to establish their brand and amass users before introducing ads.",
23 | "Introducing ads too early can deter potential users.",
24 | "Uber is reinvesting a lot of money to make their service better."
25 | ]
26 | ]
27 | }
28 | return test_case
29 |
30 |
31 | @pytest.fixture(scope='module')
32 | def sample_with_decompose():
33 | test_case = {
34 | "answers": [
35 | "Yes. Did you watch The Social Network? They went a while before introducing ads, so they could make \
36 | money, as they needed to establish their brand and amass users. Once you have dedicated users, \
37 | introducing ads won't deter most, but if you are still new, having ads will deter a lot. The same goes \
38 | for Uber, it's not that they aren't making money, it's that they are reinvesting a ton of it to make \
39 | their service better."
40 | ],
41 | "gt_answers": [
42 | "Firms like Snapchat and Uber need to establish their brand and amass users before introducing ads. \
43 | Introducing ads too early can deter potential users. Uber is reinvesting a lot of money to make their \
44 | service better."
45 | ]
46 | }
47 | return test_case
48 |
49 |
50 | @pytest.fixture(scope='module')
51 | def testset(sample):
52 | ds = Dataset.from_dict(sample)
53 | return ds
54 |
55 |
56 | @pytest.fixture(scope='module')
57 | def testset_with_decompose(sample_with_decompose):
58 | ds = Dataset.from_dict(sample_with_decompose)
59 | return ds
60 |
61 |
62 | @pytest.mark.slow
63 | def test_case_on_answer_claim_recall_metric(testset):
64 | nli_model = NLIModel(
65 | 'text2text-generation',
66 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration'
67 | )
68 | metric = AnswerNLICorrectness(nli_model=nli_model)
69 | assert metric.name == "answer_claim_recall"
70 | assert metric.mtype == 'AnswerCorrectness'
71 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
72 | assert score == 0 or score == 1
73 |
74 |
75 | @pytest.mark.slow
76 | def test_case_on_answer_claim_recall_metric_with_decompose(testset_with_decompose):
77 | nli_model = NLIModel(
78 | 'text2text-generation',
79 | 'hf-internal-testing/tiny-random-T5ForConditionalGeneration'
80 | )
81 | metric = AnswerNLICorrectness(nli_model=nli_model, decompose_model="nltk")
82 | assert metric.name == "answer_claim_recall"
83 | assert metric.mtype == 'AnswerCorrectness'
84 | score, results = metric.compute(testset_with_decompose['answers'], testset_with_decompose['gt_answers'], 1)
85 | assert score == 0 or score == 1
86 |
--------------------------------------------------------------------------------
/tests/units/test_answer_disambig_f1.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerDisambigF1Correctness
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "Ali Daei holds the record for the most international goals in world football according to FIFA. Josef Bican holds the record for the most total goals in world football according to UEFA.",
12 | ],
13 | "gt_answers": [
14 | ["Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world international football.",
15 | "The players with the highest all-time goals and highest men's and women's international football goals differ. The player with the highest all-time men's football goals is Josef Bican, who in 2020 was recognized by FIFA, the international governing body of football, as the record scorer with an estimated 805 goals. Christine Sinclair has the highest goals in women's international football with 187 and is the all-time leader for international goals scored for men or women. Cristiano Ronaldo and Ali Daei are currently tied for leading goalscorer in the history of men's international football with 109."],
16 | ]
17 | }
18 | return test_case
19 |
20 |
21 | @pytest.fixture(scope='module')
22 | def testset(sample):
23 | ds = Dataset.from_dict(sample)
24 | return ds
25 |
26 |
27 | @pytest.mark.slow
28 | def test_case_on_answer_disambig_f1(testset):
29 | metric = AnswerDisambigF1Correctness()
30 | assert metric.name == "answer_disambig_f1"
31 | assert metric.mtype == 'AnswerCorrectness'
32 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
33 | assert 0 <= score <= 1
34 |
--------------------------------------------------------------------------------
/tests/units/test_answer_distinct.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerDistinct
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the \
12 | highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world \
13 | international football.",
14 | "A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, was \
15 | the oldest person ever from the United States and the second-oldest fully documented person ever. Jeanne \
16 | Calment was a French supercentenarian and the oldest human whose age is well-documented, with a lifespan \
17 | of 122 years and 164 days, and was the oldest person in the world as of 1997."
18 | ]
19 | }
20 | return test_case
21 |
22 |
23 | @pytest.fixture(scope='module')
24 | def testset(sample):
25 | ds = Dataset.from_dict(sample)
26 | return ds
27 |
28 |
29 | @pytest.mark.slow
30 | def test_case_on_answer_distinct(testset):
31 | metric = AnswerDistinct(n_grams=1)
32 | assert metric.name == "answer_distinct"
33 | repr(metric) == 'answer_distinct'
34 | assert metric.mtype == 'AnswerInformativeness'
35 | score, results = metric.compute(pred_answers=testset['answers'])
36 | assert 0 <= score <= 1
37 |
38 | metric = AnswerDistinct(n_grams=2)
39 | assert metric.name == "answer_distinct"
40 | repr(metric) == 'answer_distinct'
41 | assert metric.mtype == 'AnswerInformativeness'
42 | score, results = metric.compute(pred_answers=testset['answers'])
43 | assert 0 <= score <= 1
--------------------------------------------------------------------------------
/tests/units/test_answer_edit_distance.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerEditDistance
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "Language models trained on massive code corpora can generalize to tasks without the need "
12 | "for task-specific fine tuning."
13 | ],
14 | "gt_answers": [
15 | "Large language models trained on massive code corpora can generalize to new tasks without the need "
16 | "for task-specific fine-tuning."
17 | ]
18 | }
19 | return test_case
20 |
21 |
22 | @pytest.fixture(scope='module')
23 | def testset(sample):
24 | ds = Dataset.from_dict(sample)
25 | return ds
26 |
27 |
28 | @pytest.mark.slow
29 | def test_case_on_answer_edit_distance(testset):
30 | metric = AnswerEditDistance()
31 | assert metric.name == "answer_edit_distance"
32 | assert metric.mtype == 'AnswerCorrectness'
33 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
34 | assert score == 5 / 18
35 |
--------------------------------------------------------------------------------
/tests/units/test_answer_exect_match.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerEMCorrectness
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the "
12 | "highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world "
13 | "international football.",
14 | "A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, was "
15 | "the oldest person ever from the United States and the second-oldest fully documented person ever. Jeanne "
16 | "Calment was a French supercentenarian and the oldest human whose age is well-documented, with a lifespan "
17 | "of 122 years and 164 days, and was the oldest person in the world as of 1997. In 1985, the oldest living "
18 | "person was Mathew Beard and in 1986 it was Augusta Holtz, who lived 115 years and 79 days, from 1871 to "
19 | "1986."
20 | ],
21 | "gt_answers": [
22 | [
23 | ["Daei", "Ali Daei"],
24 | ["Bican", "Josef Bican"],
25 | ["Sinclair", "Christine Sinclair"]
26 | ],
27 | [
28 | ["Jeanne Calment"],
29 | ["Sarah Knauss"],
30 | ["Augusta-Holtz"],
31 | ]
32 | ]
33 | }
34 | return test_case
35 |
36 |
37 | @pytest.fixture(scope='module')
38 | def testset(sample):
39 | ds = Dataset.from_dict(sample)
40 | return ds
41 |
42 |
43 | @pytest.mark.slow
44 | def test_case_on_answer_exact_match(testset):
45 | metric = AnswerEMCorrectness()
46 | assert metric.name == "answer_exact_match"
47 | assert metric.mtype == 'AnswerCorrectness'
48 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
49 | assert 0 <= score <= 1
50 |
51 | metric = AnswerEMCorrectness(ignore_case=True)
52 | assert metric.name == "answer_exact_match"
53 | assert metric.mtype == 'AnswerCorrectness'
54 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
55 | assert 0 <= score <= 1
56 |
--------------------------------------------------------------------------------
/tests/units/test_answer_f1.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerF1Correctness
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "Ali Dael has the highest goals in men's world international football with 109 goals. Josef Bican has the \
12 | highest goals all-time in men's football and Christine Sinclair has the highest goals in women's world \
13 | international football.",
14 | "A supercentenarian is someone who has reached the age of 110. Sarah Knauss, whose age is undisputed, was \
15 | the oldest person ever from the United States and the second-oldest fully documented person ever. Jeanne \
16 | Calment was a French supercentenarian and the oldest human whose age is well-documented, with a lifespan \
17 | of 122 years and 164 days, and was the oldest person in the world as of 1997."
18 | ],
19 | "gt_answers": [
20 | ["Daei", "Ali Daei"],
21 | ["Jeanne Calment"]
22 | ],
23 | "answers_zh": [
24 | "魏晋",
25 | "北齐只设于清都郡。",
26 | ],
27 | "gt_answers_zh": [
28 | ["魏晋", "魏晋时期"],
29 | ["北齐只设于清都郡。", "清都郡"]
30 | ],
31 | "answers_num":[[1,2,3], [4,5,6]],
32 | "gt_answers_num":[[2,3,4,5,6], [1,2,3,4,5]]
33 | }
34 | return test_case
35 |
36 | @pytest.fixture(scope='module')
37 | def testset(sample):
38 | ds = Dataset.from_dict(sample)
39 | return ds
40 |
41 | @pytest.mark.slow
42 | def test_case_on_answer_f1(testset):
43 | metric = AnswerF1Correctness(normalize=True, language='en')
44 | assert metric.name == "answer_f1"
45 | assert metric.mtype == 'AnswerCorrectness'
46 | score, results = metric.compute(testset['answers'], testset['gt_answers'])
47 | assert 0 <= score <= 1
48 |
49 | metric = AnswerF1Correctness(normalize=True, language='zh')
50 | assert metric.name == "answer_f1"
51 | assert metric.mtype == 'AnswerCorrectness'
52 | score_zh, results_zh = metric.compute(testset['answers_zh'], testset['gt_answers_zh'])
53 | assert 0 <= score_zh <= 1
54 |
55 | metric = AnswerF1Correctness(normalize=False)
56 | assert metric.name == "answer_f1"
57 | assert metric.mtype == 'AnswerCorrectness'
58 | score, results = metric.compute(testset['answers_num'], testset['gt_answers_num'])
59 | assert 0 <= score <= 1
60 |
--------------------------------------------------------------------------------
/tests/units/test_answer_lcs_ratio.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerLCSRatio
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "Language models trained on massive code corpora can generalize to tasks without the need "
12 | "for task-specific fine-tuning."
13 | ],
14 | "gt_answers": [
15 | "Large language models trained on massive code corpora can generalize to new tasks without the need "
16 | "for task-specific fine-tuning."
17 | ]
18 | }
19 | return test_case
20 |
21 |
22 | @pytest.fixture(scope='module')
23 | def testset(sample):
24 | ds = Dataset.from_dict(sample)
25 | return ds
26 |
27 |
28 | @pytest.mark.slow
29 | def test_case_on_answer_lcs_ratio(testset):
30 | metric = AnswerLCSRatio()
31 | assert metric.name == "answer_lcs_ratio"
32 | assert metric.mtype == 'AnswerCorrectness'
33 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
34 | assert score == 16 / 17
35 |
--------------------------------------------------------------------------------
/tests/units/test_answer_rouge.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 | from typing import List
4 | from rageval.metrics import AnswerRougeCorrectness
5 |
6 |
7 | class CharTokenizer:
8 | """Tokenize text into characters."""
9 | def tokenize(self, text: str) -> List[str]:
10 | # Tokenize by characters to avoid a dependency on word segmentation methods.
11 | return [c for c in text]
12 |
13 |
14 | @pytest.fixture(scope='module')
15 | def sample():
16 | test_case = {
17 | "answers": [
18 | "###刚刚发声,A股这种情况十分罕见!大聪明逆市抄底330亿,一篇研报引爆全球,市场逻辑生变?",
19 | "The quick brown fox jumps over the lazy dog."
20 | ],
21 | "gt_answers": [
22 | [
23 | "刚刚过去的这个月,美股总市值暴跌了将近6万亿美元(折合人民币超过40万亿),这背后的原因可能不仅仅是加息这么简单。最近瑞士信贷知名分析师Zoltan Polzsar撰写了一篇极其重要的文章,详细分析了现有世界秩序的崩坏本质以及美国和西方将要采取的应对策略。在该文中,Zoltan Polzsar直指美国通胀的本质和其长期性。同期,A股市场亦出现了大幅杀跌的情况。"
24 | ],
25 | [
26 | "The quick brown fox jumps over the lazy dog.",
27 | "The brown fox jumps over the lazy dog."
28 | ]
29 | ]
30 | }
31 | return test_case
32 |
33 |
34 | @pytest.fixture(scope='module')
35 | def testset(sample):
36 | ds = Dataset.from_dict(sample)
37 | return ds
38 |
39 |
40 | def test_case_on_answer_exact_match(testset):
41 |
42 | # Test with Chinese tokenizer
43 | chinese_tokenizer = CharTokenizer()
44 | metric = AnswerRougeCorrectness('rouge1', chinese_tokenizer)
45 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
46 | assert metric.mtype == 'AnswerCorrectness'
47 | assert 0 <= score <= 1
48 |
49 | # Test with English tokenizer
50 | metric = AnswerRougeCorrectness('rouge1')
51 | score, results = metric.compute(testset['answers'], testset['gt_answers'], 1)
52 | assert metric.mtype == 'AnswerCorrectness'
53 | assert 0 <= score <= 1
54 |
--------------------------------------------------------------------------------
/tests/units/test_answer_ter.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import AnswerTERCorrectness
5 |
6 |
7 | @pytest.fixture(scope='module')
8 | def sample():
9 | test_case = {
10 | "answers": [
11 | "does this sentence match??",
12 | "what about this sentence?",
13 | "What did the TER metric user say to the developer?"
14 | ],
15 | "gt_answers": [
16 | ["does this sentence match", "does this sentence match!?!"],
17 | ["wHaT aBoUt ThIs SeNtEnCe?", "wHaT aBoUt ThIs SeNtEnCe?"],
18 | ["Your jokes are...", "...TERrible"]
19 | ]
20 | }
21 | return test_case
22 |
23 |
24 | @pytest.fixture(scope='module')
25 | def testset(sample):
26 | ds = Dataset.from_dict(sample)
27 | return ds
28 |
29 |
30 | @pytest.mark.slow
31 | def test_case_on_answer_ter(testset):
32 | metric = AnswerTERCorrectness()
33 | assert metric.name == "answer_ter"
34 | assert metric.mtype == 'AnswerCorrectness'
35 | score, results = metric.compute(testset['answers'], testset['gt_answers'])
36 | assert score == 110.00000000000001
37 | assert results[0] == 25.0
38 |
39 |
--------------------------------------------------------------------------------
/tests/units/test_context_recall.py:
--------------------------------------------------------------------------------
1 | """Test the ContextReCall Metric."""
2 |
3 | # -*- coding: utf-8 -*-
4 |
5 | import pytest
6 | from datasets import Dataset
7 | from langchain.llms.fake import FakeListLLM
8 |
9 | from rageval.models.openai import OpenAILLM
10 | from rageval.metrics import ContextRecall
11 |
12 |
13 | @pytest.fixture(scope='module')
14 | def sample():
15 | test_case = {
16 | "questions": ["恐龙是怎么被命名的?"],
17 | "gt_answers": [
18 | [
19 | "1841年,英国科学家理查德·欧文在研究几块样子像蜥蜴骨头化石时,认为它们是某种史前动物留下来的,并命名为恐龙,意思是“恐怖的蜥蜴”。"
20 | ]
21 | ],
22 | "contexts": [
23 | [
24 | "[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。 "
25 | "[12]“我们的结果显示恐龙所具有的生长速率和新陈代谢速率,既不是冷血生物体也不是温血生物体所具有的特征。它们既不像哺乳动物或者鸟类,也不像爬行动物或者鱼类,而是介于现代冷血动物和温血动物之间。"
26 | "简言之,它们的生理机能在现代社会并不常见。”美国亚利桑那大学进化生物学家和生态学家布莱恩·恩奎斯特说。墨西哥生物学家表示,正是这种中等程度的新陈代谢使得恐龙可以长得比任何哺乳动物都要大。"
27 | "温血动物需要大量进食,因此它们频繁猎捕和咀嚼植物。“很难想象霸王龙大小的狮子能够吃饱以 存活下来。",
28 | "[12]哺乳动物起源于爬行动物,它们的前身是“似哺乳类的爬行动物”,即兽孔目,早期则是“似爬行类的哺乳动物”,即哺乳型动物。 [12]中生代的爬行动物,大部分在中生代的末期灭绝了;一部分适应了变化的环境"
29 | "被保留下来,即现存的爬行动物(如龟鳖类、蛇类、鳄类等);还有一部分沿着不同的进化方向,进化成了现今的鸟类和哺乳类。 [12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动"
30 | "物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。"
31 | ]
32 | ]
33 | }
34 | return test_case
35 |
36 |
37 | @pytest.fixture(scope='module')
38 | def testset(sample):
39 | ds = Dataset.from_dict(sample)
40 | return ds
41 |
42 |
43 | @pytest.mark.skip
44 | def test_batch_on_context_recall_metric(testset):
45 | model = OpenAILLM('gpt-3.5-turbo-16k', 'OPENAI_API_KEY')
46 | metric = ContextRecall(model)
47 | score, results = metric.compute(testset['questions'], testset['gt_answers'], testset['contexts'], 1)
48 | assert score == 0 or score == 1
49 |
50 |
51 | @pytest.mark.slow
52 | def test_batch_on_context_recall_metric_fakellm1(testset):
53 | model = FakeListLLM(
54 | responses=[
55 | '[\n {\n "statement_1":"恐龙的命名始于1841年,由英国科学家理查德·欧文命名。",\n "reason": "The answer provides '
56 | 'the exact year and the scientist who named the dinosaurs.",\n "Attributed": "1"\n },\n {\n'
57 | ' "statement_2":"欧文在研究几块样子像蜥蜴骨头化石时,认为它们是某种史前动物留下来的,并命名为恐龙。",\n "reason": "The answer '
58 | 'accurately describes the process of how dinosaurs were named.",\n "Attributed": "1"\n }\n]'
59 | ]
60 | )
61 | metric = ContextRecall(model)
62 | score, results = metric.compute(testset['questions'], testset['gt_answers'], testset['contexts'], 1)
63 | assert metric.mtype == 'ContextRelevancy'
64 | assert 0 <= score <= 1
65 |
66 |
67 | @pytest.mark.slow
68 | def test_batch_on_context_recall_metric_fakellm2(testset):
69 | model = FakeListLLM(responses=['wrong response format'])
70 | metric = ContextRecall(model)
71 | score, results = metric.compute(testset['questions'], testset['gt_answers'], testset['contexts'], 1)
72 | assert metric.mtype == 'ContextRelevancy'
73 | assert 0 <= score <= 1
74 |
--------------------------------------------------------------------------------
/tests/units/test_context_reject_rate.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 | from langchain.llms.fake import FakeListLLM
4 |
5 | from rageval.metrics import ContextRejectRate
6 |
7 |
8 | @pytest.fixture(scope='module')
9 | def sample():
10 | test_case = {
11 | "questions": [
12 | "Why did Bushnell set himself on fire?",
13 | "Did Bushnell have a wife?"
14 | ],
15 | "contexts": [
16 | [
17 | [
18 | "An active-duty member of the U.S. Air Force has died after he set himself ablaze outside the "
19 | "Israeli Embassy in Washington, D.C., while declaring that he “will no longer be complicit in "
20 | "genocide.”"
21 | ],
22 | [
23 | "The 25-year-old airman, Aaron Bushnell, of San Antonio, Texas, died from his injuries, the "
24 | "Metropolitan Police Department said Monday."
25 | ],
26 | [
27 | "Bushnell had walked up to the embassy shortly before 1 p.m. Sunday and began livestreaming on "
28 | "the video streaming platform Twitch, a person familiar with the matter told The Associated "
29 | "Press. Law enforcement officials believe he set his phone down and then doused himself in "
30 | "accelerant and ignited the flames. At one point, he said he “will no longer be complicit in "
31 | "genocide,” the person said. The video was later removed from the platform, but law enforcement "
32 | "officials have obtained and reviewed a copy."
33 | ]
34 | ],
35 | [
36 | [
37 | "An active-duty member of the U.S. Air Force has died after he set himself ablaze outside the "
38 | "Israeli Embassy in Washington, D.C., while declaring that he “will no longer be complicit in "
39 | "genocide.”"
40 | ],
41 | [
42 | "The 25-year-old airman, Aaron Bushnell, of San Antonio, Texas, died from his injuries, the "
43 | "Metropolitan Police Department said Monday."
44 | ],
45 | [
46 | "Bushnell had walked up to the embassy shortly before 1 p.m. Sunday and began livestreaming on "
47 | "the video streaming platform Twitch, a person familiar with the matter told The Associated "
48 | "Press. Law enforcement officials believe he set his phone down and then doused himself in "
49 | "accelerant and ignited the flames. At one point, he said he “will no longer be complicit in "
50 | "genocide,” the person said. The video was later removed from the platform, but law enforcement "
51 | "officials have obtained and reviewed a copy."
52 | ]
53 | ]
54 | ]
55 | }
56 | return test_case
57 |
58 |
59 | @pytest.fixture(scope='module')
60 | def testset(sample):
61 | ds = Dataset.from_dict(sample)
62 | return ds
63 |
64 |
65 | @pytest.mark.slow
66 | def test_case_on_context_reject_rate(testset):
67 | model = FakeListLLM(
68 | responses=[
69 | "Answer: wrong response format",
70 | "Answer: sorry, cannot answer the question"
71 | ]
72 | )
73 | metric = ContextRejectRate(model)
74 | assert metric.name == "context_reject_rate"
75 | assert metric.mtype == 'AnswerGroundedness'
76 | score, results = metric.compute(testset['questions'], testset['contexts'], 1)
77 | assert score == 0.5
78 |
--------------------------------------------------------------------------------
/tests/units/test_nli.py:
--------------------------------------------------------------------------------
1 | # import sys
2 | # sys.path.insert(0, '../src')
3 |
4 | import pytest
5 |
6 | from rageval.models import NLIModel
7 |
8 |
9 | @pytest.fixture(scope='module')
10 | def test_case():
11 | sample = {
12 | "claim": "In 1980, the oldest justice on the United States Supreme Court was Justice William O. Douglas.",
13 | "evidence": "August 3, 1994 \u2013 June 30, 2022 (27 years, 10 months, 27 days) photo source: Wikimedia "
14 | "Commons After the passing of Ruth Bader Ginsberg in 2020, Stephen Breyer was the oldest "
15 | "sitting member of the Supreme Court until his retirement in 2022. Stepping down at the age "
16 | "of 83, Breyer is now one of the oldest Supreme Court justices ever. Breyer was nominated by "
17 | "Bill Clinton and served on the Court for more than 27 years. During his tenure, Breyer fell "
18 | "in line with the liberal wing of the court. Before he was appointed to the Supreme Court, "
19 | "Breyer served as a judge on the U.S. Court of Appeals for the First Circuit; he was the "
20 | "Chief Judge for the last four years of his appointment.",
21 | "stance": "irrelevant"
22 | }
23 | return sample
24 |
25 |
26 | @pytest.mark.slow
27 | def test_nli(test_case):
28 |
29 | # model = NLIModel('sentiment-analysis', 'roberta-large-mnli')
30 | model = NLIModel(
31 | 'text-classification',
32 | 'hf-internal-testing/tiny-random-RobertaPreLayerNormForSequenceClassification'
33 | )
34 |
35 | # test request
36 | result = model.infer_prob(test_case['evidence'], test_case['claim'])
37 | # print(result)
38 | assert result[0]['label'] in ['LABEL_0', 'LABEL_1']
39 | assert 'score' in result[0]
40 |
41 | # case = test_case()
42 | # test_nli(case)
43 |
--------------------------------------------------------------------------------
/tests/units/test_openai_api.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import pytest
5 | from unittest.mock import patch, MagicMock
6 | import openai
7 | import httpx
8 |
9 | from rageval.models import OpenAILLM
10 |
11 | from langchain.schema import Generation, LLMResult
12 |
13 |
14 | @pytest.fixture(scope='module')
15 | def test_case():
16 | questions = ["截止2001年岛上的人口有多少?", "墨西哥土拨鼠栖息的地区海拔约是多少米?"]
17 | ground_truths = [
18 | [
19 | "总人口9495(2001年)。",
20 | "墨西哥土拨鼠栖息在海拔1600-2200米的平原上。"
21 | ]
22 | ]
23 | contexts = [
24 | [
25 | "米尼科伊岛(Minicoy)位于印度拉克沙群岛中央直辖区最南端,是Lakshadweep县的一个城镇。它与拉克沙群岛隔九度海峡相望,与马尔代夫伊哈万迪富卢环礁隔八度海峡相望。总人口9495(2001年)。米尼科伊岛位于盐水胡东南部的一个有人居住的岛屿,全岛几乎被椰子树覆盖,唯一的地标是一座灯塔。Viringili位于米尼科伊岛西南侧,是一个长度不到两百米的小岛,曾被用作麻风病患者的驱逐地。该地2001年总人口9495人,其中男性4616人,女性4879人;0—6岁人口1129人,其中男571人,女558人;识字率81.95%,其中男性为83.51%,女性为80.47%。",
26 | "墨西哥土拨鼠(\"Cynomys mexicanus\"),又名墨西哥草原松鼠或墨西哥草原犬鼠,是原住于墨西哥的一种啮齿目。牠们是日间活动的。由于牠们被看为害虫,估其数量下降至濒危水平。墨西哥土拨鼠栖息在海拔1600-2200米的平原上。牠们分布在墨西哥的圣路易斯波托西州北部及科阿韦拉州。牠们主要吃草,并透过食物来吸收水份。牠们有时也会吃昆虫。牠们的天敌有郊狼、短尾猫、鹰、獾及鼬。墨西哥土拨鼠是会冬眠的,其繁殖季节也较短,一般只在1月至4月。妊娠期为1个月,雌鼠每年只会产一胎,一胎平均有四子。幼鼠出生时眼睛闭合,会先以尾巴作为辅助,直至出生后40日才能看见。于5月至6月会断奶,到了1岁就会离开巢穴。冬天前幼鼠就会离开母鼠。幼鼠之间会互咬、嘶叫及扭住来玩耍。牠们1岁后就达至性成熟,寿命约3-5年。成年重约1公斤及长14-17吋,雄性较雌性大只。牠们呈黄色,耳朵较深色,腹部较浅色。墨西哥土拨鼠的语言最为复杂,能奔跑达每小时55公里。所以当受到威胁时,牠们会大叫作为警报,并且高速逃走。墨西哥土拨鼠的巢穴是挖掘出来的。巢穴的入口像漏斗,通道长达100呎,两侧有空间储存食物及休息。巢穴可以多达几百只墨西哥土拨鼠,但一般少于50只,群族有一只雄性的领袖。牠们有时会与斑点黄鼠及穴鸮分享他们的洞穴。于1956年,墨西哥土拨鼠曾在科阿韦拉州、新莱昂州及圣路易斯波托西州出没。到了1980年代,牠们从新莱昂州消失,其分布地少于800平方米。由于牠们被认为是害虫,故经常被毒杀,到了1994年到达濒危的状况。"
27 | ]
28 | ]
29 | return {'questions': questions,
30 | 'ground_truths': ground_truths,
31 | 'contexts': contexts}
32 |
33 | @pytest.fixture
34 | def openai_llm():
35 | return OpenAILLM(model="gpt-3.5-turbo", _api_key_env_var="OPENAI_API_KEY")
36 |
37 | def test_init(openai_llm):
38 | assert openai_llm.model == "gpt-3.5-turbo"
39 | assert openai_llm.base_url == "https://api.openai.com/v1"
40 | assert openai_llm.num_retries == 3
41 | assert openai_llm.timeout == 60
42 | assert openai_llm.api_key == os.getenv("OPENAI_API_KEY", "NO_KEY")
43 |
44 | def test_build_request(openai_llm):
45 | request = openai_llm.build_request()
46 | assert request["model"] == "gpt-3.5-turbo"
47 | assert request["max_tokens"] is None
48 | assert request["n"] is None
49 | assert request["temperature"] is None
50 | assert request["top_p"] is None
51 | assert request["logprobs"] is None
52 |
53 | def test_is_chat_model_engine(openai_llm):
54 | assert openai_llm._is_chat_model_engine is True
55 |
56 | @patch("rageval.models.openai.openai.OpenAI")
57 | def test_llm(mock_openai, openai_llm):
58 | llm = openai_llm.llm
59 | mock_openai.assert_called_once_with(
60 | api_key=openai_llm.api_key,
61 | base_url=openai_llm.base_url,
62 | max_retries=openai_llm.num_retries,
63 | timeout=openai_llm.timeout
64 | )
65 |
66 | @patch("rageval.models.openai.openai.OpenAI")
67 | def test_get_chat_model_response(mock_openai, openai_llm):
68 | mock_response = MagicMock()
69 | mock_openai().chat.completions.create.return_value = mock_response
70 | prompt = [{"role": "user", "content": "Hello"}]
71 | response = openai_llm._get_chat_model_response(prompt)
72 | assert response == mock_response
73 | mock_openai().chat.completions.create.assert_called_once()
74 |
75 | @patch("rageval.models.openai.openai.OpenAI")
76 | def test_get_instruct_model_response(mock_openai, openai_llm):
77 | mock_response = MagicMock()
78 | mock_openai().completions.create.return_value = mock_response
79 | prompt = "Hello"
80 | response = openai_llm._get_instruct_model_response(prompt)
81 | assert response == mock_response
82 | mock_openai().completions.create.assert_called_once()
83 |
84 | @patch("rageval.models.openai.OpenAILLM._get_chat_model_response")
85 | @patch("rageval.models.openai.OpenAILLM._get_instruct_model_response")
86 | def test_generate(mock_instruct_response, mock_chat_response, openai_llm):
87 | mock_chat_response.return_value = {"choices": [{"message": {"content": "Hi"}}]}
88 | mock_instruct_response.return_value = {"choices": [{"text": "Hi"}]}
89 |
90 | prompt_chat = [{"role": "user", "content": "Hello"}]
91 | result_chat = openai_llm.generate(prompt_chat)
92 | assert isinstance(result_chat, LLMResult)
93 | assert result_chat.generations[0][0].text == "Hi"
94 |
95 | prompt_instruct = "Hello"
96 | openai_llm.model = "gpt-3.5-turbo-instruct"
97 | result_instruct = openai_llm.generate(prompt_instruct)
98 | assert isinstance(result_instruct, LLMResult)
99 | assert result_instruct.generations[0][0].text == "Hi"
100 |
101 | def test_create_llm_result(openai_llm):
102 | response = {
103 | "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
104 | "choices": [{"message": {"content": "Hi"}, "finish_reason": "stop", "logprobs": None}]
105 | }
106 | result = openai_llm.create_llm_result(response)
107 | assert isinstance(result, LLMResult)
108 | assert result.llm_output["token_usage"]["prompt_tokens"] == 10
109 | assert result.llm_output["token_usage"]["completion_tokens"] == 20
110 | assert result.llm_output["token_usage"]["total_tokens"] == 30
111 | assert result.generations[0][0].text == "Hi"
112 |
113 | @patch.object(OpenAILLM, 'generate')
114 | def test_batch_generate(mock_generate, openai_llm):
115 | # Mock the generate method to return a simple LLMResult
116 | mock_generate.return_value = LLMResult(generations=[[Generation(text="Hi")]])
117 |
118 | # Define prompts for testing
119 | prompts = [
120 | [{"role": "user", "content": "Hello"}],
121 | [{"role": "user", "content": "How are you?"}],
122 | ]
123 |
124 | # Call batch_generate
125 | results = openai_llm.batch_generate(prompts, max_workers=2)
126 |
127 | # Verify generate was called with each prompt
128 | assert mock_generate.call_count == len(prompts)
129 | mock_generate.assert_any_call(prompts[0])
130 | mock_generate.assert_any_call(prompts[1])
131 |
132 | # Check results
133 | assert len(results) == len(prompts)
134 | for result in results:
135 | assert isinstance(result, LLMResult)
136 | assert result.generations[0][0].text == "Hi"
137 |
138 | @patch.object(OpenAILLM, 'generate')
139 | def test_batch_generate_order(mock_generate, openai_llm):
140 | # Mock the generate method to return different results based on input
141 | def side_effect(prompt):
142 | if prompt == [{"role": "user", "content": "Hello"}]:
143 | return LLMResult(generations=[[Generation(text="Hi")]])
144 | elif prompt == [{"role": "user", "content": "How are you?"}]:
145 | return LLMResult(generations=[[Generation(text="I am fine")]])
146 |
147 | mock_generate.side_effect = side_effect
148 |
149 | # Define prompts for testing
150 | prompts = [
151 | [{"role": "user", "content": "Hello"}],
152 | [{"role": "user", "content": "How are you?"}],
153 | ]
154 |
155 | # Call batch_generate
156 | results = openai_llm.batch_generate(prompts, max_workers=2)
157 |
158 | # Check results are in the correct order
159 | assert results[0].generations[0][0].text == "Hi"
160 | assert results[1].generations[0][0].text == "I am fine"
161 |
--------------------------------------------------------------------------------
/tests/units/test_text_length.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from rageval.metrics import TextLength
5 | import rageval as rl
6 |
7 |
8 | @pytest.fixture(scope='module')
9 | def sample():
10 | test_case = {
11 | #"questions": [
12 | # "习近平主席在何时何地会见了埃塞俄比亚总理海尔马里亚姆?",
13 | # "埃塞俄比亚希望与中国在哪些领域加强合作?"
14 | #],
15 | "answers": [
16 | "习近平主席在2017年5月12日于北京人民大会堂会见了埃塞俄比亚总理海尔马里亚姆。",
17 | "埃塞俄比亚希望与中国在以下领域加强合作:\n1. **共建“一带一路”框架下合作**:埃塞俄比亚表示希望能够积极参与“一带一路”倡议,深化与中国在基础设施建设、产能合作、互联互通等领域的合作。\n2. **提高工业化水平和出口创汇能力**:埃塞俄比亚期待中国在推动其工业化进程以及提升出口创汇能力方面提供帮助和合作。\n3. **安全、有序、有效推进经贸合作**:希望与中方在贸易和投资合作领域取得进展,实现稳定、有序和高效的合作。"
18 | ]
19 | }
20 | return test_case
21 |
22 |
23 | @pytest.fixture(scope='module')
24 | def testset(sample):
25 | ds = Dataset.from_dict(sample)
26 | return ds
27 |
28 |
29 | @pytest.mark.slow
30 | def test_case_on_text_length(testset):
31 | metric = TextLength(tokenize_model="Qwen/Qwen2-0.5B-Instruct")
32 | assert metric.name == "text_length"
33 | score, results = metric.compute(testset["answers"])
34 | print(score, results)
35 | assert score == 75.0
36 |
--------------------------------------------------------------------------------
/tutorials/README.md:
--------------------------------------------------------------------------------
1 | # Tutorials
2 |
3 | Welcome to the tutorials directory for the RAGEval project! This directory contains step-by-step guides and resources to help you get started with RAG evaluation methods, best practices, and practical implementations.
4 |
5 |
6 | ## Table of Contents
7 |
8 | 1. [Overview](#overview)
9 | 2. [Getting Started](#getting-started)
10 | 3. [Tutorials](#tutorials)
11 | - [Basic RAG Evaluation](#basic-rag-evaluation)
12 | - [Advanced Techniques](#advanced-techniques)
13 | - [Custom Dataset Evaluation](#custom-dataset-evaluation)
14 | 4. [Contributing](#contributing)
15 | 5. [License](#license)
16 |
17 | ## Overview
18 |
19 | The RAG evaluation project aims to provide a comprehensive framework for evaluating the performance of retrieval-augmented models. This directory is dedicated to tutorials that will guide you through various aspects of RAG evaluation, helping you understand both the theory and practical applications.
20 |
21 | ## Getting Started
22 |
23 | To begin, ensure you have the necessary dependencies installed. Follow the [installation instructions](../INSTALL.md) to set up your environment.
24 |
25 | ## Tutorials
26 |
27 | ### Basic RAG Evaluation
28 |
29 | - **Description**: Learn how to perform a simple evaluation of RAG models using standard metrics.
30 | - **Link**: [Basic RAG Evaluation Tutorial](basic_rag_evaluation.md)
31 |
32 | ### Advanced Techniques
33 |
34 | - **Description**: Explore advanced evaluation techniques that enhance your understanding of model performance.
35 | - **Link**: [Advanced Techniques Tutorial](advanced_techniques.md)
36 |
37 | ### Custom Dataset Evaluation
38 |
39 | - **Description**: A guide on how to evaluate RAG models on your custom datasets, including preprocessing steps and metric selection.
40 | - **Link**: [Custom Dataset Evaluation Tutorial](custom_dataset_evaluation.md)
41 |
42 | ## Contributing
43 |
44 | We welcome contributions! If you have ideas for additional tutorials or improvements to existing ones, please submit a pull request or open an issue. For more information, check our [Contributing Guidelines](../CONTRIBUTING.md).
45 |
46 | ## License
47 |
48 | This project is licensed under the Apache License. See the [LICENSE](../LICENSE) file for details.
49 |
50 | ---
51 |
52 | If you have any questions or need further assistance, feel free to open an issue in the main repository. Happy evaluating!
--------------------------------------------------------------------------------
/tutorials/tutorial 1/df_result_excel.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gomate-community/rageval/01e2580efe714b3130602e02a7e7b017d16d79a2/tutorials/tutorial 1/df_result_excel.xlsx
--------------------------------------------------------------------------------
/tutorials/tutorial 1/requirements.txt:
--------------------------------------------------------------------------------
1 | openpyxl
2 | pandas
3 | transformers
4 | matplotlib
5 | datasets
6 | ragchecker
7 | vllm
--------------------------------------------------------------------------------