├── .github └── workflows │ └── release.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── benchmark ├── README.md ├── data │ ├── accurate_context │ │ └── dolly_example_ids.json │ ├── download_data.sh │ ├── gather_benchmark_data.py │ ├── noisy_context │ │ └── msmarco_example_ids.json │ └── zero_context │ │ └── nq_example_ids.json ├── evaluation │ ├── autocheck.py │ ├── corr.py │ └── evaluate.py ├── human_annotations_v1 │ ├── LICENSE.txt │ ├── accurate_context │ │ ├── dolly_alpaca_7B_answers.json │ │ ├── dolly_chatgpt_answers.json │ │ ├── dolly_claude2_answers.json │ │ ├── dolly_davinci001_answers.json │ │ ├── dolly_falcon_40B_instruct_answers.json │ │ ├── dolly_gpt4_answers.json │ │ └── dolly_llama2_70b_chat_answers.json │ ├── noisy_context │ │ ├── msmarco_alpaca_7B_answers.json │ │ ├── msmarco_chatgpt_answers.json │ │ ├── msmarco_claude2_answers.json │ │ ├── msmarco_davinci001_answers.json │ │ ├── msmarco_falcon_40B_instruct_answers.json │ │ ├── msmarco_gpt4_answers.json │ │ └── msmarco_llama2_70b_chat_answers.json │ └── zero_context │ │ ├── nq_alpaca_7B_answers.json │ │ ├── nq_chatgpt_answers.json │ │ ├── nq_claude2_answers.json │ │ ├── nq_davinci001_answers.json │ │ ├── nq_falcon_40B_instruct_answers.json │ │ ├── nq_gpt4_answers.json │ │ └── nq_llama2_70b_chat_answers.json └── response_collection │ ├── README.md │ ├── __init__.py │ ├── chatglm3_6b.py │ ├── collector_base.py │ ├── gpt4_turbo.py │ ├── main.py │ └── mistral.py ├── demo ├── README.md ├── __init__.py ├── main.py └── miscellaneous.py ├── example ├── check.sh ├── example_in.json ├── extract-check.sh ├── extract-check_wo_ref.sh └── extract.sh ├── imgs ├── demo.gif ├── evaluation.png ├── framework.png ├── localization_example_1.jpg ├── localization_example_2.jpg ├── localization_example_3.jpg ├── localization_example_4.jpg ├── localization_method.jpg ├── settings.png └── venn.png ├── notebooks └── refchecker_usages.ipynb ├── pyproject.toml └── refchecker ├── __init__.py ├── aggregator.py ├── base.py ├── checker ├── README.md ├── __init__.py ├── alignscore │ ├── __init__.py │ ├── alignscore.py │ ├── alignscore_checker.py │ ├── dataloader.py │ ├── inference.py │ └── model.py ├── checker_base.py ├── checker_prompts.py ├── llm_checker.py ├── nli_checker.py └── repc │ ├── general.py │ ├── ml_models.py │ └── repc_checker.py ├── cli.py ├── extractor ├── README.md ├── __init__.py ├── extractor_base.py ├── extractor_prompts.py └── llm_extractor.py ├── localizer ├── __init__.py └── embed_localizer.py ├── retriever ├── __init__.py └── google_retriever.py └── utils.py /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | on: 3 | release: 4 | types: [published] 5 | 6 | jobs: 7 | pypi_release: 8 | name: Builds Using Poetry and Publishes to PyPI 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: '3.x' 15 | - name: Install poetry 16 | uses: abatilo/actions-poetry@v2 17 | - run: poetry config pypi-token.pypi "${{ secrets.PYPI_API_TOKEN }}" 18 | - name: Publish package 19 | run: poetry publish --build 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | __pycache__/ 3 | *.ckpt 4 | *.pkl 5 | *.sh 6 | 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | ## Benchmark 2 | 3 | This folder contains the script for downloading the benchmark dataset. 4 | 5 | ### Download Benchmark Data 6 | 7 | Please take the following steps to download the benchmark dataset. 8 | 9 | 1. Install `gsutil` with instructions here: https://cloud.google.com/storage/docs/gsutil_install 10 | 11 | 2. Run the following script 12 | ```bash 13 | cd benchmark/data 14 | sh download_data.sh 15 | ``` 16 | 17 | After downloading and processing, the benchmark data will be saved to the three folders under [benchmark/data](../benchmark/data) 18 | 19 | ### Collect Responses for Your Model 20 | 21 | Please refer to the instructions in [response_collection](response_collection/README.md). 22 | 23 | ### Run the Checking Pipeline and Evaluation Script 24 | 25 | Run the checking pipeline: 26 | ```bash 27 | python evaluation/autocheck.py --model= --extractor= --checker= 28 | ``` 29 | 30 | Evaluate the results: 31 | ```bash 32 | python evaluation/evaluate.py python evaluation/evaluate.py --model= --extractor= --checker= --output_file= 33 | ``` -------------------------------------------------------------------------------- /benchmark/data/accurate_context/dolly_example_ids.json: -------------------------------------------------------------------------------- 1 | [ 2 | "12646", 3 | "1097", 4 | "2394", 5 | "13774", 6 | "2000", 7 | "5780", 8 | "12318", 9 | "10432", 10 | "1478", 11 | "11587", 12 | "14649", 13 | "13169", 14 | "6922", 15 | "2178", 16 | "4763", 17 | "9188", 18 | "2673", 19 | "12470", 20 | "13708", 21 | "10945", 22 | "14049", 23 | "2065", 24 | "13836", 25 | "13455", 26 | "12437", 27 | "5893", 28 | "2951", 29 | "6876", 30 | "12377", 31 | "1258", 32 | "10998", 33 | "4070", 34 | "4859", 35 | "12263", 36 | "12686", 37 | "12427", 38 | "11049", 39 | "3576", 40 | "8480", 41 | "11647", 42 | "10240", 43 | "10894", 44 | "11129", 45 | "14669", 46 | "7522", 47 | "5716", 48 | "4540", 49 | "4060", 50 | "11403", 51 | "9953", 52 | "8470", 53 | "272", 54 | "1831", 55 | "5521", 56 | "12253", 57 | "286", 58 | "6205", 59 | "5100", 60 | "13821", 61 | "4682", 62 | "10001", 63 | "10662", 64 | "11722", 65 | "2623", 66 | "5356", 67 | "2600", 68 | "9397", 69 | "7704", 70 | "3133", 71 | "7924", 72 | "3999", 73 | "14230", 74 | "221", 75 | "11836", 76 | "4603", 77 | "2954", 78 | "4065", 79 | "2382", 80 | "2052", 81 | "13618", 82 | "9779", 83 | "12359", 84 | "9079", 85 | "13894", 86 | "9326", 87 | "1812", 88 | "5175", 89 | "9608", 90 | "10547", 91 | "4980", 92 | "3601", 93 | "12004", 94 | "12310", 95 | "7390", 96 | "9525", 97 | "65", 98 | "2971", 99 | "6734", 100 | "13918", 101 | "4750" 102 | ] -------------------------------------------------------------------------------- /benchmark/data/download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir zero_context/nq 2 | gsutil -m cp -R gs://natural_questions/v1.0/dev zero_context/nq 3 | gzip -d zero_context/nq/dev/nq-dev-00.jsonl.gz 4 | gzip -d zero_context/nq/dev/nq-dev-01.jsonl.gz 5 | gzip -d zero_context/nq/dev/nq-dev-02.jsonl.gz 6 | gzip -d zero_context/nq/dev/nq-dev-03.jsonl.gz 7 | gzip -d zero_context/nq/dev/nq-dev-04.jsonl.gz 8 | 9 | python gather_benchmark_data.py --dataset=nq 10 | python gather_benchmark_data.py --dataset=msmarco 11 | python gather_benchmark_data.py --dataset=dolly 12 | 13 | rm -r zero_context/nq 14 | -------------------------------------------------------------------------------- /benchmark/data/gather_benchmark_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datasets import load_dataset 4 | 5 | 6 | def process_dolly(): 7 | dolly_data = load_dataset('databricks/databricks-dolly-15k', split='train') 8 | example_ids = json.load(open('accurate_context/dolly_example_ids.json')) 9 | 10 | chosen_examples = dict() 11 | for i, d in enumerate(dolly_data): 12 | if str(i) in example_ids: 13 | ex = {'id': str(i)} 14 | ex['question'] = d['instruction'] 15 | ex['context'] = [d['context']] 16 | ex['category'] = d['category'] 17 | ex['human_response'] = d['response'] 18 | 19 | chosen_examples[ex['id']] = ex 20 | 21 | ret = [chosen_examples[ex_id] for ex_id in example_ids] 22 | json.dump(ret, open('accurate_context/dolly.json', 'w'), indent=4) 23 | 24 | 25 | def process_msmarco(): 26 | ms_data = load_dataset('ms_marco', 'v2.1', split='validation') 27 | example_ids = json.load(open('noisy_context/msmarco_example_ids.json')) 28 | 29 | chosen_examples = dict() 30 | for d in ms_data: 31 | if str(d['query_id']) in example_ids: 32 | ex = {'id': str(d['query_id'])} 33 | ex['question'] = d['query'] 34 | ex['context'] = d['passages']['passage_text'] 35 | ex['query_type'] = d['query_type'] 36 | ex['answers'] = d['answers'] 37 | ex['wellFormedAnswers'] = d['wellFormedAnswers'] 38 | ex['context_is_selected'] = d['passages']['is_selected'] 39 | ex['context_ur'] = d['passages']['url'] 40 | 41 | chosen_examples[ex['id']] = ex 42 | 43 | ret = [chosen_examples[ex_id] for ex_id in example_ids] 44 | json.dump(ret, open('noisy_context/msmarco.json', 'w'), indent=4) 45 | 46 | 47 | def process_nq(): 48 | example_ids = json.load(open('zero_context/nq_example_ids.json')) 49 | 50 | chosen_examples = dict() 51 | for i in range(5): 52 | with open(f'zero_context/nq/dev/nq-dev-0{i}.jsonl') as f: 53 | for l in f.readlines(): 54 | l = json.loads(l) 55 | assert 'example_id' in l 56 | if str(l['example_id']) in example_ids: 57 | ex = {'id': str(l['example_id'])} 58 | ex['question'] = l['question_text'] 59 | 60 | annotations = [] 61 | for anno in l['annotations']: 62 | cleaned_anno = dict() 63 | cleaned_anno['short_answers'] = [] 64 | for short_ans in anno['short_answers']: 65 | short_ans_start = short_ans['start_token'] 66 | short_ans_end = short_ans['end_token'] 67 | short_answer = [tok['token'] for tok in l['document_tokens'][short_ans_start: short_ans_end] if 68 | not tok['html_token']] 69 | cleaned_anno['short_answers'].append(' '.join(short_answer).strip()) 70 | long_ans_start = anno['long_answer']['start_token'] 71 | long_ans_end = anno['long_answer']['end_token'] 72 | long_answer = [tok['token'] for tok in l['document_tokens'][long_ans_start: long_ans_end] if 73 | not tok['html_token']] 74 | cleaned_anno['long_answer'] = ' '.join(long_answer).strip() 75 | 76 | if len(cleaned_anno['short_answers']) > 0 and any([len(a) for a in cleaned_anno['short_answers']]) and len(cleaned_anno['long_answer']) > 0: 77 | annotations.append(cleaned_anno) 78 | assert len(annotations) > 0 79 | ex['context'] = [annotations[0]['long_answer']] 80 | ex['short_answers'] = annotations[0]['short_answers'] 81 | chosen_examples[ex['id']] = ex 82 | ret = [chosen_examples[ex_id] for ex_id in example_ids] 83 | json.dump(ret, open('zero_context/nq.json', 'w'), indent=4) 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--dataset', type=str, choices=['nq', 'msmarco', 'dolly']) 89 | 90 | args = parser.parse_args() 91 | 92 | if args.dataset == 'nq': 93 | process_nq() 94 | elif args.dataset == 'msmarco': 95 | process_msmarco() 96 | elif args.dataset == 'dolly': 97 | process_dolly() 98 | -------------------------------------------------------------------------------- /benchmark/data/noisy_context/msmarco_example_ids.json: -------------------------------------------------------------------------------- 1 | [ 2 | "167875", 3 | "1098142", 4 | "60339", 5 | "1098131", 6 | "423918", 7 | "164058", 8 | "431481", 9 | "164078", 10 | "74518", 11 | "20924", 12 | "196366", 13 | "830068", 14 | "424911", 15 | "853309", 16 | "424287", 17 | "71168", 18 | "158999", 19 | "144833", 20 | "71250", 21 | "415165", 22 | "411030", 23 | "168000", 24 | "73420", 25 | "1094146", 26 | "1093157", 27 | "1098565", 28 | "1097044", 29 | "192218", 30 | "453218", 31 | "271818", 32 | "1002884", 33 | "1096985", 34 | "1006506", 35 | "909506", 36 | "991037", 37 | "643122", 38 | "450391", 39 | "998737", 40 | "785306", 41 | "708853", 42 | "21793", 43 | "810239", 44 | "57346", 45 | "1102431", 46 | "248616", 47 | "741305", 48 | "40371", 49 | "1095893", 50 | "309249", 51 | "318922", 52 | "1090507", 53 | "392483", 54 | "283271", 55 | "1098016", 56 | "1090502", 57 | "278453", 58 | "1097486", 59 | "284565", 60 | "465781", 61 | "1099731", 62 | "1072119", 63 | "166795", 64 | "405776", 65 | "429983", 66 | "1088312", 67 | "1088355", 68 | "1008326", 69 | "544914", 70 | "1033994", 71 | "995200", 72 | "157359", 73 | "129756", 74 | "421437", 75 | "100248", 76 | "1083749", 77 | "143463", 78 | "1084984", 79 | "851286", 80 | "537560", 81 | "540432", 82 | "63192", 83 | "708743", 84 | "1093959", 85 | "1040099", 86 | "219496", 87 | "1100064", 88 | "147966", 89 | "76098", 90 | "1080229", 91 | "155382", 92 | "213986", 93 | "857956", 94 | "1092210", 95 | "1100861", 96 | "145322", 97 | "458110", 98 | "425121", 99 | "172330", 100 | "376687", 101 | "160460" 102 | ] -------------------------------------------------------------------------------- /benchmark/data/zero_context/nq_example_ids.json: -------------------------------------------------------------------------------- 1 | [ 2 | "4073418639113603971", 3 | "-3339006116507262453", 4 | "-1582234629547061501", 5 | "2223188268381215709", 6 | "-8883866849303680195", 7 | "-3476455683698256952", 8 | "-5646056836091194880", 9 | "-5003538501505873412", 10 | "-8680361524959876970", 11 | "-6500826823871433949", 12 | "1851594182842213971", 13 | "-2651390481836592629", 14 | "7422852161474049311", 15 | "238018115866608950", 16 | "-6395885127392955720", 17 | "2536150218240275989", 18 | "-3170493396531530980", 19 | "-204697481439605710", 20 | "-4283464153717448291", 21 | "-853451338783771197", 22 | "-1388143263608839064", 23 | "3413977600585971868", 24 | "-4834721811831894172", 25 | "-6591614197125818072", 26 | "2339339749699668536", 27 | "7526277995958905171", 28 | "-2419576910417046724", 29 | "-1773607187248198254", 30 | "4086158102790148091", 31 | "-1732734622903151644", 32 | "-7778233237890403173", 33 | "583026970021621830", 34 | "5185978890489959594", 35 | "-1531676396763282931", 36 | "5078709962400196312", 37 | "-7296166998534064378", 38 | "734309880189507498", 39 | "5063128834246440475", 40 | "5239772721292332989", 41 | "6217752837777787594", 42 | "6882831655380911332", 43 | "-1222948057007052760", 44 | "-278447034238251050", 45 | "-4882553194262710105", 46 | "5901661238535651276", 47 | "5217063434217549041", 48 | "-8544960023313580012", 49 | "4272117895813370426", 50 | "4245798066923223457", 51 | "8852834747561852791", 52 | "-5039045537721106027", 53 | "-6687867009117829006", 54 | "-6933377434468493740", 55 | "850563670386330666", 56 | "-8490085242981497626", 57 | "8361850218460994084", 58 | "5259818422556035007", 59 | "465961107727601737", 60 | "-5525220120773157166", 61 | "-2107324154024383180", 62 | "-2126374007719970792", 63 | "-1123449784666821676", 64 | "4265986416867251482", 65 | "2141218895657905276", 66 | "6672231732847608194", 67 | "4943207042321394497", 68 | "6014950976264156000", 69 | "-5004457603684974952", 70 | "-1206653570097564556", 71 | "6405015309290964923", 72 | "3790834797035922554", 73 | "-6992125597790283273", 74 | "-610965397636500508", 75 | "4876812477232188511", 76 | "-7186555013910059700", 77 | "-3428106355897335676", 78 | "8451264603946634151", 79 | "6605227747829720945", 80 | "8848582210866367992", 81 | "7539197459439257235", 82 | "6403691727808631943", 83 | "8900387163847503723", 84 | "334627437220863517", 85 | "1669394955232263644", 86 | "-1259577751352734948", 87 | "4674269357266381610", 88 | "-7306805811512692008", 89 | "2532633865902253646", 90 | "1389300307230755697", 91 | "2191432481376971337", 92 | "-5489426796364143729", 93 | "3480908309420822259", 94 | "3311962143974666464", 95 | "2352051038192309240", 96 | "7735703748660141931", 97 | "3028343500075334931", 98 | "1848006246448298349", 99 | "-927966355158112429", 100 | "-5984668857988357373", 101 | "-704529031648349308" 102 | ] -------------------------------------------------------------------------------- /benchmark/evaluation/autocheck.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from tqdm import tqdm 4 | import os 5 | 6 | from refchecker import ( 7 | LLMChecker, 8 | NLIChecker, 9 | AlignScoreChecker, 10 | RepCChecker 11 | ) 12 | 13 | from refchecker.extractor import LLMExtractor 14 | from refchecker.checker import LLMChecker 15 | 16 | 17 | def _get_checker(checker_model): 18 | checker = None 19 | if checker_model == 'nli': 20 | checker = NLIChecker() 21 | elif checker_model == 'alignscore': 22 | checker = AlignScoreChecker( 23 | batch_size=args.batch_size 24 | ) 25 | elif checker_model == 'repc': 26 | checker = RepCChecker() 27 | else: 28 | checker = LLMChecker( 29 | model=checker_model, 30 | api_base=args.api_base, 31 | batch_size=args.batch_size 32 | ) 33 | return checker 34 | 35 | 36 | def _get_extractor(extractor_model): 37 | claim_extractor = LLMExtractor( 38 | claim_format=args.claim_format, 39 | model=extractor_model, 40 | api_base=args.api_base, 41 | batch_size=args.batch_size 42 | ) 43 | return claim_extractor 44 | 45 | 46 | def autocheck(extractor_model, checker_model): 47 | claim_extractor = None 48 | checker = None 49 | 50 | for setting, ds in zip( 51 | ["zero_context", "noisy_context", "accurate_context"], 52 | ["nq", "msmarco", "dolly"] 53 | ): 54 | print(f'Evaluating {args.model} on {setting} setting with {extractor_model} extractor and {checker_model} checker') 55 | 56 | input_dir = os.path.join('human_annotations_v1', setting) 57 | response_filename = f'{ds}_{args.model}_answers.json' 58 | output_dir = os.path.join(args.output_dir, setting) 59 | if not os.path.exists(output_dir): 60 | os.mkdir(output_dir) 61 | output_filename = os.path.join(output_dir, response_filename) 62 | 63 | if os.path.exists(output_filename): 64 | response_data = json.load(open(output_filename)) 65 | else: 66 | response_data = json.load(open(os.path.join(input_dir, response_filename))) 67 | # in case the order of response data is not aligned with ours 68 | id_to_data = {d['id']: d for d in json.load(open(f'data/{setting}/{ds}.json'))} 69 | 70 | # === Extraction === 71 | batch_questions = [] 72 | batch_responses = [] 73 | kg_key = f'{extractor_model}_response_kg' 74 | for r in response_data: 75 | d = id_to_data[r['id']] 76 | r['question'] = d['question'] 77 | r['context'] = d['context'] 78 | if kg_key not in r: 79 | batch_questions.append(r['question']) 80 | batch_responses.append(r['response']) 81 | 82 | if len(batch_responses): 83 | if claim_extractor is None: 84 | claim_extractor = _get_extractor(extractor_model) 85 | 86 | print(f'Running Claim Extraction on {len(batch_responses)} examples...') 87 | extraction_results = claim_extractor.extract( 88 | batch_responses=batch_responses, 89 | batch_questions=batch_questions, 90 | max_new_tokens=1000 91 | ) 92 | 93 | assert len(extraction_results) == len(batch_responses) 94 | 95 | _i = 0 96 | for r in response_data: 97 | if kg_key not in r: 98 | r[kg_key] = [{'claim': c.content, 'attributed_sent_ids': c.attributed_sent_ids} for c in extraction_results[_i].claims] 99 | r['extraction_orig_response'] = extraction_results[_i].extractor_response 100 | _i += 1 101 | 102 | json.dump(response_data, open(output_filename, 'w'), indent=4) 103 | 104 | # # === Checking === 105 | batch_claims = [] 106 | batch_references = [] 107 | batch_questions = [] 108 | batch_responses = [] 109 | 110 | label_key = f'{checker_model}_label' 111 | for r in response_data: 112 | if kg_key in r: 113 | claims = [c['claim'] for c in r[kg_key] if label_key not in c] 114 | if len(claims): 115 | batch_claims.append(claims) 116 | _references = [] 117 | if len(r['context']) > 1: 118 | for pi, psg in enumerate(r['context']): 119 | _references.append(f'Passage {pi}: {psg}') 120 | else: 121 | for pi, psg in enumerate(r['context']): 122 | _references.append(psg) 123 | batch_references.append(_references) 124 | 125 | batch_questions.append(r['question']) 126 | batch_responses.append(r['response']) 127 | 128 | if checker_model in ['nli', 'alignscore', 'repc'] and ds != 'nq': 129 | max_reference_segment_length = 200 130 | else: 131 | max_reference_segment_length = 0 132 | 133 | if len(batch_claims): 134 | print(f'Running Checking on {len(batch_claims)} examples...') 135 | if checker is None: 136 | checker = _get_checker(checker_model) 137 | 138 | checking_results = checker.check( 139 | batch_claims=batch_claims, 140 | batch_references=batch_references, 141 | batch_questions=batch_questions, 142 | batch_responses=batch_responses, 143 | max_reference_segment_length=max_reference_segment_length, 144 | merge_psg=True 145 | ) 146 | _i = 0 147 | for r in response_data: 148 | d = id_to_data[r['id']] 149 | if kg_key in r: 150 | claims = [c['claim'] for c in r[kg_key] if label_key not in c] 151 | if len(claims): 152 | labels = checking_results[_i] 153 | _j = 0 154 | for claim in r[kg_key]: 155 | if label_key not in claim: 156 | claim[label_key] = labels[_j] 157 | _j += 1 158 | _i += 1 159 | json.dump(response_data, open(output_filename, 'w'), indent=4) 160 | 161 | cnt = 0 162 | for r in response_data: 163 | if kg_key in r: 164 | is_checking_finished = True 165 | for t in r[kg_key]: 166 | if label_key not in t: 167 | is_checking_finished = False 168 | break 169 | if is_checking_finished: 170 | cnt += 1 171 | 172 | print(f'{setting}: {cnt} finished.') 173 | 174 | 175 | def main(): 176 | if not os.path.exists(args.output_dir): 177 | os.mkdir(args.output_dir) 178 | 179 | autocheck(args.extractor, args.checker) 180 | 181 | 182 | if __name__ == '__main__': 183 | parser = argparse.ArgumentParser() 184 | parser.add_argument('--model', type=str) 185 | parser.add_argument('--extractor', type=str) 186 | parser.add_argument('--claim_format', type=str, choices=['triplet', 'subsentence']) 187 | parser.add_argument('--checker', type=str) 188 | parser.add_argument('--api_base', type=str) 189 | parser.add_argument('--batch_size', type=int, default=100) 190 | parser.add_argument('--output_dir', type=str) 191 | 192 | args = parser.parse_args() 193 | 194 | main() -------------------------------------------------------------------------------- /benchmark/evaluation/corr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from sklearn.metrics import f1_score, accuracy_score 4 | from scipy import stats 5 | 6 | 7 | def main(): 8 | print(f'===== {args.extractor} + {args.checker} ====') 9 | for setting, ds in zip( 10 | ["zero_context", "noisy_context", "accurate_context"], 11 | ["nq", "msmarco", "dolly"] 12 | ): 13 | gt_factual_list = [] 14 | pred_factual_list = [] 15 | 16 | gt_halu_rates = [] 17 | pred_halu_rates = [] 18 | # for llm in ["alpaca_7B", "chatgpt", "claude2", "davinci001", "falcon_40B_instruct", "gpt4", "llama2_70b_chat"]: 19 | for llm in ["alpaca_7B", "chatgpt", "claude2", "falcon_40B_instruct", "gpt4", "llama2_70b_chat"]: 20 | response_file = f'{args.data_dir}/{setting}/{ds}_{llm}_answers.json' 21 | response_data = json.load(open(response_file)) 22 | 23 | kg_key = f'{args.extractor}_response_kg' 24 | label_key = f'{args.checker}_label' 25 | for r in response_data: 26 | if 'claude2_response_kg' not in r or len(r['claude2_response_kg']) == 0: 27 | continue 28 | 29 | gt_factual = all([t['human_label'] == 'Entailment' for t in r['claude2_response_kg']]) 30 | gt_factual_list.append(gt_factual) 31 | gt_halu_rates.append(len([c for c in r['claude2_response_kg'] if c['human_label'] != 'Entailment']) / len(r['claude2_response_kg'])) 32 | 33 | assert kg_key in r 34 | pred_factual = all([c[label_key] == 'Entailment' for c in r[kg_key]]) 35 | pred_factual_list.append(pred_factual) 36 | if len(r[kg_key]): 37 | pred_halu_rates.append(len([c for c in r[kg_key] if c[label_key] != 'Entailment']) / len(r[kg_key])) 38 | else: 39 | pred_halu_rates.append(0) 40 | 41 | print(f'{setting}') 42 | acc = round(accuracy_score(gt_factual_list, pred_factual_list) * 100, 2) 43 | fact_f1 = round(f1_score(gt_factual_list, pred_factual_list, pos_label=True) * 100, 2) 44 | nonfact_f1 = round(f1_score(gt_factual_list, pred_factual_list, pos_label=False) * 100, 2) 45 | pearson = round(stats.pearsonr(gt_halu_rates, pred_halu_rates).statistic * 100, 2) 46 | spearman = round(stats.spearmanr(gt_halu_rates, pred_halu_rates).statistic * 100, 2) 47 | print(f'Acc: {acc}\tFact. F1: {fact_f1}\tNonFact. F1: {nonfact_f1}\tPearson: {pearson}\tSpearman: {spearman}') 48 | # print(f'Fact. F1: {fact_f1}') 49 | # print(f'NonFact. F1: {nonfact_f1}') 50 | # print(f'Pearson: {pearson}') 51 | # print(f'Spearman: {spearman}') 52 | print(f'=========================================\n') 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--extractor', type=str) 58 | parser.add_argument('--checker', type=str) 59 | parser.add_argument('--data_dir', type=str) 60 | 61 | args = parser.parse_args() 62 | 63 | main() 64 | # python evaluation/corr.py --extractor=bedrock/meta.llama3-70b-instruct-v1:0 --checker=bedrock/meta.llama3-70b-instruct-v1:0 --data_dir=triplet_llama3 65 | # python evaluation/corr.py --extractor=bedrock/meta.llama3-70b-instruct-v1:0 --checker=alignscore --data_dir=triplet_llama3 66 | # python evaluation/corr.py --extractor=bedrock/meta.llama3-70b-instruct-v1:0 --checker=gpt-4-turbo --data_dir=triplet_llama3 67 | 68 | # python evaluation/corr.py --extractor=bedrock/meta.llama3-70b-instruct-v1:0 --checker=bedrock/meta.llama3-70b-instruct-v1:0 --data_dir=subsent_llama3 69 | # python evaluation/corr.py --extractor=bedrock/meta.llama3-70b-instruct-v1:0 --checker=alignscore --data_dir=subsent_llama3 70 | # python evaluation/corr.py --extractor=bedrock/meta.llama3-70b-instruct-v1:0 --checker=gpt-4-turbo --data_dir=subsent_llama3 -------------------------------------------------------------------------------- /benchmark/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import argparse 4 | from collections import Counter 5 | 6 | 7 | def get_evaluation_results(model, extractor, checker): 8 | ret = dict() 9 | avg_abstain_list = [] 10 | avg_contra_list = [] 11 | avg_entail_list = [] 12 | avg_neutral_list = [] 13 | 14 | for setting in ['zero_context', 'noisy_context', 'accurate_context']: 15 | contra_list = [] 16 | entail_list = [] 17 | neutral_list = [] 18 | 19 | abstain_cnt = 0 20 | response_data = json.load(open(f'data/{setting}/{setting}_{model}_answers.json')) 21 | for r in response_data: 22 | c, e, n = 0, 0, 0 23 | 24 | if f'{extractor}_response_kg' not in r: 25 | n_triplets = 0 26 | else: 27 | n_triplets = len(r[f'{extractor}_response_kg']) 28 | if n_triplets == 0: 29 | # abstain response 30 | abstain_cnt += 1 31 | else: 32 | # non abstain response 33 | if checker == 'ensemble': 34 | for t in r[f'{extractor}_response_kg']: 35 | v = Counter([t['gpt4_label'], t['claude2_label'], t['nli_label']]).most_common(1)[0][0] 36 | if v == 'Entailment': 37 | e += 1 38 | elif v == 'Neutral': 39 | n += 1 40 | elif v == 'Contradiction': 41 | c += 1 42 | else: 43 | n += 1 44 | else: 45 | for v in [x[f'{checker}_label'] for x in r[f'{extractor}_response_kg']]: 46 | if v == 'Entailment': 47 | e += 1 48 | elif v == 'Neutral': 49 | n += 1 50 | elif v == 'Contradiction': 51 | c += 1 52 | else: 53 | n += 1 54 | assert e + n + c == n_triplets, r[f'{extractor}_response_kg'] 55 | contra_list.append(c / n_triplets) 56 | entail_list.append(e / n_triplets) 57 | neutral_list.append(n / n_triplets) 58 | abstain_rate = abstain_cnt / len(response_data) 59 | 60 | ret[setting] = { 61 | 'abstain': abstain_rate * 100, 62 | 'entailment': np.mean(entail_list) * 100, 63 | 'neutral': np.mean(neutral_list) * 100, 64 | 'contradiction': np.mean(contra_list) * 100 65 | } 66 | avg_entail_list += entail_list 67 | avg_neutral_list += neutral_list 68 | avg_contra_list += contra_list 69 | avg_abstain_list.append(abstain_rate) 70 | 71 | ret['avg'] = { 72 | 'abstain': np.mean(avg_abstain_list) * 100, 73 | 'entailment': np.mean(avg_entail_list) * 100, 74 | 'neutral': np.mean(avg_neutral_list) * 100, 75 | 'contradiction': np.mean(avg_contra_list) * 100 76 | } 77 | 78 | return ret 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--model', type=str) 84 | parser.add_argument('--extractor', type=str) 85 | parser.add_argument('--checker', type=str) 86 | parser.add_argument('--output_file', type=str) 87 | 88 | args = parser.parse_args() 89 | 90 | ret = get_evaluation_results(args.model, args.extractor, args.checker) 91 | json.dump(ret, open(args.output_file, 'w'), indent=4) -------------------------------------------------------------------------------- /benchmark/response_collection/README.md: -------------------------------------------------------------------------------- 1 | ## Collection Responses of Your LLM on RefChecker Benchmark 2 | 3 | Please take the following steps to collect responses of your LLM on the RefChecker benchmark data. 4 | 5 | 1. Write a class inherit from `ResponseCollectorBase` in [collector_base.py](collector_base.py). Please check the examples in [gpt4_turbo.py](gpt4_turbo.py), [chatglm3_6b.py](chatglm3_6b.py) and [mistral.py](mistral.py). 6 | 7 | 8 | 2. Modify [main.py](main.py) to add your model there. 9 | 10 | 3. Run the following command to collect responses from your model: 11 | 12 | ```bash 13 | cd benchmark 14 | python response_collection/main.py --model= 15 | ``` 16 | 17 | The file containing responses will be saved to folds of different settings in [benchmark/data](../../benchmark/data/). -------------------------------------------------------------------------------- /benchmark/response_collection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/benchmark/response_collection/__init__.py -------------------------------------------------------------------------------- /benchmark/response_collection/chatglm3_6b.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | 3 | from collector_base import ResponseCollectorBase 4 | 5 | 6 | class ChatGLM3(ResponseCollectorBase): 7 | def __init__( 8 | self, 9 | mname, 10 | device='cuda' 11 | ) -> None: 12 | super().__init__(mname, device) 13 | 14 | self.model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).to(self.device) 15 | self.model.eval() 16 | self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True) 17 | 18 | self.max_contex_length = 8100 19 | 20 | def tokenizer_encode(self, prompt): 21 | return self.tokenizer.encode(prompt, add_special_tokens=False) 22 | 23 | def tokenizer_decode(self, encoded): 24 | return self.tokenizer.decode(encoded) 25 | 26 | def get_response(self, prompt): 27 | res, _ = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False) 28 | res = res.strip() 29 | return res -------------------------------------------------------------------------------- /benchmark/response_collection/collector_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | 5 | 6 | closed_qa_prompt = """Instruction: Provide a well-formed answer to the question using information from the given context. 7 | Question: {question} 8 | Context: {context} 9 | """ 10 | 11 | ie_prompt = """Instruction: {question} 12 | Context: {context} 13 | """ 14 | 15 | sum_prompt = """Instruction: {question} 16 | Context: {context} 17 | """ 18 | 19 | 20 | class ResponseCollectorBase: 21 | def __init__( 22 | self, 23 | mname, 24 | device='cuda' 25 | ) -> None: 26 | self.model = None 27 | self.tokenizer = None 28 | 29 | self.mname = mname 30 | self.device = device 31 | 32 | self.max_contex_length = None 33 | self.max_new_tokens = 300 34 | 35 | def collect_response(self): 36 | assert self.max_contex_length is not None 37 | 38 | for setting, ds in zip( 39 | ['zero_context', 'noisy_context', 'accurate_context'], 40 | ["nq", "msmarco", "dolly"] 41 | ): 42 | examples = json.load(open(f'data/{setting}/{ds}.json')) 43 | response_file = f'data/{setting}/{setting}_{self.mname}_answers.json' 44 | if os.path.exists(response_file): 45 | response_data = json.load(open(response_file)) 46 | else: 47 | response_data = [{'id': ex['id']} for ex in examples] 48 | json.dump(response_data, open(response_file, 'w'), indent=4) 49 | 50 | finish_cnt = 0 51 | for ex, r in tqdm(zip(examples, response_data), total=len(examples)): 52 | if 'response' in r: 53 | finish_cnt += 1 54 | continue 55 | input_prompt = self.get_input(ds, ex) 56 | res = self.get_response(input_prompt) 57 | if res and len(res): 58 | r['input'] = input_prompt 59 | r['response'] = res 60 | json.dump(response_data, open(response_file, 'w'), indent=4) 61 | finish_cnt += 1 62 | print(f'{setting}: {finish_cnt} responses collected.') 63 | 64 | def get_input(self, split, example): 65 | if split == 'nq': 66 | return example['question'] 67 | elif split == 'msmarco': 68 | tail = '\nAnswer: \n' 69 | 70 | prompt = f'Please answer the following question based on the provided passages.\n\nQuestion: {example["question"]}?\n\nPassages:\n' 71 | for i, p in enumerate(example['context']): 72 | prompt += f'Passage {i}: {p}\n' 73 | prompt_encoded = self.tokenizer_encode(prompt) 74 | if prompt_encoded: 75 | tail_length = len(self.tokenizer_encode(tail)) 76 | 77 | prompt_encoded = prompt_encoded[:self.max_contex_length - self.max_new_tokens - tail_length] 78 | prompt_truncated = self.tokenizer_decode(prompt_encoded) 79 | prompt_truncated += tail 80 | return prompt_truncated 81 | return prompt + tail 82 | elif split == 'dolly': 83 | if example['category'] == 'closed_qa': 84 | prompt = closed_qa_prompt.format(**{'question': example['question'], 'context': example['context'][0]}) 85 | elif example['category'] == 'information_extraction': 86 | prompt = closed_qa_prompt.format(**{'question': example['question'], 'context': example['context'][0]}) 87 | elif example['category'] == 'summarization': 88 | prompt = sum_prompt.format(**{'question': example['question'], 'context': example['context'][0]}) 89 | prompt_encoded = self.tokenizer_encode(prompt) 90 | if prompt_encoded: 91 | prompt_encoded = prompt_encoded[:self.max_contex_length - self.max_new_tokens] 92 | prompt_truncated = self.tokenizer_decode(prompt_encoded) 93 | return prompt_truncated 94 | return prompt 95 | 96 | def tokenizer_encode(self, prompt): 97 | raise NotImplementedError 98 | 99 | def tokenizer_decode(self, encoded): 100 | raise NotImplementedError 101 | 102 | def get_response(self, prompt): 103 | raise NotImplementedError 104 | -------------------------------------------------------------------------------- /benchmark/response_collection/gpt4_turbo.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | 3 | 4 | from refchecker.utils import get_openai_model_response 5 | from collector_base import ResponseCollectorBase 6 | 7 | class GPT4Turbo(ResponseCollectorBase): 8 | def __init__(self, mname, device='cuda') -> None: 9 | super().__init__(mname, device) 10 | 11 | self.tokenizer = tiktoken.encoding_for_model("gpt-4-1106-preview") 12 | self.max_contex_length = 127000 13 | 14 | def tokenizer_encode(self, prompt): 15 | return self.tokenizer.encode(prompt) 16 | 17 | def tokenizer_decode(self, encoded): 18 | return self.tokenizer.decode(encoded) 19 | 20 | def get_response(self, prompt): 21 | res = get_openai_model_response(prompt, temperature=0, model='gpt-4-1106-preview') 22 | return res.strip() -------------------------------------------------------------------------------- /benchmark/response_collection/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_model(): 5 | if args.model == 'mistral_7b': 6 | from mistral import Mistral 7 | model = Mistral(mname=args.model) 8 | elif args.model == 'chatglm3_6b': 9 | from chatglm3_6b import ChatGLM3 10 | model = ChatGLM3(mname=args.model) 11 | elif args.model == 'gpt4_turbo': 12 | from gpt4_turbo import GPT4Turbo 13 | model = GPT4Turbo(mname=args.model) 14 | return model 15 | 16 | def main(): 17 | model = get_model() 18 | model.collect_response() 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--model', type=str, choices=['mistral_7b', 'chatglm3_6b', 'gpt4_turbo']) 24 | 25 | args = parser.parse_args() 26 | 27 | main() 28 | -------------------------------------------------------------------------------- /benchmark/response_collection/mistral.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from vllm import LLM, SamplingParams 3 | 4 | from collector_base import ResponseCollectorBase 5 | 6 | 7 | class Mistral(ResponseCollectorBase): 8 | def __init__( 9 | self, 10 | mname, 11 | device='cuda' 12 | ) -> None: 13 | super().__init__(mname, device) 14 | 15 | self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") 16 | 17 | self.llm = LLM( 18 | model='mistralai/Mistral-7B-Instruct-v0.1', 19 | tokenizer='mistralai/Mistral-7B-Instruct-v0.1', 20 | trust_remote_code=True, 21 | tensor_parallel_size=8) 22 | 23 | self.max_contex_length = 8100 24 | 25 | self.sampling_params = SamplingParams(temperature=0, max_tokens=self.max_new_tokens) 26 | 27 | def tokenizer_encode(self, prompt): 28 | return self.tokenizer.encode(prompt, add_special_tokens=False) 29 | 30 | def tokenizer_decode(self, encoded): 31 | return self.tokenizer.decode(encoded) 32 | 33 | def get_response(self, prompt): 34 | prompt = f'[INST] {prompt} [/INST]' 35 | 36 | res = self.llm.generate([prompt], self.sampling_params, use_tqdm=False)[0].outputs[0].text 37 | res = res.strip() 38 | return res -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | ## RefChecker Demo 2 | 3 | You can run the RefChecker on your server. 4 | 5 | ### Export API Keys 6 | 7 | To run the demo, we should first set the relevant API keys for the extractor and checker. 8 | 9 | - If we use OpenAI models (i.e. GPT-4), run the following command: 10 | ```bash 11 | export OPENAI_API_KEY= 12 | ``` 13 | 14 | - To use Claude 2, if we have an Anthropic API Key, run: 15 | ```bash 16 | export ANTHROPIC_API_KEY= 17 | ``` 18 | 19 | If we are using Claude 2 on AWS Bedrock and running the demo on AWS, we need to export the region of the server: 20 | 21 | ```bash 22 | export aws_bedrock_region= 23 | ``` 24 | 25 | - If we want to use Google search to find references, export the Serper API key: 26 | 27 | ```bash 28 | export SERPER_API_KEY= 29 | ``` 30 | 31 | ### Run the Demo 32 | 33 | Execute the following command to run the demo: 34 | 35 | ```bash 36 | streamlit run demo/main.py \ 37 | --server.port={port} -- \ # set the deploy port 38 | --enable_search # enable Google search 39 | ``` 40 | 41 | It will print the URL of the demo, you can open it in your browser to interact with the demo. 42 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/demo/__init__.py -------------------------------------------------------------------------------- /example/check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | refchecker-cli check \ 3 | --input_path example/example_out_claims.json \ 4 | --output_path example/example_out.json \ 5 | --checker_name bedrock/anthropic.claude-3-sonnet-20240229-v1:0 \ 6 | --aggregator_name soft 7 | -------------------------------------------------------------------------------- /example/example_in.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "reference": "`` I Dig Rock and Roll Music '' is a 1967 song by the American folk group Peter , Paul and Mary , written by Paul Stookey , James Mason and Dave Dixon .", 4 | "question": "who sang i dig rock and roll music", 5 | "response": "The song \"I Dig Rock and Roll Music\" was originally performed by Chuck Berry in 1957. However, it has since been covered by many artists, including The Beatles, The Rolling Stones, and Led Zeppelin." 6 | }, 7 | { 8 | "reference": "The Impalas were an American doo - wop group in the late 1950s , best known for their hit , `` Sorry ( I Ran All the Way Home ) '' .", 9 | "question": "who sang i ran all the way home", 10 | "response": "The song \"I Ran All the Way Home\" was originally performed by the American rock band The Black Eyed Peas. The song was released as a single in 2003 and became a hit, reaching the top of the Billboard Hot 100 chart in the United States." 11 | }, 12 | { 13 | "reference": "`The Parliament of India is bicameral . Concurrence of both houses are required to pass any bill . However , the authors of the Constitution of India visualised situations of deadlock between the upper house i.e. Rajya Sabha and the lower house i.e. Lok Sabha . Therefore , the Constitution of India provides for Joint sittings of both the Houses to break this deadlock . The joint sitting of the Parliament is called by the President and is presided over by the Speaker or , in his absence , by the Deputy Speaker of the Lok Sabha or in his absence , the Deputy - Chairman of the Rajya Sabha . If any of the above officers are not present then any other member of the Parliament can preside by consensus of both the House .", 14 | "question": "who presides over the joint sessions of parliament", 15 | "response": "The Speaker of the House of Representatives presides over the joint sessions of Congress. The Speaker is chosen by the majority party in the House of Representatives and is responsible for conducting the business of the House. During joint sessions, the Speaker acts as the presiding officer and is responsible for maintaining order and ensuring that the proceedings run smoothly." 16 | } 17 | ] 18 | -------------------------------------------------------------------------------- /example/extract-check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | refchecker-cli extract-check \ 3 | --input_path example/example_in.json \ 4 | --output_path example/example_out.json \ 5 | --extractor_name bedrock/anthropic.claude-3-sonnet-20240229-v1:0 \ 6 | --extractor_max_new_tokens 1000 \ 7 | --checker_name bedrock/anthropic.claude-3-sonnet-20240229-v1:0 \ 8 | --aggregator_name soft -------------------------------------------------------------------------------- /example/extract-check_wo_ref.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # using google retriever requires access to claude2 for query generation 3 | refchecker-cli extract-check \ 4 | --input_path example/example_in.json \ 5 | --output_path example/example_out.json \ 6 | --extractor_name bedrock/anthropic.claude-3-sonnet-20240229-v1:0 \ 7 | --extractor_max_new_tokens 1000 \ 8 | --checker_name bedrock/anthropic.claude-3-sonnet-20240229-v1:0 \ 9 | --aggregator_name soft \ 10 | --use_retrieval -------------------------------------------------------------------------------- /example/extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | refchecker-cli extract \ 3 | --input_path example/example_in.json \ 4 | --output_path example/example_out_claims.json \ 5 | --extractor_name bedrock/anthropic.claude-3-sonnet-20240229-v1:0 \ 6 | --extractor_max_new_tokens 1000 7 | -------------------------------------------------------------------------------- /imgs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/demo.gif -------------------------------------------------------------------------------- /imgs/evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/evaluation.png -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/framework.png -------------------------------------------------------------------------------- /imgs/localization_example_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/localization_example_1.jpg -------------------------------------------------------------------------------- /imgs/localization_example_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/localization_example_2.jpg -------------------------------------------------------------------------------- /imgs/localization_example_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/localization_example_3.jpg -------------------------------------------------------------------------------- /imgs/localization_example_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/localization_example_4.jpg -------------------------------------------------------------------------------- /imgs/localization_method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/localization_method.jpg -------------------------------------------------------------------------------- /imgs/settings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/settings.png -------------------------------------------------------------------------------- /imgs/venn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/RefChecker/1df1b25cee792ba2b171302e31ca4f768bd67703/imgs/venn.png -------------------------------------------------------------------------------- /notebooks/refchecker_usages.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "# Modify the following values accordingly\n", 12 | "os.environ[\"AWS_ACCESS_KEY_ID\"] = \"\"\n", 13 | "os.environ[\"AWS_SECRET_ACCESS_KEY\"] = \"\"\n", 14 | "os.environ[\"AWS_REGION_NAME\"] = \"\"\n", 15 | "\n", 16 | "\n", 17 | "question = \"\"\"What's the longest river in the world?\"\"\"\n", 18 | "\n", 19 | "response1 = \"\"\"\n", 20 | "The longest river in the world is the Nile River, located in northeastern Africa. \n", 21 | "It stretches for approximately 6,853 kilometers (4,258 miles) from its sources in Burundi, Rwanda, and Tanzania to its delta on the Mediterranean Sea in Egypt.\n", 22 | "\"\"\"\n", 23 | "\n", 24 | "response2 = \"\"\"\n", 25 | "The longest river in the world is the Nile River. \n", 26 | "It flows northward through northeastern Africa for approximately 6,650 kilometers (4,132 miles) from its most distant source in the African Great Lakes region to the Mediterranean Sea.\n", 27 | "\"\"\"\n", 28 | "\n", 29 | "reference = \"\"\"\n", 30 | "The Nile is a major north-flowing river in northeastern Africa. \n", 31 | "It flows into the Mediterranean Sea. The Nile is the longest river in Africa and has historically been considered the longest river in the world, though this has been contested by research suggesting that the Amazon River is slightly longer. \n", 32 | "Of the world's major rivers, the Nile is one of the smallest, as measured by annual flow in cubic metres of water. \n", 33 | "About 6,650 km (4,130 mi) long, its drainage basin covers eleven countries: the Democratic Republic of the Congo, Tanzania, Burundi, Rwanda, Uganda, Kenya, Ethiopia, Eritrea, South Sudan, Sudan, and Egypt. \n", 34 | "In particular, the Nile is the primary water source of Egypt, Sudan and South Sudan. \n", 35 | "Additionally, the Nile is an important economic river, supporting agriculture and fishing.\n", 36 | "\"\"\"" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "Warning: scikit-learn-intelex not installed, sklearn acceleration for the RepC checker is not enabled.\n" 49 | ] 50 | }, 51 | { 52 | "name": "stderr", 53 | "output_type": "stream", 54 | "text": [ 55 | "100%|██████████| 1/1 [00:02<00:00, 2.80s/it]" 56 | ] 57 | }, 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Claims in Response 1:\n", 63 | "['The longest river in the world', 'is', 'the Nile River']\n", 64 | "['The Nile River', 'located in', 'northeastern Africa']\n", 65 | "['The Nile River', 'stretches for', 'approximately 6,853 kilometers (4,258 miles)']\n", 66 | "['The Nile River', 'has sources in', 'Burundi']\n", 67 | "['The Nile River', 'has sources in', 'Rwanda']\n", 68 | "['The Nile River', 'has sources in', 'Tanzania']\n", 69 | "['The Nile River', 'has delta on', 'the Mediterranean Sea']\n", 70 | "['The Nile River delta', 'located in', 'Egypt']\n", 71 | "----\n", 72 | "Claims in Response 2:\n", 73 | "['The longest river in the world', 'is', 'the Nile River']\n", 74 | "['the Nile River', 'flows', 'northward']\n", 75 | "['the Nile River', 'flows through', 'northeastern Africa']\n", 76 | "['the Nile River', 'has length of approximately', '6,650 kilometers (4,132 miles)']\n", 77 | "['the Nile River', 'originates from', 'its most distant source in the African Great Lakes region']\n", 78 | "['the Nile River', 'ends at', 'the Mediterranean Sea']\n", 79 | "['the African Great Lakes region', 'is source of', 'the Nile River']\n", 80 | "['the Mediterranean Sea', 'is destination of', 'the Nile River']\n", 81 | "----\n" 82 | ] 83 | }, 84 | { 85 | "name": "stderr", 86 | "output_type": "stream", 87 | "text": [ 88 | "\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "from refchecker import LLMExtractor\n", 94 | "\n", 95 | "# claim extraction\n", 96 | "extractor = LLMExtractor(\n", 97 | " claim_format='triplet', \n", 98 | " model='bedrock/anthropic.claude-3-sonnet-20240229-v1:0',\n", 99 | " batch_size=8\n", 100 | ")\n", 101 | "\n", 102 | "# each element in claims is an instance of Claim\n", 103 | "extraction_results = extractor.extract(\n", 104 | " batch_responses=[response1, response2],\n", 105 | " max_new_tokens=1000\n", 106 | ")\n", 107 | "for i, res in enumerate(extraction_results):\n", 108 | " print(f'Claims in Response {i+1}:')\n", 109 | " for claim in res.claims:\n", 110 | " print(claim.content)\n", 111 | " print('----')" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 3, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stderr", 121 | "output_type": "stream", 122 | "text": [ 123 | "100%|██████████| 1/1 [00:01<00:00, 1.55s/it]" 124 | ] 125 | }, 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "Checking results for Response 1:\n", 131 | "['The longest river in the world', 'is', 'the Nile River'] --> Contradiction\n", 132 | "['The Nile River', 'located in', 'northeastern Africa'] --> Entailment\n", 133 | "['The Nile River', 'stretches for', 'approximately 6,853 kilometers (4,258 miles)'] --> Entailment\n", 134 | "['The Nile River', 'has sources in', 'Burundi'] --> Entailment\n", 135 | "['The Nile River', 'has sources in', 'Rwanda'] --> Entailment\n", 136 | "['The Nile River', 'has sources in', 'Tanzania'] --> Entailment\n", 137 | "['The Nile River', 'has delta on', 'the Mediterranean Sea'] --> Entailment\n", 138 | "['The Nile River delta', 'located in', 'Egypt'] --> Entailment\n", 139 | "---\n", 140 | "Checking results for Response 2:\n", 141 | "['The longest river in the world', 'is', 'the Nile River'] --> Contradiction\n", 142 | "['the Nile River', 'flows', 'northward'] --> Entailment\n", 143 | "['the Nile River', 'flows through', 'northeastern Africa'] --> Entailment\n", 144 | "['the Nile River', 'has length of approximately', '6,650 kilometers (4,132 miles)'] --> Entailment\n", 145 | "['the Nile River', 'originates from', 'its most distant source in the African Great Lakes region'] --> Entailment\n", 146 | "['the Nile River', 'ends at', 'the Mediterranean Sea'] --> Entailment\n", 147 | "['the African Great Lakes region', 'is source of', 'the Nile River'] --> Entailment\n", 148 | "['the Mediterranean Sea', 'is destination of', 'the Nile River'] --> Entailment\n", 149 | "---\n" 150 | ] 151 | }, 152 | { 153 | "name": "stderr", 154 | "output_type": "stream", 155 | "text": [ 156 | "\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "from refchecker import LLMChecker\n", 162 | "\n", 163 | "checker = LLMChecker(model='bedrock/anthropic.claude-3-sonnet-20240229-v1:0')\n", 164 | "\n", 165 | "batch_claims = []\n", 166 | "for res in extraction_results:\n", 167 | " batch_claims.append([claim.content for claim in res.claims])\n", 168 | "\n", 169 | "batch_reference = [reference] * len(batch_claims)\n", 170 | "\n", 171 | "checking_results = checker.check(\n", 172 | " batch_claims=batch_claims,\n", 173 | " batch_references=batch_reference,\n", 174 | " max_reference_segment_length=0\n", 175 | ")\n", 176 | "\n", 177 | "for i, (extract_res, check_res) in enumerate(zip(extraction_results, checking_results)):\n", 178 | " print(f'Checking results for Response {i+1}:')\n", 179 | " for claim, pred_label in zip(extract_res.claims, check_res):\n", 180 | " print(f'{claim.content} --> {pred_label}')\n", 181 | " print('---')" 182 | ] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "rc", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.10.14" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 2 206 | } 207 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "refchecker" 3 | version = "0.2.17" 4 | description = "RefChecker provides automatic checking pipeline for detecting fine-grained hallucinations generated by Large Language Models." 5 | authors = [ 6 | "Xiangkun Hu ", 7 | "Dongyu Ru ", 8 | "Qipeng Guo ", 9 | "Lin Qiu ", 10 | "Zheng Zhang " 11 | ] 12 | readme = "README.md" 13 | license = "Apache-2.0" 14 | 15 | [tool.poetry.dependencies] 16 | python = "^3.10" 17 | spacy = "^3.7" 18 | boto3 = "^1.35" 19 | torch = "^2" 20 | transformers = "^4.41" 21 | rank-bm25 = "^0.2" 22 | beautifulsoup4 = "^4.12" 23 | anthropic = "^0.29" 24 | plotly = "^5.22" 25 | nltk = "^3.8" 26 | pytorch_lightning = "^2.3" # for alignscore 27 | scikit-learn = "^1.5" 28 | accelerate = "^0.31" 29 | litellm = "^1.49" 30 | diskcache = "^5" 31 | 32 | # optional dependencies required by specific modules 33 | scikit-learn-intelex = { version = "^2024.1.0", optional = true } 34 | vllm = { version = "^0.5", optional = true } 35 | 36 | [tool.poetry.extras] 37 | repcex = ["scikit-learn-intelex"] 38 | open-extractor = ["vllm"] 39 | 40 | [tool.poetry.scripts] 41 | refchecker-cli = "refchecker.cli:main" 42 | 43 | 44 | [build-system] 45 | requires = ["poetry-core"] 46 | build-backend = "poetry.core.masonry.api" 47 | -------------------------------------------------------------------------------- /refchecker/__init__.py: -------------------------------------------------------------------------------- 1 | from .checker import LLMChecker 2 | from .extractor import LLMExtractor -------------------------------------------------------------------------------- /refchecker/aggregator.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | 4 | def soft_agg(results): 5 | """Aggregate results by taking the ratio of each category.""" 6 | if not results: 7 | return { 8 | "Entailment": 0.0, 9 | "Neutral": 0.0, 10 | "Contradiction": 0.0, 11 | "Abstain": 1.0, 12 | } 13 | 14 | if all(len(result) == 1 for result in results): 15 | for i in range(len(results)): 16 | if len(results[i]) == 1: 17 | results[i] = results[i][0] 18 | 19 | total = len(results) 20 | agg = { 21 | "Entailment": 0.0, 22 | "Neutral": 0.0, 23 | "Contradiction": 0.0, 24 | "Abstain": 0.0, 25 | } 26 | for result in results: 27 | agg[result] += 1.0 28 | for key in agg: 29 | agg[key] /= total 30 | return agg 31 | 32 | 33 | def strict_agg(results): 34 | """Aggregate results by zero-tolerance on negative labels.""" 35 | if not results: 36 | return "Abstain" 37 | 38 | if all(len(result) == 1 for result in results): 39 | for i in range(len(results)): 40 | if len(results[i]) == 1: 41 | results[i] = results[i][0] 42 | 43 | ret = "Entailment" 44 | for result in results: 45 | if result == "Contradiction": 46 | return "Contradiction" 47 | if result == "Neutral": 48 | ret = "Neutral" 49 | return ret 50 | 51 | 52 | def major_agg(results): 53 | """Aggregate results by majority vote.""" 54 | if not results: 55 | return "Abstain" 56 | 57 | if all(len(result) == 1 for result in results): 58 | for i in range(len(results)): 59 | if len(results[i]) == 1: 60 | results[i] = results[i][0] 61 | 62 | agg = Counter(results) 63 | return agg.most_common(1)[0][0] 64 | -------------------------------------------------------------------------------- /refchecker/base.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Union 3 | 4 | from spacy.lang.en import English 5 | 6 | 7 | class RCSentence: 8 | def __init__(self, sentence_text, is_blank, start=None, end=None) -> None: 9 | self.text = sentence_text 10 | self.is_blank = is_blank 11 | self.start = start 12 | self.end = end 13 | 14 | def __repr__(self) -> str: 15 | return self.text 16 | 17 | def to_dict(self): 18 | return {'text': self.text, 'is_blank': self.is_blank, 'start': self.start, 'end': self.end} 19 | 20 | @classmethod 21 | def from_dict(cls, sent_dict: dict): 22 | return cls(text=sent_dict['text'], is_blank=sent_dict['is_blank'], start=sent_dict['start'], end=sent_dict['end']) 23 | 24 | 25 | class RCText: 26 | def __init__(self, response_text) -> None: 27 | self.orig_text = response_text 28 | 29 | self.nlp = English() 30 | self.nlp.add_pipe("sentencizer") 31 | 32 | self.sentences = None 33 | self._sent_id_to_index = dict() 34 | self.sentencize() 35 | 36 | self.indexed_response = None 37 | 38 | def sentencize(self): 39 | blanks = [' ', '\n', '\r'] 40 | 41 | sents = [] 42 | for s in self.nlp(self.orig_text).sents: 43 | s_text = s.text 44 | prefix = '' 45 | start_idx = 0 46 | while start_idx < len(s_text): 47 | if s_text[start_idx] in blanks: 48 | prefix += s_text[start_idx] 49 | start_idx += 1 50 | else: 51 | start_idx = start_idx 52 | break 53 | if len(prefix): 54 | sents.append(RCSentence(prefix, is_blank=True)) 55 | 56 | surfix = '' 57 | end_idx = len(s_text) - 1 58 | while end_idx > start_idx: 59 | if s_text[end_idx] in blanks: 60 | surfix = s_text[end_idx] + surfix 61 | end_idx -= 1 62 | else: 63 | break 64 | if len(s_text[start_idx: end_idx+1]): 65 | sents.append(RCSentence(s_text[start_idx: end_idx+1], is_blank=False, start=s.start, end=s.end)) 66 | if len(surfix): 67 | sents.append(RCSentence(surfix, is_blank=True)) 68 | 69 | self.sentences = sents 70 | sent_id = 1 71 | for index, sent in enumerate(self.sentences): 72 | if not sent.is_blank: 73 | self._sent_id_to_index[str(sent_id)] = index 74 | sent_id += 1 75 | 76 | def get_indexed_response(self, condense_newlines: bool): 77 | if self.indexed_response is None: 78 | sent_id = 1 79 | res = '' 80 | for i, s in enumerate(self.sentences): 81 | sent_text = s.text 82 | if condense_newlines: 83 | sent_text = re.sub(r'(\n\s*)+\n', '\n', sent_text) 84 | sent_text = re.sub(r'(\r\s*)+\r', '\r', sent_text) 85 | if s.is_blank: 86 | res += sent_text 87 | else: 88 | res += f'[{sent_id}] {sent_text}' 89 | sent_id += 1 90 | if i < len(self.sentences) - 1: 91 | res += ' ' 92 | self.indexed_response = res 93 | return self.indexed_response 94 | 95 | def get_sentence_by_id(self, sent_id: str): 96 | assert sent_id in self._sent_id_to_index, "Invalid sentence ID" 97 | assert self._sent_id_to_index[sent_id] < len(self.sentences) 98 | return self.sentences[self._sent_id_to_index[sent_id]] 99 | 100 | def get_sentence_ids(self): 101 | return list(self._sent_id_to_index.keys()) 102 | 103 | def to_dict(self): 104 | return { 105 | 'sents': [s.to_dict() for s in self.sentences], 106 | 'sent_id_to_index': self._sent_id_to_index 107 | } 108 | 109 | 110 | class RCClaim: 111 | def __init__( 112 | self, 113 | format: str, 114 | content: Union[str, list], 115 | attributed_sent_ids: List[str] 116 | ) -> None: 117 | self.format = format 118 | self.content = content 119 | self.attributed_sent_ids = attributed_sent_ids 120 | 121 | def __repr__(self) -> str: 122 | if self.format == 'triplet': 123 | return f'("{self.content[0]}", "{self.content[1]}", "{self.content[2]}")' 124 | elif self.format == 'subsentence': 125 | ret = self.content + ' ' 126 | for sid in self.attributed_sent_ids: 127 | ret += f'[{sid}]' 128 | return ret 129 | else: 130 | raise ValueError(f'Unknown Claim Format: {self.format}') 131 | 132 | def get_content(self, preserve_triplet_form=False): 133 | if self.format == 'triplet': 134 | if preserve_triplet_form: 135 | return f'("{self.content[0]}", "{self.content[1]}", "{self.content[2]}")' 136 | else: 137 | return f'{self.content[0]} {self.content[1]} {self.content[2]}' 138 | else: 139 | return self.content 140 | 141 | def to_dict(self): 142 | ret = { 143 | 'format': self.format, 144 | 'content': self.content, 145 | 'attributed_sent_ids': self.attributed_sent_ids 146 | } 147 | return ret 148 | 149 | @classmethod 150 | def from_dict(cls, claim_dict: dict): 151 | return cls( 152 | format=claim_dict['format'], 153 | content=claim_dict['content'], 154 | attributed_sent_ids=claim_dict['attributed_sent_ids'] 155 | ) 156 | 157 | 158 | class ExtractionResult: 159 | def __init__( 160 | self, 161 | claims: List[RCClaim], 162 | response: Union[str, RCText], 163 | extractor_response: str = None, 164 | ) -> None: 165 | self.claims = claims 166 | self.response = response 167 | self.extractor_response = extractor_response -------------------------------------------------------------------------------- /refchecker/checker/README.md: -------------------------------------------------------------------------------- 1 | ## Checker 2 | 3 | Our hallucination checkers take as input a list of reference documents from retrieval or provided by users when querying LLMs, output a label list with each element chosen from `["Entailment", "Neutral", "Contradiction"]`. We provide [LLMChecker](llm_checker.py), and [NLIChecker](nli_checker.py) with the usage demonstrated below. 4 | 5 | ```python 6 | >>> from refchecker.checker import NLIChecker 7 | 8 | >>> checker = NLIChecker() 9 | >>> references = [ 10 | "`` I Dreamed a Dream '' is a song from the musical Les Mis\u00e9rables . " 11 | "It is a solo that is sung by the character Fantine during the first act . " 12 | "The music is by Claude - Michel Sch\u00f6nberg , with orchestrations by " 13 | "John Cameron . The English lyrics are by Neil Diamond And Herbert Kretzmer ," 14 | " based on the original French libretto by Alain Boublil and Jean - Marc " 15 | "Natel from the original French production ." 16 | ] # each element is the reference or list of references for each input example. 17 | >>> claims = [[["I Dreamed a Dream", "originally from", "the stage musical Les Mis\u00e9rables"], 18 | ["I Dreamed a Dream", "written by", "Claude-Michel Sch\u00f6nberg and Alain Boublil"], 19 | ["Anne Hathaway", "sang I Dreamed a Dream in", "the 2012 film adaptation of Les Mis\u00e9rables"]]] 20 | # each element is the claims for each input example. 21 | >>> checker.check( 22 | claims, 23 | references 24 | ) # [['Entailment', 'Contradiction', 'Neutral']] 25 | 26 | ``` 27 | 28 | For LLM-based checkers, we query LLMs with the following prompt to get the prediction: 29 | 30 | ``` 31 | I have a claim that made by a language model to a question, please help me for checking whether the claim can be entailed according to the provided reference which is related to the question. 32 | The reference is a list of passages, and the claim is represented as a triplet formatted with ("subject", "predicate", "object"). 33 | 34 | If the claim is supported by ANY passage in the reference, answer 'Entailment'. 35 | If the claim is contradicted with the reference, answer 'Contradiction'. 36 | If the reference is not relevant to the claim or DOES NOT contain information to verify the claim, answer 'Neutral'. 37 | 38 | Please DO NOT use your own knowledge for the judgement, just compare the reference and the claim to get the answer. 39 | 40 | ### Question: 41 | {question} 42 | 43 | ### Reference: 44 | {reference} 45 | 46 | ### Claim: 47 | {claim} 48 | 49 | Your answer should be only a single word in ['Entailment', 'Neutral', 'Contradiction'] 50 | ``` 51 | 52 | NLI-based checkers conduct pair-wise text classification on premise and hypothesis. We concatenate the question q and the reference R as the premise, and concatenate the three elements in a triplet as the hypothesis. We adopt a pre-trained language model (RoBERTa with 355M parameters) as the encoder, and perform ternary-classification as in the usual NLI setting. 53 | -------------------------------------------------------------------------------- /refchecker/checker/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm_checker import LLMChecker 2 | from .nli_checker import NLIChecker 3 | from .alignscore.alignscore_checker import AlignScoreChecker 4 | -------------------------------------------------------------------------------- /refchecker/checker/alignscore/__init__.py: -------------------------------------------------------------------------------- 1 | from .alignscore import AlignScore -------------------------------------------------------------------------------- /refchecker/checker/alignscore/alignscore.py: -------------------------------------------------------------------------------- 1 | from .inference import Inferencer 2 | from typing import List 3 | 4 | class AlignScore: 5 | def __init__(self, model: str, batch_size: int, device: int, ckpt_path: str, evaluation_mode='nli_sp', verbose=True) -> None: 6 | self.model = Inferencer( 7 | ckpt_path=ckpt_path, 8 | model=model, 9 | batch_size=batch_size, 10 | device=device, 11 | verbose=verbose 12 | ) 13 | self.model.nlg_eval_mode = evaluation_mode 14 | 15 | def score(self, contexts: List[str], claims: List[str]) -> List[float]: 16 | return self.model.nlg_eval(contexts, claims)[1].tolist() -------------------------------------------------------------------------------- /refchecker/checker/alignscore/alignscore_checker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from typing import List, Union 4 | from tqdm import tqdm 5 | 6 | from ..checker_base import CheckerBase 7 | from .inference import Inferencer 8 | from ...base import RCClaim 9 | 10 | import torch 11 | 12 | 13 | LABELS = ["Entailment", "Neutral", "Contradiction"] 14 | 15 | 16 | class AlignScoreChecker(CheckerBase): 17 | def __init__( 18 | self, 19 | ckpt_path='alignscore.ckpt', 20 | device=0, 21 | batch_size=16 22 | ): 23 | """ 24 | Initializes the AlignScoreChecker with the specified checkpoint path, device, and batch size. 25 | 26 | Parameters 27 | ---------- 28 | ckpt_path : str, optional 29 | The path to the AlignScore checkpoint file, defaults to 'alignscore.ckpt'. 30 | device : int, optional 31 | The device to run inference on, defaults to 0. 32 | batch_size : int, optional 33 | The batch size for inference, defaults to 16. 34 | """ 35 | 36 | super().__init__() 37 | self._download_ckpt(ckpt_path) 38 | self.scorer = Inferencer( 39 | ckpt_path, model="roberta-large", device=device, verbose=False 40 | ) 41 | self.batch_size = batch_size 42 | 43 | def _download_ckpt(self, ckpt_path): 44 | if not os.path.exists(ckpt_path): 45 | url = "https://huggingface.co/yzha/AlignScore/resolve/main/AlignScore-large.ckpt" 46 | command=["wget", "-O", ckpt_path, url] 47 | try: 48 | subprocess.call(command) 49 | except Exception as e: 50 | print(e) 51 | 52 | @torch.no_grad() 53 | def _check( 54 | self, 55 | claims: List[Union[str, List[str]]], 56 | references: List[str], 57 | **kwargs 58 | 59 | ): 60 | """ 61 | Batch checking claims against references. 62 | 63 | Parameters 64 | ---------- 65 | claims : List[Union[str, List[str]]] 66 | List of claims. 67 | references : List[str] 68 | List of reference passages (split according to 'max_reference_segment_length'). 69 | responses : List[str] 70 | List of model response texts. 71 | questions : List[str] 72 | List of questions corresponding to each triplet. 73 | 74 | Returns 75 | ------- 76 | ret : List[str] 77 | List of labels for the checking results. 78 | 79 | """ 80 | 81 | N1, N2 = len(references), len(claims) 82 | assert N1 == N2, f"Batches must be of the same length. {N1} != {N2}" 83 | if isinstance(claims[0], list): 84 | assert len(claims[0]) == 3 85 | claims = [f"{c[0]} {c[1]} {c[2]}" for c in claims] 86 | 87 | batch_preds = [] 88 | for i in tqdm(range(0, len(claims), self.batch_size)): 89 | batch_claims = claims[i:i + self.batch_size] 90 | batch_references = references[i:i + self.batch_size] 91 | scores = self.scorer.inference(premise=batch_references, hypo=batch_claims)[-1] 92 | preds = scores.argmax(dim=-1) 93 | batch_preds.extend(preds) 94 | ret = [LABELS[p] for p in batch_preds] 95 | 96 | return ret 97 | 98 | 99 | if __name__ == "__main__": 100 | checker = AlignScoreChecker() 101 | print(checker._check( 102 | claims=["The dog is cute.", "The dog is cute."], 103 | references=["The dog is cute.", "The dog is not cute."], 104 | response=None, question=None 105 | )) 106 | -------------------------------------------------------------------------------- /refchecker/checker/alignscore/inference.py: -------------------------------------------------------------------------------- 1 | from logging import warning 2 | import spacy 3 | from nltk.tokenize import sent_tokenize 4 | import torch 5 | from .model import BERTAlignModel 6 | from transformers import AutoConfig, AutoTokenizer 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | 10 | class Inferencer(): 11 | def __init__(self, ckpt_path, model='bert-base-uncased', batch_size=32, device='cuda', verbose=True) -> None: 12 | self.device = device 13 | if ckpt_path is not None: 14 | self.model = BERTAlignModel(model=model).load_from_checkpoint(checkpoint_path=ckpt_path, strict=False).to(self.device) 15 | else: 16 | warning('loading UNTRAINED model!') 17 | self.model = BERTAlignModel(model=model).to(self.device) 18 | self.model.eval() 19 | self.batch_size = batch_size 20 | 21 | self.config = AutoConfig.from_pretrained(model) 22 | self.tokenizer = AutoTokenizer.from_pretrained(model) 23 | self.spacy = spacy.load('en_core_web_sm') 24 | 25 | self.loss_fct = nn.CrossEntropyLoss(reduction='none') 26 | self.softmax = nn.Softmax(dim=-1) 27 | 28 | self.smart_type = 'smart-n' 29 | self.smart_n_metric = 'f1' 30 | 31 | self.disable_progress_bar_in_inference = False 32 | 33 | self.nlg_eval_mode = None # bin, bin_sp, nli, nli_sp 34 | self.verbose = verbose 35 | 36 | def inference_example_batch(self, premise: list, hypo: list): 37 | """ 38 | inference a example, 39 | premise: list 40 | hypo: list 41 | using self.inference to batch the process 42 | 43 | SummaC Style aggregation 44 | """ 45 | self.disable_progress_bar_in_inference = True 46 | assert len(premise) == len(hypo), "Premise must has the same length with Hypothesis!" 47 | 48 | out_score = [] 49 | for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Evaluating", total=len(premise), disable=(not self.verbose)): 50 | out_score.append(self.inference_per_example(one_pre, one_hypo)) 51 | 52 | return None, torch.tensor(out_score), None 53 | 54 | def inference_per_example(self, premise:str, hypo: str): 55 | """ 56 | inference a example, 57 | premise: string 58 | hypo: string 59 | using self.inference to batch the process 60 | """ 61 | def chunks(lst, n): 62 | """Yield successive n-sized chunks from lst.""" 63 | for i in range(0, len(lst), n): 64 | yield ' '.join(lst[i:i + n]) 65 | 66 | premise_sents = sent_tokenize(premise) 67 | premise_sents = premise_sents or [''] 68 | 69 | n_chunk = len(premise.strip().split()) // 350 + 1 70 | n_chunk = max(len(premise_sents) // n_chunk, 1) 71 | premise_sents = [each for each in chunks(premise_sents, n_chunk)] 72 | 73 | hypo_sents = sent_tokenize(hypo) 74 | 75 | premise_sent_mat = [] 76 | hypo_sents_mat = [] 77 | for i in range(len(premise_sents)): 78 | for j in range(len(hypo_sents)): 79 | premise_sent_mat.append(premise_sents[i]) 80 | hypo_sents_mat.append(hypo_sents[j]) 81 | 82 | if self.nlg_eval_mode is not None: 83 | if self.nlg_eval_mode == 'nli_sp': 84 | output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] ### use NLI head OR ALIGN head 85 | elif self.nlg_eval_mode == 'bin_sp': 86 | output_score = self.inference(premise_sent_mat, hypo_sents_mat)[1] ### use NLI head OR ALIGN head 87 | elif self.nlg_eval_mode == 'reg_sp': 88 | output_score = self.inference(premise_sent_mat, hypo_sents_mat)[0] ### use NLI head OR ALIGN head 89 | 90 | output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect 91 | return output_score 92 | 93 | 94 | output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] ### use NLI head OR ALIGN head 95 | output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect 96 | 97 | return output_score 98 | 99 | 100 | def inference(self, premise, hypo): 101 | """ 102 | inference a list of premise and hypo 103 | 104 | Standard aggregation 105 | """ 106 | if isinstance(premise, str) and isinstance(hypo, str): 107 | premise = [premise] 108 | hypo = [hypo] 109 | 110 | batch = self.batch_tokenize(premise, hypo) 111 | output_score_reg = [] 112 | output_score_bin = [] 113 | output_score_tri = [] 114 | 115 | for mini_batch in tqdm(batch, desc="Evaluating", disable=not self.verbose or self.disable_progress_bar_in_inference): 116 | mini_batch = mini_batch.to(self.device) 117 | with torch.no_grad(): 118 | model_output = self.model(mini_batch) 119 | model_output_reg = model_output.reg_label_logits.cpu() 120 | model_output_bin = model_output.seq_relationship_logits # Temperature Scaling / 2.5 121 | model_output_tri = model_output.tri_label_logits 122 | 123 | model_output_bin = self.softmax(model_output_bin).cpu() 124 | model_output_tri = self.softmax(model_output_tri).cpu() 125 | output_score_reg.append(model_output_reg[:,0]) 126 | output_score_bin.append(model_output_bin[:,1]) 127 | output_score_tri.append(model_output_tri[:,:]) 128 | 129 | output_score_reg = torch.cat(output_score_reg) 130 | output_score_bin = torch.cat(output_score_bin) 131 | output_score_tri = torch.cat(output_score_tri) 132 | 133 | if self.nlg_eval_mode is not None: 134 | if self.nlg_eval_mode == 'nli': 135 | output_score_nli = output_score_tri[:,0] 136 | return None, output_score_nli, None 137 | elif self.nlg_eval_mode == 'bin': 138 | return None, output_score_bin, None 139 | elif self.nlg_eval_mode == 'reg': 140 | return None, output_score_reg, None 141 | else: 142 | ValueError("unrecognized nlg eval mode") 143 | 144 | 145 | return output_score_reg, output_score_bin, output_score_tri 146 | 147 | def inference_reg(self, premise, hypo): 148 | """ 149 | inference a list of premise and hypo 150 | 151 | Standard aggregation 152 | """ 153 | self.model.is_reg_finetune = True 154 | if isinstance(premise, str) and isinstance(hypo, str): 155 | premise = [premise] 156 | hypo = [hypo] 157 | 158 | batch = self.batch_tokenize(premise, hypo) 159 | output_score = [] 160 | 161 | for mini_batch in tqdm(batch, desc="Evaluating", disable=self.disable_progress_bar_in_inference): 162 | mini_batch = mini_batch.to(self.device) 163 | with torch.no_grad(): 164 | model_output = self.model(mini_batch).seq_relationship_logits.cpu().view(-1) 165 | output_score.append(model_output) 166 | output_score = torch.cat(output_score) 167 | return output_score 168 | 169 | def batch_tokenize(self, premise, hypo): 170 | """ 171 | input premise and hypos are lists 172 | """ 173 | assert isinstance(premise, list) and isinstance(hypo, list) 174 | assert len(premise) == len(hypo), "premise and hypo should be in the same length." 175 | 176 | batch = [] 177 | for mini_batch_pre, mini_batch_hypo in zip(self.chunks(premise, self.batch_size), self.chunks(hypo, self.batch_size)): 178 | try: 179 | mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation='only_first', padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') 180 | except: 181 | warning('text_b too long...') 182 | mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation=True, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') 183 | batch.append(mini_batch) 184 | 185 | return batch 186 | def smart_doc(self, premise: list, hypo: list): 187 | """ 188 | inference a example, 189 | premise: list 190 | hypo: list 191 | using self.inference to batch the process 192 | 193 | SMART Style aggregation 194 | """ 195 | self.disable_progress_bar_in_inference = True 196 | assert len(premise) == len(hypo), "Premise must has the same length with Hypothesis!" 197 | assert self.smart_type in ['smart-n', 'smart-l'] 198 | 199 | out_score = [] 200 | for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Evaluating SMART", total=len(premise)): 201 | out_score.append(self.smart_l(one_pre, one_hypo)[1] if self.smart_type == 'smart-l' else self.smart_n(one_pre, one_hypo)[1]) 202 | 203 | return None, torch.tensor(out_score), None 204 | 205 | def smart_l(self, premise, hypo): 206 | premise_sents = [each.text for each in self.spacy(premise).sents] 207 | hypo_sents = [each.text for each in self.spacy(hypo).sents] 208 | 209 | premise_sent_mat = [] 210 | hypo_sents_mat = [] 211 | for i in range(len(premise_sents)): 212 | for j in range(len(hypo_sents)): 213 | premise_sent_mat.append(premise_sents[i]) 214 | hypo_sents_mat.append(hypo_sents[j]) 215 | 216 | output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] 217 | output_score = output_score.view(len(premise_sents), len(hypo_sents)) 218 | 219 | ### smart-l 220 | lcs = [[0] * (len(hypo_sents)+1)] * (len(premise_sents)+1) 221 | for i in range(len(premise_sents)+1): 222 | for j in range(len(hypo_sents)+1): 223 | if i != 0 and j != 0: 224 | m = output_score[i-1, j-1] 225 | lcs[i][j] = max([lcs[i-1][j-1]+m, 226 | lcs[i-1][j]+m, 227 | lcs[i][j-1]]) 228 | 229 | return None, lcs[-1][-1] / len(premise_sents), None 230 | 231 | def smart_n(self, premise, hypo): 232 | ### smart-n 233 | n_gram = 1 234 | 235 | premise_sents = [each.text for each in self.spacy(premise).sents] 236 | hypo_sents = [each.text for each in self.spacy(hypo).sents] 237 | 238 | premise_sent_mat = [] 239 | hypo_sents_mat = [] 240 | for i in range(len(premise_sents)): 241 | for j in range(len(hypo_sents)): 242 | premise_sent_mat.append(premise_sents[i]) 243 | hypo_sents_mat.append(hypo_sents[j]) 244 | 245 | output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] 246 | output_score = output_score.view(len(premise_sents), len(hypo_sents)) 247 | 248 | prec = sum([max([sum([output_score[i+n, j+n]/n_gram for n in range(0, n_gram)]) for i in range(len(premise_sents)-n_gram+1)]) for j in range(len(hypo_sents)-n_gram+1)]) 249 | prec = prec / (len(hypo_sents) - n_gram + 1) if (len(hypo_sents) - n_gram + 1) > 0 else 0. 250 | 251 | 252 | premise_sents = [each.text for each in self.spacy(hypo).sents]# simple change 253 | hypo_sents = [each.text for each in self.spacy(premise).sents]# 254 | 255 | premise_sent_mat = [] 256 | hypo_sents_mat = [] 257 | for i in range(len(premise_sents)): 258 | for j in range(len(hypo_sents)): 259 | premise_sent_mat.append(premise_sents[i]) 260 | hypo_sents_mat.append(hypo_sents[j]) 261 | 262 | output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:,0] 263 | output_score = output_score.view(len(premise_sents), len(hypo_sents)) 264 | 265 | recall = sum([max([sum([output_score[i+n, j+n]/n_gram for n in range(0, n_gram)]) for i in range(len(premise_sents)-n_gram+1)]) for j in range(len(hypo_sents)-n_gram+1)]) 266 | recall = prec / (len(hypo_sents) - n_gram + 1) if (len(hypo_sents) - n_gram + 1) > 0 else 0. 267 | 268 | f1 = 2 * prec * recall / (prec + recall) 269 | 270 | if self.smart_n_metric == 'f1': 271 | return None, f1, None 272 | elif self.smart_n_metric == 'precision': 273 | return None, prec, None 274 | elif self.smart_n_metric == 'recall': 275 | return None, recall, None 276 | else: 277 | ValueError("SMART return type error") 278 | 279 | def chunks(self, lst, n): 280 | """Yield successive n-sized chunks from lst.""" 281 | for i in range(0, len(lst), n): 282 | yield lst[i:i + n] 283 | 284 | def nlg_eval(self, premise, hypo): 285 | assert self.nlg_eval_mode is not None, "Select NLG Eval mode!" 286 | if (self.nlg_eval_mode == 'bin') or (self.nlg_eval_mode == 'nli') or (self.nlg_eval_mode == 'reg'): 287 | return self.inference(premise, hypo) 288 | 289 | elif (self.nlg_eval_mode == 'bin_sp') or (self.nlg_eval_mode == 'nli_sp') or (self.nlg_eval_mode == 'reg_sp'): 290 | return self.inference_example_batch(premise, hypo) 291 | 292 | else: 293 | ValueError("Unrecognized NLG Eval mode!") 294 | -------------------------------------------------------------------------------- /refchecker/checker/alignscore/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | from transformers import get_linear_schedule_with_warmup, AutoConfig 4 | from torch.optim import AdamW 5 | from transformers import BertForPreTraining, BertModel, RobertaModel, AlbertModel, AlbertForMaskedLM, RobertaForMaskedLM 6 | import torch 7 | import torch.nn as nn 8 | import pytorch_lightning as pl 9 | from sklearn.metrics import f1_score 10 | from dataclasses import dataclass 11 | 12 | 13 | 14 | class BERTAlignModel(pl.LightningModule): 15 | def __init__(self, model='bert-base-uncased', using_pretrained=True, *args, **kwargs) -> None: 16 | super().__init__() 17 | # Already defined in lightning: self.device 18 | self.save_hyperparameters() 19 | self.model = model 20 | 21 | if 'muppet' in model: 22 | assert using_pretrained == True, "Only support pretrained muppet!" 23 | self.base_model = RobertaModel.from_pretrained(model) 24 | self.mlm_head = RobertaForMaskedLM(AutoConfig.from_pretrained(model)).lm_head 25 | 26 | elif 'roberta' in model: 27 | if using_pretrained: 28 | self.base_model = RobertaModel.from_pretrained(model) 29 | self.mlm_head = RobertaForMaskedLM.from_pretrained(model).lm_head 30 | else: 31 | self.base_model = RobertaModel(AutoConfig.from_pretrained(model)) 32 | self.mlm_head = RobertaForMaskedLM(AutoConfig.from_pretrained(model)).lm_head 33 | 34 | elif 'albert' in model: 35 | if using_pretrained: 36 | self.base_model = AlbertModel.from_pretrained(model) 37 | self.mlm_head = AlbertForMaskedLM.from_pretrained(model).predictions 38 | else: 39 | self.base_model = AlbertModel(AutoConfig.from_pretrained(model)) 40 | self.mlm_head = AlbertForMaskedLM(AutoConfig.from_pretrained(model)).predictions 41 | 42 | elif 'bert' in model: 43 | if using_pretrained: 44 | self.base_model = BertModel.from_pretrained(model) 45 | self.mlm_head = BertForPreTraining.from_pretrained(model).cls.predictions 46 | else: 47 | self.base_model = BertModel(AutoConfig.from_pretrained(model)) 48 | self.mlm_head = BertForPreTraining(AutoConfig.from_pretrained(model)).cls.predictions 49 | 50 | elif 'electra' in model: 51 | self.generator = BertModel(AutoConfig.from_pretrained('prajjwal1/bert-small')) 52 | self.generator_mlm = BertForPreTraining(AutoConfig.from_pretrained('prajjwal1/bert-small')).cls.predictions 53 | 54 | self.base_model = BertModel(AutoConfig.from_pretrained('bert-base-uncased')) 55 | self.discriminator_predictor = ElectraDiscriminatorPredictions(self.base_model.config) 56 | 57 | 58 | self.bin_layer = nn.Linear(self.base_model.config.hidden_size, 2) 59 | self.tri_layer = nn.Linear(self.base_model.config.hidden_size, 3) 60 | self.reg_layer = nn.Linear(self.base_model.config.hidden_size, 1) 61 | 62 | self.dropout = nn.Dropout(p=0.1) 63 | 64 | self.need_mlm = True 65 | self.is_finetune = False 66 | self.mlm_loss_factor = 0.5 67 | 68 | self.softmax = nn.Softmax(dim=-1) 69 | 70 | def forward(self, batch): 71 | if 'electra' in self.model: 72 | return self.electra_forward(batch) 73 | base_model_output = self.base_model( 74 | input_ids = batch['input_ids'], 75 | attention_mask = batch['attention_mask'], 76 | token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None 77 | ) 78 | 79 | prediction_scores = self.mlm_head(base_model_output.last_hidden_state) ## sequence_output for mlm 80 | seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output)) ## pooled output for classification 81 | tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output)) 82 | reg_label_score = self.reg_layer(base_model_output.pooler_output) 83 | 84 | total_loss = None 85 | if 'mlm_label' in batch.keys(): ### 'mlm_label' and 'align_label' when training 86 | ce_loss_fct = nn.CrossEntropyLoss(reduction='sum') 87 | masked_lm_loss = ce_loss_fct(prediction_scores.view(-1, self.base_model.config.vocab_size), batch['mlm_label'].view(-1)) #/ self.con vocabulary 88 | next_sentence_loss = ce_loss_fct(seq_relationship_score.view(-1, 2), batch['align_label'].view(-1)) / math.log(2) 89 | tri_label_loss = ce_loss_fct(tri_label_score.view(-1, 3), batch['tri_label'].view(-1)) / math.log(3) 90 | reg_label_loss = self.mse_loss(reg_label_score.view(-1), batch['reg_label'].view(-1), reduction='sum') 91 | 92 | masked_lm_loss_num = torch.sum(batch['mlm_label'].view(-1) != -100) 93 | next_sentence_loss_num = torch.sum(batch['align_label'].view(-1) != -100) 94 | tri_label_loss_num = torch.sum(batch['tri_label'].view(-1) != -100) 95 | reg_label_loss_num = torch.sum(batch['reg_label'].view(-1) != -100.0) 96 | 97 | return ModelOutput( 98 | loss=total_loss, 99 | all_loss=[masked_lm_loss, next_sentence_loss, tri_label_loss, reg_label_loss] if 'mlm_label' in batch.keys() else None, 100 | loss_nums=[masked_lm_loss_num, next_sentence_loss_num, tri_label_loss_num, reg_label_loss_num] if 'mlm_label' in batch.keys() else None, 101 | prediction_logits=prediction_scores, 102 | seq_relationship_logits=seq_relationship_score, 103 | tri_label_logits=tri_label_score, 104 | reg_label_logits=reg_label_score, 105 | hidden_states=base_model_output.hidden_states, 106 | attentions=base_model_output.attentions 107 | ) 108 | 109 | def electra_forward(self, batch): 110 | if 'mlm_label' in batch.keys(): 111 | ce_loss_fct = nn.CrossEntropyLoss() 112 | generator_output = self.generator_mlm(self.generator( 113 | input_ids = batch['input_ids'], 114 | attention_mask = batch['attention_mask'], 115 | token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None 116 | ).last_hidden_state) 117 | masked_lm_loss = ce_loss_fct(generator_output.view(-1, self.generator.config.vocab_size), batch['mlm_label'].view(-1)) 118 | 119 | hallucinated_tokens = batch['input_ids'].clone() 120 | 121 | hallucinated_tokens[batch['mlm_label']!=-100] = torch.argmax(generator_output, dim=-1)[batch['mlm_label']!=-100] 122 | replaced_token_label = (batch['input_ids'] == hallucinated_tokens).long()#.type(torch.LongTensor) #[batch['mlm_label'] == -100] = -100 123 | replaced_token_label[batch['mlm_label']!=-100] = (batch['mlm_label'] == hallucinated_tokens)[batch['mlm_label']!=-100].long() 124 | replaced_token_label[batch['input_ids'] == 0] = -100 ### ignore paddings 125 | 126 | base_model_output = self.base_model( 127 | input_ids = hallucinated_tokens if 'mlm_label' in batch.keys() else batch['input_ids'], 128 | attention_mask = batch['attention_mask'], 129 | token_type_ids = batch['token_type_ids'] if 'token_type_ids' in batch.keys() else None 130 | ) 131 | hallu_detect_score = self.discriminator_predictor(base_model_output.last_hidden_state) 132 | seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output)) ## pooled output for classification 133 | tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output)) 134 | reg_label_score = self.reg_layer(base_model_output.pooler_output) 135 | 136 | total_loss = None 137 | 138 | if 'mlm_label' in batch.keys(): ### 'mlm_label' and 'align_label' when training 139 | total_loss = [] 140 | ce_loss_fct = nn.CrossEntropyLoss() 141 | hallu_detect_loss = ce_loss_fct(hallu_detect_score.view(-1,2),replaced_token_label.view(-1)) 142 | next_sentence_loss = ce_loss_fct(seq_relationship_score.view(-1, 2), batch['align_label'].view(-1)) 143 | tri_label_loss = ce_loss_fct(tri_label_score.view(-1, 3), batch['tri_label'].view(-1)) 144 | reg_label_loss = self.mse_loss(reg_label_score.view(-1), batch['reg_label'].view(-1)) 145 | 146 | total_loss.append(10.0 * hallu_detect_loss if not torch.isnan(hallu_detect_loss).item() else 0.) 147 | total_loss.append(0.2 * masked_lm_loss if (not torch.isnan(masked_lm_loss).item() and self.need_mlm) else 0.) 148 | total_loss.append(next_sentence_loss if not torch.isnan(next_sentence_loss).item() else 0.) 149 | total_loss.append(tri_label_loss if not torch.isnan(tri_label_loss).item() else 0.) 150 | total_loss.append(reg_label_loss if not torch.isnan(reg_label_loss).item() else 0.) 151 | 152 | total_loss = sum(total_loss) 153 | 154 | return ModelOutput( 155 | loss=total_loss, 156 | all_loss=[masked_lm_loss, next_sentence_loss, tri_label_loss, reg_label_loss, hallu_detect_loss] if 'mlm_label' in batch.keys() else None, 157 | prediction_logits=hallu_detect_score, 158 | seq_relationship_logits=seq_relationship_score, 159 | tri_label_logits=tri_label_score, 160 | reg_label_logits=reg_label_score, 161 | hidden_states=base_model_output.hidden_states, 162 | attentions=base_model_output.attentions 163 | ) 164 | 165 | def training_step(self, train_batch, batch_idx): 166 | output = self(train_batch) 167 | 168 | return {'losses': output.all_loss, 'loss_nums': output.loss_nums} 169 | 170 | def training_step_end(self, step_output): 171 | losses = step_output['losses'] 172 | loss_nums = step_output['loss_nums'] 173 | assert len(loss_nums) == len(losses), 'loss_num should be the same length as losses' 174 | 175 | loss_mlm_num = torch.sum(loss_nums[0]) 176 | loss_bin_num = torch.sum(loss_nums[1]) 177 | loss_tri_num = torch.sum(loss_nums[2]) 178 | loss_reg_num = torch.sum(loss_nums[3]) 179 | 180 | loss_mlm = torch.sum(losses[0]) / loss_mlm_num if loss_mlm_num > 0 else 0. 181 | loss_bin = torch.sum(losses[1]) / loss_bin_num if loss_bin_num > 0 else 0. 182 | loss_tri = torch.sum(losses[2]) / loss_tri_num if loss_tri_num > 0 else 0. 183 | loss_reg = torch.sum(losses[3]) / loss_reg_num if loss_reg_num > 0 else 0. 184 | 185 | total_loss = self.mlm_loss_factor * loss_mlm + loss_bin + loss_tri + loss_reg 186 | 187 | self.log('train_loss', total_loss)# , sync_dist=True 188 | self.log('mlm_loss', loss_mlm) 189 | self.log('bin_label_loss', loss_bin) 190 | self.log('tri_label_loss', loss_tri) 191 | self.log('reg_label_loss', loss_reg) 192 | 193 | return total_loss 194 | 195 | def validation_step(self, val_batch, batch_idx): 196 | if not self.is_finetune: 197 | with torch.no_grad(): 198 | output = self(val_batch) 199 | 200 | return {'losses': output.all_loss, 'loss_nums': output.loss_nums} 201 | 202 | with torch.no_grad(): 203 | output = self(val_batch)['seq_relationship_logits'] 204 | output = self.softmax(output)[:, 1].tolist() 205 | pred = [int(align_prob>0.5) for align_prob in output] 206 | 207 | labels = val_batch['align_label'].tolist() 208 | 209 | return {"pred": pred, 'labels': labels}#, "preds":preds, "labels":x['labels']} 210 | 211 | def validation_step_end(self, step_output): 212 | losses = step_output['losses'] 213 | loss_nums = step_output['loss_nums'] 214 | assert len(loss_nums) == len(losses), 'loss_num should be the same length as losses' 215 | 216 | loss_mlm_num = torch.sum(loss_nums[0]) 217 | loss_bin_num = torch.sum(loss_nums[1]) 218 | loss_tri_num = torch.sum(loss_nums[2]) 219 | loss_reg_num = torch.sum(loss_nums[3]) 220 | 221 | loss_mlm = torch.sum(losses[0]) / loss_mlm_num if loss_mlm_num > 0 else 0. 222 | loss_bin = torch.sum(losses[1]) / loss_bin_num if loss_bin_num > 0 else 0. 223 | loss_tri = torch.sum(losses[2]) / loss_tri_num if loss_tri_num > 0 else 0. 224 | loss_reg = torch.sum(losses[3]) / loss_reg_num if loss_reg_num > 0 else 0. 225 | 226 | total_loss = self.mlm_loss_factor * loss_mlm + loss_bin + loss_tri + loss_reg 227 | 228 | self.log('train_loss', total_loss)# , sync_dist=True 229 | self.log('mlm_loss', loss_mlm) 230 | self.log('bin_label_loss', loss_bin) 231 | self.log('tri_label_loss', loss_tri) 232 | self.log('reg_label_loss', loss_reg) 233 | 234 | return total_loss 235 | 236 | def validation_epoch_end(self, outputs): 237 | if not self.is_finetune: 238 | total_loss = torch.stack(outputs).mean() 239 | self.log("val_loss", total_loss, prog_bar=True, sync_dist=True) 240 | 241 | else: 242 | all_predictions = [] 243 | all_labels = [] 244 | for each_output in outputs: 245 | all_predictions.extend(each_output['pred']) 246 | all_labels.extend(each_output['labels']) 247 | 248 | self.log("f1", f1_score(all_labels, all_predictions), prog_bar=True, sync_dist=True) 249 | 250 | def configure_optimizers(self): 251 | """Prepare optimizer and schedule (linear warmup and decay)""" 252 | no_decay = ["bias", "LayerNorm.weight"] 253 | optimizer_grouped_parameters = [ 254 | { 255 | "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 256 | "weight_decay": self.hparams.weight_decay, 257 | }, 258 | { 259 | "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 260 | "weight_decay": 0.0, 261 | }, 262 | ] 263 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) 264 | 265 | scheduler = get_linear_schedule_with_warmup( 266 | optimizer, 267 | num_warmup_steps=int(self.hparams.warmup_steps_portion * self.trainer.estimated_stepping_batches), 268 | num_training_steps=self.trainer.estimated_stepping_batches, 269 | ) 270 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 271 | return [optimizer], [scheduler] 272 | 273 | def mse_loss(self, input, target, ignored_index=-100.0, reduction='mean'): 274 | mask = (target == ignored_index) 275 | out = (input[~mask]-target[~mask])**2 276 | if reduction == "mean": 277 | return out.mean() 278 | elif reduction == "sum": 279 | return out.sum() 280 | 281 | class ElectraDiscriminatorPredictions(nn.Module): 282 | """Prediction module for the discriminator, made up of two dense layers.""" 283 | 284 | def __init__(self, config): 285 | super().__init__() 286 | 287 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 288 | self.dense_prediction = nn.Linear(config.hidden_size, 2) 289 | self.config = config 290 | self.gelu = nn.GELU() 291 | 292 | def forward(self, discriminator_hidden_states): 293 | hidden_states = self.dense(discriminator_hidden_states) 294 | hidden_states = self.gelu(hidden_states) 295 | logits = self.dense_prediction(hidden_states).squeeze(-1) 296 | 297 | return logits 298 | 299 | @dataclass 300 | class ModelOutput(): 301 | loss: Optional[torch.FloatTensor] = None 302 | all_loss: Optional[list] = None 303 | loss_nums: Optional[list] = None 304 | prediction_logits: torch.FloatTensor = None 305 | seq_relationship_logits: torch.FloatTensor = None 306 | tri_label_logits: torch.FloatTensor = None 307 | reg_label_logits: torch.FloatTensor = None 308 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 309 | attentions: Optional[Tuple[torch.FloatTensor]] = None -------------------------------------------------------------------------------- /refchecker/checker/checker_base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from itertools import groupby 3 | 4 | from ..utils import split_text 5 | from ..base import RCClaim 6 | 7 | 8 | def merge_ret(ret): 9 | """Merge results from multiple paragraphs""" 10 | if "Entailment" in ret: 11 | return "Entailment" 12 | if "Contradiction" in ret: 13 | return "Contradiction" 14 | return "Neutral" 15 | 16 | 17 | def merge_multi_psg_ret(ret): 18 | """Merge results from multiple passages 19 | TODO: consider possible cases where the results are inconsistent. 20 | """ 21 | if "Entailment" in ret: 22 | return "Entailment" 23 | if "Contradiction" in ret: 24 | return "Contradiction" 25 | return "Neutral" 26 | 27 | 28 | class CheckerBase: 29 | def __init__(self) -> None: 30 | """ 31 | Initializer for the CheckerBase class. 32 | 33 | Initialize labels for 'Entailment', 'Neutral', and 'Contradiction'. 34 | Also initializes a list of all labels. 35 | """ 36 | 37 | self.label_entailment = 'Entailment' 38 | self.label_neutral = 'Neutral' 39 | self.label_contradiction = 'Contradiction' 40 | self.labels = ["Entailment", "Neutral", "Contradiction"] 41 | 42 | def check( 43 | self, 44 | batch_claims: List[List[Union[str, List[str]]]], 45 | batch_references: Union[List[str], List[List[str]]], 46 | batch_questions: List[str] = None, 47 | max_reference_segment_length: int = 0, 48 | merge_psg: bool = False, 49 | is_joint: bool = True, 50 | joint_check_num: int = 5, 51 | sagemaker_client=None, 52 | sagemaker_params=None, 53 | sagemaker_get_response_func=None, 54 | custom_llm_api_func=None, 55 | **kwargs 56 | ): 57 | """ 58 | Check claims against references. 59 | 60 | Parameters 61 | ---------- 62 | batch_claims : List[List[Union[str, List[str]]]] 63 | List consists of the claims extracted from each given example. 64 | batch_references : Union[List[str], List[List[str]]] 65 | List of reference passages for each given example. 66 | batch_questions : List[str], optional 67 | List of questions for each given example, defaults to None. 68 | max_reference_segment_length : int, optional 69 | Maximum length of each reference segment, defaults to 0. 70 | merge_psg : bool, optional 71 | Whether to merge results from multiple passages, defaults to False. 72 | is_joint: bool, optional 73 | Whether perform joint checking for claims to accelerate the checking process. 74 | joint_check_num: int, optional 75 | Number of claims to check jointly in one prompt. Defaults to 5. 76 | 77 | Returns 78 | ------- 79 | results : List[List[str]] 80 | Grouped triplet checking results corresponding to each given example. 81 | 82 | """ 83 | assert len(batch_claims) == len(batch_references) 84 | if batch_questions is None: 85 | batch_questions = [None] * len(batch_claims) 86 | 87 | # check whether the claims or references are empty 88 | valid_batch_claims = [] 89 | valid_batch_references = [] 90 | valid_batch_questions = [] 91 | 92 | empty_claim_indices = set() 93 | empty_ref_indices = set() 94 | for index, (claims, references, questions) in enumerate(zip(batch_claims, batch_references, batch_questions)): 95 | if len(claims) == 0: 96 | empty_claim_indices.add(index) 97 | if isinstance(references, list) and len(references) == 0: 98 | empty_ref_indices.add(index) 99 | 100 | if index not in empty_claim_indices and index not in empty_ref_indices: 101 | valid_batch_claims.append(claims) 102 | valid_batch_references.append(references) 103 | valid_batch_questions.append(questions) 104 | 105 | if is_joint: 106 | # joint checking is for LLM-based checkers only, and it doesn't need merge_psg 107 | labels = self._check( 108 | claims=valid_batch_claims, 109 | references=valid_batch_references, 110 | questions=valid_batch_questions, 111 | is_joint=True, 112 | joint_check_num=joint_check_num, 113 | sagemaker_client=sagemaker_client, 114 | sagemaker_params=sagemaker_params, 115 | sagemaker_get_response_func=sagemaker_get_response_func, 116 | custom_llm_api_func=custom_llm_api_func, 117 | **kwargs 118 | ) 119 | if merge_psg: 120 | labels = [ 121 | [merge_multi_psg_ret(claim_labels) for claim_labels in item_labels] 122 | for item_labels in labels 123 | ] 124 | else: 125 | input_flattened = [] 126 | input_ids = [] 127 | for idx, (claims, references, questions) in enumerate(zip(valid_batch_claims, valid_batch_references, valid_batch_questions)): 128 | if isinstance(references, str): 129 | references = [references] 130 | segments_all_psg = [] 131 | for psg in references: 132 | if max_reference_segment_length > 0: 133 | segments = split_text(psg, max_reference_segment_length) 134 | else: 135 | segments = [psg] 136 | segments_all_psg.append(segments) 137 | for c_idx, claim in enumerate(claims): 138 | for idx_psg, seg_psg in enumerate(segments_all_psg): 139 | for seg in seg_psg: 140 | input_flattened.append([claim, seg, questions]) 141 | input_ids.append([idx, c_idx, idx_psg]) 142 | ret = self._check( 143 | claims=[inp[0] for inp in input_flattened], 144 | references=[inp[1] for inp in input_flattened], 145 | questions=[inp[2] for inp in input_flattened], 146 | is_joint=False, 147 | sagemaker_client=sagemaker_client, 148 | sagemaker_params=sagemaker_params, 149 | sagemaker_get_response_func=sagemaker_get_response_func, 150 | custom_llm_api_func=custom_llm_api_func, 151 | ) 152 | 153 | ret = [[x] + y for x, y in zip(ret, input_ids)] 154 | ret_merge_seg = [[merge_ret([item[0] for item in group])] + key[:-1] for key, group in groupby(ret, key=lambda x: x[1:])] 155 | if merge_psg: 156 | ret_merge_psg = [ 157 | [merge_multi_psg_ret([item[0] for item in group])] + key[:-1] 158 | for key, group in groupby(ret_merge_seg, key=lambda x: x[1:]) 159 | ] 160 | else: 161 | ret_merge_psg = [ 162 | [[item[0] for item in group]] + key[:-1] 163 | for key, group in groupby(ret_merge_seg, key=lambda x: x[1:]) 164 | ] 165 | labels = [[item[0] for item in group] for key, group in groupby(ret_merge_psg, key=lambda x: x[1:])] 166 | 167 | # filling the results with empty claims or references 168 | final_labels = [] 169 | cur_i = 0 170 | for index, (claims, references) in enumerate(zip(batch_claims, batch_references)): 171 | if index in empty_claim_indices: 172 | final_labels.append([]) 173 | elif index in empty_ref_indices: 174 | final_labels.append([[]] * len(claims)) 175 | else: 176 | final_labels.append(labels[cur_i]) 177 | cur_i += 1 178 | 179 | return final_labels # [batch_size, claim_num, reference_num] 180 | 181 | 182 | def _check( 183 | self, 184 | claims: List[RCClaim], 185 | references: List[str], 186 | responses: List[str], 187 | questions: List[str], 188 | **kwargs 189 | ): 190 | """ 191 | Internal method for checking claims against references. 192 | 193 | This method should be implemented by subclasses. 194 | 195 | Parameters 196 | ---------- 197 | claims : List[RCClaim] 198 | List of claims to be checked. 199 | references : List[str] 200 | List of reference passages. 201 | responses : List[str] 202 | List of model response texts. 203 | questions : List[str] 204 | List of questions. 205 | 206 | Returns 207 | ------- 208 | List[str] 209 | List of checking results. 210 | """ 211 | 212 | raise NotImplementedError 213 | -------------------------------------------------------------------------------- /refchecker/checker/checker_prompts.py: -------------------------------------------------------------------------------- 1 | JOINT_CHECKING_PROMPT_Q = """I have a list of claims that made by a language model to a question, please help me for checking whether the claims can be entailed according to the provided reference which is related to the question. 2 | The reference is a list of passages, and each of the claims is represented as a triplet formatted with ("subject", "predicate", "object"). 3 | 4 | If the claim is supported by ANY passage in the reference, answer 'Entailment'. 5 | If NO passage in the reference entail the claim, and the claim is contradicted with some passage in the reference, answer 'Contradiction'. 6 | If NO passage entail or contradict with claim, or DOES NOT contain information to verify the claim, answer 'Neutral'. 7 | 8 | Please DO NOT use your own knowledge for the judgement, just compare the reference and the claim to get the answer. 9 | 10 | ### Question: 11 | [QUESTION] 12 | 13 | ### Reference: 14 | [REFERENCE] 15 | 16 | ### Claims: 17 | [CLAIMS] 18 | 19 | 20 | Your answer should always be only a list of labels, each of the labels is a single word in ['Entailment', 'Neutral', 'Contradiction'], for example, you should output a list like follows: 21 | 22 | Entailment 23 | Neutral 24 | Contradiction 25 | Neutral 26 | 27 | 28 | DO NOT add explanations or you own reasoning to the output, only output the label list. 29 | """ 30 | 31 | 32 | LLM_CHECKING_PROMPT_Q = \ 33 | """I have a claim that made by a language model to a question, please help me for checking whether the claim can be entailed according to the provided reference which is related to the question. 34 | The reference is a list of passages, and the claim is represented as a triplet formatted with ("subject", "predicate", "object"). 35 | 36 | If the claim is supported by ANY passage in the reference, answer 'Entailment'. 37 | If NO passage in the reference entail the claim, and the claim is contradicted with some passage in the reference, answer 'Contradiction'. 38 | If NO passage entail or contradict with claim, or DOES NOT contain information to verify the claim, answer 'Neutral'. 39 | 40 | Please DO NOT use your own knowledge for the judgement, just compare the reference and the claim to get the answer. 41 | 42 | ### Question: 43 | {question} 44 | 45 | ### Reference: 46 | {reference} 47 | 48 | ### Claim: 49 | {claim} 50 | 51 | Your answer should always be only a single word in ['Entailment', 'Neutral', 'Contradiction']. DO NOT add explanations or you own reasoning to the output. 52 | """ 53 | 54 | LLM_CHECKING_PROMPT = \ 55 | """I have a claim that made by a language model, please help me for checking whether the claim can be entailed according to the provided reference. 56 | The reference is a list of passages, and the claim is represented as a triplet formatted with ("subject", "predicate", "object"). 57 | 58 | If the claim is supported by ANY passage in the reference, answer 'Entailment'. 59 | If NO passage in the reference entail the claim, and the claim is contradicted with some passage in the reference, answer 'Contradiction'. 60 | If NO passage entail or contradict with claim, or DOES NOT contain information to verify the claim, answer 'Neutral'. 61 | 62 | Please DO NOT use your own knowledge for the judgement, just compare the reference and the claim to get the answer. 63 | 64 | ### Reference: 65 | {reference} 66 | 67 | ### Claim: 68 | {claim} 69 | 70 | Your answer should always be only a single word in ['Entailment', 'Neutral', 'Contradiction']. DO NOT add explanations or you own reasoning to the output. 71 | """ 72 | 73 | 74 | SUBSENTENCE_CLAIM_CHECKING_PROMPT = \ 75 | """I have a claim that made by a language model, please help me for checking whether the claim can be entailed according to the provided reference. 76 | The reference is a list of passages, and the claim is a sentence. 77 | 78 | If the claim is supported by ANY passage in the reference, answer 'Entailment'. 79 | If NO passage in the reference entail the claim, and the claim is contradicted with some passage in the reference, answer 'Contradiction'. 80 | If NO passage entail or contradict with claim, or DOES NOT contain information to verify the claim, answer 'Neutral'. 81 | 82 | Please DO NOT use your own knowledge for the judgement, just compare the reference and the claim to get the answer. 83 | 84 | ### Reference: 85 | {reference} 86 | 87 | ### Claim: 88 | {claim} 89 | 90 | Your answer should always be only a single word in ['Entailment', 'Neutral', 'Contradiction']. DO NOT add explanations or you own reasoning to the output. 91 | """ -------------------------------------------------------------------------------- /refchecker/checker/llm_checker.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Union 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | from .checker_base import CheckerBase 7 | from ..utils import get_model_batch_response 8 | from .checker_prompts import * 9 | 10 | 11 | class LLMChecker(CheckerBase): 12 | def __init__( 13 | self, 14 | model: str = 'bedrock/anthropic.claude-3-sonnet-20240229-v1:0', 15 | batch_size: int = 16, 16 | api_base: str = None 17 | ) -> None: 18 | """ 19 | Initializer for the LLMChecker class. 20 | 21 | Initializes LLMChecker with the provided model and batch size. 22 | 23 | Parameters: 24 | ----------- 25 | model : str 26 | The name or identifier of the language model to use. 27 | batch_size : int, optional 28 | Batch size for checking, defaults to 16. 29 | """ 30 | 31 | super().__init__() 32 | self.prompt_temp = LLM_CHECKING_PROMPT 33 | self.prompt_temp_wq = LLM_CHECKING_PROMPT_Q 34 | self.prompt_temp_subsent = SUBSENTENCE_CLAIM_CHECKING_PROMPT 35 | 36 | self.batch_size = batch_size 37 | 38 | self.model = model 39 | self.api_base = api_base 40 | 41 | def _check( 42 | self, 43 | claims: List[Union[str, List[str], List[List[str]]]], 44 | references: List[Union[str, List[str]]], 45 | questions: List[str] = None, 46 | is_joint: bool = False, 47 | joint_check_num: int = 5, 48 | sagemaker_client=None, 49 | sagemaker_params=None, 50 | sagemaker_get_response_func=None, 51 | custom_llm_api_func=None, 52 | **kwargs 53 | ): 54 | """ 55 | Batch checking claims against references. 56 | 57 | Parameters 58 | ---------- 59 | claims : List[Union[str, List[str]]] 60 | List of claims. 61 | references : List[str] 62 | List of reference passages (split according to 'max_reference_segment_length'). 63 | responses : List[str] 64 | List of model response texts. 65 | questions : List[str] 66 | List of questions corresponding to each triplet. 67 | is_joint: bool, optional 68 | Whether perform joint checking for claims to accelerate the checking process. 69 | joint_check_num: int, optional 70 | Number of claims to check jointly in one prompt. Defaults to 5. 71 | Returns 72 | ------- 73 | ret : List[str] 74 | List of labels for the checking results. 75 | 76 | """ 77 | if is_joint: 78 | batch_claim_nums = [len(claims_per_batch) for claims_per_batch in claims] 79 | batch_ref_nums = [] 80 | for ref_per_batch in references: 81 | if isinstance(ref_per_batch, str): 82 | batch_ref_nums.append(1) 83 | else: 84 | assert isinstance(ref_per_batch, list) 85 | batch_ref_nums.append(len(ref_per_batch)) 86 | 87 | prompt_template = JOINT_CHECKING_PROMPT_Q 88 | 89 | prompt_list = [] 90 | prompt_ids = [] # for setting the limit of max num of claims 91 | claim_nums = [] 92 | p_id = 0 93 | for claims_per_batch, references_per_batch, question_per_batch in zip(claims, references, questions): 94 | if len(claims_per_batch) == 0: 95 | continue 96 | 97 | if isinstance(references_per_batch, str): 98 | references_per_batch = [references_per_batch] 99 | 100 | for ref in references_per_batch: 101 | _claim_cnt = 0 102 | claims_text = '' 103 | 104 | for _ci, c in enumerate(claims_per_batch): 105 | claims_text += f'("{c[0]}", "{c[1]}", "{c[2]}")\n' 106 | _claim_cnt += 1 107 | if _claim_cnt >= joint_check_num or _ci == len(claims_per_batch) - 1: 108 | prompt = prompt_template.replace('[QUESTION]', question_per_batch) 109 | prompt = prompt.replace('[REFERENCE]', ref) 110 | prompt = prompt.replace('[CLAIMS]', claims_text.strip()) 111 | prompt_list.append(prompt) 112 | 113 | prompt_ids.append(p_id) 114 | claim_nums.append(_claim_cnt) 115 | _claim_cnt = 0 116 | claims_text = '' 117 | 118 | p_id += 1 119 | 120 | labels_list = [] 121 | for i in tqdm(range(0, len(prompt_list), self.batch_size)): 122 | batch_prompts = prompt_list[i:i + self.batch_size] 123 | 124 | llm_responses = get_model_batch_response( 125 | prompts=batch_prompts, 126 | temperature=0, 127 | model=self.model, 128 | max_new_tokens=joint_check_num * 10 + 100, 129 | api_base=self.api_base, 130 | sagemaker_client=sagemaker_client, 131 | sagemaker_params=sagemaker_params, 132 | sagemaker_get_response_func=sagemaker_get_response_func, 133 | custom_llm_api_func=custom_llm_api_func, 134 | **kwargs 135 | ) 136 | 137 | for llm_response in llm_responses: 138 | if llm_response is not None: 139 | labels = self._parse_joint_checking_labels(llm_response) 140 | labels_list.append(labels) 141 | else: 142 | raise 'API returns None or empty string' 143 | 144 | # pad labels with Neutral 145 | assert len(claim_nums) == len(labels_list) 146 | for _i, claim_n in enumerate(claim_nums): 147 | if len(labels_list[_i]) < claim_n: 148 | labels_list[_i] = labels_list[_i] + ['Neutral'] * (claim_n - len(labels_list[_i])) 149 | elif len(labels_list[_i]) > claim_n: 150 | labels_list[_i] = labels_list[_i][:claim_n] 151 | # merge labels 152 | merged_label_list = [] 153 | for _i, _pid in enumerate(prompt_ids): 154 | if _i > 0 and _pid == prompt_ids[_i - 1]: 155 | merged_label_list[-1] += labels_list[_i] 156 | else: 157 | merged_label_list.append(labels_list[_i]) 158 | 159 | ret_labels = [] 160 | _index = 0 161 | for _i, claim_num in enumerate(batch_claim_nums): 162 | if claim_num > 0: 163 | one_batch_labels = merged_label_list[_index: _index + batch_ref_nums[_i]] # [ref_num, claim_num] 164 | 165 | _index += batch_ref_nums[_i] 166 | 167 | one_batch_labels = np.array(one_batch_labels).transpose(1, 0) 168 | # if batch_ref_nums[_i] == 1: 169 | # one_batch_labels = one_batch_labels.squeeze(-1) 170 | ret_labels.append(one_batch_labels.tolist()) 171 | else: 172 | ret_labels.append([]) 173 | return ret_labels 174 | else: 175 | ret_labels = [] 176 | prompt_list = [] 177 | for claim, reference, question in zip(claims, references, questions): 178 | claim_text = str(claim) 179 | 180 | if isinstance(claim, list) and len(claim) == 3: 181 | if question is None: 182 | prompt = self.prompt_temp.format( 183 | reference=reference, 184 | claim=claim_text 185 | ) 186 | else: 187 | prompt = self.prompt_temp_wq.format( 188 | question=question, 189 | reference=reference, 190 | claim=claim_text 191 | ) 192 | elif isinstance(claim, str): 193 | if question and len(question): 194 | reference = question + ' ' + reference 195 | prompt = self.prompt_temp_subsent.format( 196 | reference=reference, 197 | claim=claim_text 198 | ) 199 | else: 200 | raise f'Unknown claim format: {type(claim)}' 201 | prompt_list.append(prompt) 202 | 203 | for i in tqdm(range(0, len(prompt_list), self.batch_size)): 204 | batch_prompts = prompt_list[i:i + self.batch_size] 205 | 206 | llm_responses = get_model_batch_response( 207 | prompts=batch_prompts, 208 | temperature=0, 209 | model=self.model, 210 | max_new_tokens=10, 211 | api_base=self.api_base, 212 | sagemaker_client=sagemaker_client, 213 | sagemaker_params=sagemaker_params, 214 | sagemaker_get_response_func=sagemaker_get_response_func, 215 | custom_llm_api_func=custom_llm_api_func, 216 | **kwargs 217 | ) 218 | 219 | for llm_response in llm_responses: 220 | if llm_response and len(llm_response): 221 | label = None 222 | if self.label_contradiction.lower() in llm_response.lower(): 223 | label = self.label_contradiction 224 | elif self.label_entailment.lower() in llm_response.lower(): 225 | label = self.label_entailment 226 | else: 227 | label = self.label_neutral 228 | ret_labels.append(label) 229 | else: 230 | raise 'API returns None or empty string' 231 | return ret_labels 232 | 233 | def _parse_joint_checking_labels(self, text): 234 | pattern = r'\b(Entailment|Neutral|Contradiction)\b' 235 | matches = re.findall(pattern, text, re.IGNORECASE) 236 | parsed_labels = [label.title() for label in matches] 237 | return parsed_labels -------------------------------------------------------------------------------- /refchecker/checker/nli_checker.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | from tqdm import tqdm 3 | 4 | import torch 5 | from transformers import ( 6 | AutoTokenizer, AutoModelForSequenceClassification 7 | ) 8 | 9 | from .checker_base import CheckerBase 10 | from ..base import RCClaim 11 | 12 | 13 | LABELS = ["Entailment", "Neutral", "Contradiction"] 14 | 15 | 16 | class NLIChecker(CheckerBase): 17 | def __init__( 18 | self, 19 | model='ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli', 20 | device=0, # Todo: support distributed inference 21 | batch_size=16 22 | ): 23 | """ 24 | Initializes the NLIChecker with the specified model, device, and batch size. 25 | 26 | Parameters 27 | ---------- 28 | model : str, optional 29 | The name or identifier of the model to use, defaults to 'ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli'. 30 | device : int, optional 31 | The device to run inference on, defaults to 0. 32 | batch_size : int, optional 33 | The batch size for inference, defaults to 16. 34 | """ 35 | 36 | super().__init__() 37 | self.model = AutoModelForSequenceClassification.from_pretrained(model).to(device) 38 | self.model.eval() 39 | self.tokenizer = AutoTokenizer.from_pretrained(model) 40 | self.device = device 41 | self.batch_size = batch_size 42 | 43 | @torch.no_grad() 44 | def _check( 45 | self, 46 | claims: List[Union[str, List[str], List[List[str]]]], 47 | references: List[Union[str, List[str]]], 48 | **kwargs 49 | ): 50 | """ 51 | Batch checking claims against references. 52 | 53 | Parameters 54 | ---------- 55 | claims : List[RCClaim] 56 | List of claims. 57 | references : List[str] 58 | List of reference passages (split according to 'max_reference_segment_length'). 59 | responses : List[str] 60 | List of model response texts. 61 | questions : List[str] 62 | List of questions corresponding to each triplet. 63 | 64 | Returns 65 | ------- 66 | ret : List[str] 67 | List of labels for the checking results. 68 | 69 | """ 70 | 71 | N1, N2 = len(references), len(claims) 72 | assert N1 == N2, f"Batches must be of the same length. {N1} != {N2}" 73 | # claims = [c.get_content() for c in claims] 74 | batch_preds = [] 75 | for i in tqdm(range(0, len(claims), self.batch_size)): 76 | batch_claims = claims[i:i + self.batch_size] 77 | batch_references = references[i:i + self.batch_size] 78 | 79 | inputs = self.tokenizer( 80 | batch_references, batch_claims, max_length=512, truncation=True, 81 | return_tensors="pt", padding=True, return_token_type_ids=True 82 | ) 83 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 84 | output = self.model(**inputs).logits.softmax(dim=-1).cpu() # [batch_size, 3] 85 | preds = output.argmax(dim=-1) 86 | batch_preds.extend(preds) 87 | ret = [LABELS[p] for p in batch_preds] 88 | 89 | return ret 90 | -------------------------------------------------------------------------------- /refchecker/checker/repc/general.py: -------------------------------------------------------------------------------- 1 | prompt_template_dict = { 2 | "teknium/OpenHermes-2.5-Mistral-7B": 3 | { 4 | "system_begin": "<|im_start|>system\n", 5 | "system_end": "<|im_end|>\n", 6 | "user_begin": "<|im_start|>user\n", 7 | "user_end": "<|im_end|>\n", 8 | "assistant_begin": "<|im_start|>assistant\n", 9 | "assistant_end": "<|im_end|>\n" 10 | }, 11 | "42MARU/GenAI-llama-2-13b": 12 | { 13 | "system_begin": "### System:\n", 14 | "system_end": "\n\n", 15 | "user_begin": "### User:\n", 16 | "user_end": "\n\n", 17 | "assistant_begin": "### Assistant:\n" 18 | }, 19 | "ehartford/dolphin-2.1-mistral-7b": 20 | { 21 | "system_begin": "<|im_start|>system\n", 22 | "system_end": "<|im_end|>\n", 23 | "user_begin": "<|im_start|>user\n", 24 | "user_end": "<|im_end|>\n", 25 | "assistant_begin": "<|im_start|>assistant\n", 26 | "assistant_end": "<|im_end|>\n" 27 | }, 28 | "Qwen/Qwen-7B-Chat": 29 | { 30 | "system_begin": "<|im_start|>system\n", 31 | "system_end": "<|im_end|>\n", 32 | "user_begin": "<|im_start|>user\n", 33 | "user_end": "<|im_end|>\n", 34 | "assistant_begin": "<|im_start|>assistant\n", 35 | "assistant_end": "<|im_end|>\n" 36 | }, 37 | "daryl149/llama-2-7b-chat-hf": 38 | { 39 | "system_begin": "[INST] <>\n", 40 | "system_end": "\n<>\n\n", 41 | "user_begin": "", 42 | "user_end": "", 43 | "assistant_begin": " [/INST] ", 44 | "assistant_end": "[INST] " 45 | }, 46 | "AIDC-ai-business/Marcoroni-7B-v3": 47 | { 48 | "system_begin": "### Instruction:\n", 49 | "system_end": "\n\n", 50 | "user_begin": "### Input:\n", 51 | "user_end": "\n\n", 52 | "assistant_begin": "### Response:\n", 53 | "assistant_end": "\n\n" 54 | }, 55 | "/home/daven/research/GeoLLaMA-PT/outputs/llama-2023-05-27-23-10/checkpoint-30140": 56 | { 57 | "system_begin": "### Instruction:\n", 58 | "system_end": "\n\n", 59 | "user_begin": "### Input:\n", 60 | "user_end": "\n\n", 61 | "assistant_begin": "### Response:\n" 62 | }, 63 | "/home/daven/geo30b_ckpts/step900_sft": 64 | { 65 | "system_begin": "### Instruction:\n", 66 | "system_end": "\n\n", 67 | "user_begin": "### Input:\n", 68 | "user_end": "\n\n", 69 | "assistant_begin": "### Response:\n" 70 | }, 71 | } -------------------------------------------------------------------------------- /refchecker/checker/repc/repc_checker.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import tarfile 4 | from huggingface_hub import hf_hub_download 5 | from typing import Any, List, Union 6 | 7 | from transformers import ( 8 | AutoTokenizer, AutoModelForCausalLM 9 | ) 10 | 11 | from ..checker_base import CheckerBase 12 | from .ml_models import * 13 | from ...base import RCClaim 14 | 15 | LABELS = ["Entailment", "Neutral", "Contradiction"] 16 | 17 | prompt_template_dict = { 18 | "chatml": 19 | { 20 | "system_begin": "<|im_start|>system\n", 21 | "system_end": "<|im_end|>\n", 22 | "user_begin": "<|im_start|>user\n", 23 | "user_end": "<|im_end|>\n", 24 | "assistant_begin": "<|im_start|>assistant\n", 25 | "assistant_end": "<|im_end|>\n" 26 | } 27 | } 28 | class RepCChecker(CheckerBase): 29 | def __init__( 30 | self, 31 | model='teknium/OpenHermes-2.5-Mistral-7B', 32 | classifier='nn_ensemble', 33 | classifier_dir='saved_models/repc', 34 | prompt_style='chatml', 35 | selected_token=-1, 36 | device=0, 37 | batch_size=16 38 | ): 39 | """ 40 | Initializes the RepCChecker with the specified parameters. 41 | 42 | Parameters 43 | ---------- 44 | model : str, optional 45 | The name or identifier of the RepC backbone to use, defaults to 'teknium/OpenHermes-2.5-Mistral-7B'. 46 | classifier : str, optional 47 | The type of classifier to use, must be one of ['svm', 'nn', 'svm_ensemble', 'nn_ensemble'], defaults to 'nn_ensemble'. 48 | classifier_dir : str, optional 49 | The directory to save/load the classifier model, defaults to 'saved_models/repc'. 50 | prompt_style : str, optional 51 | The style of the prompt to use, defaults to 'chatml'. 52 | selected_token : int, optional 53 | The selected token index to obtain the embedding used for classification, defaults to -1 (the last token). 54 | device : int, optional 55 | The device to run classifier on, defaults to 0. 56 | batch_size : int, optional 57 | The batch size for the backbone model, defaults to 16. 58 | """ 59 | 60 | super().__init__() 61 | self.model = AutoModelForCausalLM.from_pretrained( 62 | model, 63 | device_map="cuda:1", 64 | torch_dtype=torch.float16, 65 | trust_remote_code=True, 66 | ) 67 | self.model.eval() 68 | self.tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left") 69 | self.tokenizer.pad_token = self.tokenizer.eos_token 70 | self.prompt_style = prompt_style 71 | self.selected_token = selected_token 72 | self.device = device 73 | self.classifier_str = classifier 74 | self.classifier_dir = classifier_dir 75 | self.batch_size = batch_size 76 | if classifier == "nn_ensemble": 77 | self.n_train = 2000 78 | expert_paths = [f"{self.classifier_dir}/nn/upload/nn_anli_n{self.n_train}_l{i}" for i in range(self.model.config.num_hidden_layers)] 79 | if not os.path.exists(f"{self.classifier_dir}/nn/upload/nn_anli_n{self.n_train}_l31"): 80 | hf_hub_download(repo_id="zthang/repe", filename="nn.tar.gz", local_dir=self.classifier_dir) 81 | tar = tarfile.open(os.path.join(self.classifier_dir, "nn.tar.gz"), "r:gz") 82 | tar.extractall(path=os.path.join(self.classifier_dir, "nn")) 83 | tar.close() 84 | self.classifier = EnsembleClassifier(input_size=(self.model.config.num_hidden_layers) * 3, 85 | output_size=3, 86 | num_experts=self.model.config.num_hidden_layers, 87 | expert_paths=expert_paths, 88 | expert_type="nn", 89 | classifier_type="mlp") 90 | self.classifier_path = os.path.join(self.classifier_dir, "ensemble_mlp_nn_2000_anli_n2000_l0") 91 | if not os.path.exists(self.classifier_path): 92 | hf_hub_download(repo_id="zthang/repe", filename="ensemble_mlp_nn_2000_anli_n2000_l0", local_dir=self.classifier_dir) 93 | elif classifier == "nn": 94 | self.selected_layer = 17 95 | self.n_train = 2000 96 | self.input_size = 4096 97 | self.hidden_size = 4096 // 4 98 | self.classifier = PyTorchClassifier(input_size=self.input_size, hidden_size=self.hidden_size) 99 | self.classifier_path = f"{self.classifier_dir}/nn/nn_anli_n{self.n_train}_l{self.selected_layer}" 100 | if not os.path.exists(self.classifier_path): 101 | hf_hub_download(repo_id="zthang/repe", filename=f"nn/nn_anli_n{self.n_train}_l{self.selected_layer}", local_dir=self.classifier_dir) 102 | elif classifier == "svm_ensemble": 103 | self.n_train = 1000 104 | expert_paths = [f"{self.classifier_dir}/svm/upload/svm_anli_n{self.n_train}_l{i}" for i in range(self.model.config.num_hidden_layers)] 105 | if not os.path.exists(f"{self.classifier_dir}/svm/upload/svm_anli_n{self.n_train}_l31"): 106 | hf_hub_download(repo_id="zthang/repe", filename="svm.tar.gz", local_dir=self.classifier_dir) 107 | tar = tarfile.open(os.path.join(self.classifier_dir, "svm.tar.gz"), "r:gz") 108 | tar.extractall(path=os.path.join(self.classifier_dir, "svm")) 109 | tar.close() 110 | self.classifier = EnsembleClassifier(input_size=(self.model.config.num_hidden_layers) * 3, 111 | output_size=3, 112 | num_experts=self.model.config.num_hidden_layers, 113 | expert_paths=expert_paths, 114 | expert_type="svm", 115 | classifier_type="mlp") 116 | self.classifier_path = os.path.join(self.classifier_dir, "ensemble_mlp_svm_1000_anli_n1000_l0") 117 | if not os.path.exists(self.classifier_path): 118 | hf_hub_download(repo_id="zthang/repe", filename="ensemble_mlp_svm_1000_anli_n1000_l0", local_dir=self.classifier_dir) 119 | elif classifier == "svm": 120 | self.selected_layer = 15 121 | self.n_train = 1000 122 | self.classifier = SVM(kernel="rbf") 123 | self.classifier_path = f"{self.classifier_dir}/svm/svm_anli_n{self.n_train}_l{self.selected_layer}" 124 | if not os.path.exists(self.classifier_path): 125 | hf_hub_download(repo_id="zthang/repe", filename=f"svm/svm_anli_n{self.n_train}_l{self.selected_layer}", local_dir=self.classifier_dir) 126 | else: 127 | raise ValueError("classifier must in [svm, nn, svm_ensemble, nn_ensemble.") 128 | self.classifier.load(self.classifier_path) 129 | 130 | def get_prompt(self, prompt_style, question, premise, hypothesis): 131 | return f"{prompt_template_dict[prompt_style]['system_begin']}Consider the NLI label between the user given premise and hypothesis.{prompt_template_dict[prompt_style]['system_end']}" \ 132 | f"{prompt_template_dict[prompt_style]['user_begin']}Premise: {question}\n{premise}\nHypothesis: {hypothesis}{prompt_template_dict[prompt_style]['user_end']}" \ 133 | f"{prompt_template_dict[prompt_style]['assistant_begin']}The NLI label (Entailment, Neutral, Contradiction) is" 134 | 135 | @torch.no_grad() 136 | def _check( 137 | self, 138 | claims: List[RCClaim], 139 | references: List[str], 140 | responses: List[str], 141 | questions: List[str], 142 | ): 143 | """ 144 | Batch checking claims against references. 145 | 146 | Parameters 147 | ---------- 148 | claims : List[RCClaim] 149 | List of claims. 150 | references : List[str] 151 | List of reference passages (split according to 'max_reference_segment_length'). 152 | responses : List[str] 153 | List of model response texts. 154 | questions : List[str] 155 | List of questions corresponding to each triplet. 156 | 157 | Returns 158 | ------- 159 | ret : List[str] 160 | List of labels for the checking results. 161 | 162 | """ 163 | 164 | N1, N2 = len(references), len(claims) 165 | assert N1 == N2, f"Batches must be of the same length. {N1} != {N2}" 166 | 167 | claims = [c.get_content() for c in claims] 168 | 169 | batch_preds = [] 170 | prompt_list = [self.get_prompt(prompt_style=self.prompt_style, question=questions[i], premise=references[i], hypothesis=claims[i]) for i in range(N1)] 171 | for i in tqdm(range(0, len(prompt_list), self.batch_size)): 172 | batch_prompts = prompt_list[i:i + self.batch_size] 173 | inputs = self.tokenizer(batch_prompts, return_tensors="pt", padding=True) 174 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 175 | res = self.model(**inputs, output_hidden_states=True, use_cache=False) 176 | if self.classifier_str in ["svm", "nn"]: 177 | hidden_states = res["hidden_states"][1:][self.selected_layer].cpu().numpy() 178 | hidden_states = hidden_states[:, self.selected_token, :] 179 | else: 180 | hidden_states = torch.stack(res["hidden_states"][1:]).transpose(0, 1)[:, :, -1, :].cpu().numpy() 181 | preds = self.classifier.predict(hidden_states) 182 | batch_preds.extend(preds) 183 | ret = [LABELS[p] for p in batch_preds] 184 | return ret 185 | 186 | if __name__ == "__main__": 187 | claims = ["H&R Block Online time to process tax return 1-2 days", "H&R Block Online time to process tax return 1-2 days"] 188 | references = ["I can’t imagine how it would take 2 hours. What record keeping does someone with a 1040ez generally need to do? I used to do all my taxes by hand. I figured about 8 hours total for federal and state and my tax situation was not simple. ", 189 | "I did a full 1040 with stock sales, itemized deductions and rental properties. They list “1 hour for form submission”. How does that take 1 hour? Most people hit ‘submit’ on their tax software and other people shove them in an envelope and put them in the mail. Where do they get 1 hour from?"] 190 | question = "how long does it usually take to get taxes back from h & r block online?" 191 | checker = RepCChecker() 192 | ret = checker._check(claims=claims, references=references, question=question, response="") 193 | print(ret) 194 | -------------------------------------------------------------------------------- /refchecker/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from argparse import ArgumentParser, RawTextHelpFormatter 4 | from tqdm import tqdm 5 | 6 | from .extractor import LLMExtractor 7 | from .checker import LLMChecker 8 | 9 | from .retriever import GoogleRetriever 10 | from .aggregator import strict_agg, soft_agg, major_agg 11 | from .base import RCClaim 12 | 13 | 14 | def get_args(): 15 | parser = ArgumentParser(formatter_class=RawTextHelpFormatter) 16 | parser.add_argument( 17 | "mode", nargs="?", choices=["extract", "check", "extract-check"], 18 | help="extract: Extract claims from provided responses.\n" 19 | "check: Check whether the provided claims are factual.\n" 20 | "extract-check: Extract claims and check whether they are factual." 21 | ) 22 | parser.add_argument( 23 | "--input_path", type=str, required=True, 24 | help="Input path to the json file." 25 | ) 26 | parser.add_argument( 27 | "--output_path", type=str, required=True, 28 | help="Output path to the result json file." 29 | ) 30 | parser.add_argument( 31 | "--cache_dir", type=str, default="./.cache", 32 | help="Path to the cache directory. Default: ./.cache" 33 | ) 34 | parser.add_argument( 35 | '--extractor_name', type=str, default="bedrock/anthropic.claude-3-sonnet-20240229-v1:0", 36 | help="Model used for extracting triplets. Default: bedrock/anthropic.claude-3-sonnet-20240229-v1:0" 37 | ) 38 | parser.add_argument( 39 | '--extractor_max_new_tokens', type=int, default=500, 40 | help="Max generated tokens of the extractor, set a larger value for longer documents. Default: 500" 41 | ) 42 | parser.add_argument( 43 | '--claim_format', type=str, default='triplet', 44 | choices=['triplet', 'subsentence'], 45 | help='The format of the extracted claims. Default: subsentence' 46 | ) 47 | parser.add_argument( 48 | "--checker_name", type=str, default="bedrock/anthropic.claude-3-sonnet-20240229-v1:0", 49 | help="Model used for checking whether the triplets are factual. " 50 | "Default: Claude 3 Sonnet" 51 | ) 52 | parser.add_argument( 53 | "--extractor_api_base", type=str, 54 | help="API base URL if using vllm for deploying the extractor" 55 | ) 56 | parser.add_argument( 57 | "--checker_api_base", type=str, 58 | help="API base URL if using vllm for deploying the checker" 59 | ) 60 | parser.add_argument( 61 | "--repc_classifier_name", type=str, default="nn_ensemble", 62 | choices=["svm", "svm_ensemble", "nn", "nn_ensemble"], 63 | help="Classifier Model used for RepC checker, only valid when RepC checker is used. " 64 | "Default: nn_ensemble, neural network classifier with layer ensemble." 65 | ) 66 | parser.add_argument( 67 | "--retriever_name", type=str, default="google", choices=["google"], 68 | help="Model used for retrieving reference (currently only google is" 69 | " supported). Default: google." 70 | ) 71 | parser.add_argument( 72 | "--aggregator_name", type=str, default="soft", 73 | choices=["strict", "soft", "major"], 74 | help="Aggregator used for aggregating the results from multiple " 75 | "triplets. Default: soft.\n" 76 | "* strict: If any of the triplets is Contradiction, the response" 77 | " is Contradiction.\nIf all of the triplets are Entailment, the " 78 | "response is Entailment. Otherwise, the\nresponse is Neutral.\n" 79 | "* soft: The ratio of each category is calculated.\n" 80 | "* major: The category with the most votes is selected." 81 | ) 82 | parser.add_argument( 83 | "--use_retrieval", action="store_true", 84 | help="Whether to use retrieval to find the reference for checking. " 85 | "Required if the reference\nfield in input data is not provided." 86 | ) 87 | parser.add_argument( 88 | "--batch_size_extractor", type=int, default=16, 89 | help="Batch size for extractor." 90 | ) 91 | parser.add_argument( 92 | "--batch_size_checker", type=int, default=16, 93 | help="Batch size for checker." 94 | ) 95 | 96 | return parser.parse_args() 97 | 98 | 99 | def main(): 100 | args = get_args() 101 | # set environment variables 102 | # if args.openai_key: 103 | # with open(args.openai_key, "r") as fp: 104 | # os.environ["OPENAI_API_KEY"] = fp.read().strip() 105 | # if args.anthropic_key: 106 | # with open(args.anthropic_key, "r") as fp: 107 | # os.environ["ANTHROPIC_API_KEY"] = fp.read().strip() 108 | # if args.aws_bedrock_region: 109 | # os.environ["AWS_REGION_NAME"] = args.aws_bedrock_region 110 | # if args.serper_api_key: 111 | # os.environ["SERPER_API_KEY"] = args.serper_api_key 112 | 113 | if args.mode == "extract": 114 | extract(args) 115 | elif args.mode == "check": 116 | check(args) 117 | elif args.mode == "extract-check": 118 | output_path = args.output_path 119 | args.output_path = output_path + ".temp" 120 | extract(args) 121 | args.input_path = args.output_path 122 | args.output_path = output_path 123 | check(args) 124 | else: 125 | raise NotImplementedError 126 | 127 | 128 | def extract(args): 129 | # initialize models 130 | extractor = LLMExtractor( 131 | claim_format=args.claim_format, 132 | model=args.extractor_name, 133 | api_base=args.extractor_api_base 134 | ) 135 | 136 | # load data 137 | with open(args.input_path, "r") as fp: 138 | input_data = json.load(fp) 139 | 140 | # extract triplets 141 | print('Extracting') 142 | question_list = [d.get('question', None) for d in input_data] 143 | response_list = [d['response'] for d in input_data] 144 | 145 | extraction_results = extractor.extract( 146 | batch_responses=response_list, 147 | batch_questions=question_list, 148 | max_new_tokens=args.extractor_max_new_tokens 149 | ) 150 | for res, d in zip(extraction_results, input_data): 151 | d['claims'] = [c.content for c in res.claims] 152 | 153 | with open(args.output_path, "w") as fp: 154 | json.dump(input_data, fp, indent=2) 155 | 156 | 157 | def check(args): 158 | # initialize models 159 | if args.checker_name == "nli": 160 | from .checker.nli_checker import NLIChecker 161 | checker = NLIChecker(batch_size=args.batch_size_checker) 162 | elif args.checker_name == "alignscore": 163 | from .checker.alignscore.alignscore_checker import AlignScoreChecker 164 | checker = AlignScoreChecker(batch_size=args.batch_size_checker) 165 | elif args.checker_name == "repc": 166 | from .checker.repc.repc_checker import RepCChecker 167 | checker = RepCChecker(classifier=args.repc_classifier_name, batch_size=args.batch_size_checker) 168 | else: 169 | checker = LLMChecker( 170 | model=args.checker_name, 171 | batch_size=args.batch_size_checker, 172 | api_base=args.checker_api_base 173 | ) 174 | 175 | retriever = None 176 | if args.use_retrieval: 177 | if args.retriever_name == "google": 178 | retriever = GoogleRetriever(args.cache_dir) 179 | else: 180 | raise NotImplementedError 181 | 182 | if args.aggregator_name == "strict": 183 | agg_fn = strict_agg 184 | elif args.aggregator_name == "soft": 185 | agg_fn = soft_agg 186 | elif args.aggregator_name == "major": 187 | agg_fn = major_agg 188 | else: 189 | raise NotImplementedError 190 | 191 | # load data 192 | with open(args.input_path, "r") as fp: 193 | input_data = json.load(fp) 194 | 195 | # check claims 196 | print('Checking') 197 | claim_list = [] 198 | reference_list = [] 199 | question_list = [] 200 | for item in input_data: 201 | assert "claims" in item, "claims field is required" 202 | claims = item['claims'] 203 | if args.use_retrieval: 204 | reference = retriever.retrieve(item["response"]) 205 | item["reference"] = reference 206 | else: 207 | assert "reference" in item, \ 208 | "reference field is required if retriever is not used." 209 | reference = item["reference"] 210 | question = item.get("question", None) 211 | claim_list.append(claims) 212 | reference_list.append(reference) 213 | question_list.append(question) 214 | 215 | results = checker.check( 216 | batch_claims=claim_list, 217 | batch_references=reference_list, 218 | batch_questions=question_list) 219 | agg_results = [agg_fn(r) for r in results] 220 | 221 | output_data = [{ 222 | **input_data[i], 223 | **{ 224 | "Y": agg_results[i], 225 | "ys": results[i], 226 | } 227 | } for i in range(len(input_data))] 228 | with open(args.output_path, "w") as fp: 229 | json.dump(output_data, fp, indent=2) 230 | 231 | 232 | if __name__ == "__main__": 233 | main() 234 | -------------------------------------------------------------------------------- /refchecker/extractor/README.md: -------------------------------------------------------------------------------- 1 | ## Claim-Triplet Extractor 2 | 3 | In this work, we adopt LLMs as knowledge extraction models, leveraging their strong language understanding capabilities across diverse textual contexts. We provide [MistralExtractor](mistral_extractor.py) based on Supervised Fine-tuning, [MixtralExtractor](mixtral_extractor.py), [Claude2Extractor](claude2_extractor.py) and [GPT4Extractor](gpt4_extractor.py) as extractor interfaces equipped with different LLMs. 4 | 5 | 6 | ```python 7 | >>> from refchecker.extractor import Claude2Extractor 8 | 9 | >>> response = ( 10 | "Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. " 11 | "It was announced at the company's Artificial Intelligence (AI) Day event on " 12 | "August 19, 2021" 13 | ) 14 | >>> extractor = Claude2Extractor() 15 | >>> triplets = extractor.extract(response) 16 | >>> print(triplets) 17 | """ 18 | [['Optimus', 'is', 'robotic humanoid'], ['Optimus', 'under development by', 'Tesla, Inc.'], ['Optimus', 'also known as', 'Tesla Bot'], ['Tesla, Inc.', 'announced', 'Optimus'], ['Announcement of Optimus', 'occurred at', 'Artificial Intelligence (AI) Day event'], ['Artificial Intelligence (AI) Day event', 'held on', 'August 19, 2021'], ['Artificial Intelligence (AI) Day event', 'organized by', 'Tesla, Inc.']] 19 | """ 20 | ``` 21 | 22 | We query LLMs with the following prompt to get decomposed triplets. Each triplet is in the format of (head, relation, tail) and serves as a basic checking unit in the next stage. Note that we include the question for LLM response generation in the extraction process, because it may contain useful information in the QA scenario such as heads and relations in some triplets, as shown in the second in-context example. 23 | 24 | ``` 25 | Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("head", "relation", "tail"). When you finished generating the KG, please output a word "". 26 | Here are some in-context examples: 27 | 28 | Question: 29 | Given these paragraphs about the Tesla bot, what is its alias? 30 | Candidate Answer: 31 | Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021 32 | KG: 33 | ("Optimus", "is", "robotic humanoid") 34 | ("Optimus", "under development by", "Tesla, Inc.") 35 | ("Optimus", "also known as", "Tesla Bot") 36 | ("Tesla, Inc.", "announced", "Optimus") 37 | ("Announcement of Optimus", "occured at", "Artificial Intelligence (AI) Day event") 38 | ("Artificial Intelligence (AI) Day event", "held on", "August 19, 2021") 39 | ("Artificial Intelligence (AI) Day event", "organized by", "Tesla, Inc.") 40 | 41 | 42 | Question: 43 | here is some text about Andre Weiss, how many years was Andre at University of Dijon in Paris? 44 | Candidate Answer: 45 | 11 years 46 | KG: 47 | ("Andre Weiss at University of Dijon in Paris", "duration", "11 years") 48 | 49 | 50 | Now geneate the KG for the following candidate answer based on the provided question: 51 | 52 | Question: 53 | {q}? 54 | Candidate Answer: 55 | {a} 56 | KG: 57 | ``` -------------------------------------------------------------------------------- /refchecker/extractor/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm_extractor import LLMExtractor -------------------------------------------------------------------------------- /refchecker/extractor/extractor_base.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Union 3 | from ..base import RCText, RCClaim 4 | 5 | 6 | class ExtractorBase: 7 | def __init__( 8 | self, 9 | claim_format:str='triplet' 10 | ) -> None: 11 | assert claim_format in ['triplet', 'subsentence'] 12 | self.claim_format = claim_format 13 | 14 | def extract( 15 | self, 16 | batch_responses, 17 | batch_questions=None, 18 | max_new_tokens=500, 19 | sagemaker_client=None, 20 | sagemaker_params=None, 21 | sagemaker_get_response_func=None, 22 | custom_llm_api_func=None, 23 | **kwargs 24 | ): 25 | if self.claim_format == 'triplet': 26 | result = self.extract_claim_triplets( 27 | batch_responses=batch_responses, 28 | batch_questions=batch_questions, 29 | max_new_tokens=max_new_tokens, 30 | sagemaker_client=sagemaker_client, 31 | sagemaker_params=sagemaker_params, 32 | sagemaker_get_response_func=sagemaker_get_response_func, 33 | custom_llm_api_func=custom_llm_api_func, 34 | **kwargs 35 | ) 36 | elif self.claim_format == 'subsentence': 37 | result = self.extract_subsentence_claims( 38 | batch_responses=batch_responses, 39 | batch_questions=batch_questions, 40 | max_new_tokens=max_new_tokens, 41 | sagemaker_client=sagemaker_client, 42 | sagemaker_params=sagemaker_params, 43 | sagemaker_get_response_func=sagemaker_get_response_func, 44 | custom_llm_api_func=custom_llm_api_func, 45 | **kwargs 46 | ) 47 | return result 48 | 49 | def extract_claim_triplets( 50 | self, 51 | batch_responses, 52 | batch_questions=None, 53 | max_new_tokens=500, 54 | sagemaker_client=None, 55 | sagemaker_params=None, 56 | sagemaker_get_response_func=None, 57 | custom_llm_api_func=None, 58 | **kwargs 59 | ): 60 | raise NotImplementedError 61 | 62 | def extract_subsentence_claims( 63 | self, 64 | batch_responses, 65 | batch_questions=None, 66 | max_new_tokens=500, 67 | sagemaker_client=None, 68 | sagemaker_params=None, 69 | sagemaker_get_response_func=None, 70 | custom_llm_api_func=None, 71 | **kwargs 72 | ): 73 | raise NotImplementedError 74 | 75 | def parse_claims( 76 | self, 77 | response, 78 | claim_starting_prefix=None, 79 | excluded_content_prefix=None, 80 | response_sentence_ids=None 81 | ): 82 | response = response.strip() 83 | if excluded_content_prefix and excluded_content_prefix in response: 84 | response = response[:response.index(excluded_content_prefix)] 85 | 86 | if claim_starting_prefix and claim_starting_prefix in response: 87 | response = response[response.index(claim_starting_prefix) + len(claim_starting_prefix):] 88 | 89 | if self.claim_format == 'triplet': 90 | return self._parse_claim_triplets(response) 91 | elif self.claim_format == 'subsentence': 92 | claims = [] 93 | # for c in re.findall(r'.*[\[\d+\]]+', response): 94 | for c in re.findall(r'.*[\[(\d+(?:,\s*\d+)*)\]]', response): 95 | sent_ids = [] 96 | first_sid_index = None 97 | for sid in re.finditer(r'\[(\d+(?:,\s*\d+)*)\]', c): 98 | if first_sid_index is None: 99 | first_sid_index = sid.start() 100 | sent_id_str = sid.group()[1:-1] 101 | if ',' in sent_id_str: 102 | for _id in sent_id_str.split(','): 103 | _id = _id.strip() 104 | sent_ids.append(_id) 105 | else: 106 | sent_ids.append(sid.group()[1:-1]) 107 | sent_ids = [_id for _id in sent_ids if _id in response_sentence_ids] 108 | if len(sent_ids): 109 | claims.append(RCClaim( 110 | format=self.claim_format, 111 | content=c[:first_sid_index].strip(), 112 | attributed_sent_ids=sent_ids 113 | )) 114 | return claims 115 | else: 116 | raise ValueError(f'Unknown Claim Format: {self.format}') 117 | 118 | def _parse_claim_triplets(self, text): 119 | ret = [] 120 | patterns = [ 121 | r'\(".*", ".*", ".*"\)', 122 | r'\(".*", ".*", \'.*\'\)', 123 | r'\(".*", \'.*\', ".*"\)', 124 | r'\(\'.*\', ".*", ".*"\)', 125 | r'\(".*", \'.*\', \'.*\'\)', 126 | r'\(\'.*\', ".*", \'.*\'\)', 127 | r'\(\'.*\', \'.*\', ".*"\)', 128 | r'\(\'.*\', \'.*\', \'.*\'\)' 129 | ] 130 | for p in patterns: 131 | triplets = self._parse_triplets(p, text, triple_length=3) 132 | if triplets: 133 | ret += triplets 134 | 135 | # deduplication 136 | final_triple_set = [] 137 | for t in ret: 138 | t = tuple(t) 139 | if t not in final_triple_set and t != ('subject', 'predicate', 'object'): 140 | final_triple_set.append(t) 141 | 142 | # return [list(t) for t in final_triple_set] 143 | return [RCClaim('triplet', list(t), None) for t in final_triple_set] 144 | 145 | def _parse_triplets(self, pattern, text, triple_length=3): 146 | triplets = [] 147 | matches = re.findall(pattern, text) 148 | for m in matches: 149 | try: 150 | t = eval(m) 151 | except: 152 | t = m.split(', ') 153 | if t[0].startswith('('): 154 | t[0] = t[0][1:] 155 | if t[-1].endswith(')'): 156 | t[-1] = t[-1][:-1] 157 | if len(t) != triple_length: 158 | continue 159 | if any([not isinstance(e, str) for e in t]): 160 | continue 161 | if any([len(e) == 0 for e in t]): 162 | continue 163 | triplets.append(list(t)) 164 | return triplets -------------------------------------------------------------------------------- /refchecker/extractor/extractor_prompts.py: -------------------------------------------------------------------------------- 1 | 2 | LLM_Triplet_To_Claim_PROMPT_Q = """Given a question and a response to the question, please extract a KG from the response condition on the question and represent the KG with triplets formatted with `(subject, predicate, object)`. In addition, you should attribute the triplets to the sentences followed by the sentence ids. 3 | 4 | After extracting the KG, you should convert the KG to a list of claims. Each claim should satisfy the following criteria: 5 | * A claim is a piece of `knowledge point` in the response. 6 | * A claim should be fine-grained. One claim should not contain more than one pieces of knowledge. 7 | * A claim should be self-contained, it is not dependent on other claims. 8 | * Each claim should truly reflect the meaning to be expressed in the response, and the information in the claim should be complete and unambiguous, and necessary conditions and attributes should not be missed. For example, for the text "Donald Trump won the presidential election in 2016.", the claim "Donald Trump won the presidential election" is a bad claim where it misses necessary information "in 2016", so a complete claim should be "Donald Trump won the presidential election in 2016". 9 | * Some sentence in the response may not contain claims. Some sentence may contain one or more claims. A claim may occur across multiple sentences. 10 | * Opinions, speculations or questions in the text are not claims. 11 | 12 | We have added sentence IDs at the beginning of the sentences in the text, you should output a claim followed by a list of IDs, and these IDs stand for the sentences this claim attribute to. The extraction process: 13 | 1. You should first identify whether there are claims in the text. If there are no claims, output "Abstain". 14 | 2. If there are claims in the text, you should first identify knowledge points in the text and formuate claims for them. 15 | 3. Note that a knowledge point may just reflect partial information in the text, so if we take it in isolation, it may not reflect the original meaning in the text, you should fill in the missing information to make it complete and clear. The information in a claim may come across multiple sentences in the text. 16 | 4. Formulate the claim into one sentence, and followed by the sentence indices as the attributed sentences in the text. 17 | 5. Make sure you have generated all the claims in the text. 18 | 19 | Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the text is factual or not, just extract the claims from it. 20 | 21 | Here are some examples: 22 | 23 | ### Question 24 | Given these paragraphs about the Tesla bot, what is its alias? 25 | ### Response 26 | [1] Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. [2] It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021. 27 | ### KG 28 | (Optimus, is, robotic humanoid) [1] 29 | (Optimus, under development by, Tesla, Inc.) [1] 30 | (Optimus, also known as, Tesla Bot) [1] 31 | (Tesla, Inc., announced, Optimus) [2] 32 | (Announcement of Optimus, occurred at, Artificial Intelligence (AI) Day event) [2] 33 | (Artificial Intelligence (AI) Day event, held on, August 19, 2021) [2] 34 | (Artificial Intelligence (AI) Day event, organized by, Tesla, Inc.) [2] 35 | ### Claims 36 | Optimus is a robotic humanoid. [1] 37 | Optimus is under development by Tesla, Inc. [1] 38 | Optimus is also known as Tesla Bot. [1] 39 | Tesla, Inc. announced Optimus at the Artificial Intelligence (AI) Day event August 19, 2021. [2] 40 | 41 | ### Question 42 | here is some text about Andre Weiss, how many years was Andre at University of Dijon in Paris? 43 | ### Response 44 | [1] 11 years 45 | ### KG 46 | ("Andre Weiss at University of Dijon in Paris", "duration", "11 years") [1] 47 | ### Claims 48 | Andre Weiss was at University of Dijon in Paris for 11 years. [1] 49 | 50 | ### Question 51 | who sang i ran all the way home 52 | ### Response 53 | [1] I'm sorry, I cannot provide an accurate answer without additional context. [2] Could you please provide more information? 54 | ### KG 55 | Abstain 56 | ### Claims 57 | Abstain 58 | 59 | ### Question 60 | does jimmy jazz sell fake shoes 61 | ### Response 62 | [1] Based on the provided passages, it appears that Jimmy Jazz may sell fake shoes. [2] Passage 0 and Passage 1 both mention that the shoes purchased from Jimmy Jazz were found to be fake, and Passage 3 is a Ripoff Report stating that Jimmy Jazz E-Commerce LLC has a history of late shipping and lies on their website. [3] Additionally, Passage 4 mentions that Kanye West, who has collaborated with Jimmy Jazz in the past, has signed fake sneakers, which could suggest that Jimmy Jazz may be involved in selling counterfeit products. 63 | [4] However, it's important to note that there are also passages that do not mention anything about fake shoes, such as Passage 2, Passage 5, Passage 6, Passage 7, Passage 8, and Passage 9. [5] These passages do not provide any information that would confirm or deny the suspicion of Jimmy Jazz selling fake shoes. 64 | [6] Therefore, based on the provided passages, it is possible that Jimmy Jazz may sell fake shoes, but there is not enough conclusive evidence to confirm this suspicion. 65 | ### KG 66 | (Passage 0, mention, shoes purchased from Jimmy Jazz were found to be fake) [2] 67 | (Passage 1, mention, shoes purchased from Jimmy Jazz were found to be fake) [2] 68 | (Passage 3, is, a Ripoff Report) [2] 69 | (Passage 3, state, Jimmy Jazz E-Commerce LLC has a history of late shipping and lies on their website) [2] 70 | (Passage 4, mention, Kanye West has collaborated with Jimmy Jazz in the past) [3] 71 | (Passage 4, mention, Kanye West has signed fake sneakers) [3] 72 | (Passage 2, does not mention, fake shoes) [4] 73 | (Passage 5, does not mention, fake shoes) [4] 74 | (Passage 6, does not mention, fake shoes) [4] 75 | (Passage 7, does not mention, fake shoes) [4] 76 | (Passage 8, does not mention, fake shoes) [4] 77 | (Passage 9, does not mention, fake shoes) [4] 78 | ### Claims 79 | Passage 0 and Passage 1 mention that shoes purchased from Jimmy Jazz were found to be fake. [2] 80 | Passage 3 is a Ripoff Report stating that Jimmy Jazz E-Commerce LLC has a history of late shipping and lies on their website. [2] 81 | Passage 4 mentions that Kanye West has collaborated with Jimmy Jazz in the past. [3] 82 | Passage 4 mentions that Kanye West has signed fake sneakers. [3] 83 | Passage 2, Passage 5, Passage 6, Passage 7, Passage 8, and Passage 9 do not mention anything about fake shoes. [4] 84 | 85 | ### Question 86 | calories in ground beef stroganoff 87 | ### Response 88 | [1] Sure! [2] Based on the provided passages, the answer to the question "calories in ground beef stroganoff?" [3] is 227 calories. [4] This information can be found in Passage 9. 89 | ### KG 90 | (Passage 9, state, ground beef stroganoff contains 227 calories) [2][3][4] 91 | ### Claims 92 | Passage 9 states that ground beef stroganoff contains 227 calories. [2][3][4] 93 | 94 | Now generate the KG and claims for the following response based on the provided question. You should always follow the output format shown in the above examples and just ouput the KG and claims as well as the sentence IDs and DO NOT output any extra information: 95 | ### Question 96 | {q} 97 | ### Response 98 | {r} 99 | ### KG 100 | """ 101 | 102 | 103 | 104 | 105 | LLM_CLAIM_EXTRACTION_PROMPT_Q = """You are an AI assistant, you can help to extract claims from a model-generated response for a question. In addition, you should attribute the claims to the sentences followed by the sentence ids. 106 | Each claim should satisfy the following criteria: 107 | * A claim is a piece of `knowledge point` in the response. 108 | * A claim should be fine-grained. One claim should not contain more than one pieces of knowledge. 109 | * A claim should be self-contained, it is not dependent on other claims. 110 | * Each claim should truly reflect the meaning to be expressed in the response, and the information in the claim should be complete and unambiguous, and necessary conditions and attributes should not be missed. For example, for the text "Donald Trump won the presidential election in 2016.", the claim "Donald Trump won the presidential election" is a bad claim where it misses necessary information "in 2016", so a complete claim should be "Donald Trump won the presidential election in 2016". 111 | * Some sentence in the response may not contain claims. Some sentence may contain one or more claims. A claim may occur across multiple sentences. 112 | * Opinions, speculations or questions in the text are not claims. 113 | 114 | We have added sentence IDs at the beginning of the sentences in the text, you should output a claim followed by a list of IDs, and these IDs stand for the sentences this claim attribute to. The extraction process: 115 | 1. You should first identify whether there are claims in the text. If there are no claims, output "Abstain". 116 | 2. If there are claims in the text, you should first identify knowledge points in the text and formuate claims for them. 117 | 3. Note that a knowledge point may just reflect partial information in the text, so if we take it in isolation, it may not reflect the original meaning in the text, you should fill in the missing information to make it complete and clear. The information in a claim may come across multiple sentences in the text. 118 | 4. Formulate the claim into one sentence, and followed by the sentence indices as the attributed sentences in the text. 119 | 5. Make sure you have generated all the claims in the text. 120 | 121 | Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the text is factual or not, just extract the claims from it. 122 | 123 | Here are some examples: 124 | 125 | ### Question 126 | What is Optimus? 127 | ### Response 128 | [1] Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. [2] It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021. 129 | ### Claims in Response 130 | Optimus is a robotic humanoid. [1] 131 | Optimus is under development by Tesla, Inc. [1] 132 | Optimus is also known as Tesla Bot. [1] 133 | Tesla, Inc. announced Optimus at the Artificial Intelligence (AI) Day event August 19, 2021. [2] 134 | 135 | ### Question 136 | Here is some text about Andre Weiss, how many years was Andre at University of Dijon in Paris? 137 | ### Response 138 | [1] 11 years 139 | ### Claims in Response 140 | Andre Weiss was at University of Dijon in Paris for 11 years. [1] 141 | 142 | ### Question 143 | who sang i ran all the way home? 144 | ### Response 145 | [1] I'm sorry, I cannot provide an accurate answer without additional context. [2] Could you please provide more information? 146 | ### Claims in Response 147 | Abstain 148 | 149 | ### Question 150 | Does jimmy jazz sell fake shoes? 151 | ### Response 152 | [1] Based on the provided passages, it appears that Jimmy Jazz may sell fake shoes. [2] Passage 0 and Passage 1 both mention that the shoes purchased from Jimmy Jazz were found to be fake, and Passage 3 is a Ripoff Report stating that Jimmy Jazz E-Commerce LLC has a history of late shipping and lies on their website. [3] Additionally, Passage 4 mentions that Kanye West, who has collaborated with Jimmy Jazz in the past, has signed fake sneakers, which could suggest that Jimmy Jazz may be involved in selling counterfeit products. 153 | [4] However, it's important to note that there are also passages that do not mention anything about fake shoes, such as Passage 2, Passage 5, Passage 6, Passage 7, Passage 8, and Passage 9. [5] These passages do not provide any information that would confirm or deny the suspicion of Jimmy Jazz selling fake shoes. 154 | [6] Therefore, based on the provided passages, it is possible that Jimmy Jazz may sell fake shoes, but there is not enough conclusive evidence to confirm this suspicion. 155 | ### Claims in Response 156 | Passage 0 and Passage 1 mention that shoes purchased from Jimmy Jazz were found to be fake. [2] 157 | Passage 3 is a Ripoff Report stating that Jimmy Jazz E-Commerce LLC has a history of late shipping and lies on their website. [2] 158 | Passage 4 mentions that Kanye West has collaborated with Jimmy Jazz in the past. [3] 159 | Passage 4 mentions that Kanye West has signed fake sneakers. [3] 160 | Passage 2, Passage 5, Passage 6, Passage 7, Passage 8, and Passage 9 do not mention anything about fake shoes. [4] 161 | 162 | ### Question 163 | calories in ground beef stroganoff 164 | ### Response 165 | [1] Sure! [2] Based on the provided passages, the answer to the question "calories in ground beef stroganoff?" [3] is 227 calories. [4] This information can be found in Passage 9. 166 | ### Claims in Response 167 | Passage 9 states that ground beef stroganoff contains 227 calories. [2][3][4] 168 | 169 | Now please generate the claims from the following text. You should always follow the output format shown in the above examples and just ouput the claims without any extra information: 170 | ### Question 171 | {q} 172 | ### Response 173 | {r} 174 | ### Claims in Response 175 | """ 176 | 177 | 178 | LLM_TRIPLET_EXTRACTION_PROMPT_Q = \ 179 | """Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("subject", "predicate", "object"), each triplet in a line. 180 | Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. Importantly, ensure that the extracted KG does not contain overlapping or redundant information. Each piece of information should be represented in the KG only once, and you should avoid creating triplets that are simply the inverse of another triplet. For example, if you extract the triplet ("John", "owns", "Car"), do not also include ("Car", "owned by", "John") as it represents the same information in reverse. 181 | 182 | Clarification on redundancy: First, Do not create triplets that reverse the subject and object to state the same fact. Next, Ensure each fact is represented uniquely in the simplest form, and avoid creating multiple triplets that convey the same information. 183 | 184 | Here are some in-context examples: 185 | 186 | ### Question: 187 | Given these paragraphs about the Tesla bot, what is its alias? 188 | 189 | ### Candidate Answer: 190 | Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021. 191 | 192 | ### KG: 193 | ("Optimus", "is", "robotic humanoid") 194 | ("Optimus", "under development by", "Tesla, Inc.") 195 | ("Optimus", "also known as", "Tesla Bot") 196 | ("Tesla, Inc.", "announced", "Optimus") 197 | ("Announcement of Optimus", "occurred at", "Artificial Intelligence (AI) Day event") 198 | ("Artificial Intelligence (AI) Day event", "held on", "August 19, 2021") 199 | ("Artificial Intelligence (AI) Day event", "organized by", "Tesla, Inc.") 200 | 201 | ### Question: 202 | here is some text about Andre Weiss, how many years was Andre at University of Dijon in Paris? 203 | 204 | ### Candidate Answer: 205 | 11 years 206 | 207 | ### KG: 208 | ("Andre Weiss at University of Dijon in Paris", "duration", "11 years") 209 | 210 | 211 | Now generate the KG for the following candidate answer based on the provided question: 212 | 213 | ### Question: 214 | {q} 215 | 216 | ### Candidate Answer: 217 | {a} 218 | 219 | ### KG: 220 | """ 221 | 222 | LLM_TRIPLET_EXTRACTION_PROMPT = \ 223 | """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("subject", "predicate", "object"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. Importantly, ensure that the extracted KG does not contain overlapping or redundant information. Each piece of information should be represented in the KG only once, and you should avoid creating triplets that are simply the inverse of another triplet. For example, if you extract the triplet ("John", "owns", "Car"), do not also include ("Car", "owned by", "John") as it represents the same information in reverse. 224 | 225 | Clarification on redundancy: First, Do not create triplets that reverse the subject and object to state the same fact. Next, Ensure each fact is represented uniquely in the simplest form, and avoid creating multiple triplets that convey the same information. 226 | 227 | Here are some in-context examples: 228 | 229 | ### Input: 230 | Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021. 231 | 232 | ### KG: 233 | ("Optimus", "is", "robotic humanoid") 234 | ("Optimus", "under development by", "Tesla, Inc.") 235 | ("Optimus", "also known as", "Tesla Bot") 236 | ("Tesla, Inc.", "announced", "Optimus") 237 | ("Announcement of Optimus", "occurred at", "Artificial Intelligence (AI) Day event") 238 | ("Artificial Intelligence (AI) Day event", "held on", "August 19, 2021") 239 | ("Artificial Intelligence (AI) Day event", "organized by", "Tesla, Inc.") 240 | 241 | ### Input: 242 | The song "Here Comes the Boom" was originally released by American rock band Nelly in 2002 for the soundtrack of the film "The Longest Yard." 243 | 244 | ### KG: 245 | ("The song 'Here Comes the Boom'", "originally released by", "American rock band Nelly") 246 | ("The song 'Here Comes the Boom'", "released in", "2002") 247 | ("The song 'Here Comes the Boom'", "featured in", "soundtrack of the film 'The Longest Yard'") 248 | ("American rock band Nelly", "released", "The song 'Here Comes the Boom'") 249 | ("The Longest Yard", "had soundtrack featuring", "The song 'Here Comes the Boom'") 250 | 251 | 252 | Now generate the KG for the provided input text: 253 | 254 | ### Input: 255 | {input_text} 256 | 257 | ### KG: 258 | """ -------------------------------------------------------------------------------- /refchecker/extractor/llm_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from tqdm import tqdm 3 | 4 | from .extractor_base import ExtractorBase 5 | from ..utils import get_model_batch_response 6 | from ..base import RCText, ExtractionResult 7 | from .extractor_prompts import * 8 | 9 | 10 | class LLMExtractor(ExtractorBase): 11 | def __init__( 12 | self, 13 | claim_format: str = 'triplet', 14 | model: str = 'bedrock/anthropic.claude-3-sonnet-20240229-v1:0', 15 | batch_size: int = 16, 16 | api_base: str = None 17 | ) -> None: 18 | super().__init__(claim_format) 19 | 20 | self.model = model 21 | self.batch_size = batch_size 22 | self.api_base = api_base 23 | 24 | def extract_subsentence_claims( 25 | self, 26 | batch_responses, 27 | batch_questions=None, 28 | max_new_tokens=500, 29 | sagemaker_client=None, 30 | sagemaker_params=None, 31 | sagemaker_get_response_func=None, 32 | custom_llm_api_func=None, 33 | **kwargs 34 | ): 35 | """Extract subsentence claims from the response text. 36 | Parameters 37 | ---------- 38 | response : List[str] 39 | List of model response texts. 40 | question : List[str|None] | None 41 | List of questions corresponding to each response. 42 | max_new_tokens : int, optional 43 | Maximum number of tokens to generate, defaults to 500. 44 | Returns 45 | ------- 46 | List[ExtractionResult] 47 | List of extracted claims for each response. 48 | """ 49 | 50 | prompt_list = [] 51 | result_list = [] 52 | rc_responses = [] 53 | for _i, r in enumerate(batch_responses): 54 | rc_r = RCText(r) 55 | indexed_r_text = rc_r.get_indexed_response(condense_newlines=True) 56 | q = None 57 | if batch_questions: 58 | q = batch_questions[_i] 59 | if q and len(q): 60 | prompt = LLM_Triplet_To_Claim_PROMPT_Q.format(q=q, r=indexed_r_text) 61 | else: 62 | raise NotImplementedError 63 | prompt_list.append(prompt) 64 | rc_responses.append(rc_r) 65 | 66 | for _i in tqdm(range(0, len(prompt_list), self.batch_size)): 67 | batch_prompts = prompt_list[_i:_i+self.batch_size] 68 | llm_responses = get_model_batch_response( 69 | prompts=batch_prompts, 70 | temperature=0, 71 | model=self.model, 72 | n_choices=1, 73 | max_new_tokens=max_new_tokens, 74 | api_base=self.api_base, 75 | sagemaker_client=sagemaker_client, 76 | sagemaker_params=sagemaker_params, 77 | sagemaker_get_response_func=sagemaker_get_response_func, 78 | custom_llm_api_func=custom_llm_api_func, 79 | **kwargs 80 | ) 81 | 82 | if llm_responses and len(llm_responses): 83 | for _j, res in enumerate(llm_responses): 84 | claims = self.parse_claims( 85 | res, 86 | claim_starting_prefix='### Claims', 87 | excluded_content_prefix='### Question', 88 | response_sentence_ids=rc_responses[_i + _j].get_sentence_ids()) 89 | result = ExtractionResult( 90 | claims=claims, 91 | response=rc_responses[_i + _j], 92 | extractor_response=res, 93 | ) 94 | result_list.append(result) 95 | else: 96 | return None 97 | return result_list 98 | 99 | def extract_claim_triplets( 100 | self, 101 | batch_responses, 102 | batch_questions=None, 103 | max_new_tokens=500, 104 | sagemaker_client=None, 105 | sagemaker_params=None, 106 | sagemaker_get_response_func=None, 107 | custom_llm_api_func=None, 108 | **kwargs 109 | ): 110 | """Extract KG triplets from the response text. 111 | Parameters 112 | ---------- 113 | response : List[str] 114 | List of model response texts. 115 | question : List[str|None] | None 116 | List of questions corresponding to each response. 117 | max_new_tokens : int, optional 118 | Maximum number of tokens to generate, defaults to 500. 119 | Returns 120 | ------- 121 | List[ExtractionResult] 122 | List of extracted claims for each response. 123 | """ 124 | 125 | prompt_list = [] 126 | result_list = [] 127 | 128 | for _i, r in enumerate(batch_responses): 129 | q = None 130 | if batch_questions: 131 | q = batch_questions[_i] 132 | if q is None: 133 | prompt = LLM_TRIPLET_EXTRACTION_PROMPT.format( 134 | input_text=r 135 | ) 136 | else: 137 | prompt = LLM_TRIPLET_EXTRACTION_PROMPT_Q.format( 138 | q=q, 139 | a=r 140 | ) 141 | prompt_list.append(prompt) 142 | 143 | for _i in tqdm(range(0, len(prompt_list), self.batch_size)): 144 | batch_prompts = prompt_list[_i:_i+self.batch_size] 145 | 146 | llm_responses = get_model_batch_response( 147 | prompts=batch_prompts, 148 | temperature=1e-5, 149 | model=self.model, 150 | n_choices=1, 151 | max_new_tokens=max_new_tokens, 152 | api_base=self.api_base, 153 | sagemaker_client=sagemaker_client, 154 | sagemaker_params=sagemaker_params, 155 | sagemaker_get_response_func=sagemaker_get_response_func, 156 | custom_llm_api_func=custom_llm_api_func, 157 | **kwargs 158 | ) 159 | 160 | if llm_responses and len(llm_responses): 161 | for res in llm_responses: 162 | claims = self.parse_claims(res, '###') 163 | result = ExtractionResult( 164 | claims=claims, 165 | response=None, 166 | extractor_response=res 167 | ) 168 | result_list.append(result) 169 | else: 170 | return None 171 | return result_list 172 | -------------------------------------------------------------------------------- /refchecker/localizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .embed_localizer import NaiveEmbedLocalizer 2 | -------------------------------------------------------------------------------- /refchecker/localizer/embed_localizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 4 | 5 | from ..utils import split_text 6 | 7 | 8 | class NaiveEmbedLocalizer(object): 9 | """aligning the text and triplets""" 10 | def __init__( 11 | self, device: int = 0, segment_len: int = 256 12 | ): 13 | path_or_name = "princeton-nlp/sup-simcse-roberta-large" 14 | self.model = AutoModelForSequenceClassification.from_pretrained( 15 | path_or_name 16 | ).to(device) 17 | self.tokenizer = AutoTokenizer.from_pretrained(path_or_name) 18 | self.device = device 19 | self.segment_len = segment_len 20 | 21 | @torch.no_grad() 22 | def _encode_text(self, text, avg_pooling=False): 23 | """encode text into embeddings""" 24 | inputs = self.tokenizer( 25 | text, max_length=512, truncation=True, return_tensors="pt", 26 | padding=True, return_token_type_ids=True 27 | ) 28 | _tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) 29 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 30 | _output = self.model(**inputs, output_hidden_states=True) 31 | _hid = _output.hidden_states[0][0][1: -1] 32 | if avg_pooling: 33 | _hid = _hid.mean(0) 34 | return _tokens[1: -1], _hid 35 | 36 | @staticmethod 37 | def cosine_dist(emb1, emb2): 38 | return float( 39 | (emb1 * emb2).sum() / (torch.norm(emb1, 2) * torch.norm(emb2, 2)) 40 | ) 41 | 42 | @staticmethod 43 | def normalize(text): 44 | return text.lower().replace("Ġ", " ").replace("▁", " ") 45 | 46 | @staticmethod 47 | def decorate(text, color, bgcolor): 48 | return f"""{text}""" 50 | 51 | def locate(self, text, triplet, threshold=[0.65, 0.6, 0.65]): 52 | assert len(triplet) == 3, "triplet should have 3 elements" 53 | tokens = [] 54 | segments = split_text(text, self.segment_len) 55 | text_emb = [] # embeddings for text [L, d] 56 | triplet_emb = [] # embeddings for triplet [3, d] 57 | for seg in segments: 58 | token, emb = self._encode_text(seg, avg_pooling=False) 59 | tokens.extend(token) 60 | text_emb.append(emb) 61 | text_emb = torch.cat(text_emb, 0) 62 | mask = np.zeros(len(tokens)) 63 | lens = [] 64 | for element in triplet: 65 | if len(element)>0: 66 | _, emb = self._encode_text(element, avg_pooling=True) 67 | triplet_emb.append(emb) 68 | lens.append(len(token)) 69 | for i in range(3): 70 | if len(triplet[i]) > 0: 71 | bounds = [] 72 | # varing window size between 0.8 and 1.2 times of the number of 73 | # the triplet element's tokens 74 | len_lb = max(1, int(0.8 * lens[i])) 75 | len_ub = min(len(text_emb) - 1, int(1.2 * lens[i])) 76 | for length in range(len_lb, len_ub): 77 | for j in range(len(text_emb) - length): 78 | emb1 = text_emb[j: j + length].mean(0) 79 | emb2 = triplet_emb[i] 80 | sim = self.cosine_dist(emb1, emb2) 81 | _phrase = self.normalize("".join(tokens[j: j + length])) 82 | if (_phrase == triplet[i].strip().lower()): 83 | sim = threshold[i] + 0.01 84 | if (len(_phrase) - len(triplet[i].strip().lower()) < 5) and (_phrase.startswith(triplet[i].strip().lower())) or (triplet[i].strip().lower().startswith(_phrase)): 85 | sim = threshold[i] + 0.01 86 | if sim > threshold[i]: 87 | bounds.append([j, j + length, sim]) 88 | for j in range(len(bounds) - 1, -1, -1): 89 | if bounds[j][2] < threshold[i] * 1.2 and any([((x[0] >= bounds[j][0] and x[0] < bounds[j][1]) or (x[1] > bounds[j][0] and x[1] <= bounds[j][1])) and x[2] > bounds[j][2] * 1.05 for x in bounds[:j]]): 90 | del bounds[j] 91 | for b in bounds: 92 | mask[b[0]: b[1]] = i + 1 93 | vs = [int(x) for x in mask] 94 | ret = '' 95 | cmap = ['black', 'red', 'blue', 'green'] 96 | for i in range(len(tokens)): 97 | ret += self.decorate(tokens[i], cmap[vs[i]], "#F1CEF3") 98 | return ret.replace("Ġ", " ").replace("▁", " ") 99 | 100 | 101 | if __name__ == "__main__": 102 | localizer = NaiveEmbedLocalizer() 103 | text = """Eleanor Arnason (born 1945) is an American science fiction and fantasy writer. She is best known for her novel A Woman of the Iron People (1991), which won the James Tiptree, Jr. Award and was a finalist for the Nebula Award for Best Novel. Her other works include Ring of Swords (1993), The Sword Smith (1998), and The Hound of Merin (2002). She has also written several short stories, including "Dapple" (1991), which won the Nebula Award for Best Novelette. """ 104 | sents = [ 105 | "Eleanor Arnason (born 1945) is an American science fiction and fantasy writer.", 106 | "She is best known for her novel A Woman of the Iron People (1991), which won the James Tiptree, Jr. Award and was a finalist for the Nebula Award for Best Novel.", 107 | "Her other works include Ring of Swords (1993), The Sword Smith (1998), and The Hound of Merin (2002).", 108 | ] 109 | triplets = [ 110 | ["Eleanor Arnason", "born", "1945"], 111 | ["Eleanor Arnason", "is", "American science fiction and fantasy writer"], 112 | ["A Woman of the Iron People (1991)", "won", "James Tiptree, Jr. Award"] 113 | ] 114 | for triplet in triplets: 115 | for sent in sents: 116 | print(localizer.locate(sent, triplet)) 117 | print() 118 | -------------------------------------------------------------------------------- /refchecker/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .google_retriever import GoogleRetriever -------------------------------------------------------------------------------- /refchecker/retriever/google_retriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import requests 5 | import warnings 6 | from typing import List, Dict, Union, Tuple 7 | 8 | import rank_bm25 9 | import diskcache 10 | from bs4 import BeautifulSoup 11 | 12 | from ..utils import get_model_batch_response, sentencize 13 | 14 | 15 | SERPER_URL = "https://google.serper.dev/search" 16 | 17 | PROMPT_FOR_QUERY_GEN = """Please generate a question on the given text so that when searching on Google with the question, it's possible to get some relevant information on the topics addressed in the text. Note, you just need to give the final question without quotes in one line, and additional illustration should not be included. 18 | 19 | For example: 20 | Input text: The Lord of the Rings trilogy consists of The Fellowship of the Ring, The Two Towers, and The Return of the King. 21 | Output: What are the three books in The Lord of the Rings trilogy? 22 | 23 | Input text: %s 24 | Output: """ 25 | 26 | 27 | class GoogleRetriever: 28 | def __init__(self, cache_dir: str = "./.cache"): 29 | self.bm25 = None 30 | self._load_key() 31 | cache_dir = os.path.join(cache_dir, "serper") 32 | self.cache = diskcache.Cache(cache_dir) 33 | 34 | def _load_key(self): 35 | self.api_key = os.environ.get("SERPER_API_KEY", None) 36 | assert self.api_key is not None, \ 37 | f"Require environment variable SERPER_API_KEY." 38 | 39 | def _query_google(self, query: str) -> dict: 40 | """Search Google using Serper API and retrieve abundant information""" 41 | if query in self.cache: 42 | return self.cache[query] 43 | else: 44 | payload = json.dumps({"q": query}) 45 | headers = { 46 | "X-API-KEY": self.api_key, 47 | "Content-Type": "application/json" 48 | } 49 | response = requests.request( 50 | "POST", SERPER_URL, headers=headers, data=payload 51 | ) 52 | response_dict = json.loads(response.text) 53 | self.cache[query] = response_dict 54 | return response_dict 55 | 56 | def _get_queries(self, paragraph: str) -> List[str]: 57 | """Use LLM to generate query to search on the Internet to get relevant 58 | information. Currently only single query is generated.""" 59 | prompt = PROMPT_FOR_QUERY_GEN % paragraph 60 | query = get_model_batch_response([prompt], model='gpt-3.5-turbo', temperature=0)[0] 61 | if query is None: 62 | raise RuntimeError( 63 | "Retriever: Empty response from LLM for query generation." 64 | ) 65 | return [query.strip()] 66 | 67 | @staticmethod 68 | def _parse_results(results: dict) -> Tuple[List[dict], bool]: 69 | """Adapted from `FacTool` to utilize retrieved results as answers.""" 70 | snippets = [] 71 | with_answerbox = False 72 | if results.get("answerBox"): 73 | # This case indicates that Google has made a good answer to the question, and it's as desired to utilize this information. 74 | answer_box: dict = results.get("answerBox", {}) 75 | if answer_box.get("answer"): 76 | element = { 77 | "content": answer_box.get("answer"), 78 | "source": answer_box.get("link"), 79 | } 80 | snippets = [element] 81 | elif answer_box.get("snippet"): 82 | element = { 83 | "content": answer_box.get("snippet").replace("\n", " "), 84 | "source": answer_box.get("link"), 85 | } 86 | snippets = [element] 87 | elif answer_box.get("snippetHighlighted"): 88 | element = { 89 | "content": answer_box.get("snippetHighlighted"), 90 | "source": answer_box.get("link"), 91 | } 92 | snippets = [element] 93 | if len(snippets) > 0: 94 | with_answerbox = True 95 | if results.get("knowledgeGraph"): 96 | kg: dict = results.get("knowledgeGraph", {}) 97 | title = kg.get("title") 98 | entity_type = kg.get("type") 99 | if entity_type: 100 | element = { 101 | "content": f"{title}: {entity_type}", 102 | "source": kg.get("link"), 103 | } 104 | snippets.append(element) 105 | description = kg.get("description") 106 | if description: 107 | element = {"content": description, "source": kg.get("link")} 108 | snippets.append(element) 109 | for attribute, value in kg.get("attributes", {}).items(): 110 | element = {"content": f"{attribute}: {value}", "source": kg.get("link")} 111 | snippets.append(element) 112 | # TODO: set num of parsing link in parameters 113 | for result in results["organic"][:3]: 114 | if "snippet" in result: 115 | element = {"content": result["snippet"], "source": result["link"]} 116 | snippets.append(element) 117 | for attribute, value in result.get("attributes", {}).items(): 118 | element = {"content": f"{attribute}: {value}", "source": result["link"]} 119 | snippets.append(element) 120 | 121 | if len(snippets) == 0: 122 | warnings.warn("No usable google search results.") 123 | 124 | return snippets, with_answerbox 125 | 126 | @staticmethod 127 | def _get_url_text(url) -> str: 128 | # Read page and return text 129 | buf = [] 130 | try: 131 | soup = BeautifulSoup( 132 | requests.get(url, timeout=10).text, "html.parser" 133 | ) 134 | for p in soup.find_all("p"): 135 | pt = p.get_text() 136 | if len(buf) == 0 or pt not in buf[-1]: 137 | buf.append(pt) 138 | return "\n".join(buf) 139 | except: 140 | return "" 141 | 142 | @staticmethod 143 | def _split_doc( 144 | text: str, 145 | max_words_per_paragrpah=384, 146 | short_paragraph_threshold=96, 147 | preserve_threshold=8, 148 | ) -> List[str]: 149 | """Use spacy to split a document to paragraphs.""" 150 | paras = text.splitlines() 151 | splitted = [] 152 | sent_to_be_concat = "" 153 | accumulate_length = 0 154 | for p in paras: 155 | p = p.strip() 156 | if len(p) < 1: 157 | continue # empty lines 158 | sents = sentencize(p) 159 | for sent in sents: 160 | if accumulate_length + len(sent) <= max_words_per_paragrpah: 161 | sent_to_be_concat += sent.text_with_ws 162 | accumulate_length += len(sent) 163 | else: 164 | splitted.append(sent_to_be_concat) 165 | sent_to_be_concat = sent.text_with_ws 166 | accumulate_length = len(sent) 167 | if accumulate_length <= short_paragraph_threshold: 168 | sent_to_be_concat += " " 169 | else: 170 | splitted.append(sent_to_be_concat) 171 | sent_to_be_concat = "" 172 | accumulate_length = 0 173 | if accumulate_length >= preserve_threshold: 174 | splitted.append(sent_to_be_concat) 175 | return splitted 176 | 177 | def _process_retrieved_docs( 178 | self, 179 | docs: List[dict], 180 | query: str, 181 | best_k=8, 182 | max_words_per_paragraph=384, 183 | skip_repeated_corpus=True, 184 | ) -> List[Dict[str, Union[str, None]]]: # {"content": , "url": } 185 | if len(docs) == 0: 186 | return None 187 | if len(docs) == 1: 188 | return docs 189 | else: 190 | links_dict = {} 191 | corpus, links = [], [] # List of documents 192 | # retrieve through the links 193 | for relevance in docs: 194 | url = relevance["source"] 195 | if "youtube" in url: 196 | continue # skip youtube due to slow fetching 197 | if url in links_dict.keys(): 198 | if skip_repeated_corpus: 199 | continue 200 | online_text = links_dict[url] 201 | else: 202 | online_text = self._get_url_text(url) 203 | links_dict[url] = online_text 204 | splitted_text = self._split_doc( 205 | online_text, max_words_per_paragraph 206 | ) 207 | corpus.extend(splitted_text) 208 | links.extend([url] * len(splitted_text)) 209 | 210 | meta_doc_dict = dict(zip(corpus, links)) 211 | tokenized_corpus = [doc.split(" ") for doc in corpus] 212 | 213 | bm25 = rank_bm25.BM25Okapi(tokenized_corpus) 214 | best_docs = bm25.get_top_n(query.split(), corpus, n=best_k) 215 | return [ 216 | {"content": k, "source": meta_doc_dict[k]} 217 | for k in best_docs 218 | ] 219 | 220 | def retrieve( 221 | self, 222 | text: str, 223 | top_k=3, 224 | max_words_per_paragraph=384 225 | ) -> List[Dict[str, Union[str, None]]]: 226 | """ 227 | Search reference documents on the Internet based on LLM generated query. 228 | Parameters 229 | ---------- 230 | text : str 231 | Text to be checked. 232 | top_k : int 233 | Number of reference documents to be retrieved. 234 | max_words_per_paragraph : int 235 | Maximum number of words in each reference document. 236 | Returns 237 | ------- 238 | List[str] 239 | List of reference documents 240 | """ 241 | 242 | # Step 1. Generate queries for searching using LLM. 243 | queries = self._get_queries(text) 244 | # Step 2. Search google with the queries. 245 | relevant_info_dicts, best_docs_all = [], [] 246 | for q in queries: 247 | searched_results = self._query_google(q) 248 | parsed_results, with_answerbox = self._parse_results( 249 | searched_results 250 | ) 251 | if with_answerbox: 252 | answerbox_answer, parsed_results = ( 253 | parsed_results[0], 254 | parsed_results[1:], 255 | ) 256 | relevant_info_dicts.extend(parsed_results) 257 | best_docs = self._process_retrieved_docs( 258 | relevant_info_dicts, 259 | q, 260 | best_k=math.ceil((top_k - with_answerbox) / len(queries)), 261 | max_words_per_paragraph=max_words_per_paragraph, 262 | skip_repeated_corpus=True, 263 | ) 264 | if with_answerbox: 265 | best_docs.insert(0, answerbox_answer) 266 | best_docs_all.extend(best_docs) 267 | refs = [ 268 | doc["content"] for doc in best_docs_all 269 | ] 270 | return refs 271 | -------------------------------------------------------------------------------- /refchecker/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import os 4 | 5 | import spacy 6 | from openai.types import Completion as OpenAICompletion 7 | from openai import RateLimitError as OpenAIRateLimitError 8 | from openai import APIError as OpenAIAPIError 9 | from openai import Timeout as OpenAITimeout 10 | 11 | from litellm import batch_completion 12 | from litellm.types.utils import ModelResponse 13 | 14 | # Setup spaCy NLP 15 | nlp = None 16 | 17 | # Setup OpenAI API 18 | openai_client = None 19 | 20 | # Setup Claude 2 API 21 | bedrock = None 22 | anthropic_client = None 23 | 24 | 25 | def sentencize(text): 26 | """Split text into sentences""" 27 | global nlp 28 | if not nlp: 29 | nlp = spacy.load("en_core_web_sm") 30 | doc = nlp(text) 31 | return [sent for sent in doc.sents] 32 | 33 | 34 | def split_text(text, segment_len=200): 35 | """Split text into segments according to sentence boundaries.""" 36 | segments, seg = [], [] 37 | sents = [[token.text for token in sent] for sent in sentencize(text)] 38 | for sent in sents: 39 | if len(seg) + len(sent) > segment_len: 40 | segments.append(" ".join(seg)) 41 | seg = sent 42 | # single sentence longer than segment_len 43 | if len(seg) > segment_len: 44 | # split into chunks of segment_len 45 | seg = [ 46 | " ".join(seg[i:i+segment_len]) 47 | for i in range(0, len(seg), segment_len) 48 | ] 49 | segments.extend(seg) 50 | seg = [] 51 | else: 52 | seg.extend(sent) 53 | if seg: 54 | segments.append(" ".join(seg)) 55 | return segments 56 | 57 | 58 | def get_model_batch_response( 59 | prompts, 60 | model='bedrock/anthropic.claude-3-sonnet-20240229-v1:0', 61 | temperature=0, 62 | n_choices=1, 63 | max_new_tokens=500, 64 | api_base=None, 65 | sagemaker_client=None, 66 | sagemaker_params=None, 67 | sagemaker_get_response_func=None, 68 | custom_llm_api_func=None, 69 | **kwargs 70 | ): 71 | """ 72 | Get batch generation results with given prompts. 73 | 74 | Parameters 75 | ---------- 76 | prompts : List[str] 77 | List of prompts for generation. 78 | temperature : float, optional 79 | The generation temperature, use greedy decoding when setting 80 | temperature=0, defaults to 0. 81 | model : str, optional 82 | The model for generation, defaults to 'bedrock/anthropic.claude-3-sonnet-20240229-v1:0'. 83 | n_choices : int, optional 84 | How many samples to return for each prompt input, defaults to 1. 85 | max_new_tokens : int, optional 86 | Maximum number of newly generated tokens, defaults to 500. 87 | 88 | Returns 89 | ------- 90 | response_list : List[str] 91 | List of generated text. 92 | """ 93 | if not prompts or len(prompts) == 0: 94 | raise ValueError("Invalid input.") 95 | 96 | if sagemaker_client is not None: 97 | parameters = { 98 | "max_new_tokens": max_new_tokens, 99 | "temperature": temperature 100 | } 101 | if sagemaker_params is not None: 102 | for k, v in sagemaker_params.items(): 103 | if k in parameters: 104 | parameters[k] = v 105 | response_list = [] 106 | for prompt in prompts: 107 | r = sagemaker_client.invoke_endpoint( 108 | EndpointName=model, 109 | Body=json.dumps( 110 | { 111 | "inputs": prompt, 112 | "parameters": parameters, 113 | } 114 | ), 115 | ContentType="application/json", 116 | ) 117 | if sagemaker_get_response_func is not None: 118 | response = sagemaker_get_response_func(r) 119 | else: 120 | r = json.loads(r['Body'].read().decode('utf8')) 121 | response = r['outputs'][0] 122 | response_list.append(response) 123 | return response_list 124 | 125 | elif custom_llm_api_func is not None: 126 | return custom_llm_api_func(prompts) 127 | else: 128 | message_list = [] 129 | for prompt in prompts: 130 | if len(prompt) == 0: 131 | raise ValueError("Invalid prompt.") 132 | if isinstance(prompt, str): 133 | messages = [{ 134 | 'role': 'user', 135 | 'content': prompt 136 | }] 137 | elif isinstance(prompt, list): 138 | messages = prompt 139 | else: 140 | raise ValueError("Invalid prompt type.") 141 | message_list.append(messages) 142 | import litellm 143 | litellm.suppress_debug_info = True 144 | # litellm.drop_params=True 145 | while True: 146 | responses = batch_completion( 147 | model=model, 148 | messages=message_list, 149 | temperature=temperature, 150 | n=n_choices, 151 | max_tokens=max_new_tokens, 152 | api_base=api_base, 153 | max_workers=None, 154 | **kwargs 155 | ) 156 | try: 157 | assert all([isinstance(r, ModelResponse) for r in responses]) 158 | if n_choices == 1: 159 | response_list = [r.choices[0].message.content for r in responses] 160 | else: 161 | response_list = [[res.message.content for res in r.choices] for r in responses] 162 | 163 | assert all([r is not None for r in response_list]) 164 | return response_list 165 | except: 166 | exception = None 167 | for e in responses: 168 | if isinstance(e, ModelResponse): 169 | continue 170 | elif isinstance(e, OpenAIRateLimitError) or isinstance(e, OpenAIAPIError) or isinstance(e, OpenAITimeout): 171 | exception = e 172 | break 173 | else: 174 | print('Exit with the following error:') 175 | print(e) 176 | return None 177 | 178 | print(f"{exception} [sleep 10 seconds]") 179 | time.sleep(10) 180 | continue --------------------------------------------------------------------------------