├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── auto-rag-eval ├── Data │ ├── Arxiv │ │ ├── arxiv_categories.json │ │ └── preprocessor.py │ ├── SecFilings │ │ └── preprocessor.py │ └── StackExchange │ │ ├── preprocessor.py │ │ └── stack_exchanges.json ├── ExamAnalysis │ ├── README.md │ ├── __pycache__ │ │ ├── bloom_taxonomy_model.cpython-39.pyc │ │ ├── generate_irt_plots.cpython-37.pyc │ │ ├── generate_iterative_irt_plots.cpython-37.pyc │ │ ├── generate_recursive_irt_plots.cpython-37.pyc │ │ ├── item_response_models.cpython-37.pyc │ │ ├── item_response_models.cpython-39.pyc │ │ └── iterative_item_response_models.cpython-37.pyc │ ├── bloom_taxonomy_model.py │ ├── compute_exam_radar_plot.ipynb │ ├── generate_irt_plots.py │ ├── generate_iterative_irt_plots.py │ ├── item_response_models.py │ ├── iterative_item_response_models.py │ └── taxonomy_analysis.ipynb ├── ExamEvaluator │ ├── DevOpsExam │ │ ├── DevOpsExam.yaml │ │ ├── DevOpsRagExam.yaml │ │ ├── __pycache__ │ │ │ └── preprocess_exam.cpython-38.pyc │ │ └── preprocess_exam.py │ ├── README.md │ └── task_evaluation.sh ├── ExamGenerator │ ├── README.md │ ├── __pycache__ │ │ ├── distractors_generator.cpython-37.pyc │ │ ├── enrich_existing_exam.cpython-37.pyc │ │ ├── extend_ir_existing_exam.cpython-37.pyc │ │ ├── fake_exam_generator.cpython-37.pyc │ │ ├── multi_choice_exam.cpython-37.pyc │ │ ├── multi_choice_exam_generator.cpython-37.pyc │ │ ├── multi_choice_question.cpython-37.pyc │ │ ├── question_generator.cpython-37.pyc │ │ ├── raw_question_generator.cpython-37.pyc │ │ └── utils.cpython-37.pyc │ ├── distractors_generator.py │ ├── extend_ir_existing_exam.py │ ├── fake_exam_generator.py │ ├── multi_choice_exam.py │ ├── multi_choice_question.py │ ├── question_generator.py │ └── utils.py ├── LLMServer │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── base_model.cpython-37.pyc │ ├── base_model.py │ ├── bedrock │ │ ├── __pycache__ │ │ │ ├── claude_instant.cpython-37.pyc │ │ │ ├── claude_v2.cpython-37.pyc │ │ │ └── conversation.cpython-37.pyc │ │ ├── claude_instant.py │ │ ├── claude_v2.py │ │ └── claude_v3.py │ └── llm_exam_generator.py ├── RetrievalSystems │ ├── README.md │ ├── __pycache__ │ │ ├── bm25.cpython-37.pyc │ │ ├── common.cpython-37.pyc │ │ ├── context_utils.cpython-37.pyc │ │ ├── docs_faiss_index.cpython-37.pyc │ │ ├── dpr_context_aggregator.cpython-37.pyc │ │ ├── dpr_context_retriever.cpython-37.pyc │ │ ├── embedding_retriever.cpython-37.pyc │ │ └── siamese_retriever.cpython-37.pyc │ ├── bm25.py │ ├── context_utils.py │ ├── docs_faiss_index.py │ ├── dpr_context_aggregator.py │ ├── embedding_retriever.py │ ├── siamese_retriever.py │ └── test_retrieval_models.py ├── __init__.py └── py.typed ├── images └── generation_summary.png └── pyproject.toml /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automated Evaluation of Retrieval-Augmented Language Models with Task-Specific Exam Generation 2 | 3 | This repository is the companion of the ICML 2024 paper [Automated Evaluation of Retrieval-Augmented Language Models with Task-Specific Exam Generation](https://arxiv.org/abs/2405.13622) ([Blog](https://www.amazon.science/blog/automated-evaluation-of-rag-pipelines-with-exam-generation)) 4 | 5 |

6 | Alt Text 7 |

8 | 9 | **Goal**: For a given knowledge corpus: 10 | * Leverage an LLM to generate an multi-choice exam associated with the task of interest. 11 | * Evaluate variants of RaG systems on this exam. 12 | * Evaluate and iteratively improve the exam. 13 | 14 | The only thing you need to experiment with this code is a `json` file with your knowledge corpus in the format described bellow. 15 | 16 | ## I - Package Structure 17 | 18 | * `Data`: For each use case, contains: 19 | * Preprocessing Code 20 | * Knowledge Corpus Data 21 | * Exam Data (Raw and Processed) 22 | * Retrieval Index 23 | * `ExamGenerator`: Code to generate and process the multi-choice exam using knowledge corpus and LLM generator(s). 24 | * `ExamEvaluator`: Code to evaluate exam using a combination `(Retrieval System, LLM, ExamCorpus)`, relying on `lm-harness` library. 25 | * `LLMServer`: Unified LLM endpoints to generate the exam. 26 | * `RetrievalSystems`: Unified Retrieval System classes (eg DPR, BM25, Embedding Similarity...). 27 | 28 | ## II - Exam Data Generation Process 29 | 30 | We illustrate our methodology on 4 tasks of interest: AWS DevOPS Troubleshooting, StackExchange Q&A, Sec Filings Q&A and Arxiv Q&A. We then show how to adapt the methodology to any task. 31 | 32 | ### StackExchange 33 | 34 | Run the commands bellow, where `question-date` is the data with the raw data generation. Add `--save-exam` if you want to save the exam and remove it if you're only interested by analytics. 35 | 36 | ```bash 37 | cd auto-rag-eval 38 | rm -rf Data/StackExchange/KnowledgeCorpus/main/* 39 | python3 -m Data.StackExchange.preprocessor 40 | python3 -m ExamGenerator.question_generator --task-domain StackExchange 41 | python3 -m ExamGenerator.multi_choice_exam --task-domain StackExchange --question-date "question-date" --save-exam 42 | ``` 43 | 44 | 45 | ### Arxiv 46 | 47 | ```bash 48 | cd auto-rag-eval 49 | rm -rf Data/Arxiv/KnowledgeCorpus/main/* 50 | python3 -m Data.Arxiv.preprocessor 51 | python3 -m ExamGenerator.question_generator --task-domain Arxiv 52 | python3 -m ExamGenerator.multi_choice_exam --task-domain Arxiv --question-date "question-date" --save-exam 53 | ``` 54 | 55 | ### Sec Filings 56 | 57 | ```bash 58 | cd auto-rag-eval 59 | rm -rf Data/SecFilings/KnowledgeCorpus/main/* 60 | python3 -m Data.SecFilings.preprocessor 61 | python3 -m ExamGenerator.question_generator --task-domain SecFilings 62 | python3 -m ExamGenerator.multi_choice_exam --task-domain SecFilings --question-date "question-date" --save-exam 63 | ``` 64 | 65 | ### Add you own task MyOwnTask 66 | 67 | #### Create file structure 68 | 69 | ```bash 70 | cd src/llm_automated_exam_evaluation/Data/ 71 | mkdir MyOwnTask 72 | mkdir MyOwnTask/KnowledgeCorpus 73 | mkdir MyOwnTask/KnowledgeCorpus/main 74 | mkdir MyOwnTask/RetrievalIndex 75 | mkdir MyOwnTask/RetrievalIndex/main 76 | mkdir MyOwnTask/ExamData 77 | mkdir MyOwnTask/RawExamData 78 | ``` 79 | 80 | #### Create documentation corpus 81 | 82 | Store in `MyOwnTask/KnowledgeCorpus/main` a `json` file, with contains a list of documentation, each with format bellow. See `DevOps/html_parser.py`, `DevOps/preprocessor.py` or `StackExchange/preprocessor.py` for some examples. 83 | 84 | ```bash 85 | {'source': 'my_own_source', 86 | 'docs_id': 'Doc1022', 87 | 'title': 'Dev Desktop Set Up', 88 | 'section': 'How to [...]', 89 | 'text': "Documentation Text, should be long enough to make informative questions but shorter enough to fit into context", 90 | 'start_character': 'N/A', 91 | 'end_character': 'N/A', 92 | 'date': 'N/A', 93 | } 94 | ``` 95 | 96 | #### Generate Exam and Retrieval index 97 | 98 | First generate the raw exam and the retrieval index. 99 | Note that you might need to add support for your own LLM, more on this bellow. 100 | You might want to modify the prompt used for the exam generation in `LLMExamGenerator` class in `ExamGenerator/question_generator.py`. 101 | 102 | ```bash 103 | python3 -m ExamGenerator.question_generator --task-domain MyOwnTask 104 | ``` 105 | 106 | Once this is done (can take a couple of hours depending on the documentation size), generate the processed exam. 107 | To do so, check MyRawExamDate in RawExamData (eg 2023091223) and run: 108 | 109 | ```bash 110 | python3 -m ExamGenerator.multi_choice_exam --task-domain MyOwnTask --question-date MyRawExamDate --save-exam 111 | ``` 112 | 113 | ### Bring your own LLM 114 | 115 | We currently support endpoints for Bedrock (Claude) in `LLMServer` file. 116 | The only thing needed to bring your own is a class, with an `inference` function that takes a prompt in input and output both the prompt and completed text. 117 | Modify `LLMExamGenerator` class in `ExamGenerator/question_generator.py` to incorporate it. 118 | Different LLM generate different types of questions. Hence, you might want to modify the raw exam parsing in `ExamGenerator/multi_choice_questions.py`. 119 | You can experiment using `failed_questions.ipynb` notebook from `ExamGenerator`. 120 | 121 | ## IV - Exam Evaluation Process 122 | 123 | We leverage [lm-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) package to evaluate the (LLM&Retrieval) system on the generated exam. 124 | To do, follow the next steps: 125 | 126 | ### Create a benchmark 127 | 128 | Create a benchmark folder for for your task, here `DevOpsExam`, see `ExamEvaluator/DevOpsExam` for the template. 129 | It contains a code file preprocess_exam,py for prompt templates and more importantly, a set of tasks to evaluate models on: 130 | 131 | * `DevOpsExam` contains the tasks associated to ClosedBook (not retrieval) and OpenBook (Oracle Retrieval). 132 | * `DevOpsRagExam` contains the tasks associated to Retrieval variants (DPR/Embeddings/BM25...). 133 | 134 | The script`task_evaluation.sh` provided illustrates the evalation of `Llamav2:Chat:13B` and `Llamav2:Chat:70B` on the task, using In-Context-Learning (ICL) with respectively 0, 1 and 2 samples. 135 | 136 | ## Citation 137 | 138 | To cite this work, please use 139 | ```bash 140 | @misc{autorageval2024, 141 | title={Automated Evaluation of Retrieval-Augmented Language Models with Task-Specific Exam Generation}, 142 | author={Gauthier Guinet and Behrooz Omidvar-Tehrani and Anoop Deoras and Laurent Callot}, 143 | year={2024}, 144 | eprint={2405.13622}, 145 | archivePrefix={arXiv}, 146 | primaryClass={cs.CL} 147 | } 148 | ``` 149 | 150 | 151 | ## Security 152 | 153 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 154 | 155 | ## License 156 | 157 | This project is licensed under the Apache-2.0 License. 158 | 159 | -------------------------------------------------------------------------------- /auto-rag-eval/Data/Arxiv/preprocessor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | import time 5 | from datetime import datetime 6 | from functools import reduce 7 | from os.path import abspath, dirname 8 | from typing import List 9 | 10 | from datasets import concatenate_datasets, load_dataset 11 | from tqdm import tqdm 12 | 13 | logger = logging.getLogger(__name__) 14 | ROOTPATH = dirname(dirname(abspath(__file__))) 15 | 16 | with open(f"{ROOTPATH}/Arxiv/arxiv_categories.json", "r") as f: 17 | CATEGORIES = {elem['tag']: elem['name'] for elem in json.load(f)} 18 | 19 | 20 | class ArxivData: 21 | 22 | def __init__(self, 23 | n_samples: int, 24 | max_char_length: int, 25 | min_char_length: int): 26 | 27 | self.n_samples = n_samples 28 | self.max_char_length = max_char_length 29 | self.min_char_length = min_char_length 30 | self.cat_list = ['cs', 'econ', 'eess', 'math', 'astro-ph', 31 | 'cond-mat', 'hep', 'nlin', 'nucl', 32 | 'physics', 'q-bio', 'q-fin', 'stat'] 33 | 34 | def process_qna(self, row): 35 | 36 | return {'source': row['authors'], 37 | 'docs_id': row['id'], 38 | 'title': row['title'], 39 | 'section': self.get_full_cat_list(row['categories']), 40 | 'start_character': 'N/A', 41 | 'end_character': 'N/A', 42 | 'date': 'N/A', 43 | 'text': f"{row['title']}. {self.preprocess(row['abstract'])}", 44 | } 45 | 46 | def get_full_cat_list(self, cat: List[str]) -> List[str]: 47 | 48 | convert_categories = {'q-alg': 'math.QA', 49 | 'alg-geom': 'math.AG', 50 | 'chao-dyn': 'nlin.CD', 51 | 'solv-int': 'nlin.SI', 52 | 'cmp-lg': 'cs.CL', 53 | 'dg-ga': 'math.DG', 54 | 'patt-sol': 'nlin.PS', 55 | 'adap-org': 'nlin.AO', 56 | 'funct-an': 'math.FA', 57 | 'mtrl-th': 'cond-mat.mtrl-sci', 58 | 'comp-gas': 'cond-mat.stat-mech', 59 | 'supr-con': 'cond-mat.supr-con', 60 | 'acc-phys': 'physics.acc-ph', 61 | 'plasm-ph': 'physics.plasm-ph', 62 | 'ao-sci': 'physics.ao-ph', 63 | 'bayes-an': 'stat.ME', 64 | 'atom-ph': 'physics.atom-ph', 65 | 'chem-ph': 'physics.chem-ph', 66 | **{k: k for k in CATEGORIES.keys()}} 67 | 68 | return [convert_categories[y] for x in cat for y in x.split(' ')] 69 | 70 | def preprocess(self, text: str) -> str: 71 | 72 | # Remove any URLs 73 | # text = re.sub(r'http\S+', '', text) 74 | # Remove any LaTeX expressions 75 | # text = re.sub(r'\$[^$]+\$', '', text) 76 | # Replace newline characters and extra spaces 77 | text = text.replace('\n', ' ').replace('\r', '').strip() 78 | text = re.sub(' +', ' ', text) 79 | return text 80 | 81 | def load_save_data(self) -> None: 82 | 83 | dataset = load_dataset("gfissore/arxiv-abstracts-2021", 84 | split="train", 85 | # cache_dir="/home/USER/.cache/huggingface/datasets/", 86 | ) 87 | 88 | # Remove too lengthy or shorty answers to avoid repeting operation 89 | sub_df = dataset.filter(lambda x: self.min_char_length <= len(x['abstract']) <= self.max_char_length) 90 | logger.error((f"Reducing dataset size from {len(dataset)} to {len(sub_df)} by keeping abstract with" 91 | f" character length between {self.min_char_length} and {self.max_char_length}.")) 92 | 93 | # Join over all categories at the end and shuffle again 94 | def funcs(cat): 95 | return [ 96 | 97 | # Filter only for a given category 98 | lambda data: data.filter(lambda x: any([tag[:len(cat)] == cat 99 | for tag in self.get_full_cat_list(x['categories'])])), 100 | 101 | # Select Subset of data and preprocess to keep only top answer 102 | lambda data: data.shuffle(seed=42).select(range(min(len(data), self.n_samples))).map(self.process_qna, 103 | remove_columns=['id', 'submitter', 104 | 'authors', 'title', 105 | 'comments', 'journal-ref', 106 | 'doi', 'abstract', 107 | 'report-no', 'categories', 108 | 'versions'] 109 | ), 110 | 111 | ] 112 | 113 | data_cat_list = [] 114 | 115 | for cat in tqdm(self.cat_list): 116 | data_cat_list.append(reduce(lambda res, f: f(res), funcs(cat), sub_df)) 117 | concat_dataset = concatenate_datasets(data_cat_list).shuffle(seed=42) 118 | 119 | concat_dataset.to_json(f"{ROOTPATH}/Arxiv/KnowledgeCorpus/main/data_{datetime.fromtimestamp(time.time()).strftime('%Y%m%d%H')}.json", 120 | lines=False) 121 | 122 | 123 | if __name__ == "__main__": 124 | 125 | arxiv_data = ArxivData(n_samples=1000, 126 | max_char_length=1500, 127 | min_char_length=1000) 128 | 129 | arxiv_data.load_save_data() 130 | -------------------------------------------------------------------------------- /auto-rag-eval/Data/SecFilings/preprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | from datetime import datetime 4 | from functools import reduce 5 | from os.path import abspath, dirname 6 | from typing import Dict, List 7 | 8 | from datasets import concatenate_datasets, load_dataset 9 | from markdownify import markdownify as md 10 | from tqdm import tqdm 11 | 12 | ROOTPATH = dirname(dirname(abspath(__file__))) 13 | 14 | 15 | # SEC TEMPLATE 16 | 17 | # [KEEP] 0.Business: Overview of the company's main operations, including its products or services. 18 | # [KEEP] 1.Risk Factors: Discussion of risks and challenges the company faces. 19 | # [REMOVE] 2.Unresolved Staff Comments: Comments by SEC staff on the company's previous filings that haven't been resolved. 20 | # [REMOVE] 3.Properties: Information about the company's physical properties (like real estate). 21 | # [REMOVE] 4.Legal Proceedings: Information on any significant legal actions involving the company. 22 | # [REMOVE] 5.Market for Registrant’s Common Equity, Related Stockholder Matters and Issuer Purchases of Equity Securities: Details about the company’s stock, including dividends, the number of shareholders, and any buyback programs. 23 | # [REMOVE] 6.Selected Financial Data: Summary of specific financial data for a five-year period. 24 | 25 | # [KEEP] 8.Management’s Discussion and Analysis of Financial Condition and Results of Operations (MD&A): A detailed analysis from management’s perspective on the company’s financials and operations. 26 | # [REMOVE] 9.Quantitative and Qualitative Disclosures About Market Risk: Information on market risk, such as foreign exchange risk, interest rate risk, etc. 27 | # [REMOVE] 1.Financial Statements and Supplementary Data: Complete financial statements including balance sheets, income statements, and cash flow statements. 28 | # [REMOVE] 11.Changes in and Disagreements with Accountants on Accounting and Financial Disclosure: If there have been changes or disagreements with accountants, this section provides details. 29 | # [REMOVE] 12.Directors, Executive Officers and Corporate Governance: Information about the company’s directors and high-level executives. 30 | # [REMOVE] 13.Executive Compensation: Detailed information about the compensation of top executives. 31 | # [REMOVE] 14.Security Ownership of Certain Beneficial Owners and Management and Related Stockholder Matters: Details about the shares held by major shareholders and company executives. 32 | # [REMOVE] 15.Certain Relationships and Related Transactions, and Director Independence: Information about any transactions between the company and its directors or executives. 33 | # [REMOVE] 16.Principal Accountant Fees and Services: Fees and services provided by the company's accountants. 34 | # [REMOVE] 17.Exhibits, Financial Statement Schedules: Lists all the exhibits and financial statements schedules. 35 | # [REMOVE] 18.Form 10-K Summary: Summary of the key information from the 10-K (optional). 36 | # [REMOVE] 19. [OPTIONAl] CEO and CFO Certifications: As required by the Sarbanes-Oxley Act, certifications by the CEO and CFO regarding the accuracy of the financial statements. 37 | 38 | 39 | class ExchangeData: 40 | 41 | def __init__(self, 42 | n_samples: int, 43 | max_char_length: int): 44 | 45 | self.n_samples = n_samples 46 | self.max_char_length = max_char_length 47 | self.cat_list = ['Stackoverflow', 'math', 'superuser', 48 | 'serverfault', 'askubuntu', 'electronics', 49 | 'physics', 'unix', 'tex', 'english', 50 | 'meta', 'apple', 'ell', 'gaming', 51 | 'stats', 'softwareengineering', 52 | 'mathoverflow', 'gis', 'diy', 'magento'] 53 | 54 | def get_best_answer(self, answers: List[Dict[str, str]]): 55 | """return the best answer, that is, the one with the highest score""" 56 | best_index = 0 57 | best_score = answers[0]["pm_score"] 58 | for i in range(1, len(answers)): 59 | if answers[i]["pm_score"] > best_score : 60 | best_score = answers[i]["pm_score"] 61 | best_index = i 62 | return answers[best_index]["text"] 63 | 64 | def lang_callback(self, el): 65 | lang = el['class'][0] if el.has_attr('class') else None 66 | return lang.split("-")[-1] if lang else None 67 | 68 | def html2md(self, text: str) -> str: 69 | text = md(text, code_language_callback=self.lang_callback) 70 | text = re.sub(r"\n\s*\n", "\n\n", text).strip() 71 | return text.encode('utf-8', 'replace').decode() 72 | 73 | def process_qna(self, row): 74 | 75 | return {'source': row['metadata'], 76 | 'docs_id': row['qid'], 77 | 'title': 'N/A', 78 | 'section': 'N/A', 79 | 'start_character': 'N/A', 80 | 'end_character': 'N/A', 81 | 'text': self.html2md(f"### User: {row['question']} -\n\n### Top Answer: {self.get_best_answer(row['answers'])}"), 82 | } 83 | 84 | def get_topic(self, source_list: List[str]) -> str: 85 | 86 | filtered_list = list(set([elem.replace('https://', '').split('/')[0].split('.')[0] for elem in source_list])) 87 | 88 | return filtered_list[0] 89 | 90 | def load_save_dataset(self) -> None: 91 | 92 | # Heavy dataset, ~22GB, use cache location if not enough memory 93 | dataset = load_dataset("HuggingFaceH4/stack-exchange-preferences", 94 | split="train", 95 | # cache_dir="/home/USERID/.cache/huggingface/datasets/", 96 | ) 97 | 98 | # funcs = [ 99 | # # Select Subset of data and preprocess to keep only top answer 100 | # lambda data: data.shuffle(seed=42).select(range(self.n_samples)).map(self.process_qna, 101 | # remove_columns=['qid', 'metadata', 'answers', 'question'] 102 | # ), 103 | # # Remove too lengthy answers 104 | # lambda data: data.filter(lambda x: len(x['text']) <= self.max_char_length) 105 | # ] 106 | 107 | # filtered_dataset = reduce(lambda res, f: f(res), funcs, dataset) 108 | # filtered_dataset.to_json(f"{ROOTPATH}/Data/StackExchange/KnowledgeCorpus/main/data_{datetime.fromtimestamp(time.time()).strftime('%Y%m%d%H')}.json", 109 | # lines=False) 110 | 111 | # Join over all categories at the end and shuffle again 112 | def funcs(cat): 113 | return [ 114 | 115 | # Filter only for a given category 116 | lambda data: data.filter(lambda x: cat == self.get_topic(x['metadata'])), 117 | 118 | # Select Subset of data and preprocess to keep only top answer 119 | lambda data: data.shuffle(seed=42).select(range(min(len(data), self.n_samples))).map(self.process_qna, 120 | remove_columns=['qid', 121 | 'metadata', 122 | 'answers', 123 | 'question'] 124 | ), 125 | 126 | ] 127 | 128 | data_cat_list = [] 129 | 130 | for cat in tqdm(self.cat_list): 131 | data_cat_list.append(reduce(lambda res, f: f(res), funcs(cat), dataset)) 132 | concat_dataset = concatenate_datasets(data_cat_list).shuffle(seed=42) 133 | 134 | concat_dataset.to_json(f"{ROOTPATH}/Data/StackExchange/KnowledgeCorpus/main/data_{datetime.fromtimestamp(time.time()).strftime('%Y%m%d%H')}.json", 135 | lines=False) 136 | 137 | 138 | if __name__ == "__main__": 139 | 140 | stack_exchange_data = ExchangeData(n_samples=400, 141 | max_char_length=1500) 142 | 143 | stack_exchange_data.load_save_dataset() 144 | -------------------------------------------------------------------------------- /auto-rag-eval/Data/StackExchange/preprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | from datetime import datetime 4 | from functools import reduce 5 | from os.path import abspath, dirname 6 | from typing import Dict, List 7 | 8 | from datasets import concatenate_datasets, load_dataset 9 | from markdownify import markdownify as md 10 | from tqdm import tqdm 11 | 12 | ROOTPATH = dirname(dirname(abspath(__file__))) 13 | 14 | 15 | # Please use first the HuggingFace script at https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences to get the data 16 | 17 | class StackExchangeData: 18 | 19 | def __init__(self, 20 | n_samples: int, 21 | max_char_length: int): 22 | 23 | self.n_samples = n_samples 24 | self.max_char_length = max_char_length 25 | self.cat_list = ['Stackoverflow', 'math', 'superuser', 26 | 'serverfault', 'askubuntu', 'electronics', 27 | 'physics', 'unix', 'tex', 'english', 28 | 'meta', 'apple', 'ell', 'gaming', 29 | 'stats', 'softwareengineering', 30 | 'mathoverflow', 'gis', 'diy', 'magento'] 31 | 32 | def get_best_answer(self, answers: List[Dict[str, str]]): 33 | """return the best answer, that is, the one with the highest score""" 34 | best_index = 0 35 | best_score = answers[0]["pm_score"] 36 | for i in range(1, len(answers)): 37 | if answers[i]["pm_score"] > best_score : 38 | best_score = answers[i]["pm_score"] 39 | best_index = i 40 | return answers[best_index]["text"] 41 | 42 | def lang_callback(self, el): 43 | lang = el['class'][0] if el.has_attr('class') else None 44 | return lang.split("-")[-1] if lang else None 45 | 46 | def html2md(self, text: str) -> str: 47 | text = md(text, code_language_callback=self.lang_callback) 48 | text = re.sub(r"\n\s*\n", "\n\n", text).strip() 49 | return text.encode('utf-8', 'replace').decode() 50 | 51 | def process_qna(self, row): 52 | 53 | return {'source': row['metadata'], 54 | 'docs_id': row['qid'], 55 | 'title': 'N/A', 56 | 'section': 'N/A', 57 | 'start_character': 'N/A', 58 | 'end_character': 'N/A', 59 | 'text': self.html2md(f"### User: {row['question']} -\n\n### Top Answer: {self.get_best_answer(row['answers'])}"), 60 | } 61 | 62 | def get_topic(self, source_list: List[str]) -> str: 63 | 64 | filtered_list = list(set([elem.replace('https://', '').split('/')[0].split('.')[0] for elem in source_list])) 65 | 66 | return filtered_list[0] 67 | 68 | def load_save_dataset(self) -> None: 69 | 70 | # Heavy dataset, ~22GB, use cache location if not enough memory 71 | dataset = load_dataset("HuggingFaceH4/stack-exchange-preferences", 72 | split="train", 73 | # cache_dir="/home/USERID/.cache/huggingface/datasets/", 74 | ) 75 | 76 | # funcs = [ 77 | # # Select Subset of data and preprocess to keep only top answer 78 | # lambda data: data.shuffle(seed=42).select(range(self.n_samples)).map(self.process_qna, 79 | # remove_columns=['qid', 'metadata', 'answers', 'question'] 80 | # ), 81 | # # Remove too lengthy answers 82 | # lambda data: data.filter(lambda x: len(x['text']) <= self.max_char_length) 83 | # ] 84 | 85 | # filtered_dataset = reduce(lambda res, f: f(res), funcs, dataset) 86 | # filtered_dataset.to_json(f"{ROOTPATH}/Data/StackExchange/KnowledgeCorpus/main/data_{datetime.fromtimestamp(time.time()).strftime('%Y%m%d%H')}.json", 87 | # lines=False) 88 | 89 | # Join over all categories at the end and shuffle again 90 | def funcs(cat): 91 | return [ 92 | 93 | # Filter only for a given category 94 | lambda data: data.filter(lambda x: cat == self.get_topic(x['metadata'])), 95 | 96 | # Select Subset of data and preprocess to keep only top answer 97 | lambda data: data.shuffle(seed=42).select(range(min(len(data), self.n_samples))).map(self.process_qna, 98 | remove_columns=['qid', 99 | 'metadata', 100 | 'answers', 101 | 'question'] 102 | ), 103 | 104 | ] 105 | 106 | data_cat_list = [] 107 | 108 | for cat in tqdm(self.cat_list): 109 | data_cat_list.append(reduce(lambda res, f: f(res), funcs(cat), dataset)) 110 | concat_dataset = concatenate_datasets(data_cat_list).shuffle(seed=42) 111 | 112 | concat_dataset.to_json(f"{ROOTPATH}/Data/StackExchange/KnowledgeCorpus/main/data_{datetime.fromtimestamp(time.time()).strftime('%Y%m%d%H')}.json", 113 | lines=False) 114 | 115 | 116 | if __name__ == "__main__": 117 | 118 | stack_exchange_data = StackExchangeData( 119 | n_samples=400, 120 | max_char_length=1500) 121 | 122 | stack_exchange_data.load_save_dataset() 123 | -------------------------------------------------------------------------------- /auto-rag-eval/Data/StackExchange/stack_exchanges.json: -------------------------------------------------------------------------------- 1 | [ 2 | "3dprinting", 3 | "3dprinting.meta", 4 | "academia", 5 | "ai.meta", 6 | "ai", 7 | "android.meta", 8 | "android", 9 | "anime.meta", 10 | "anime", 11 | "apple.meta", 12 | "apple", 13 | "arduino.meta", 14 | "arduino", 15 | "askubuntu", 16 | "astronomy", 17 | "astronomy.meta", 18 | "aviation", 19 | "aviation.meta", 20 | "avp", 21 | "avp.meta", 22 | "beer", 23 | "beer.meta", 24 | "bicycles", 25 | "bicycles.meta", 26 | "bioinformatics", 27 | "bioinformatics.meta", 28 | "biology", 29 | "biology.meta", 30 | "bitcoin", 31 | "bitcoin.meta", 32 | "blender", 33 | "blender.meta", 34 | "boardgames", 35 | "boardgames.meta", 36 | "bricks", 37 | "bricks.meta", 38 | "buddhism", 39 | "buddhism.meta", 40 | "cardano", 41 | "cardano.meta", 42 | "chemistry", 43 | "chemistry.meta", 44 | "chess", 45 | "chess.meta", 46 | "chinese", 47 | "chinese.meta", 48 | "christianity", 49 | "christianity.meta", 50 | "civicrm", 51 | "civicrm.meta", 52 | "codegolf", 53 | "codegolf.meta", 54 | "codereview", 55 | "codereview.meta", 56 | "coffee", 57 | "coffee.meta", 58 | "cogsci", 59 | "cogsci.meta", 60 | "computergraphics", 61 | "computergraphics.meta", 62 | "conlang", 63 | "conlang.meta", 64 | "cooking", 65 | "cooking.meta", 66 | "craftcms", 67 | "craftcms.meta", 68 | "crafts", 69 | "crafts.meta", 70 | "crypto", 71 | "crypto.meta", 72 | "cs", 73 | "cs.meta", 74 | "cseducators", 75 | "cseducators.meta", 76 | "cstheory", 77 | "cstheory.meta", 78 | "datascience", 79 | "datascience.meta", 80 | "dba", 81 | "dba.meta", 82 | "devops", 83 | "devops.meta", 84 | "diy", 85 | "diy.meta", 86 | "drones", 87 | "drones.meta", 88 | "drupal", 89 | "drupal.meta", 90 | "dsp", 91 | "dsp.meta", 92 | "earthscience", 93 | "earthscience.meta", 94 | "ebooks", 95 | "ebooks.meta", 96 | "economics", 97 | "economics.meta", 98 | "electronics", 99 | "electronics.meta", 100 | "elementaryos", 101 | "elementaryos.meta", 102 | "ell", 103 | "ell.meta", 104 | "emacs", 105 | "emacs.meta", 106 | "engineering", 107 | "engineering.meta", 108 | "english", 109 | "english.meta", 110 | "eosio", 111 | "eosio.meta", 112 | "esperanto", 113 | "esperanto.meta", 114 | "ethereum", 115 | "ethereum.meta", 116 | "expatriates", 117 | "expatriates.meta", 118 | "expressionengine", 119 | "expressionengine.meta", 120 | "fitness", 121 | "fitness.meta", 122 | "freelancing", 123 | "freelancing.meta", 124 | "french", 125 | "french.meta", 126 | "gamedev", 127 | "gamedev.meta", 128 | "gaming", 129 | "gaming.meta", 130 | "gardening", 131 | "gardening.meta", 132 | "genealogy", 133 | "genealogy.meta", 134 | "german", 135 | "german.meta", 136 | "gis", 137 | "gis.meta", 138 | "graphicdesign", 139 | "graphicdesign.meta", 140 | "ham", 141 | "ham.meta", 142 | "hardwarerecs", 143 | "hardwarerecs.meta", 144 | "health", 145 | "health.meta", 146 | "hermeneutics", 147 | "hermeneutics.meta", 148 | "hinduism", 149 | "hinduism.meta", 150 | "history", 151 | "history.meta", 152 | "homebrew", 153 | "homebrew.meta", 154 | "hsm", 155 | "hsm.meta", 156 | "interpersonal", 157 | "interpersonal.meta", 158 | "iot", 159 | "iot.meta", 160 | "iota", 161 | "iota.meta", 162 | "islam", 163 | "islam.meta", 164 | "italian", 165 | "italian.meta", 166 | "japanese", 167 | "japanese.meta", 168 | "joomla", 169 | "joomla.meta", 170 | "judaism", 171 | "judaism.meta", 172 | "korean", 173 | "korean.meta", 174 | "languagelearning", 175 | "languagelearning.meta", 176 | "latin", 177 | "latin.meta", 178 | "law", 179 | "law.meta", 180 | "lifehacks", 181 | "lifehacks.meta", 182 | "linguistics", 183 | "linguistics.meta", 184 | "literature", 185 | "literature.meta", 186 | "magento", 187 | "magento.meta", 188 | "martialarts", 189 | "martialarts.meta", 190 | "materials", 191 | "materials.meta", 192 | "math", 193 | "math.meta", 194 | "matheducators", 195 | "matheducators.meta", 196 | "mathematica", 197 | "mathematica.meta", 198 | "mathoverflow", 199 | "mechanics.meta", 200 | "mechanics", 201 | "meta.askubuntu", 202 | "meta.mathoverflow", 203 | "meta.serverfault", 204 | "meta.stackexchange", 205 | "meta.stackoverflow", 206 | "meta.superuser", 207 | "moderators.meta", 208 | "moderators", 209 | "monero.meta", 210 | "monero", 211 | "money.meta", 212 | "money", 213 | "movies.meta", 214 | "movies", 215 | "music.meta", 216 | "music", 217 | "musicfans.meta", 218 | "musicfans", 219 | "mythology.meta", 220 | "mythology", 221 | "networkengineering.meta", 222 | "networkengineering", 223 | "opendata.meta", 224 | "opendata", 225 | "opensource.meta", 226 | "opensource", 227 | "or.meta", 228 | "or", 229 | "outdoors.meta", 230 | "outdoors", 231 | "parenting.meta", 232 | "parenting", 233 | "patents.meta", 234 | "patents", 235 | "pets.meta", 236 | "pets", 237 | "philosophy.meta", 238 | "philosophy", 239 | "photo.meta", 240 | "photo", 241 | "physics.meta", 242 | "physics", 243 | "pm.meta", 244 | "pm", 245 | "poker.meta", 246 | "poker", 247 | "politics.meta", 248 | "politics", 249 | "portuguese.meta", 250 | "portuguese", 251 | "puzzling.meta", 252 | "puzzling", 253 | "quant.meta", 254 | "quant", 255 | "quantumcomputing.meta", 256 | "quantumcomputing", 257 | "raspberrypi.meta", 258 | "raspberrypi", 259 | "retrocomputing.meta", 260 | "retrocomputing", 261 | "reverseengineering.meta", 262 | "reverseengineering", 263 | "robotics.meta", 264 | "robotics", 265 | "rpg.meta", 266 | "rpg", 267 | "rus.meta", 268 | "rus", 269 | "russian.meta", 270 | "russian", 271 | "salesforce.meta", 272 | "salesforce", 273 | "scicomp.meta", 274 | "scicomp", 275 | "scifi.meta", 276 | "scifi", 277 | "security.meta", 278 | "security", 279 | "serverfault", 280 | "sharepoint", 281 | "sharepoint.meta", 282 | "sitecore", 283 | "sitecore.meta", 284 | "skeptics", 285 | "skeptics.meta", 286 | "softwareengineering", 287 | "softwareengineering.meta", 288 | "softwarerecs", 289 | "softwarerecs.meta", 290 | "sound", 291 | "sound.meta", 292 | "space", 293 | "space.meta", 294 | "spanish", 295 | "spanish.meta", 296 | "sports", 297 | "sports.meta", 298 | "sqa", 299 | "sqa.meta", 300 | "stackapps", 301 | "stats.meta", 302 | "stats", 303 | "stellar.meta", 304 | "stellar", 305 | "superuser", 306 | "sustainability", 307 | "sustainability.meta", 308 | "tex", 309 | "tex.meta", 310 | "tezos", 311 | "tezos.meta", 312 | "tor", 313 | "tor.meta", 314 | "travel", 315 | "travel.meta", 316 | "tridion", 317 | "tridion.meta", 318 | "ukrainian", 319 | "ukrainian.meta", 320 | "unix", 321 | "unix.meta", 322 | "ux", 323 | "ux.meta", 324 | "vegetarianism", 325 | "vegetarianism.meta", 326 | "vi", 327 | "vi.meta", 328 | "webapps", 329 | "webapps.meta", 330 | "webmasters", 331 | "webmasters.meta", 332 | "windowsphone", 333 | "windowsphone.meta", 334 | "woodworking", 335 | "woodworking.meta", 336 | "wordpress", 337 | "wordpress.meta", 338 | "workplace", 339 | "workplace.meta", 340 | "worldbuilding", 341 | "worldbuilding.meta", 342 | "writers", 343 | "writers.meta", 344 | "Stackoverflow", 345 | ] -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/README.md: -------------------------------------------------------------------------------- 1 | # Exam Analysis 2 | 3 | This folder contains several function and notebook utilies for the analysis of the generated exam. 4 | In particular: 5 | 6 | * **Item Response Theory Models** 7 | * `item_response_models.py` contains the classes for the base IRT model `BaseItemResponseModel` and for the `HierarchicalItemResponseModel` 8 | * `iterative_item_response_models.py` contains the class for the `IterativeHierarchicalItemResponseModel` described in section 6 of the paper. 9 | * `generate_irt_plots` allows to generate the IRT graphs and analysis results for your task of interest, using the previous classes. 10 | * **Bloom's Taxonomy** 11 | * `bloom_taxonomy_model.py`: Automated classification of a question into Bloom's taxonomy criteria. 12 | * `taxonomy_analysis.ipynb`: Notebook to apply Bloom's taxonomy model to a given exam and study results. 13 | * **General Utilities** 14 | * `compute_exam_radar_plot.ipynb` is a utility notebook to generate radar plot per categories of the exam performance. -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/__pycache__/bloom_taxonomy_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamAnalysis/__pycache__/bloom_taxonomy_model.cpython-39.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/__pycache__/generate_irt_plots.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamAnalysis/__pycache__/generate_irt_plots.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/__pycache__/generate_iterative_irt_plots.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamAnalysis/__pycache__/generate_iterative_irt_plots.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/__pycache__/generate_recursive_irt_plots.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamAnalysis/__pycache__/generate_recursive_irt_plots.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/__pycache__/item_response_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamAnalysis/__pycache__/item_response_models.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/__pycache__/item_response_models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamAnalysis/__pycache__/item_response_models.cpython-39.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/__pycache__/iterative_item_response_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamAnalysis/__pycache__/iterative_item_response_models.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/bloom_taxonomy_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import spacy 4 | 5 | # If required, download the spaCy English model 6 | # !python -m spacy download en_core_web_sm 7 | 8 | TAXONOMY_V2 = {'Remembering': ['list', 'identify', 'name', 'define', 'mention', 'recall', 'label', 'state', 'recognize', 'repeat'], 9 | 'Understanding': ['explain', 'describe', 'summarize', 'predict', 'interpret', 'paraphrase', 'translate', 'illustrate', 'rephrase', 'clarify', 'check', 'find', 'experience', 'suspect', 'review', 'notice', 'assume', 'interact', 'observe', 'understand'], 10 | 'Applying': ['demonstrate', 'apply', 'use', 'illustrate', 'solve', 'show', 'execute', 'implement', 'operate', 'practice', 'set', 'configure', 'use', 'try', 'follow', 'take', 'use', 'run', 'serve', 'task', 'operate', 'work', 'enable', 'exist', 'read', 'write'], 11 | 'Analyzing': ['analyze', 'distinguish', 'compare', 'differentiate', 'examine', 'test', 'question', 'inspect', 'debate', 'investigate', 'manage', 'resolve', 'optimize', 'troubleshoot', 'investigate', 'compare', 'differentiate'], 12 | 'Evaluating': ['evaluate', 'rate', 'justify', 'critique', 'decide', 'rank', 'measure', 'recommend', 'test', 'validate', 'assess', 'evaluate', 'decide', 'choose', 'verify', 'test', 'monitor', 'validate', 'recommend'], 13 | 'Creating': ['design', 'construct', 'produce', 'invent', 'devise', 'formulate', 'originate', 'assemble', 'generate', 'compose', 'create', 'design', 'develop', 'generate', 'implement', 'produce', 'build', 'customize', 'formulate']} 14 | 15 | TAXONOMY_V1 = {"Remembering" : ["list", "identify", "name", "define", "mention", "recall", "label", "state", "recognize", "repeat"], 16 | "Understanding" : ["explain", "describe", "summarize", "predict", "interpret", "paraphrase", "translate", "illustrate", "rephrase", "clarify"], 17 | "Applying" : ["demonstrate", "apply", "use", "illustrate", "solve", "show", "execute", "implement", "operate", "practice"], 18 | "Analyzing" : ["analyze", "distinguish", "compare", "differentiate", "examine", "test", "question", "inspect", "debate", "investigate"], 19 | "Evaluating" : ["evaluate", "rate", "justify", "critique", "decide", "rank", "measure", "recommend", "test", "validate", "assess"], 20 | "Creating" : ["design", "construct", "produce", "invent", "devise", "formulate", "originate", "assemble", "generate", "compose"]} 21 | 22 | 23 | def categorize_question(question: str, 24 | taxonomy: Dict[str, List[str]] = TAXONOMY_V2) -> List[str]: 25 | ''' 26 | Categorize questions using Bloom's taxonomy approximation. 27 | 28 | Parameters: 29 | question (str): The question to categorize 30 | taxonomy (Dict[str, List[str]]): The taxonomy to use for categorization 31 | ''' 32 | 33 | nlp = spacy.load("en_core_web_sm") 34 | 35 | # Define verb lists for each category 36 | 37 | # Convert the question to lowercase and split it into words 38 | doc = nlp(question) 39 | verbs = [token.lemma_ for token in doc if token.pos_ == "VERB"] 40 | 41 | # Check for verbs from each category 42 | classif = [key 43 | for key in taxonomy.keys() 44 | if any(verb in verbs for verb in taxonomy[key])] 45 | 46 | return classif if len(classif) > 0 else ["Uncategorized"] 47 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/generate_irt_plots.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os.path import abspath, dirname 3 | 4 | from ExamAnalysis.item_response_models import ( 5 | ExamSetting, 6 | HierarchicalItemResponseModel, 7 | ItemResponseModel, 8 | ) 9 | from tqdm import tqdm 10 | 11 | 12 | def get_all_students(model, task): 13 | 14 | root_path = f'{dirname(dirname(abspath(__file__)))}/Data/{task}/EvalResults' 15 | extended_students = [ 16 | [ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/13b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl', 17 | llm='llamav2:13B', 18 | retrieval='closed_book', 19 | icl=i, 20 | name=f'Closed Book@{i} [13B]'), 21 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl', 22 | llm='llamav2:13B', 23 | retrieval='rag_siamese', 24 | icl=i, 25 | name=f'Rag Siamese@{i} [13B]'), 26 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl', 27 | llm='llamav2:13B', 28 | retrieval='rag_dpr', 29 | icl=i, 30 | name=f'Rag DPR@{i} [13B]'), 31 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl', 32 | llm='llamav2:13B', 33 | retrieval='rag_bm25', 34 | icl=i, 35 | name=f'Rag BM25@{i} [13B]'), 36 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/13b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_icl{i}.jsonl', 37 | llm='llamav2:13B', 38 | retrieval='rag_multi_qa', 39 | icl=i, 40 | name=f'Rag MultiQA@{i} [13B]'), 41 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/13b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_icl{i}.jsonl', 42 | llm='llamav2:13B', 43 | retrieval='rag_dprv2', 44 | icl=i, 45 | name=f'Rag DPRV2@{i} [13B]'), 46 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/13b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl', 47 | llm='llamav2:13B', 48 | retrieval='open_book', 49 | icl=i, 50 | name=f'Open Book@{i} [13B]')] 51 | for i in range(3) 52 | ] 53 | 54 | # Add 70B Models 55 | extended_students.extend([[ 56 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/70b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl', 57 | llm='llamav2:70B', 58 | retrieval='closed_book', 59 | icl=i, 60 | name=f'Closed Book@{i} [70B]'), 61 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl', 62 | llm='llamav2:70B', 63 | retrieval='rag_siamese', 64 | icl=i, 65 | name=f'Rag Siamese@{i} [70B]'), 66 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl', 67 | llm='llamav2:70B', 68 | retrieval='rag_dpr', 69 | icl=i, 70 | name=f'Rag DPR@{i} [70B]'), 71 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl', 72 | llm='llamav2:70B', 73 | retrieval='rag_bm25', 74 | icl=i, 75 | name=f'Rag BM25@{i} [70B]'), 76 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/70b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_icl{i}.jsonl', 77 | llm='llamav2:70B', 78 | retrieval='rag_multi_qa', 79 | icl=i, 80 | name=f'Rag MultiQA@{i} [70B]'), 81 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/70b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_icl{i}.jsonl', 82 | llm='llamav2:70B', 83 | retrieval='rag_dprv2', 84 | icl=i, 85 | name=f'Rag DPRV2@{i} [70B]'), 86 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/70b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl', 87 | llm='llamav2:70B', 88 | retrieval='open_book', 89 | icl=i, 90 | name=f'Open Book@{i} [70B]')] for i in range(3)], 91 | ) 92 | 93 | # Add Mistral:7B Models 94 | extended_students.extend([[ 95 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/mistral/7b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl', 96 | llm='mistral:7b', 97 | retrieval='closed_book', 98 | icl=i, 99 | name=f'Closed Book@{i} [7B]'), 100 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl', 101 | llm='mistral:7b', 102 | retrieval='rag_siamese', 103 | icl=i, 104 | name=f'Rag Siamese@{i} [7B]'), 105 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl', 106 | llm='mistral:7b', 107 | retrieval='rag_dpr', 108 | icl=i, 109 | name=f'Rag DPR@{i} [7B]'), 110 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl', 111 | llm='mistral:7b', 112 | retrieval='rag_bm25', 113 | icl=i, 114 | name=f'Rag BM25@{i} [7B]'), 115 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/mistral/7b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_icl{i}.jsonl', 116 | llm='mistral:7b', 117 | retrieval='rag_multi_qa', 118 | icl=i, 119 | name=f'Rag MultiQA@{i} [7B]'), 120 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/mistral/7b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_icl{i}.jsonl', 121 | llm='mistral:7b', 122 | retrieval='rag_dprv2', 123 | icl=i, 124 | name=f'Rag DPRV2@{i} [7B]'), 125 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/mistral/7b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl', 126 | llm='mistral:7b', 127 | retrieval='open_book', 128 | icl=i, 129 | name=f'Open Book@{i} [7B]')] for i in range(3)], 130 | ) 131 | 132 | return [i for elem in extended_students for i in elem] 133 | 134 | 135 | def print_nested_dict(d, indent=0): 136 | """Recursively prints nested dictionaries with increasing indentation.""" 137 | for key, value in d.items(): 138 | print(' ' * indent + str(key)) 139 | if isinstance(value, dict): 140 | print_nested_dict(value, indent + 1) 141 | else: 142 | print(' ' * (indent + 1) + (f"{value:.02f}" if type(value) != str else value)) 143 | 144 | 145 | if __name__ == '__main__': 146 | 147 | MODELS = ["llamav2"] 148 | TASKS = ['StackExchange', 'Arxiv', 'SecFilings'] 149 | IRT_MODELS = [3] 150 | 151 | for task in tqdm(TASKS): 152 | 153 | all_stats = {} 154 | task_path = f"{dirname(dirname(abspath(__file__)))}/Data/{task}/EvalResults/IRT" 155 | 156 | for llm_model in MODELS: 157 | 158 | for irt_model_type in IRT_MODELS: 159 | 160 | print(f'Starting Analysis for task {task}, llm: {llm_model} and irt {irt_model_type}') 161 | expe_name = f"{llm_model}_hierar_irt_{irt_model_type}" 162 | 163 | item_response_analyzer = HierarchicalItemResponseModel(students=get_all_students(llm_model, task), 164 | irt_model_type=irt_model_type) 165 | estimator = item_response_analyzer.fit() 166 | all_stats[expe_name] = item_response_analyzer.compute_stats(estimator) 167 | 168 | item_response_analyzer.plot(estimator=estimator, 169 | exam_model=f'{task}:{llm_model.capitalize()}', 170 | save_path=f"{task_path}/12_{task}_fig_{expe_name}.png", 171 | font_size=12) 172 | 173 | item_response_analyzer.plot(estimator=estimator, 174 | exam_model=f'{task}:{llm_model.capitalize()}', 175 | save_path=f"{task_path}/14_{task}_fig_{expe_name}.png", 176 | font_size=14) 177 | 178 | item_response_analyzer.plot(estimator=estimator, 179 | exam_model=f'{task}:{llm_model.capitalize()}', 180 | save_path=f"{task_path}/16_{task}_fig_{expe_name}.png", 181 | font_size=16) 182 | 183 | item_response_analyzer.plot(estimator=estimator, 184 | exam_model=f'{task}:{llm_model.capitalize()}', 185 | save_path=f"{task_path}/18_{task}_fig_{expe_name}.png", 186 | font_size=18) 187 | 188 | item_response_analyzer.plot(estimator=estimator, 189 | exam_model=f'{task}:{llm_model.capitalize()}', 190 | save_path=f"{task_path}/20_{task}_fig_{expe_name}.png", 191 | font_size=20) 192 | 193 | item_response_analyzer.plot(estimator=estimator, 194 | exam_model=f'{task}:{llm_model.capitalize()}', 195 | save_path=f"{task_path}/22_{task}_fig_{expe_name}.png", 196 | font_size=22) 197 | 198 | with open(f"{task_path}/{task}_stats_hierar_irt.json", "w") as outfile: 199 | outfile.write(json.dumps(all_stats)) 200 | 201 | for task in tqdm(TASKS): 202 | all_stats = {} 203 | task_path = f"{dirname(dirname(abspath(__file__)))}/Data/{task}/EvalResults/IRT" 204 | 205 | for llm_model in MODELS: 206 | 207 | for irt_model_type in [2, 3]: 208 | 209 | print(f'Starting Analysis for task {task}, llm: {llm_model} and irt {irt_model_type}') 210 | expe_name = f"{llm_model}_base_irt_{irt_model_type}" 211 | 212 | item_response_analyzer = ItemResponseModel(students=get_all_students(llm_model, task), 213 | irt_model_type=irt_model_type) 214 | estimator = item_response_analyzer.fit() 215 | all_stats[expe_name] = item_response_analyzer.compute_stats(estimator) 216 | 217 | item_response_analyzer.plot(estimator=estimator, 218 | exam_model=f'{task}:{llm_model.capitalize()}', 219 | save_path=f"{task_path}/fig_{expe_name}.png") 220 | 221 | with open(f"{task_path}/stats_base_irt.json", "w") as outfile: 222 | outfile.write(json.dumps(all_stats)) 223 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/generate_iterative_irt_plots.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os.path import abspath, dirname 3 | 4 | from ExamAnalysis.item_response_models import ExamSetting 5 | from ExamAnalysis.iterative_item_response_models import \ 6 | IterativeHierarchicalItemResponseModel 7 | from tqdm import tqdm 8 | 9 | 10 | def get_all_students(model, task): 11 | 12 | 13 | root_path = f'{dirname(dirname(abspath(__file__)))}/Data/{task}/EvalResults' 14 | 15 | extended_students = [ 16 | [ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/13b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl', 17 | llm='llamav2:13B', 18 | retrieval='closed_book', 19 | icl=i, 20 | name=f'Closed Book@{i} [13B]'), 21 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl', 22 | llm='llamav2:13B', 23 | retrieval='rag_siamese', 24 | icl=i, 25 | name=f'Rag Siamese@{i} [13B]'), 26 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl', 27 | llm='llamav2:13B', 28 | retrieval='rag_dpr', 29 | icl=i, 30 | name=f'Rag DPR@{i} [13B]'), 31 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl', 32 | llm='llamav2:13B', 33 | retrieval='rag_bm25', 34 | icl=i, 35 | name=f'Rag BM25@{i} [13B]'), 36 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/13b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_new_ir_icl{i}.jsonl', 37 | llm='llamav2:13B', 38 | retrieval='rag_multi_qa', 39 | icl=i, 40 | name=f'Rag MultiQA@{i} [13B]'), 41 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/13b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_new_ir_icl{i}.jsonl', 42 | llm='llamav2:13B', 43 | retrieval='rag_dprv2', 44 | icl=i, 45 | name=f'Rag DPRV2@{i} [13B]'), 46 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/13b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl', 47 | llm='llamav2:13B', 48 | retrieval='open_book', 49 | icl=i, 50 | name=f'Open Book@{i} [13B]')] 51 | for i in range(3) 52 | ] 53 | 54 | # Add 70B Models 55 | extended_students.extend([[ 56 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/70b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl', 57 | llm='llamav2:70B', 58 | retrieval='closed_book', 59 | icl=i, 60 | name=f'Closed Book@{i} [70B]'), 61 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl', 62 | llm='llamav2:70B', 63 | retrieval='rag_siamese', 64 | icl=i, 65 | name=f'Rag Siamese@{i} [70B]'), 66 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl', 67 | llm='llamav2:70B', 68 | retrieval='rag_dpr', 69 | icl=i, 70 | name=f'Rag DPR@{i} [70B]'), 71 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl', 72 | llm='llamav2:70B', 73 | retrieval='rag_bm25', 74 | icl=i, 75 | name=f'Rag BM25@{i} [70B]'), 76 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/70b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_new_ir_icl{i}.jsonl', 77 | llm='llamav2:70B', 78 | retrieval='rag_multi_qa', 79 | icl=i, 80 | name=f'Rag MultiQA@{i} [70B]'), 81 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/70b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_new_ir_icl{i}.jsonl', 82 | llm='llamav2:70B', 83 | retrieval='rag_dprv2', 84 | icl=i, 85 | name=f'Rag DPRV2@{i} [70B]'), 86 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/70b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl', 87 | llm='llamav2:70B', 88 | retrieval='open_book', 89 | icl=i, 90 | name=f'Open Book@{i} [70B]')] for i in range(3)], 91 | ) 92 | 93 | # Add Mistral:7B Models 94 | extended_students.extend([[ 95 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/mistral/7b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl', 96 | llm='mistral:7b', 97 | retrieval='closed_book', 98 | icl=i, 99 | name=f'Closed Book@{i} [7B]'), 100 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl', 101 | llm='mistral:7b', 102 | retrieval='rag_siamese', 103 | icl=i, 104 | name=f'Rag Siamese@{i} [7B]'), 105 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl', 106 | llm='mistral:7b', 107 | retrieval='rag_dpr', 108 | icl=i, 109 | name=f'Rag DPR@{i} [7B]'), 110 | ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl', 111 | llm='mistral:7b', 112 | retrieval='rag_bm25', 113 | icl=i, 114 | name=f'Rag BM25@{i} [7B]'), 115 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/mistral/7b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_new_ir_icl{i}.jsonl', 116 | llm='mistral:7b', 117 | retrieval='rag_multi_qa', 118 | icl=i, 119 | name=f'Rag MultiQA@{i} [7B]'), 120 | ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/mistral/7b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_new_ir_icl{i}.jsonl', 121 | llm='mistral:7b', 122 | retrieval='rag_dprv2', 123 | icl=i, 124 | name=f'Rag DPRV2@{i} [7B]'), 125 | ExamSetting(path_pattern=f'{root_path}/{task}Exam/mistral/7b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl', 126 | llm='mistral:7b', 127 | retrieval='open_book', 128 | icl=i, 129 | name=f'Open Book@{i} [7B]')] for i in range(3)], 130 | ) 131 | 132 | return [i for elem in extended_students for i in elem] 133 | 134 | 135 | def print_nested_dict(d, indent=0): 136 | """Recursively prints nested dictionaries with increasing indentation.""" 137 | for key, value in d.items(): 138 | print(' ' * indent + str(key)) 139 | if isinstance(value, dict): 140 | print_nested_dict(value, indent + 1) 141 | else: 142 | print(' ' * (indent + 1) + (f"{value:.02f}" if type(value) != str else value)) 143 | 144 | 145 | if __name__ == '__main__': 146 | 147 | LLM_MODELS = ["llamav2"] 148 | TASKS = ['DevOps', 'StackExchange', 'Arxiv', 'SecFilings'] 149 | IRT_TYPE = [3] 150 | N_STEPS = 4 151 | DROP_RATIO = 0.1 152 | 153 | for task in tqdm(TASKS): 154 | all_stats = {} 155 | task_path = f"{dirname(dirname(abspath(__file__)))}/Data/{task}/EvalResults/IterativeIRT" 156 | 157 | for llm_model in LLM_MODELS: 158 | 159 | for irt_model_type in IRT_TYPE: 160 | 161 | print(f'Starting Analysis for task {task}, llm: {llm_model} and irt {irt_model_type}') 162 | expe_name = f"{llm_model}_recursive_irt_{irt_model_type}" 163 | 164 | iterative_item_response_analyzer = IterativeHierarchicalItemResponseModel(students=get_all_students(llm_model, task), 165 | irt_model_type=irt_model_type) 166 | estimator_dict = iterative_item_response_analyzer.fit(n_steps = N_STEPS, 167 | drop_ratio = DROP_RATIO) 168 | all_stats[expe_name] = {step_k: iterative_item_response_analyzer.compute_stats(estimator_dict[step_k]) 169 | for step_k in estimator_dict.keys()} 170 | 171 | iterative_item_response_analyzer.plot_iterative_informativeness( 172 | estimator_dict=estimator_dict, 173 | exam_model=f'{task}:{llm_model.capitalize()}', 174 | save_path=f"{task_path}/18_{task}_fig_{expe_name}_step{N_STEPS}.png") 175 | 176 | with open(f"{task_path}/recursive_irt_step{N_STEPS}.json", "w") as outfile: 177 | outfile.write(json.dumps(all_stats)) -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/iterative_item_response_models.py: -------------------------------------------------------------------------------- 1 | # from scipy.stats import norm, lognorm, beta 2 | from typing import Dict, List 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from ExamAnalysis.item_response_models import BaseItemResponseModel, ExamSetting 7 | from scipy.optimize import minimize 8 | from sklearn.metrics import mean_squared_error 9 | 10 | 11 | class IterativeHierarchicalItemResponseModel(BaseItemResponseModel): 12 | 13 | def __init__(self, 14 | students: List[ExamSetting], 15 | irt_model_type: int): 16 | 17 | super().__init__(students=students, 18 | irt_model_type=irt_model_type) 19 | 20 | def compute_theta(self, theta_params: np.array) -> np.array: 21 | 22 | llm_params = theta_params[:self.num_llm] 23 | retrieval_params = theta_params[self.num_llm:self.num_llm + self.num_retrieval] 24 | icl_params = theta_params[self.num_llm + self.num_retrieval:] 25 | 26 | return np.array([(llm_params[self.llm_map[model.llm]] 27 | + retrieval_params[self.retrieval_map[model.retrieval]] 28 | + icl_params[model.icl]) 29 | for model in self.students]) 30 | 31 | # Define the negative log-likelihood function for the 3PL model 32 | def hierarchical_neg_log_likelihood(self, params: np.array) -> float: 33 | 34 | a = params[:self.num_items] 35 | b = params[self.num_items:2 * self.num_items] 36 | c = params[2 * self.num_items:3 * self.num_items] 37 | theta = self.compute_theta(theta_params=params[3 * self.num_items:]) 38 | 39 | likelihood = 0 40 | for i in range(self.num_items): 41 | p = self.irt_model(theta=theta, a=a[i], b=b[i], c=c[i]) 42 | likelihood += np.sum(self.data[:, i] * np.log(p) + (1 - self.data[:, i]) * np.log(1 - p)) 43 | 44 | # Add a param penalty 45 | # l2_penalty = lambda_l2 * np.sum(params ** 2) 46 | # return -likelihood + l2_penalty 47 | 48 | # Add Gaussian priors for a, b, and c 49 | # prior_a = np.sum(-0.5 * ((a - prior_means['a']) ** 2) / prior_vars['a']) 50 | # prior_b = np.sum(-0.5 * ((b - prior_means['b']) ** 2) / prior_vars['b']) 51 | # prior_c = np.sum(-0.5 * ((c - prior_means['c']) ** 2) / prior_vars['c']) 52 | 53 | return -likelihood 54 | 55 | def _fit(self, 56 | initial_guess: np.array, 57 | params_bounds: List) -> Dict[str, np.array]: 58 | 59 | # Run optimization 60 | result = minimize( 61 | self.hierarchical_neg_log_likelihood, 62 | initial_guess, 63 | method='L-BFGS-B', 64 | bounds=[elem for bounds in params_bounds for elem in bounds] 65 | ) 66 | 67 | return { 68 | 'discrimination': result.x[:self.num_items], 69 | 'difficulty': result.x[self.num_items:2 * self.num_items], 70 | 'guessing': result.x[2 * self.num_items:3 * self.num_items], 71 | 'theta_params': result.x[3 * self.num_items:], 72 | 'theta': self.compute_theta(result.x[3 * self.num_items:]) 73 | } 74 | 75 | def _get_params_bounds(self) -> List: 76 | 77 | return [ 78 | [(0.5, 1.5) for _ in range(self.num_items)], # Bounds for a [discrimination] 79 | [(0.01, 1) for _ in range(self.num_items)], # Bounds for b [difficulty] 80 | [(0.2, 0.4) for _ in range(self.num_items)], # Bounds for c [guessing] 81 | [(-3, 3) for _ in range(self.num_theta_params)] # Bounds for theta 82 | ] 83 | 84 | def fit(self, 85 | n_steps: int = 2, 86 | drop_ratio: float = 0.1) -> Dict[str, np.array]: 87 | 88 | estimator_dict = {} 89 | 90 | # Initial guesses for a, b, c and theta 91 | initial_guess = np.concatenate([ 92 | np.ones(self.num_items), # Initial guesses for a [discrimination] 93 | np.zeros(self.num_items), # Initial guesses for b [difficulty] 94 | np.full(self.num_items, 0.25), # Initial guesses for c [guessing] 95 | np.zeros(self.num_theta_params) # Initial guesses for theta 96 | ]) 97 | 98 | params = self._fit(initial_guess=initial_guess, 99 | params_bounds=self._get_params_bounds()) 100 | 101 | estimator_dict[0] = params 102 | 103 | for step in range(1, n_steps): 104 | 105 | # Low-discrimation filtering, remove low self.drop_ratio % of questions 106 | percentile_value = np.percentile(params['discrimination'], 107 | drop_ratio) 108 | 109 | # Find the index of the closest value to this percentile in the array and filter it 110 | indices_to_remove = [k 111 | for k,v in enumerate(params['discrimination']) 112 | if v <= percentile_value] 113 | self.num_items -= len(indices_to_remove) 114 | 115 | 116 | # Round 1 guesses for a, b, c and theta 117 | updated_guess = np.concatenate([ 118 | np.delete(params['discrimination'], indices_to_remove), 119 | np.delete(params['difficulty'], indices_to_remove), 120 | np.delete(params['guessing'], indices_to_remove), 121 | params['theta_params'] 122 | ]) 123 | 124 | params = self._fit(initial_guess=updated_guess, 125 | params_bounds=self._get_params_bounds()) 126 | 127 | estimator_dict[step] = params 128 | 129 | return estimator_dict 130 | 131 | 132 | def compute_stats(self, estimator: Dict[str, np.array]): 133 | 134 | # Hierachical Model Params 135 | llm_params = estimator['theta_params'][:self.num_llm] 136 | retrieval_params = estimator['theta_params'][self.num_llm:self.num_llm + self.num_retrieval] 137 | icl_params = estimator['theta_params'][self.num_llm + self.num_retrieval:] 138 | 139 | # Calculate the RMSE for each item 140 | rmse_val = [np.sqrt(mean_squared_error(self.data[:, i], 141 | self.irt_model(a=estimator['discrimination'][i], 142 | b=estimator['difficulty'][i], 143 | c=estimator['guessing'][i], 144 | theta=estimator['theta']))) for i in range(len(estimator['discrimination']))] 145 | rmse_val_moy = [np.sqrt(mean_squared_error(self.data[:, i], 146 | self.data.mean(axis=1))) for i in range(len(estimator['discrimination']))] 147 | 148 | def get_mean_std(array: np.array) -> Dict[str, float]: 149 | return {'mean': np.mean(array), 'std': np.std(array)} 150 | 151 | stats = { 152 | "Mean Exam accuracy": {'mean': 100 * self.data.mean(), 'std': 100 * self.data.mean(axis=1).std()}, 153 | "Estimators": 154 | { 155 | "Discrimination (a)": get_mean_std(estimator['discrimination']), 156 | "Difficulty (b)": get_mean_std(estimator['difficulty']), 157 | "Guessing (c)": get_mean_std(estimator['guessing']), 158 | "Theta": get_mean_std(estimator['theta']), 159 | }, 160 | 'Theta': { 161 | 'LLM': {k: f"{llm_params[i]:.02f} [+ {llm_params[i]-llm_params[0]:.02f}]" 162 | for k, i in self.llm_map.items()}, 163 | 'Retrieval': {k: f"{retrieval_params[i]:.02f} [+ {retrieval_params[i]-retrieval_params[0]:.02f}]" 164 | for k, i in self.retrieval_map.items()}, 165 | 'ICL': {f"ICL@{k}": f"{icl_params[k]:.02f} [+ {icl_params[k]-icl_params[0]:.02f}]" 166 | for k in range(self.num_icl)}, 167 | }, 168 | 'All Thetas': {stud.name: f"{estimator['theta'][i]:.02f} (Acc: {self.data.mean(axis=1)[i]:.02f})" 169 | for i, stud in enumerate(self.students)}, 170 | 'RMSE': 171 | { 172 | 'IRT Pred': get_mean_std(rmse_val), 173 | 'Mean Pred Baseline': get_mean_std(rmse_val_moy), 174 | } 175 | } 176 | 177 | return stats 178 | 179 | def plot_iterative_informativeness(self, 180 | estimator_dict: Dict[str, Dict[str, np.array]], 181 | exam_model: str, 182 | save_path: str = None, 183 | font_size: int = 18) -> None: 184 | 185 | # Set global font size 186 | plt.rcParams.update({'font.size': font_size}) 187 | 188 | # Create an array of theta values for plotting 189 | theta_values = np.linspace(-3, 3, 300) 190 | 191 | # Create a 2x2 grid of subplots 192 | fig, ax = plt.subplots(figsize=(12, 8)) 193 | 194 | # colors = ['red', 'green', 'blue', 'purple', 'orange'] 195 | 196 | for step, estimator in estimator_dict.items(): 197 | 198 | # Assume these are the estimated parameters for 3 items 199 | a = estimator['discrimination'] # Discrimination parameters 200 | b = estimator['difficulty'] # Difficulty parameters 201 | c = estimator['guessing'] # Guessing parameters 202 | 203 | test_information = np.zeros_like(theta_values) 204 | for i in range(len(a)): 205 | p = self.irt_model(theta=theta_values, a=a[i], b=b[i], c=c[i]) 206 | information = a[i]**2 * p * (1 - p) 207 | test_information += information # Sum up information from all items 208 | ax.plot(theta_values, test_information / len(a), label=f'Step {step}') 209 | 210 | # # Add markers on the x-axis for the estimated theta values 211 | # for k, theta in enumerate(estimator['theta']): 212 | # color = colors[k % len(colors)] 213 | # ax.scatter(theta, 0, marker='x', color=color) 214 | 215 | ax.set_title(f'Exam Information Curve - {exam_model} Exam - {self.irt_model_type}PL Model') 216 | ax.set_xlabel('Theta (Ability)') 217 | ax.set_ylabel('Fisher Information') 218 | ax.legend() 219 | ax.grid(True) 220 | ax.grid(True) 221 | 222 | plt.tight_layout() 223 | if save_path: 224 | plt.savefig(save_path) 225 | plt.show() 226 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamAnalysis/taxonomy_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 24, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import glob\n", 11 | "from dataclasses import dataclass\n", 12 | "from bloom_taxonomy_model import categorize_question\n", 13 | "from pathlib import Path" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 25, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "@dataclass\n", 23 | "class ExamSetting:\n", 24 | " llm: str\n", 25 | " retrieval: str\n", 26 | " icl: int\n", 27 | " name: str\n", 28 | " path_pattern: str # Assuming base path is a constant attribute of the class\n", 29 | "\n", 30 | " def find_file_path(self):\n", 31 | " \"\"\"\n", 32 | " Find the file path using the class attributes.\n", 33 | " \"\"\"\n", 34 | " # Search for files matching the pattern\n", 35 | " matching_files = glob.glob(self.path_pattern)\n", 36 | " \n", 37 | " # Return the first matching file or None\n", 38 | " if matching_files is None or matching_files == []:\n", 39 | " raise ValueError(f\"Incorrect path pattern {self.path_pattern}\")\n", 40 | "\n", 41 | " return matching_files[0]\n", 42 | " \n", 43 | " @property\n", 44 | " def exists(self):\n", 45 | "\n", 46 | " # Search for files matching the pattern\n", 47 | " matching_files = glob.glob(self.path_pattern)\n", 48 | "\n", 49 | " return matching_files is not None and matching_files != []\n", 50 | "\n", 51 | " @property\n", 52 | " def data_path(self):\n", 53 | " \"\"\"\n", 54 | " Property to get the data path.\n", 55 | " \"\"\"\n", 56 | " return self.find_file_path()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 26, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "def get_all_students(model, task):\n", 66 | "\n", 67 | " root_path = f'{Path('.').resolve().parent}/Data/{task}/EvalResults'\n", 68 | " extended_students = [\n", 69 | " [ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/13b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl',\n", 70 | " llm='llamav2:13B',\n", 71 | " retrieval='closed_book',\n", 72 | " icl=i,\n", 73 | " name=f'Closed Book@{i} [13B]'),\n", 74 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl',\n", 75 | " llm='llamav2:13B',\n", 76 | " retrieval='rag_siamese',\n", 77 | " icl=i,\n", 78 | " name=f'Rag Siamese@{i} [13B]'),\n", 79 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl',\n", 80 | " llm='llamav2:13B',\n", 81 | " retrieval='rag_dpr',\n", 82 | " icl=i,\n", 83 | " name=f'Rag DPR@{i} [13B]'),\n", 84 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/13b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl',\n", 85 | " llm='llamav2:13B',\n", 86 | " retrieval='rag_bm25',\n", 87 | " icl=i,\n", 88 | " name=f'Rag BM25@{i} [13B]'),\n", 89 | " ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/13b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_icl{i}.jsonl',\n", 90 | " llm='llamav2:13B',\n", 91 | " retrieval='rag_multi_qa',\n", 92 | " icl=i,\n", 93 | " name=f'Rag MultiQA@{i} [13B]'),\n", 94 | " ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/13b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_icl{i}.jsonl',\n", 95 | " llm='llamav2:13B',\n", 96 | " retrieval='rag_dprv2',\n", 97 | " icl=i,\n", 98 | " name=f'Rag DPRV2@{i} [13B]'),\n", 99 | " ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/13b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl',\n", 100 | " llm='llamav2:13B',\n", 101 | " retrieval='open_book',\n", 102 | " icl=i,\n", 103 | " name=f'Open Book@{i} [13B]')] \n", 104 | " for i in range(3)\n", 105 | " ]\n", 106 | "\n", 107 | " # Add 70B Models\n", 108 | " extended_students.extend([[\n", 109 | " ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/70b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl',\n", 110 | " llm='llamav2:70B',\n", 111 | " retrieval='closed_book',\n", 112 | " icl=i,\n", 113 | " name=f'Closed Book@{i} [70B]'),\n", 114 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl',\n", 115 | " llm='llamav2:70B',\n", 116 | " retrieval='rag_siamese',\n", 117 | " icl=i,\n", 118 | " name=f'Rag Siamese@{i} [70B]'),\n", 119 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl',\n", 120 | " llm='llamav2:70B',\n", 121 | " retrieval='rag_dpr',\n", 122 | " icl=i,\n", 123 | " name=f'Rag DPR@{i} [70B]'),\n", 124 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/llamav2/70b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl',\n", 125 | " llm='llamav2:70B',\n", 126 | " retrieval='rag_bm25',\n", 127 | " icl=i,\n", 128 | " name=f'Rag BM25@{i} [70B]'),\n", 129 | " ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/70b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_icl{i}.jsonl',\n", 130 | " llm='llamav2:70B',\n", 131 | " retrieval='rag_multi_qa',\n", 132 | " icl=i,\n", 133 | " name=f'Rag MultiQA@{i} [70B]'),\n", 134 | " ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/llamav2/70b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_icl{i}.jsonl',\n", 135 | " llm='llamav2:70B',\n", 136 | " retrieval='rag_dprv2',\n", 137 | " icl=i,\n", 138 | " name=f'Rag DPRV2@{i} [70B]'),\n", 139 | " ExamSetting(path_pattern=f'{root_path}/{task}Exam/llamav2/70b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl',\n", 140 | " llm='llamav2:70B',\n", 141 | " retrieval='open_book',\n", 142 | " icl=i,\n", 143 | " name=f'Open Book@{i} [70B]')] for i in range(3)],\n", 144 | " )\n", 145 | "\n", 146 | " # Add Mistral:7B Models\n", 147 | " extended_students.extend([[\n", 148 | " ExamSetting(path_pattern=f'{root_path}/{task}Exam/mistral/7b/full_sample_{task}Exam_closed_book_{model}_results_*_icl{i}.jsonl',\n", 149 | " llm='mistral:7b',\n", 150 | " retrieval='closed_book',\n", 151 | " icl=i,\n", 152 | " name=f'Closed Book@{i} [7B]'),\n", 153 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_siamese_{model}_results_*_icl{i}.jsonl',\n", 154 | " llm='mistral:7b',\n", 155 | " retrieval='rag_siamese',\n", 156 | " icl=i,\n", 157 | " name=f'Rag Siamese@{i} [7B]'),\n", 158 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_dpr_{model}_results_*_icl{i}.jsonl',\n", 159 | " llm='mistral:7b',\n", 160 | " retrieval='rag_dpr',\n", 161 | " icl=i,\n", 162 | " name=f'Rag DPR@{i} [7B]'),\n", 163 | " ExamSetting(path_pattern=f'{root_path}/{task}RagExam/mistral/7b/full_sample_{task}Exam_rag_bm25_{model}_results_*_icl{i}.jsonl',\n", 164 | " llm='mistral:7b',\n", 165 | " retrieval='rag_bm25',\n", 166 | " icl=i,\n", 167 | " name=f'Rag BM25@{i} [7B]'),\n", 168 | " ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/mistral/7b/full_sample_{task}Exam_rag_multi_qa_{model}_results_*_icl{i}.jsonl',\n", 169 | " llm='mistral:7b',\n", 170 | " retrieval='rag_multi_qa',\n", 171 | " icl=i,\n", 172 | " name=f'Rag MultiQA@{i} [7B]'),\n", 173 | " ExamSetting(path_pattern=f'{root_path}/{task}NewRagExam/mistral/7b/full_sample_{task}Exam_rag_dpr_bm25_multi_qa_{model}_results_*_icl{i}.jsonl',\n", 174 | " llm='mistral:7b',\n", 175 | " retrieval='rag_dprv2',\n", 176 | " icl=i,\n", 177 | " name=f'Rag DPRV2@{i} [7B]'),\n", 178 | " ExamSetting(path_pattern=f'{root_path}/{task}Exam/mistral/7b/full_sample_{task}Exam_open_book_{model}_results_*_icl{i}.jsonl',\n", 179 | " llm='mistral:7b',\n", 180 | " retrieval='open_book',\n", 181 | " icl=i,\n", 182 | " name=f'Open Book@{i} [7B]')] for i in range(3)],\n", 183 | " )\n", 184 | "\n", 185 | " return [i for elem in extended_students for i in elem]" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 27, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "MODELS = [\"llamav2\"]\n", 195 | "TASKS = ['StackExchange', 'Arxiv', 'SecFilings']\n", 196 | "\n", 197 | "def load_data(data_path):\n", 198 | " with open(data_path, 'r') as f:\n", 199 | " data = [json.loads(line) for line in f]\n", 200 | "\n", 201 | " return data" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 42, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "for task in ['StackExchange', 'Arxiv', 'SecFilings', 'DevOps']:\n", 211 | "\n", 212 | " students = get_all_students(model='llamav2', task=task)\n", 213 | "\n", 214 | " questions_taxonomy = [categorize_question(elem['doc']['question'])\n", 215 | " for elem in load_data(students[0].data_path)]\n", 216 | "\n", 217 | " my_task_dict = {\n", 218 | " k: [elem['doc']['question'] for i, elem in enumerate(load_data(students[0].data_path)) if k in questions_taxonomy[i]]\n", 219 | " for k in ['Remembering', 'Understanding', 'Applying', 'Analyzing', 'Evaluating', 'Creating', 'Uncategorized']\n", 220 | " }\n", 221 | "\n", 222 | " with open(f\"{task}.json\", 'w') as file:\n", 223 | " json.dump(my_task_dict, file)" 224 | ] 225 | } 226 | ], 227 | "metadata": { 228 | "kernelspec": { 229 | "display_name": "Python 3", 230 | "language": "python", 231 | "name": "python3" 232 | }, 233 | "language_info": { 234 | "codemirror_mode": { 235 | "name": "ipython", 236 | "version": 3 237 | }, 238 | "file_extension": ".py", 239 | "mimetype": "text/x-python", 240 | "name": "python", 241 | "nbconvert_exporter": "python", 242 | "pygments_lexer": "ipython3", 243 | "version": "3.9.16" 244 | } 245 | }, 246 | "nbformat": 4, 247 | "nbformat_minor": 2 248 | } 249 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamEvaluator/DevOpsExam/DevOpsExam.yaml: -------------------------------------------------------------------------------- 1 | group: DevOpsExam 2 | task: 3 | - dataset_kwargs: &id001 4 | data_files: 5 | test: exam.json 6 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/openllama_090805/ 7 | doc_to_choice: '{{choices}}' 8 | doc_to_target: '{{correct_answer}}' 9 | doc_to_text: !function preprocess_exam.make_prompt_closed_book 10 | group: &id002 11 | - multiple_choice 12 | - DevOps 13 | metric_list: &id003 14 | - aggregation: mean 15 | higher_is_better: 'true' 16 | metric: acc 17 | - aggregation: mean 18 | higher_is_better: 'true' 19 | metric: acc_norm 20 | output_type: multiple_choice 21 | task: DevOpsExam_closed_book_openllama 22 | test_split: test 23 | training_split: null 24 | validation_split: null 25 | - dataset_kwargs: *id001 26 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/openllama_090805/ 27 | doc_to_choice: '{{choices}}' 28 | doc_to_target: '{{correct_answer}}' 29 | doc_to_text: !function preprocess_exam.make_prompt_open_book 30 | group: *id002 31 | metric_list: *id003 32 | output_type: multiple_choice 33 | task: DevOpsExam_open_book_openllama 34 | test_split: test 35 | training_split: null 36 | validation_split: null 37 | - dataset_kwargs: *id001 38 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/llamav2_090805/ 39 | doc_to_choice: '{{choices}}' 40 | doc_to_target: '{{correct_answer}}' 41 | doc_to_text: !function preprocess_exam.make_prompt_closed_book 42 | group: *id002 43 | metric_list: *id003 44 | output_type: multiple_choice 45 | task: DevOpsExam_closed_book_llamav2 46 | test_split: test 47 | training_split: null 48 | validation_split: null 49 | - dataset_kwargs: *id001 50 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/llamav2_090805/ 51 | doc_to_choice: '{{choices}}' 52 | doc_to_target: '{{correct_answer}}' 53 | doc_to_text: !function preprocess_exam.make_prompt_open_book 54 | group: *id002 55 | metric_list: *id003 56 | output_type: multiple_choice 57 | task: DevOpsExam_open_book_llamav2 58 | test_split: test 59 | training_split: null 60 | validation_split: null 61 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamEvaluator/DevOpsExam/DevOpsRagExam.yaml: -------------------------------------------------------------------------------- 1 | group: DevOpsRagExam 2 | task: 3 | - dataset_kwargs: &id001 4 | data_files: 5 | test: exam.json 6 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/openllama_2023090805/ 7 | doc_to_choice: '{{choices}}' 8 | doc_to_target: '{{correct_answer}}' 9 | doc_to_text: !function preprocess_exam.make_prompt_rag_dpr 10 | group: &id002 11 | - multiple_choice 12 | - DevOps 13 | metric_list: &id003 14 | - aggregation: mean 15 | higher_is_better: 'true' 16 | metric: acc 17 | - aggregation: mean 18 | higher_is_better: 'true' 19 | metric: acc_norm 20 | output_type: multiple_choice 21 | task: DevOpsExam_rag_dpr_openllama 22 | test_split: test 23 | training_split: null 24 | validation_split: null 25 | - dataset_kwargs: *id001 26 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/openllama_2023090805/ 27 | doc_to_choice: '{{choices}}' 28 | doc_to_target: '{{correct_answer}}' 29 | doc_to_text: !function preprocess_exam.make_prompt_rag_siamese 30 | group: *id002 31 | metric_list: *id003 32 | output_type: multiple_choice 33 | task: DevOpsExam_rag_siamese_openllama 34 | test_split: test 35 | training_split: null 36 | validation_split: null 37 | - dataset_kwargs: *id001 38 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/openllama_2023090805/ 39 | doc_to_choice: '{{choices}}' 40 | doc_to_target: '{{correct_answer}}' 41 | doc_to_text: !function preprocess_exam.make_prompt_rag_bm25 42 | group: *id002 43 | metric_list: *id003 44 | output_type: multiple_choice 45 | task: DevOpsExam_rag_bm25_openllama 46 | test_split: test 47 | training_split: null 48 | validation_split: null 49 | - dataset_kwargs: *id001 50 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/llamav2_2023090805/ 51 | doc_to_choice: '{{choices}}' 52 | doc_to_target: '{{correct_answer}}' 53 | doc_to_text: !function preprocess_exam.make_prompt_rag_dpr 54 | group: *id002 55 | metric_list: *id003 56 | output_type: multiple_choice 57 | task: DevOpsExam_rag_dpr_llamav2 58 | test_split: test 59 | training_split: null 60 | validation_split: null 61 | - dataset_kwargs: *id001 62 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/llamav2_2023090805/ 63 | doc_to_choice: '{{choices}}' 64 | doc_to_target: '{{correct_answer}}' 65 | doc_to_text: !function preprocess_exam.make_prompt_rag_siamese 66 | group: *id002 67 | metric_list: *id003 68 | output_type: multiple_choice 69 | task: DevOpsExam_rag_siamese_llamav2 70 | test_split: test 71 | training_split: null 72 | validation_split: null 73 | - dataset_kwargs: *id001 74 | dataset_path: /home/ubuntu/workspace/OpenAssistantEndpoint/MultiChoiceExam/DevOpsData/llamav2_2023090805/ 75 | doc_to_choice: '{{choices}}' 76 | doc_to_target: '{{correct_answer}}' 77 | doc_to_text: !function preprocess_exam.make_prompt_rag_bm25 78 | group: *id002 79 | metric_list: *id003 80 | output_type: multiple_choice 81 | task: DevOpsExam_rag_bm25_llamav2 82 | test_split: test 83 | training_split: null 84 | validation_split: null 85 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamEvaluator/DevOpsExam/__pycache__/preprocess_exam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamEvaluator/DevOpsExam/__pycache__/preprocess_exam.cpython-38.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamEvaluator/DevOpsExam/preprocess_exam.py: -------------------------------------------------------------------------------- 1 | def make_prompt_closed_book(doc): 2 | 3 | return "###Human: Question: {}\n\nCandidates:\n{}\n\n###Assistant: Correct answer".format(doc['question'], 4 | "\n".join(doc['choices'])) 5 | 6 | 7 | def make_prompt_open_book(doc): 8 | 9 | return "###Human: Question: {}\n\nContext:{}\n\nCandidates:\n{}\n\n###Assistant: Correct answer".format(doc['question'], 10 | doc['documentation'], 11 | "\n".join(doc['choices'])) 12 | 13 | 14 | def make_prompt_rag_dpr(doc, n_retrieved_docs: int = 1): 15 | 16 | return "###Human: Question: {}\n\Retrieved Documents:\n{}\n\nCandidates:\n{}\n\n###Assistant: Correct answer".format(doc['question'], 17 | "\n".join( 18 | doc['retrieved_context']['DPR'][:n_retrieved_docs]), 19 | "\n".join(doc['choices'])) 20 | 21 | 22 | def make_prompt_rag_siamese(doc, n_retrieved_docs: int = 1): 23 | 24 | return "###Human: Question: {}\n\Retrieved Documents:\n{}\n\nCandidates:\n{}\n\n###Assistant: Correct answer".format(doc['question'], 25 | "\n".join( 26 | doc['retrieved_context']['SIAMESE'][:n_retrieved_docs]), 27 | "\n".join(doc['choices'])) 28 | 29 | 30 | def make_prompt_rag_bm25(doc, n_retrieved_docs: int = 1): 31 | 32 | return "###Human: Question: {}\n\Retrieved Documents:\n{}\n\nCandidates:\n{}\n\n###Assistant: Correct answer".format(doc['question'], 33 | "\n".join( 34 | doc['retrieved_context']['BM25'][:n_retrieved_docs]), 35 | "\n".join(doc['choices'])) 36 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamEvaluator/README.md: -------------------------------------------------------------------------------- 1 | # Exam Evaluation 2 | 3 | We leverage [lm-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) package to evaluate the (LLM&Retrieval) system on the generated exam. 4 | To do, follow the next steps (lm-harness usage might have been updated since): 5 | 6 | ### Create a benchmark 7 | 8 | Create a benchmark folder for for your task, here `DevOpsExam`, see `ExamEvaluator/DevOpsExam` for the template. 9 | It contains a code file preprocess_exam,py for prompt templates and more importantly, a set of tasks to evaluate models on: 10 | 11 | * `DevOpsExam` contains the tasks associated to ClosedBook (not retrieval) and OpenBook (Oracle Retrieval). 12 | * `DevOpsRagExam` contains the tasks associated to Retrieval variants (DPR/Embeddings/BM25...). 13 | 14 | The script`task_evaluation.sh` provided illustrates the evalation of `Llamav2:Chat:13B` and `Llamav2:Chat:70B` on the task, using In-Context-Learning (ICL) with respectively 0, 1 and 2 samples. 15 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamEvaluator/task_evaluation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #Task supported: StackExchange, DevOps and rag variants StackExchangeRag, DevOpsRag... 3 | 4 | cd lm-evaluation-harness 5 | current_date=$(date +"%y%m%d%H") 6 | task_domain=$1 7 | echo "Evaluating ${task_domain} task" 8 | 9 | model_path="add/you/model/path/here" 10 | echo "Evaluating Llamav2 - 13B - ICL@0" 11 | accelerate launch main.py \ 12 | --model hf \ 13 | --model_args "pretrained=${model_path},load_in_8bit=True" \ 14 | --tasks "${task_domain}Exam" \ 15 | --device cuda \ 16 | --output_path "results/${task_domain}Exam/llamav2/13b/results_${current_date}_icl0.json" 17 | echo "Evaluating Llamav2 - 13B - ICL@1" 18 | accelerate launch main.py \ 19 | --model hf \ 20 | --model_args "pretrained=${model_path},load_in_8bit=True" \ 21 | --tasks "${task_domain}Exam" \ 22 | --device cuda \ 23 | --num_fewshot 1 \ 24 | --output_path "results/${task_domain}Exam/llamav2/13b/results_${current_date}_icl1.json" 25 | echo "Evaluating Llamav2 - 13B - ICL@2" 26 | accelerate launch main.py \ 27 | --model hf \ 28 | --model_args "pretrained=${model_path},load_in_8bit=True" \ 29 | --tasks "${task_domain}Exam" \ 30 | --device cuda \ 31 | --num_fewshot 2 \ 32 | --output_path "results/${task_domain}Exam/llamav2/13b/results_${current_date}_icl2.json" 33 | 34 | # Note the difference in arguments when using 70B models: python3 + parallelize=True vs accelerate launch 35 | model_path="add/you/model/path/here" 36 | echo "Evaluating Llamav2 - 70B - ICL@0" 37 | python3 main.py \ 38 | --model hf \ 39 | --model_args "pretrained=${model_path},parallelize=True" \ 40 | --tasks "${task_domain}Exam" \ 41 | --device cuda \ 42 | --output_path "results/${task_domain}Exam/llamav2/70b/results_${current_date}_icl0.json" 43 | echo "Evaluating Llamav2 - 70B - ICL@1" 44 | python3 main.py \ 45 | --model hf \ 46 | --model_args "pretrained=${model_path},parallelize=True" \ 47 | --tasks "${task_domain}Exam" \ 48 | --device cuda \ 49 | --num_fewshot 1 \ 50 | --output_path "results/${task_domain}Exam/llamav2/70b/results_${current_date}_icl1.json" 51 | echo "Evaluating Llamav2 - 70B - ICL@2" 52 | python3 main.py \ 53 | --model hf \ 54 | --model_args "pretrained=${model_path},parallelize=True" \ 55 | --tasks "${task_domain}Exam" \ 56 | --device cuda \ 57 | --num_fewshot 2 \ 58 | --output_path "results/${task_domain}Exam/llamav2/70b/results_${current_date}_icl2.json" -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/README.md: -------------------------------------------------------------------------------- 1 | # Exam Generation 2 | 3 | This folder contains several function and notebook utilies for the automated generation of the exam. 4 | 5 | ## Generation 6 | 7 | * `question_generator.py`: Python class to generate the raw exam, given a knowledge corpus. 8 | * `multi_choice_question.py`: Python classes `MultiChoiceQuestionParser` and `MultiChoiceQuestion` to convert raw question into filtered questions. 9 | * `multi_choice_exam.py`: Python class `MultiChoiceExam` to be invoked to generate the processed exam from the raw exam data. 10 | 11 | ## Distractors 12 | 13 | * `distractors_generator.py` can be used to generate new distractors for your exam questions, following a two-step approach. 14 | 15 | ## Utilities 16 | 17 | * `fake_exam_generator.py`: Code to create a fake exam from a given exam to check the validity of the evaluation. 18 | * `utils.py`: Among others, class `SimilarityChecker` to evaluate the similarity between answers, using embeddings and n-gram based methods. 19 | * In case you already generated an exam and want to add new retriever to evaluate, one can use the code from `extend_ir_existing_exam.py` 20 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/distractors_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/distractors_generator.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/enrich_existing_exam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/enrich_existing_exam.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/extend_ir_existing_exam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/extend_ir_existing_exam.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/fake_exam_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/fake_exam_generator.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/multi_choice_exam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/multi_choice_exam.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/multi_choice_exam_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/multi_choice_exam_generator.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/multi_choice_question.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/multi_choice_question.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/question_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/question_generator.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/raw_question_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/raw_question_generator.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/ExamGenerator/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/distractors_generator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import json 4 | import logging 5 | import time 6 | from datetime import datetime 7 | from os.path import abspath, dirname 8 | from typing import Dict, List 9 | 10 | from ExamGenerator.utils import get_single_file_in_folder 11 | from LLMServer.bedrock.claude_instant import ClaudeInstant 12 | from LLMServer.bedrock.claude_v2 import ClaudeV2 13 | from LLMServer.base_model import BaseLLM 14 | from tqdm import tqdm 15 | 16 | logger = logging.getLogger(__name__) 17 | ROOTPATH = dirname(dirname(abspath(__file__))) 18 | 19 | 20 | class LLMDistractorGenerator: 21 | 22 | def __init__(self, 23 | llm_model: BaseLLM): 24 | 25 | self.llm_model = llm_model 26 | 27 | def make_distractor_prompts(self, 28 | question: str, 29 | answer: str) -> str: 30 | return (f"### Human: Here is a technical question on AWS documentation: {question}." 31 | f"\nThe correct answer is {answer}.\nProvide 3 incorrect answers or distractors to " 32 | "this question.\n### Assistant:") 33 | 34 | def generate_distractors(self, 35 | exam: List[Dict[str, str]]) -> Dict[int, Dict[str, str]]: 36 | 37 | generated_distractors = {} 38 | for k in tqdm(range(0, len(exam))): 39 | answer = self.llm_model.invoke( 40 | prompt=self.make_distractor_prompts(question=exam[k]['question'], 41 | answer=exam[k]['correct_answer']), 42 | params={}) 43 | generated_distractors[k] = { 44 | **exam[k], 45 | "raw_distractors": answer 46 | } 47 | return generated_distractors 48 | 49 | 50 | class BatchDistractorGenerator: 51 | 52 | def __init__(self, 53 | task_domain: str, 54 | model_list: List[str], 55 | batch_size: int): 56 | 57 | self.batch_size = batch_size 58 | self.model_list = model_list 59 | self.task_domain = task_domain 60 | 61 | self.model_map = { 62 | 'ClaudeInstant': LLMDistractorGenerator( 63 | llm_model=ClaudeInstant()), 64 | 'ClaudeV2': LLMDistractorGenerator( 65 | llm_model=ClaudeV2()) 66 | } 67 | 68 | def batch_generate_distractors(self, exam_folder: str) -> None: 69 | 70 | with open(get_single_file_in_folder(exam_folder), "r") as f: 71 | data = json.load(f) 72 | 73 | logger.error((f"Processing a total of {len(data)} documentation pieces for {self.task_domain}" 74 | f" using models {self.model_list}, with batch size of {self.batch_size} " 75 | f"({1+len(data)//self.batch_size} batches)")) 76 | 77 | # Split the data into batches 78 | batches = [data[i:i + self.batch_size] 79 | for i in range(0, len(data), self.batch_size)] 80 | 81 | start_time = datetime.fromtimestamp( 82 | time.time()).strftime('%Y%m%d%H') 83 | 84 | try: 85 | 86 | for batch_index, batch in enumerate(batches): 87 | logger.error(f"Running batch {batch_index}") 88 | 89 | with concurrent.futures.ProcessPoolExecutor() as executor: 90 | 91 | futurs = {model: executor.submit(self.model_map[model].generate_distractors, batch) 92 | for model in self.model_list} 93 | updated_questions = {model: futur.result() for model, futur in futurs.items()} 94 | # Write the dictionary to a JSON file 95 | for model in updated_questions.keys(): 96 | filename = (f"{self.task_domain}_QCM_distractors_base_exam_{exam_folder.split('/')[-1]}" 97 | f"_to_{model}_{start_time}_batch{batch_index}.json") 98 | with open(f"{ROOTPATH}/Data/{self.task_domain}/RawExamData/{filename}", "w") as write_file: 99 | json.dump(updated_questions[model], write_file) 100 | 101 | except Exception as e: 102 | 103 | logger.error(f"Failure to generate disractors for batch {batch_index}: {e}") 104 | 105 | 106 | if __name__ == "__main__": 107 | 108 | parser = argparse.ArgumentParser( 109 | description="Creates Distractors from Exam Data") 110 | 111 | parser.add_argument( 112 | "--task-domain", 113 | help="Task Domain, among DevOps, StackExchange, MyOwnTask...", 114 | ) 115 | 116 | parser.add_argument( 117 | "--exam-folder", 118 | help="Exam data to use to generate distractors, eg html_llamav2_2023091421...", 119 | ) 120 | 121 | main_args, _ = parser.parse_known_args() 122 | 123 | raw_distractor_generator = BatchDistractorGenerator(batch_size=10, 124 | task_domain=main_args.task_domain, 125 | # model_list=['openllama', 'llamav2'] 126 | model_list=['openllama'] 127 | ) 128 | 129 | # TODO: Modify prompt 130 | raw_distractor_generator.batch_generate_distractors( 131 | exam_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/ExamData/{main_args.exam_folder}") 132 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/extend_ir_existing_exam.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os.path import abspath, dirname 3 | 4 | from RetrievalSystems.bm25 import BM25ContextProvider 5 | from RetrievalSystems.dpr_context_aggregator import DPRContextGenerator 6 | from RetrievalSystems.embedding_retriever import EmbeddingContextProvider 7 | from RetrievalSystems.siamese_retriever import SiameseContextProvider 8 | from tqdm import tqdm 9 | 10 | ROOTPATH = dirname(dirname(abspath(__file__))) 11 | 12 | if __name__ == '__main__': 13 | 14 | for exam_setting in tqdm([ 15 | {'task_domain': 'Arxiv', 16 | 'exam_folder': 'small_llamav2_2023091905'}, 17 | {'task_domain': 'Arxiv', 18 | 'exam_folder': 'small_openllama_2023091905'}, 19 | ]): 20 | 21 | context_generator_dict = { 22 | 'DPR': DPRContextGenerator(context_sources={ 23 | 'SIAMESE' : SiameseContextProvider(index_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/RetrievalIndex/siamese_emb", 24 | data_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/KnowledgeCorpus/main", 25 | regenerate_index=False), 26 | 'BM25': BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/KnowledgeCorpus/main") 27 | }), 28 | 'BM25' : BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/KnowledgeCorpus/main"), 29 | 'SIAMESE' : SiameseContextProvider(index_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/RetrievalIndex/siamese_emb", 30 | data_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/KnowledgeCorpus/main", 31 | regenerate_index=True), 32 | 'MultiQA' : EmbeddingContextProvider(index_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/RetrievalIndex/multi_qa_emb", 33 | data_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/KnowledgeCorpus/main", 34 | regenerate_index=True), 35 | 'DPR:MultiQA:BM25': DPRContextGenerator(context_sources={ 36 | 'MultiQA' : EmbeddingContextProvider(index_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/RetrievalIndex/multi_qa_emb", 37 | data_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/KnowledgeCorpus/main", 38 | regenerate_index=False), 39 | 'BM25': BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{exam_setting['task_domain']}/KnowledgeCorpus/main") 40 | }), 41 | } 42 | 43 | with open(f"{ROOTPATH}/Data/{exam_setting['task_domain']}/ExamData/{exam_setting['exam_folder']}/exam.json", "r") as outfile: 44 | docs_exam = json.load(outfile) 45 | 46 | for question in docs_exam: 47 | 48 | question['retrieved_context'] = {**question['retrieved_context'], 49 | **{retriever : [elem.text for elem in context_generator.get_context_from_query( 50 | question['question'])] for retriever, context_generator in context_generator_dict.items()}} 51 | 52 | with open(f"{ROOTPATH}/Data/{exam_setting['task_domain']}/ExamData/{exam_setting['exam_folder']}/updated_ir_exam.json", "w") as outfile: 53 | outfile.write(json.dumps(docs_exam)) 54 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/fake_exam_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from os.path import abspath, dirname 4 | from typing import Dict 5 | 6 | ROOTPATH = dirname(dirname(abspath(__file__))) 7 | 8 | 9 | class FakeExamGenerator: 10 | 11 | def __init__(self, 12 | task_domain: str, 13 | exam_folder: str): 14 | 15 | self.task_domain = task_domain 16 | self.exam_folder = exam_folder 17 | 18 | def generate_fake_distractor(self, 19 | exam_question: Dict[str, str]) -> Dict[str, str]: 20 | 21 | new_candidates = [") This is an absurd choice.", 22 | ") This is an ridiculus choice.", 23 | ") Picking this choice is a nonsense.", 24 | exam_question['correct_answer'][1:]] 25 | 26 | # shuffle the candidates 27 | random.shuffle(new_candidates) 28 | 29 | # build new candidates list 30 | shuffled_candidates = [f"{chr(65 + i)}{x}" 31 | for i, x in enumerate(new_candidates)] 32 | 33 | # find the new letter for the correct answer 34 | new_correct_answer = [f"{chr(65 + i)}{x}" 35 | for i, x in enumerate(new_candidates) if x == exam_question['correct_answer'][1:]][0] 36 | 37 | return {**{k: v 38 | for k, v in exam_question.items() if k not in ['choices', 'correct_answer']}, 39 | 'choices': shuffled_candidates, 40 | 'correct_answer': new_correct_answer} 41 | 42 | def generate_fake_exam(self) -> None: 43 | 44 | with open(f"{ROOTPATH}/Data/{self.task_domain}/ExamData/{self.exam_folder}/exam.json", 'r') as f: 45 | self.exam_data = json.load(f) 46 | 47 | fake_exam = [self.generate_fake_distractor(exam_question=elem) for elem in self.exam_data] 48 | 49 | with open(f"{ROOTPATH}/Data/FakeExam/{self.task_domain}/fake_{self.exam_folder}.json", "w") as outfile: 50 | outfile.write(json.dumps(fake_exam)) 51 | 52 | 53 | if __name__ == '__main__': 54 | 55 | fake_exam = FakeExamGenerator(task_domain='Arxiv', 56 | exam_folder='llamav2_2023091905') 57 | 58 | fake_exam.generate_fake_exam() 59 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/multi_choice_exam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import random 6 | import re 7 | from collections import Counter 8 | from os.path import abspath, dirname 9 | from typing import Dict, List 10 | 11 | import numpy as np 12 | from ExamGenerator.multi_choice_question import MultiChoiceQuestion 13 | from ExamGenerator.utils import SimilarityChecker, get_n_sentences 14 | from RetrievalSystems.bm25 import BM25ContextProvider 15 | from RetrievalSystems.context_utils import ContextProvider 16 | from RetrievalSystems.dpr_context_aggregator import DPRContextGenerator 17 | from RetrievalSystems.embedding_retriever import EmbeddingContextProvider 18 | from RetrievalSystems.siamese_retriever import SiameseContextProvider 19 | from tqdm import tqdm 20 | 21 | ROOTPATH = dirname(dirname(abspath(__file__))) 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class MultiChoiceExam: 26 | 27 | def __init__(self, 28 | task_domain: str, 29 | model_name: str, 30 | context_generator_dict: Dict[str, ContextProvider], 31 | question_date: str): 32 | 33 | self.task_domain = task_domain 34 | self.model_name = model_name 35 | self.question_date = question_date 36 | self.context_generator_dict = context_generator_dict 37 | assert context_generator_dict is not None, "Provide Context Retriever Model" 38 | self.question_list: List[MultiChoiceQuestion] = [] 39 | self.question_parsing_fail: int = 0 40 | self.choices_parsing_fail: int = 0 41 | self.correct_answer_parsing_fail: int = 0 42 | self.other_parsing_fail: int = 0 43 | self.failed_question_list: List[str] = [] 44 | 45 | def load_from_list(self, 46 | raw_exam_list: List[str]) -> None: 47 | 48 | for raw_question in raw_exam_list: 49 | 50 | mcq = MultiChoiceQuestion(documentation=raw_question['documentation']['text'], 51 | raw_answer=raw_question['answer'], 52 | model_name=self.model_name) 53 | 54 | mcq.extract_information() 55 | 56 | if mcq.valid_mcq() and self.task_based_constraints(mcq=mcq): 57 | mcq.add_retrieved_context(self.context_generator_dict) 58 | self.question_list.append(mcq) 59 | else: 60 | if mcq.question is None: 61 | self.question_parsing_fail += 1 62 | 63 | if mcq.choices is None: 64 | self.choices_parsing_fail += 1 65 | 66 | if mcq.correct_answer is None: 67 | self.correct_answer_parsing_fail += 1 68 | 69 | if mcq.valid_mcq(): 70 | self.other_parsing_fail += 1 71 | 72 | self.failed_question_list.append(mcq.raw_answer) 73 | 74 | def load_all_model_question(self) -> bool: 75 | 76 | exam_directory = f"{ROOTPATH}/Data/{self.task_domain}/RawExamData/" 77 | self.n_question = 0 78 | 79 | logger.error(f"Starting to load all raw questions from {exam_directory}") 80 | 81 | raw_question_files = [os.path.join(exam_directory, f) 82 | for f in os.listdir(exam_directory) 83 | if (os.path.isfile(os.path.join(exam_directory, f)) 84 | and f.startswith(f"{self.task_domain}_QCM_{self.model_name}_{self.question_date}"))] 85 | 86 | if len(raw_question_files) == 0: 87 | 88 | return False 89 | 90 | for file in tqdm(raw_question_files): 91 | 92 | with open(file, "r") as f: 93 | raw_exam_list = list(json.load(f).values()) 94 | self.load_from_list(raw_exam_list=raw_exam_list) 95 | self.n_question += len(raw_exam_list) 96 | 97 | return True 98 | 99 | def task_based_constraints(self, 100 | mcq: MultiChoiceQuestion) -> bool: 101 | 102 | def refers_to_document(question: str) -> bool: 103 | # Patterns prioritizing specificity 104 | document_patterns = [ 105 | # term immediately followed by title in quotes 106 | r'\b(documentation|paper|article|research|study)\b\s*\"[^\"]+\"', 107 | # citation-like sentence followed by title 108 | r'\b(discussed in|addressed in|described in|of the)\b\s*\"[^\"]+\"', 109 | # fallback to original terms 110 | r'\b(documentation|paper|article|research|study)\b', 111 | ] 112 | 113 | # Check if any of the patterns match 114 | for pattern in document_patterns: 115 | if re.search(pattern, question, re.IGNORECASE): 116 | return False 117 | return True 118 | 119 | if self.task_domain in ['Arxiv', 'StackExchange']: 120 | 121 | return refers_to_document(mcq.question) 122 | 123 | else: 124 | 125 | return True 126 | 127 | def compute_exam_analytics(self, 128 | save_failed_question: bool, 129 | display_n_samples: int = 1) -> None: 130 | 131 | if self.n_question == 0: 132 | raise ValueError("Empty exam, please check model name, date and path to ensure the exam is loaded properly.") 133 | 134 | if save_failed_question: 135 | 136 | with open((f"{ROOTPATH}/ExamGenerator/DebugingData/failed_questions_" 137 | f"{self.task_domain}_{self.model_name}_{self.question_date}.json"), "w") as outfile: 138 | outfile.write(json.dumps(self.failed_question_list)) 139 | 140 | def convert_perc(x): 141 | return 100 * x / len(self.question_list) 142 | 143 | logger.error( 144 | f"\n###################### {self.model_name} ######################\n") 145 | logger.error(f"ExamID: {self.task_domain}:{self.model_name}:{self.question_date}") 146 | logger.error(("\n### Parsing Analysis:\n\n" 147 | f"Total of {len(self.question_list)}/{self.n_question} questions processed" 148 | f" ({100*len(self.question_list)/self.n_question:.02f}%)")) 149 | logger.error((f"Statistics over {len(self.failed_question_list)} failed parsing:\n" 150 | f"Question Parsing Error: {100*self.question_parsing_fail/len(self.failed_question_list):.02f}%\n" 151 | f"Choices Parsing Error: {100*self.choices_parsing_fail/len(self.failed_question_list):.02f}%\n" 152 | f"Correct Answer Parsing Error: {100*self.correct_answer_parsing_fail/len(self.failed_question_list):.02f}%\n" 153 | f"Other Parsing Error or Constraints: {100*self.other_parsing_fail/len(self.failed_question_list):.02f}%\n")) 154 | 155 | if len(self.question_list) == 0: 156 | raise ValueError( 157 | "None of the questions are properly parsed. Please check the parsing logic, using for instance ExamAnalysis/failed_question_analysis.ipynb") 158 | 159 | # Positional bias has been removed so 160 | # --- 161 | answer_analysis = Counter([question.correct_answer[0] 162 | for question in self.question_list]) 163 | # logger.error( 164 | # (f"Position Bias: {answer_analysis} (Baseline Acc: {100*max(answer_analysis.values())/len(self.question_list):.02f}%)")) 165 | 166 | logger.error(("### Accuracy Analysis:\n\n" 167 | f"Best Fixed Answer Baseline: {convert_perc(max(answer_analysis.values())):.02f}%\n" 168 | f"Longest Answer Baseline: {convert_perc(sum([mcq.correct_candidate_is_longest() for mcq in self.question_list])):.02f}%\n")) 169 | 170 | logger.error("### Sample questions:\n\n{}".format('\n'.join([f"Question {k+1}: {mcq.question}" 171 | for k, mcq in enumerate(self.question_list[:10])]))) 172 | 173 | question_keyword = ['Which', 'What', 'How', 'When', 'Why', 'Where'] 174 | question_counter = [(f"{k}{(7-len(k))*' '} -- " 175 | f"{convert_perc(sum([k.lower() in mcq.question.lower() for mcq in self.question_list])):.02f}%") 176 | for k in question_keyword] 177 | other_key = sum([not (any([k.lower() in mcq.question.lower() for k in question_keyword])) 178 | for mcq in self.question_list]) 179 | question_counter.append(f"Other -- {convert_perc(other_key):.02f}%") 180 | 181 | logger.error("\n### Question Analysis\n") 182 | logger.error("Question type:\n{}".format('\n'.join(question_counter))) 183 | first_word_analysis = sum([mcq.question.split(' ')[0].lower() in [e.lower() for e in question_keyword] 184 | for mcq in self.question_list]) 185 | logger.error(f"\nQuestion starts with {question_keyword}: {convert_perc(first_word_analysis):.02f}%") 186 | logger.error(("Avg. question char. length: " 187 | f"{np.mean([len(mcq.question) for mcq in self.question_list]):.02f}" 188 | f" (std: {np.std([len(mcq.question) for mcq in self.question_list]):.02f})")) 189 | logger.error((f"Avg. number of sentence in question: " 190 | f"{np.mean([get_n_sentences(mcq.question) for mcq in self.question_list]):.02f}")) 191 | logger.error(("Avg. answers char. length: " 192 | f"{np.mean([len(''.join(mcq.choices))/4 for mcq in self.question_list]):.02f}" 193 | f" (std: {np.std([len(''.join(mcq.choices))/4 for mcq in self.question_list]):.02f})")) 194 | logger.error(("Avg. correct answer char. length: " 195 | f"{np.mean([len(mcq.correct_answer) for mcq in self.question_list]):.02f}" 196 | f" (std: {np.std([len(mcq.correct_answer) for mcq in self.question_list]):.02f})")) 197 | logger.error(("Avg. documentation char. length: " 198 | f"{np.mean([len(mcq.documentation) for mcq in self.question_list]):.02f}" 199 | f" (std: {np.std([len(mcq.question) for mcq in self.question_list]):.02f})")) 200 | logger.error(("Avg. number of sentence in documentation: " 201 | f"{np.mean([get_n_sentences(mcq.documentation) for mcq in self.question_list]):.02f}\n")) 202 | 203 | for elem in random.sample(self.question_list, display_n_samples): 204 | 205 | similarity_checker = SimilarityChecker() 206 | 207 | elem.display(similarity_checker) 208 | 209 | def save_exam_dataset(self) -> None: 210 | 211 | if self.n_question == 0: 212 | raise ValueError("Empty exam, please check model name, date and path to ensure the exam is loaded properly.") 213 | 214 | docs_exam = [{'question': mcq.question, 215 | 'documentation': mcq.documentation, 216 | 'choices': mcq.choices, 217 | 'correct_answer': mcq.correct_answer, 218 | 'retrieved_context': mcq.retrieved_context, 219 | } for mcq in self.question_list] 220 | 221 | dir_path = f"{ROOTPATH}/Data/{self.task_domain}/ExamData/{self.model_name}_{self.question_date}" 222 | os.makedirs(dir_path, 223 | exist_ok=True) 224 | 225 | with open(f"{dir_path}/exam.json", "w") as outfile: 226 | outfile.write(json.dumps(docs_exam)) 227 | 228 | 229 | if __name__ == '__main__': 230 | 231 | parser = argparse.ArgumentParser( 232 | description="Creates Exam from Raw Exam Data") 233 | 234 | parser.add_argument( 235 | "--task-domain", 236 | help="Task Domain, among DevOps, StackExchange...", 237 | ) 238 | parser.add_argument( 239 | "--question-date", 240 | help="Date associated with the raw exam (eg 2023091223), can be seen in RawExamData", 241 | ) 242 | parser.add_argument('--save-exam', 243 | action='store_true', 244 | help='If provided, the exam is saved. Otherwise, we just compute analytics') 245 | 246 | main_args, _ = parser.parse_known_args() 247 | 248 | context_generator_dict = { 249 | 'DPR': DPRContextGenerator(context_sources={ 250 | 'SIAMESE' : SiameseContextProvider(index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex/siamese_emb", 251 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 252 | regenerate_index=False), 253 | 'BM25': BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main") 254 | }), 255 | 'BM25' : BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main"), 256 | 'SIAMESE' : SiameseContextProvider(index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex/siamese_emb", 257 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 258 | regenerate_index=True), 259 | 'MultiQA' : EmbeddingContextProvider(index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex/multi_qa_emb", 260 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 261 | regenerate_index=True), 262 | 'DPR:MultiQA:BM25': DPRContextGenerator(context_sources={ 263 | 'MultiQA' : EmbeddingContextProvider(index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex/multi_qa_emb", 264 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 265 | regenerate_index=False), 266 | 'BM25': BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main") 267 | }), 268 | } 269 | 270 | for model_name in ['llamav2', 'openllama', 'claudev2', 'claude_instant']: 271 | 272 | MultiChoiceExamLLM = MultiChoiceExam(task_domain=main_args.task_domain, 273 | model_name=model_name, 274 | question_date=main_args.question_date, 275 | context_generator_dict=context_generator_dict) 276 | 277 | llm_exam_exists = MultiChoiceExamLLM.load_all_model_question() 278 | 279 | if llm_exam_exists: 280 | 281 | MultiChoiceExamLLM.compute_exam_analytics(save_failed_question=True) 282 | 283 | if main_args.save_exam: 284 | MultiChoiceExamLLM.save_exam_dataset() 285 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/multi_choice_question.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import re 4 | from os.path import abspath, dirname 5 | from typing import Dict, List 6 | 7 | from RetrievalSystems.context_utils import ContextProvider 8 | 9 | ROOTPATH = dirname(dirname(abspath(__file__))) 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MultiChoiceQuestionParser: 15 | 16 | def __init__(self): 17 | 18 | self.min_length_question = 50 19 | 20 | def extract_with_patterns(self, 21 | text: str, 22 | patterns: List) -> List[str]: 23 | 24 | for pattern in patterns: 25 | try: 26 | matches = re.findall(pattern, text, re.DOTALL) 27 | if matches: 28 | return matches 29 | except re.error: 30 | continue 31 | return None 32 | 33 | def parse_question(self, text: str) -> str: 34 | 35 | question_patterns = [r"Question:(.*?)(?:\n[a-dA-D1-4]\)|\n\n[a-dA-D1-4]\))", 36 | r"Question 1:(.*?)(?:\n[a-dA-D1-4]\)|\n\n[a-dA-D1-4]\))", 37 | r"question:(.*?)(?:\n[a-dA-D1-4]\)|\n\n[a-dA-D1-4]\))", 38 | r"question 1:(.*?)(?:\n[a-dA-D1-4]\)|\n\n[a-dA-D1-4]\))", 39 | r"documentation:(.*?)(?:\n[a-dA-D1-4]\)|\n\n[a-dA-D1-4]\))", # for ClaudeV2 mostly 40 | r"### Assistant: (.*?)\n"] 41 | 42 | # extra_questions_patterns = [ 43 | # r"Question:(.*?)(?:\nCandidate [A-D1-4]\)|\n\nCandidate [A-D1-4]\))", 44 | # r"Question 1:(.*?)(?:\nCandidate [A-D1-4]\)|\n\nCandidate [A-D1-4]\))", 45 | # r"Question:(.*?)(?:\nCandidate [A-D1-4]\.|\n\nCandidate [A-D1-4]\.)", 46 | # r"Question 1:(.*?)(?:\nCandidate [A-D1-4]\.|\n\nCandidate [A-D1-4]\.)", 47 | # r"Question:(.*?)(?:\nOption [A-D1-4]\)|\n\nOption [A-D1-4]\))", 48 | # r"Question 1:(.*?)(?:\nOption [A-D1-4]\)|\n\nOption [A-D1-4]\))", 49 | # r"Question:(.*?)(?:\nOption [A-D1-4]\.|\n\nOption [A-D1-4]\.)", 50 | # r"Question 1:(.*?)(?:\nOption [A-D1-4]\.|\n\nOption [A-D1-4]\.)"] 51 | 52 | # Extract the question 53 | question_matches = self.extract_with_patterns(text, question_patterns) 54 | question = question_matches[0].strip() if question_matches else None 55 | question = (question 56 | if (question and len(question) > self.min_length_question and question[-1] == '?') 57 | else None) 58 | 59 | return question 60 | 61 | def parse_choices(self, text: str) -> str: 62 | 63 | choices_patterns = [r"([A-D]\) .*?)(?=$|\n[A-D]\)|\n\n)", 64 | r"([A-D]\)(?:.|\n)*?)(?=$|\n[A-D]\)|\n\n)", 65 | r"([A-D]\. .*?)(?=$|\n[A-D]\.|\n\n)", 66 | r"([A-D]\.)(?:.|\n)*?)(?=$|\n[A-D]\.|\n\n)", 67 | r"([1-4]\) .*?)(?=$|\n[1-4]\)|\n\n)", 68 | r"([1-4]\)(?:.|\n)*?)(?=$|\n[1-4]\)|\n\n)", 69 | r"([1-4]\. .*?)(?=$|\n[1-4]\.|\n\n)", 70 | r"([1-4]\.)(?:.|\n)*?)(?=$|\n[1-4]\.|\n\n)", 71 | r"([a-d]\) .*?)(?=$|\n[a-d]\)|\n\n)", 72 | r"([a-d]\)(?:.|\n)*?)(?=$|\n[a-d]\)|\n\n)", 73 | r"([a-d]\. .*?)(?=$|\n[a-d]\.|\n\n)", 74 | r"([a-d]\.)(?:.|\n)*?)(?=$|\n[a-d]\.|\n\n)"] 75 | 76 | # Extract the choices 77 | choices_matches = self.extract_with_patterns(text, choices_patterns) 78 | choices = [match.strip() 79 | for match in choices_matches] if choices_matches else None 80 | 81 | # Only keep first 4 answers 82 | choices = choices[:4] if choices and len(choices) >= 4 and len( 83 | set([choice[0] for choice in choices[:4]])) == 4 else None 84 | 85 | # Remove scenarios with empty answers ['A)], 'B)', 'C)', 'D)'] 86 | choices = choices if choices and min([len(choice) for choice in choices]) > 2 else None 87 | 88 | return choices 89 | 90 | def parse_correct_answer_key(self, text): 91 | 92 | correct_answer_patterns = [r"answer:\n\n([A-D1-4a-d])", 93 | r"answer: ([A-D1-4a-d])", 94 | r"Answer: ([A-D1-4a-d])", 95 | r"answer is ([A-D1-4])"] 96 | 97 | # Extract the correct answer key 98 | correct_answer_key_matches = self.extract_with_patterns( 99 | text, correct_answer_patterns) 100 | correct_answer_key = correct_answer_key_matches[0] if correct_answer_key_matches else None 101 | 102 | return correct_answer_key 103 | 104 | def parse_text(self, text: str) -> Dict[str, str]: 105 | 106 | text = (text.split('### Assistant:')[-1] 107 | if '### Assistant:' in text 108 | else text) 109 | 110 | question = self.parse_question(text) 111 | choices = self.parse_choices(text) 112 | correct_answer_key = self.parse_correct_answer_key(text) 113 | 114 | # Find the full text of the correct answer 115 | correct_answer = next((a for a in choices if a.startswith( 116 | correct_answer_key)), None) if correct_answer_key and choices else None 117 | 118 | # Replace first letter to be only in A-D 119 | letter_map = {'1': 'A', '2': 'B', '3': 'C', '4': 'D', 120 | 'a': 'A', 'b': 'B', 'c': 'C', 'd': 'D'} 121 | 122 | return { 123 | 'question': question, 124 | 'choices': [letter_map[s[0]] + s[1:] if s[0] 125 | in letter_map else s for s in choices] if choices else None, 126 | 'correct_answer': (letter_map[correct_answer[0]] + correct_answer[1:] 127 | if correct_answer[0] in letter_map else correct_answer) if correct_answer else None 128 | } 129 | 130 | 131 | class MultiChoiceQuestion: 132 | 133 | def __init__(self, 134 | documentation: str, 135 | raw_answer: str, 136 | model_name: str): 137 | 138 | self.documentation = documentation 139 | self.raw_answer = raw_answer 140 | self.question = None 141 | self.choices = None 142 | self.correct_answer = None 143 | self.model_name = model_name 144 | self.retrieved_context = None 145 | self.parser = MultiChoiceQuestionParser() 146 | 147 | def parse_text(self, text: str) -> None: 148 | 149 | # For new syntax prompt, one needs to remove post assistant 150 | # parsed_text = self.parser.parse_text(text.split('Assistant:')[-1]) 151 | parsed_text = self.parser.parse_text(text) 152 | 153 | self.question = parsed_text['question'] 154 | self.choices = parsed_text['choices'] 155 | self.correct_answer = parsed_text['correct_answer'] 156 | 157 | # Suffle the candidate order to remove positional bias 158 | if self.choices and self.correct_answer: 159 | 160 | self.shuffle_question() 161 | 162 | def shuffle_question(self) -> None: 163 | # strip out the letters and just keep the choices 164 | stripped_candidates = [x[1:] for x in self.choices] 165 | correct_answer_stripped = self.correct_answer[1:] 166 | 167 | # shuffle the candidates 168 | random.shuffle(stripped_candidates) 169 | 170 | # build new candidates list 171 | shuffled_candidates = [ 172 | f"{chr(65 + i)}{x}" for i, x in enumerate(stripped_candidates)] 173 | 174 | # find the new letter for the correct answer 175 | new_correct_answer = [f"{chr(65 + i)}{x}" 176 | for i, x in enumerate(stripped_candidates) if x == correct_answer_stripped][0] 177 | 178 | self.choices = shuffled_candidates 179 | self.correct_answer = new_correct_answer 180 | 181 | def extract_information(self) -> None: 182 | 183 | self.parse_text(self.raw_answer) 184 | 185 | def valid_mcq(self) -> bool: 186 | 187 | return self.question and self.choices and self.correct_answer 188 | 189 | def correct_candidate_is_longest(self): 190 | return len(self.correct_answer) >= max([len(choice) for choice in self.choices]) 191 | 192 | def display(self, 193 | similarity_checker=None) -> None: 194 | 195 | if self.question is not None and self.choices is not None and self.correct_answer is not None: 196 | 197 | logger.error(f"########### {self.model_name} ###########\n") 198 | logger.error(f'Documentation: \n {self.documentation}') 199 | logger.error(f"Question: \n {self.question}") 200 | 201 | # if self.retrieved_context is not None: 202 | # logger.error( 203 | # f"Retrieved Context: \n {self.retrieved_context}") 204 | 205 | if similarity_checker: 206 | self_processed_answer = [f"[{simil}] - {elem}" 207 | for simil, elem in zip(similarity_checker.compute_similarity(self), 208 | self.choices)] 209 | else: 210 | self_processed_answer = self.choices 211 | logger.error("Answers: \n {}".format( 212 | '\n '.join(self_processed_answer))) 213 | logger.error(f"Correct Answer: \n {self.correct_answer}\n") 214 | 215 | def generate_question_answer_pair(self, 216 | add_context: bool) -> Dict[str, str]: 217 | ''' 218 | Format the prompt for the Exam Evaluation Section 219 | ''' 220 | prompt = ("###Human: Question: {}\n\nCandidates:\n{}\n\n###Assistant: Correct answer: ".format(self.question, 221 | '\n'.join(self.choices)) 222 | if add_context is False 223 | else "###Human: Question: {}\n\nContext: {}\n\nCandidates:\n{}\n\n###Assistant: Correct answer: ".format(self.question, 224 | self.documentation, 225 | '\n'.join(self.choices))) 226 | 227 | return {"prompt": prompt, 228 | "answer": self.correct_answer} 229 | 230 | def add_retrieved_context(self, 231 | context_generator_dict: Dict[str, ContextProvider]) -> Dict[str, List[str]]: 232 | 233 | self.retrieved_context = {retriever : [elem.text for elem in context_generator.get_context_from_query( 234 | self.question)] for retriever, context_generator in context_generator_dict.items()} 235 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/question_generator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import json 4 | import logging 5 | import random 6 | import time 7 | from datetime import datetime 8 | from os.path import abspath, dirname 9 | from typing import List 10 | 11 | from ExamGenerator.utils import get_single_file_in_folder 12 | from LLMServer.bedrock.claude_instant import ClaudeInstant 13 | from LLMServer.bedrock.claude_v2 import ClaudeV2 14 | from LLMServer.llm_exam_generator import ClaudeExamGenerator, LLMExamGenerator 15 | 16 | logger = logging.getLogger(__name__) 17 | ROOTPATH = dirname(dirname(abspath(__file__))) 18 | 19 | 20 | class BatchExamGenerator: 21 | 22 | def __init__(self, 23 | task_domain: str, 24 | model_list: List[str], 25 | batch_size: int): 26 | 27 | self.batch_size = batch_size 28 | self.model_list = model_list 29 | self.task_domain = task_domain 30 | 31 | self.model_map = { 32 | 'claudev2': ClaudeExamGenerator(step_size=1, 33 | task_domain=self.task_domain, 34 | llm_model=ClaudeV2()), 35 | 'claude_instant': ClaudeExamGenerator(step_size=1, 36 | task_domain=self.task_domain, 37 | llm_model=ClaudeInstant()) 38 | } 39 | assert not (any([model not in self.model_map.keys() for model in self.model_list])) 40 | 41 | def batch_generate_exam(self, data_folder: str) -> None: 42 | 43 | with open(get_single_file_in_folder(data_folder), "r") as f: 44 | data = json.load(f) 45 | 46 | # Suffle the data to prevent overfocusing on a topic 47 | # --- 48 | random.seed(10) 49 | random.shuffle(data) 50 | 51 | logger.error((f"Processing a total of {len(data)} documentation pieces for {self.task_domain}" 52 | f" using models {self.model_list}, with batch size of {self.batch_size} " 53 | f"({1+len(data)//self.batch_size} batches)")) 54 | 55 | # Split the data into batches 56 | batches = [data[i:i + self.batch_size] 57 | for i in range(0, len(data), self.batch_size)] 58 | 59 | start_time = datetime.fromtimestamp( 60 | time.time()).strftime('%Y%m%d%H') 61 | 62 | try: 63 | 64 | for batch_index, batch in enumerate(batches): 65 | 66 | logger.error(f"Running batch {batch_index}.") 67 | if len(self.model_list) > 1: 68 | # Multiprocessing not compatible with Bedrock Usage 69 | with concurrent.futures.ProcessPoolExecutor() as executor: 70 | futurs = {model: executor.submit(self.model_map[model].generate_exam, batch) 71 | for model in self.model_list} 72 | generated_questions = {model: futur.result() for model, futur in futurs.items()} 73 | else: 74 | generated_questions = {model: self.model_map[model].generate_exam(batch) 75 | for model in self.model_list} 76 | 77 | # Write the dictionary to a JSON file 78 | for model in generated_questions.keys(): 79 | filename = f"{self.task_domain}_QCM_{model}_{start_time}_batch{batch_index}.json" 80 | with open(f"{ROOTPATH}/Data/{self.task_domain}/RawExamData/{filename}", "w") as write_file: 81 | json.dump(generated_questions[model], write_file) 82 | 83 | except Exception as e: 84 | 85 | logger.error(f"Failure to collect questions for batch {batch_index}: {e}") 86 | 87 | 88 | if __name__ == "__main__": 89 | 90 | parser = argparse.ArgumentParser( 91 | description="Creates Raw Exam from Documentation Corpus") 92 | 93 | parser.add_argument( 94 | "--task-domain", 95 | help="Task Domain, among DevOps, StackExchange, MyOwnTask...", 96 | ) 97 | 98 | main_args, _ = parser.parse_known_args() 99 | 100 | raw_exam_generator = BatchExamGenerator(batch_size=60, 101 | task_domain=main_args.task_domain, 102 | # model_list=['openllama', 'llamav2'] 103 | model_list=['claudev2'] 104 | ) 105 | 106 | raw_exam_generator.batch_generate_exam( 107 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main") 108 | -------------------------------------------------------------------------------- /auto-rag-eval/ExamGenerator/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | from typing import List 4 | 5 | import nltk 6 | import numpy as np 7 | from ExamGenerator.multi_choice_question import MultiChoiceQuestion 8 | from nltk.tokenize import sent_tokenize 9 | from sentence_transformers import SentenceTransformer 10 | from sklearn.metrics.pairwise import cosine_similarity 11 | 12 | nltk.download('punkt') 13 | 14 | 15 | def get_n_sentences(text): 16 | return len([sent for sent in nltk.sent_tokenize(text) if len(sent) > 5]) 17 | 18 | 19 | def get_single_file_in_folder(folder_path): 20 | # List all entries in the given folder 21 | entries = os.listdir(folder_path) 22 | 23 | # Filter out only the files (excluding directories and other types) 24 | files = [os.path.join(folder_path, f) for f in entries if os.path.isfile(os.path.join(folder_path, f))] 25 | 26 | # Check the number of files 27 | if len(files) == 1: 28 | return files[0] 29 | elif len(files) == 0: 30 | raise ValueError(f"No files found in the directory {folder_path}") 31 | else: 32 | raise ValueError(f"More than one file found in the directory {folder_path}. Files are: {', '.join(files)}") 33 | 34 | 35 | class SimilarityChecker: 36 | 37 | def __init__(self): 38 | self.model = SentenceTransformer('all-MiniLM-L6-v2') 39 | 40 | def preprocess_text(self, text: str) -> int: 41 | text = text.lower() 42 | word_count = Counter(text.split()) 43 | return word_count 44 | 45 | def jaccard_similarity(self, 46 | counter1: Counter, 47 | counter2: Counter) -> float: 48 | intersection = sum((counter1 & counter2).values()) 49 | union = sum((counter1 | counter2).values()) 50 | return intersection / union 51 | 52 | def calculate_max_similarity(self, 53 | sentence: List[str], 54 | reference_doc: str) -> float: 55 | similarities = [ 56 | self.jaccard_similarity(self.preprocess_text(main_sentence), self.preprocess_text(sentence)) 57 | for main_sentence in sent_tokenize(reference_doc) 58 | ] 59 | return max(similarities) 60 | 61 | def get_ngrams(self, 62 | text: str, 63 | n: int) -> List[str]: 64 | words = text.split() 65 | return [' '.join(words[i:i + n]) 66 | for i in range(len(words) - (n - 1))] 67 | 68 | def calculate_max_ngram_similarity(self, 69 | sentence: List[str], 70 | reference_doc: str, 71 | n: int) -> float: 72 | main_ngrams = self.get_ngrams(reference_doc, n) 73 | similarities = [ 74 | self.jaccard_similarity(self.preprocess_text(main_ngram), self.preprocess_text(sentence)) 75 | for main_ngram in main_ngrams 76 | ] 77 | return max(similarities, default=0) 78 | 79 | def calculate_embedding_similarity(self, 80 | sentence: List[str], 81 | mcq: MultiChoiceQuestion): 82 | main_text_embedding = self.model.encode([mcq.documentation]) 83 | sentence_embeddings = self.model.encode([sentence]) 84 | return cosine_similarity( 85 | [main_text_embedding[0]], 86 | [sentence_embeddings[0]] 87 | )[0][0] 88 | 89 | def compute_similarity(self, 90 | mcq: MultiChoiceQuestion) -> List[str]: 91 | mean_ngram = int(np.mean([len(answer.split()) for answer in mcq.choices])) 92 | return [(f"{self.calculate_max_similarity(answer, mcq.documentation):.02f}" 93 | f"{self.calculate_max_ngram_similarity(answer, mcq.documentation, mean_ngram):.02f}" 94 | f"{self.calculate_embedding_similarity(answer, mcq):.02f}") 95 | for answer in mcq.choices] 96 | -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/README.md: -------------------------------------------------------------------------------- 1 | # Large Language Models - Exam Generation 2 | 3 | ## Provided Models 4 | 5 | This folder contains the code for Large Language models to be used for exam Generation. So far, we provide the implementation for Bedrock Claude family of models and will be adding support for more models soon. 6 | 7 | Moreover, in `llm_exam_generator.py`, we provide the wrapper on top of `BaseLLM` to get `LLMExamGenerator` class and `ClaudeExamGenerator` 8 | 9 | ## Custom LLM 10 | 11 | The only piece required to add you own LLM is to follow the BaseLLM class from `base_model.py`: 12 | 13 | ```[python] 14 | class BaseLLM: 15 | 16 | def invoke(self, 17 | prompt: str) -> str: 18 | 19 | pass 20 | 21 | def stream_inference(self, 22 | prompt: str) -> Generator[str, None, None]: 23 | 24 | pass 25 | 26 | def get_id(self) -> str: 27 | 28 | pass 29 | ``` 30 | -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/LLMServer/__init__.py -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/LLMServer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/LLMServer/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Generator, Union 3 | 4 | 5 | class BaseLLM(ABC): 6 | 7 | @abstractmethod 8 | def invoke(self, 9 | prompt: str, 10 | params: Dict[str, Union[int, str]]) -> str: 11 | 12 | pass 13 | 14 | @abstractmethod 15 | def stream_inference(self, 16 | prompt: str, 17 | params: Dict[str, Union[int, str]]) -> Generator[str, None, None]: 18 | 19 | pass 20 | 21 | @abstractmethod 22 | def get_id(self) -> str: 23 | 24 | pass 25 | -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/bedrock/__pycache__/claude_instant.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/LLMServer/bedrock/__pycache__/claude_instant.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/bedrock/__pycache__/claude_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/LLMServer/bedrock/__pycache__/claude_v2.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/bedrock/__pycache__/conversation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/LLMServer/bedrock/__pycache__/conversation.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/bedrock/claude_instant.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from typing import Generator 4 | 5 | import boto3 6 | from botocore.config import Config 7 | from LLMServer.base_model import BaseLLM 8 | from tenacity import retry, stop_after_attempt, wait_exponential 9 | 10 | STOP_AFTER_ATTEMPT = 6 11 | WAIT_EXPONENTIAL_MIN = 4 12 | WAIT_EXPONENTIAL_MAX = 30 13 | 14 | 15 | def delayed_text_generator(text: str, delay: float = 0.2): 16 | tokens = text.split() 17 | for i in range(1, len(tokens) + 1): 18 | time.sleep(delay) 19 | yield ' '.join(tokens[:i]) 20 | 21 | 22 | class ClaudeInstant(BaseLLM): 23 | 24 | def __init__(self): 25 | self.bedrock = boto3.client( 26 | service_name='bedrock', 27 | config=Config(read_timeout=1000)) 28 | self.modelId = 'anthropic.claude-instant-v1' 29 | self.accept = 'application/json' 30 | self.contentType = 'application/json' 31 | self.inference_params = { 32 | # max_tokens_to_sample can be at most 4096 33 | "max_tokens_to_sample": 4096, 34 | "temperature": 0, 35 | "top_p": 0.9, 36 | } 37 | 38 | @retry( 39 | stop=stop_after_attempt(STOP_AFTER_ATTEMPT), 40 | wait=wait_exponential(min=WAIT_EXPONENTIAL_MIN, 41 | max=WAIT_EXPONENTIAL_MAX), 42 | ) 43 | def invoke(self, 44 | prompt: str) -> str: 45 | 46 | body = json.dumps({ 47 | "prompt": prompt, 48 | **self.inference_params 49 | }) 50 | 51 | response = self.bedrock.invoke_model(body=body, 52 | modelId=self.modelId, 53 | accept=self.accept, 54 | contentType=self.contentType) 55 | if response['ResponseMetadata']['HTTPStatusCode'] == 200: 56 | return json.loads( 57 | response.get('body').read()).get('completion') 58 | 59 | raise ValueError("Incorrect Generation") 60 | 61 | @retry( 62 | stop=stop_after_attempt(STOP_AFTER_ATTEMPT), 63 | wait=wait_exponential(min=WAIT_EXPONENTIAL_MIN, 64 | max=WAIT_EXPONENTIAL_MAX), 65 | ) 66 | def stream_inference(self, 67 | prompt: str) -> Generator[str, None, None]: 68 | 69 | body = json.dumps({ 70 | "prompt": prompt, 71 | **self.inference_params 72 | }) 73 | 74 | response = self.bedrock.invoke_model(body=body, 75 | modelId=self.modelId, 76 | accept=self.accept, 77 | contentType=self.contentType) 78 | if response['ResponseMetadata']['HTTPStatusCode'] == 200: 79 | return delayed_text_generator(json.loads( 80 | response.get('body').read()).get('completion')) 81 | 82 | raise ValueError("Incorrect Generation") 83 | 84 | def get_id(self): 85 | 86 | return "ClaudeV2:Instant" 87 | -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/bedrock/claude_v2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from typing import Generator 4 | 5 | import boto3 6 | from botocore.config import Config 7 | from LLMServer.base_model import BaseLLM 8 | from tenacity import retry, stop_after_attempt, wait_exponential 9 | 10 | STOP_AFTER_ATTEMPT = 6 11 | WAIT_EXPONENTIAL_MIN = 4 12 | WAIT_EXPONENTIAL_MAX = 30 13 | 14 | 15 | def delayed_text_generator(text: str, delay: float = 0.2): 16 | tokens = text.split() 17 | for i in range(1, len(tokens) + 1): 18 | time.sleep(delay) 19 | yield ' '.join(tokens[:i]) 20 | 21 | 22 | class ClaudeV2(BaseLLM): 23 | 24 | def __init__(self): 25 | self.bedrock = boto3.client( 26 | service_name='bedrock', 27 | config=Config(read_timeout=1000)) 28 | self.modelId = 'anthropic.claude-v2' 29 | self.accept = 'application/json' 30 | self.contentType = 'application/json' 31 | self.inference_params = { 32 | # max_tokens_to_sample can be at most 4096 33 | "max_tokens_to_sample": 4096, 34 | "temperature": 0, 35 | "top_p": 0.9, 36 | } 37 | 38 | @retry( 39 | stop=stop_after_attempt(STOP_AFTER_ATTEMPT), 40 | wait=wait_exponential(min=WAIT_EXPONENTIAL_MIN, 41 | max=WAIT_EXPONENTIAL_MAX), 42 | ) 43 | def invoke(self, 44 | prompt: str) -> str: 45 | 46 | body = json.dumps({ 47 | "prompt": prompt, 48 | **self.inference_params 49 | }) 50 | 51 | response = self.bedrock.invoke_model(body=body, 52 | modelId=self.modelId, 53 | accept=self.accept, 54 | contentType=self.contentType) 55 | if response['ResponseMetadata']['HTTPStatusCode'] == 200: 56 | return json.loads( 57 | response.get('body').read()).get('completion') 58 | 59 | raise ValueError("Incorrect Generation") 60 | 61 | @retry( 62 | stop=stop_after_attempt(STOP_AFTER_ATTEMPT), 63 | wait=wait_exponential(min=WAIT_EXPONENTIAL_MIN, 64 | max=WAIT_EXPONENTIAL_MAX), 65 | ) 66 | def stream_inference(self, 67 | prompt: str) -> Generator[str, None, None]: 68 | 69 | body = json.dumps({ 70 | "prompt": prompt, 71 | **self.inference_params 72 | }) 73 | 74 | response = self.bedrock.invoke_model(body=body, 75 | modelId=self.modelId, 76 | accept=self.accept, 77 | contentType=self.contentType) 78 | if response['ResponseMetadata']['HTTPStatusCode'] == 200: 79 | return delayed_text_generator(json.loads( 80 | response.get('body').read()).get('completion')) 81 | 82 | raise ValueError("Incorrect Generation") 83 | 84 | def get_id(self) -> str: 85 | 86 | return "ClaudeV2" 87 | -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/bedrock/claude_v3.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Dict, Generator, List, Literal 4 | 5 | import boto3 6 | from botocore.config import Config 7 | from LLMServer.base_model import BaseLLM 8 | from tenacity import retry, stop_after_attempt, wait_exponential 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | STOP_AFTER_ATTEMPT = 6 13 | WAIT_EXPONENTIAL_MIN = 4 14 | WAIT_EXPONENTIAL_MAX = 30 15 | 16 | AllowedModelNames = Literal['anthropic.claude-3-sonnet-20240229-v1:0', 17 | 'anthropic.claude-3-haiku-20240307-v1:0'] 18 | 19 | 20 | class ClaudeV3(BaseLLM): 21 | 22 | def __init__(self, 23 | model_name: AllowedModelNames): 24 | 25 | self.model_name = model_name 26 | self.inference_params = { 27 | # max_tokens_to_sample can be at most 4096 28 | "max_tokens_to_sample": 4096, 29 | "temperature": 0, 30 | "top_p": 0.9, 31 | } 32 | self.brt = boto3.client("bedrock-runtime", 33 | region_name="us-east-1", 34 | config=Config(read_timeout=1000) 35 | ) 36 | 37 | # To bypass the length limit, we set a maximum number of iterations 38 | self.max_n_iterations = 20 39 | 40 | @retry( 41 | stop=stop_after_attempt(STOP_AFTER_ATTEMPT), 42 | wait=wait_exponential(min=WAIT_EXPONENTIAL_MIN, 43 | max=WAIT_EXPONENTIAL_MAX), 44 | ) 45 | def _get_bedrock_response(self, query: str) -> Dict: 46 | 47 | messages = self.query_to_messages(query) 48 | request_body = { 49 | "anthropic_version": "bedrock-2023-05-31", 50 | "max_tokens": self.inference_params['max_tokens_to_sample'], 51 | "messages": messages, 52 | "temperature": self.inference_params['temperature'], 53 | } 54 | response = self.brt.invoke_model( 55 | modelId=self.model_name, 56 | body=json.dumps(request_body), 57 | ) 58 | result = json.loads(response.get('body').read()) 59 | return result 60 | 61 | def invoke(self, 62 | prompt: str) -> str: 63 | 64 | result = self._get_bedrock_response( 65 | query=prompt) 66 | 67 | return correct_generation( 68 | result["content"][0]["text"]) 69 | 70 | def invoke_with_prolongation(self, 71 | prompt: str) -> str: 72 | ''' 73 | Claude Bedrock strickly enforces a lenght of 4096 tokens. 74 | This method will try to prolong the output until the model 75 | reaches the desired output by feeding back the generated text 76 | into the prompt, at most self.max_n_iterations. 77 | ''' 78 | 79 | generated_text = "" 80 | generated_text_with_query = prompt + "\nAssistant: " 81 | 82 | for i in range(self.max_n_iterations): 83 | 84 | response = self._get_bedrock_response( 85 | query=generated_text_with_query) 86 | generated_chunk = correct_generation( 87 | response["content"][0]["text"]) 88 | 89 | # TODO: Ensure that strip usage is relevant 90 | generated_text += generated_chunk.strip() 91 | generated_text_with_query += generated_chunk.strip() 92 | 93 | logger.info(f"Prolongation of generation at round {i}") 94 | 95 | if not self._has_enforced_interruption(response=response): 96 | return generated_text 97 | 98 | raise ValueError("Failure to complete prompt: File is too big.") 99 | 100 | def _has_enforced_interruption(self, 101 | response: Dict) -> bool: 102 | ''' 103 | Detect if the generation ends because of end of text or 104 | forced interuptions. 105 | ''' 106 | if response.get("stop_reason", "") == "max_tokens": 107 | return True 108 | 109 | token_delta = 10 110 | # Check if the model has reached the maximum token limit 111 | if response.get("usage", {}).get('output_tokens', 0) >= self.inference_params['max_tokens_to_sample'] - token_delta: 112 | return True 113 | 114 | return False 115 | 116 | def query_to_messages(query: str) -> List[Dict[str, str]]: 117 | human_tag = "Human:" 118 | assistant_tag = "Assistant:" 119 | if query.startswith(human_tag): 120 | query = query[len(human_tag):] 121 | if assistant_tag in query: 122 | query, prefill = query.split(assistant_tag, 1) 123 | return [ 124 | {"role": "user", "content": query}, 125 | {"role": "assistant", "content": prefill.strip()}, 126 | ] 127 | return [{"role": "user", "content": query}] 128 | 129 | def stream_inference(self, 130 | prompt: str) -> Generator[str, None, None]: 131 | pass 132 | 133 | def get_id(self) -> str: 134 | return self.model_name 135 | 136 | 137 | def correct_generation(generated_text: str) -> str: 138 | 139 | return generated_text.replace('<', '<').replace('>', '>') 140 | -------------------------------------------------------------------------------- /auto-rag-eval/LLMServer/llm_exam_generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from os.path import abspath, dirname 3 | from typing import Dict, List 4 | 5 | from tqdm import tqdm 6 | from LLMServer.base_model import BaseLLM 7 | 8 | logger = logging.getLogger(__name__) 9 | ROOTPATH = dirname(dirname(abspath(__file__))) 10 | 11 | 12 | class LLMExamGenerator: 13 | 14 | def __init__(self, 15 | step_size: int, 16 | task_domain: str, 17 | llm_model: BaseLLM): 18 | 19 | # Step size is to mitigate when one model inference is faster than another 20 | # eg openllama:13b = 3* llamav2:70B 21 | self.step_size = step_size 22 | self.task_domain = task_domain 23 | self.llm_model = llm_model 24 | 25 | def make_question_prompt(self, documentation: str) -> str: 26 | # Adding the syntax constraint was done in V2 and appears to impact the formatting of the question. 27 | return (f"### Human: Here is some documentation from {self.task_domain}: {documentation}.\n" 28 | "From this generate a difficult multi-form question for an exam. It should have 4 candidates," 29 | " 1 correct answer and explanations. Syntax should be Question: {question}\nA){candidate A}\nB){candidate B}\n" 30 | "C){candidate C}\nD){candidate D} Correct Answer: {correct answer}\n### Assistant:") 31 | 32 | # def make_question_prompt_icl(self, example, documentation: str) -> str: 33 | # # icl = (f"### Human: Here is some documentation from {self.task_domain}: {example.documentation}.\n" 34 | # # f"From this generate a difficult multi-form question for an exam. It should have 4 candidates," 35 | # # " 1 correct answer and explanations.\n### Assistant:" 36 | # # "Question: {}\nCandidates: {}\n".format(example.question, '\n'.join(example.choices)) 37 | # # f"Correct Answer: {example.correct_answer}\n") 38 | # prompt = (f"### Human: Here is some documentation from {self.task_domain}: {documentation}.\n" 39 | # f"From this generate a difficult multi-form question for an exam. It should have 4 candidates," 40 | # " 1 correct answer and explanations.\n### Assistant:") 41 | # return f"{icl}\n{prompt}" 42 | 43 | def generate_exam(self, data: List[Dict[str, str]]) -> Dict[int, Dict[str, str]]: 44 | 45 | generated_questions = {} 46 | for k in tqdm(range(0, len(data), self.step_size)): 47 | answer = self.llm_model.invoke( 48 | prompt=self.make_question_prompt(data[k]['text']), 49 | params={}) 50 | generated_questions[k] = { 51 | "documentation": data[k], 52 | "answer": answer 53 | } 54 | return generated_questions 55 | 56 | 57 | class ClaudeExamGenerator(LLMExamGenerator): 58 | 59 | def __init__(self, 60 | step_size: int, 61 | task_domain: str, 62 | llm_model: BaseLLM): 63 | 64 | super().__init__(step_size=step_size, 65 | task_domain=task_domain, 66 | llm_model=llm_model) 67 | 68 | def make_question_prompt(self, documentation: str) -> str: 69 | return (f"\n\nHuman: Here is some documentation from {self.task_domain}: {documentation}.\n" 70 | "From this generate a difficult multi-form question for an exam. It should have 4 candidates," 71 | " 1 correct answer and explanations. Syntax should be Question: {question}\nA){candidate A}\nB){candidate B}\n" 72 | "C){candidate C}\nD){candidate D} Correct Answer: {correct answer}\n\nAssistant:") 73 | 74 | def generate_exam(self, data: List[Dict[str, str]]) -> Dict[int, Dict[str, str]]: 75 | 76 | generated_questions = {} 77 | for k in tqdm(range(0, len(data), self.step_size)): 78 | answer = self.llm_model.invoke( 79 | prompt=self.make_question_prompt(data[k]['text']), 80 | params={}) 81 | generated_questions[k] = { 82 | "documentation": data[k], 83 | "answer": answer 84 | } 85 | return generated_questions 86 | -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/README.md: -------------------------------------------------------------------------------- 1 | # Retrieval Systems 2 | 3 | ## Provided Models 4 | 5 | This folder contains several classes of retrieval models, to be evaluated in combination with LLMs. Most notably, we provide the code for: 6 | 7 | * **Sparse Methods**: 8 | * **BM25**: Classical BM25 retriver, implementation from [this repo](https://github.com/dorianbrown/rank_bm25/blob/master/rank_bm25.py) 9 | * **Dense Methods**: 10 | * **SiameseContextProvider**: Siamese mode, using `vblagoje/dpr-question_encoder-single-lfqa-base`, from Hugging Face. 11 | * **EmbeddingContextProvider**: Classical embedding model, using `sentence-transformers/multi-qa-MiniLM-L6-cos-v1`, from Hugging Face. 12 | * Moreover, we provide a general index implementation in `docs_faiss_index.py` 13 | * **Hybrid Methods**: 14 | * **DPRContextGenerator**: Cross-Encode, to aggregate base retrieval models, using `cross-encoder/ms-marco-MiniLM-L-6-v2`, from Hugging Face. 15 | 16 | One can use the file `test_retrieval_models.py` to test the implementation of models. 17 | 18 | ## Custom Models 19 | 20 | To implement your own model, follow the convention from the abstract class `ContextProvider` in `context_utils.py` 21 | 22 | ## Model Usage 23 | 24 | To leverage different models during inference and potential ensemble several models, one just need to follow the convention: 25 | 26 | ```[python] 27 | context_generator_dict = { 28 | 'DPR': DPRContextGenerator(context_sources={ 29 | 'SIAMESE' : SiameseContextProvider(index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex/siamese_emb", 30 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 31 | regenerate_index=False), 32 | 'BM25': BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main") 33 | }), 34 | 'BM25' : BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main"), 35 | 'SIAMESE' : SiameseContextProvider(index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex/siamese_emb", 36 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 37 | regenerate_index=True), 38 | 'MultiQA' : EmbeddingContextProvider(index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex/multi_qa_emb", 39 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 40 | regenerate_index=True), 41 | 'DPR:MultiQA:BM25': DPRContextGenerator(context_sources={ 42 | 'MultiQA' : EmbeddingContextProvider(index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex/multi_qa_emb", 43 | data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 44 | regenerate_index=False), 45 | 'BM25': BM25ContextProvider(data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main") 46 | }), 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/__pycache__/bm25.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/RetrievalSystems/__pycache__/bm25.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/RetrievalSystems/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/__pycache__/context_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/RetrievalSystems/__pycache__/context_utils.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/__pycache__/docs_faiss_index.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/RetrievalSystems/__pycache__/docs_faiss_index.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/__pycache__/dpr_context_aggregator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/RetrievalSystems/__pycache__/dpr_context_aggregator.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/__pycache__/dpr_context_retriever.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/RetrievalSystems/__pycache__/dpr_context_retriever.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/__pycache__/embedding_retriever.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/RetrievalSystems/__pycache__/embedding_retriever.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/__pycache__/siamese_retriever.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/auto-rag-eval/RetrievalSystems/__pycache__/siamese_retriever.cpython-37.pyc -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/bm25.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/dorianbrown/rank_bm25/blob/master/rank_bm25.py 2 | 3 | import json 4 | import math 5 | from multiprocessing import Pool, cpu_count 6 | from typing import Dict, List 7 | 8 | import numpy as np 9 | from RetrievalSystems.context_utils import (ContextPassage, ContextProvider, 10 | get_single_file_in_folder) 11 | 12 | """ 13 | All of these algorithms have been taken from the paper: 14 | Trotmam et al, Improvements to BM25 and Language Models Examined 15 | 16 | Here we implement all the BM25 variations mentioned. 17 | """ 18 | 19 | 20 | class BM25: 21 | def __init__(self, corpus, tokenizer=None): 22 | self.corpus_size = 0 23 | self.avgdl = 0 24 | self.doc_freqs = [] 25 | self.idf = {} 26 | self.doc_len = [] 27 | self.tokenizer = tokenizer 28 | 29 | if tokenizer: 30 | corpus = self._tokenize_corpus(corpus) 31 | 32 | nd = self._initialize(corpus) 33 | self._calc_idf(nd) 34 | 35 | def _initialize(self, corpus): 36 | nd = {} # word -> number of documents with word 37 | num_doc = 0 38 | for document in corpus: 39 | self.doc_len.append(len(document)) 40 | num_doc += len(document) 41 | 42 | frequencies = {} 43 | for word in document: 44 | if word not in frequencies: 45 | frequencies[word] = 0 46 | frequencies[word] += 1 47 | self.doc_freqs.append(frequencies) 48 | 49 | for word, freq in frequencies.items(): 50 | try: 51 | nd[word] += 1 52 | except KeyError: 53 | nd[word] = 1 54 | 55 | self.corpus_size += 1 56 | 57 | self.avgdl = num_doc / self.corpus_size 58 | return nd 59 | 60 | def _tokenize_corpus(self, corpus): 61 | pool = Pool(cpu_count()) 62 | tokenized_corpus = pool.map(self.tokenizer, corpus) 63 | return tokenized_corpus 64 | 65 | def _calc_idf(self, nd): 66 | raise NotImplementedError() 67 | 68 | def get_scores(self, query): 69 | raise NotImplementedError() 70 | 71 | def get_batch_scores(self, query, doc_ids): 72 | raise NotImplementedError() 73 | 74 | def get_top_n(self, query, documents, n=5): 75 | 76 | assert self.corpus_size == len(documents), "The documents given don't match the index corpus!" 77 | 78 | scores = self.get_scores(query) 79 | top_n = np.argsort(scores)[::-1][:n] 80 | return [ContextPassage(**documents[i]) for i in top_n] 81 | 82 | 83 | class BM25Okapi(BM25): 84 | def __init__(self, corpus, tokenizer=None, k1=1.5, b=0.75, epsilon=0.25): 85 | self.k1 = k1 86 | self.b = b 87 | self.epsilon = epsilon 88 | super().__init__(corpus, tokenizer) 89 | 90 | def _calc_idf(self, nd): 91 | """ 92 | Calculates frequencies of terms in documents and in corpus. 93 | This algorithm sets a floor on the idf values to eps * average_idf 94 | """ 95 | # collect idf sum to calculate an average idf for epsilon value 96 | idf_sum = 0 97 | # collect words with negative idf to set them a special epsilon value. 98 | # idf can be negative if word is contained in more than half of documents 99 | negative_idfs = [] 100 | for word, freq in nd.items(): 101 | idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5) 102 | self.idf[word] = idf 103 | idf_sum += idf 104 | if idf < 0: 105 | negative_idfs.append(word) 106 | self.average_idf = idf_sum / len(self.idf) 107 | 108 | eps = self.epsilon * self.average_idf 109 | for word in negative_idfs: 110 | self.idf[word] = eps 111 | 112 | def get_scores(self, query): 113 | """ 114 | The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores, 115 | this algorithm also adds a floor to the idf value of epsilon. 116 | See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info 117 | :param query: 118 | :return: 119 | """ 120 | score = np.zeros(self.corpus_size) 121 | doc_len = np.array(self.doc_len) 122 | for q in query: 123 | q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) 124 | score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) 125 | / (q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl))) 126 | return score 127 | 128 | def get_batch_scores(self, query, doc_ids): 129 | """ 130 | Calculate bm25 scores between query and subset of all docs 131 | """ 132 | assert all(di < len(self.doc_freqs) for di in doc_ids) 133 | score = np.zeros(len(doc_ids)) 134 | doc_len = np.array(self.doc_len)[doc_ids] 135 | for q in query: 136 | q_freq = np.array([(self.doc_freqs[di].get(q) or 0) for di in doc_ids]) 137 | score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) 138 | / (q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl))) 139 | return score.tolist() 140 | 141 | 142 | class BM25L(BM25): 143 | def __init__(self, corpus, tokenizer=None, k1=1.5, b=0.75, delta=0.5): 144 | # Algorithm specific parameters 145 | self.k1 = k1 146 | self.b = b 147 | self.delta = delta 148 | super().__init__(corpus, tokenizer) 149 | 150 | def _calc_idf(self, nd): 151 | for word, freq in nd.items(): 152 | idf = math.log(self.corpus_size + 1) - math.log(freq + 0.5) 153 | self.idf[word] = idf 154 | 155 | def get_scores(self, query): 156 | score = np.zeros(self.corpus_size) 157 | doc_len = np.array(self.doc_len) 158 | for q in query: 159 | q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) 160 | ctd = q_freq / (1 - self.b + self.b * doc_len / self.avgdl) 161 | score += (self.idf.get(q) or 0) * (self.k1 + 1) * (ctd + self.delta) / \ 162 | (self.k1 + ctd + self.delta) 163 | return score 164 | 165 | def get_batch_scores(self, query, doc_ids): 166 | """ 167 | Calculate bm25 scores between query and subset of all docs 168 | """ 169 | assert all(di < len(self.doc_freqs) for di in doc_ids) 170 | score = np.zeros(len(doc_ids)) 171 | doc_len = np.array(self.doc_len)[doc_ids] 172 | for q in query: 173 | q_freq = np.array([(self.doc_freqs[di].get(q) or 0) for di in doc_ids]) 174 | ctd = q_freq / (1 - self.b + self.b * doc_len / self.avgdl) 175 | score += (self.idf.get(q) or 0) * (self.k1 + 1) * (ctd + self.delta) / \ 176 | (self.k1 + ctd + self.delta) 177 | return score.tolist() 178 | 179 | 180 | class BM25Plus(BM25): 181 | def __init__(self, corpus, tokenizer=None, k1=1.5, b=0.75, delta=1): 182 | # Algorithm specific parameters 183 | self.k1 = k1 184 | self.b = b 185 | self.delta = delta 186 | super().__init__(corpus, tokenizer) 187 | 188 | def _calc_idf(self, nd): 189 | for word, freq in nd.items(): 190 | idf = math.log((self.corpus_size + 1) / freq) 191 | self.idf[word] = idf 192 | 193 | def get_scores(self, query): 194 | score = np.zeros(self.corpus_size) 195 | doc_len = np.array(self.doc_len) 196 | for q in query: 197 | q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) 198 | score += (self.idf.get(q) or 0) * (self.delta + (q_freq * (self.k1 + 1)) 199 | / (self.k1 * (1 - self.b + self.b * doc_len / self.avgdl) + q_freq)) 200 | return score 201 | 202 | def get_batch_scores(self, query, doc_ids): 203 | """ 204 | Calculate bm25 scores between query and subset of all docs 205 | """ 206 | assert all(di < len(self.doc_freqs) for di in doc_ids) 207 | score = np.zeros(len(doc_ids)) 208 | doc_len = np.array(self.doc_len)[doc_ids] 209 | for q in query: 210 | q_freq = np.array([(self.doc_freqs[di].get(q) or 0) for di in doc_ids]) 211 | score += (self.idf.get(q) or 0) * (self.delta + (q_freq * (self.k1 + 1)) 212 | / (self.k1 * (1 - self.b + self.b * doc_len / self.avgdl) + q_freq)) 213 | return score.tolist() 214 | 215 | 216 | class BM25ContextProvider(ContextProvider): 217 | 218 | def __init__(self, 219 | data_folder: str, 220 | bm25algo: BM25 = BM25Okapi, 221 | top_k_results: int = 3): 222 | 223 | with open(get_single_file_in_folder(data_folder), "r") as f: 224 | self.corpus = json.load(f) 225 | 226 | self.bm25 = bm25algo([self.tokenizer(doc['text']) 227 | for doc in self.corpus]) 228 | self.top_k_results = top_k_results 229 | 230 | def get_context_from_query(self, 231 | query: str) -> List[Dict[str, str]]: 232 | 233 | tokenized_query = self.tokenizer(query) 234 | 235 | return self.bm25.get_top_n(tokenized_query, 236 | self.corpus, 237 | n=self.top_k_results) 238 | 239 | def tokenizer(self, text: str) -> str: 240 | 241 | return text.split(" ") 242 | -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/context_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import inspect 3 | import os 4 | import re 5 | from typing import Dict, List, Union 6 | 7 | from pydantic import BaseModel 8 | 9 | 10 | def string_to_date(string: str): 11 | year, month, day = map(int, string.split("-")) 12 | return datetime.date(year, month, day) 13 | 14 | 15 | def filter_args(func, args_dict): 16 | sig = inspect.signature(func) 17 | return {k: v for k, v in args_dict.items() if k in sig.parameters} 18 | 19 | 20 | class ContextPassage(BaseModel): 21 | source: Union[str, List[str]] 22 | docs_id: str 23 | title: str 24 | section: Union[str, List[str]] 25 | text: str 26 | start_character: Union[str, int] 27 | end_character: Union[str, int] 28 | date: str 29 | answer_similarity: float = 0 30 | 31 | 32 | class ConstraintException(Exception): 33 | pass 34 | 35 | 36 | class ContextProvider: 37 | 38 | def get_context_from_query(self, 39 | query: str, 40 | params: Dict[str, Union[int, str]] = {}) -> List[ContextPassage]: 41 | 42 | pass 43 | 44 | 45 | def get_single_file_in_folder(folder_path): 46 | # List all entries in the given folder 47 | entries = os.listdir(folder_path) 48 | 49 | # Filter out only the files (excluding directories and other types) 50 | files = [os.path.join(folder_path, f) for f in entries if os.path.isfile(os.path.join(folder_path, f))] 51 | 52 | # Check the number of files 53 | if len(files) == 1: 54 | return files[0] 55 | elif len(files) == 0: 56 | raise ValueError(f"No files found in the directory {folder_path}") 57 | else: 58 | raise ValueError(f"More than one file found in the directory {folder_path}. Files are: {', '.join(files)}") 59 | 60 | 61 | def clean_question(text): 62 | result = cleanup_references(text) 63 | result = result.replace("\n", " ") 64 | result = re.sub(r"\s\s+", " ", result) 65 | result = result.replace("[deleted]", "") 66 | return result.lower().strip() 67 | 68 | 69 | def cleanup_references(text): 70 | # URL reference where we need to remove both the link text and URL 71 | # ...and this letter is used by most biographers as the cornerstone of Lee's personal 72 | # views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)). 73 | # ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery. 74 | result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE) 75 | 76 | # URL reference where we need to preserve link text but remove URL 77 | # At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South. 78 | # At the outbreak of the Civil War, Leyburn left his church and joined the South. 79 | result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE) 80 | 81 | # lastly remove just dangling _URL_[0-9]_ URL references 82 | result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE) 83 | return result 84 | 85 | 86 | def clean_answer(text): 87 | result = cleanup_references(text) 88 | result = result.replace("\n", " ") 89 | result = re.sub(r"\s\s+", " ", result) 90 | result = re.sub(r"BULLET::::-", "", result) 91 | return trim(result.strip()) 92 | 93 | 94 | def trim(text, word_count: int = 100): 95 | return " ".join(text.split(" ")[:word_count]) 96 | -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/docs_faiss_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from os.path import abspath, dirname 5 | from typing import Dict 6 | 7 | import faiss 8 | import numpy as np 9 | import torch 10 | from datasets import load_dataset 11 | from sentence_transformers import SentenceTransformer 12 | from transformers import AutoTokenizer, DPRContextEncoder 13 | 14 | ROOTPATH = dirname(dirname(abspath(__file__))) 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class FaissIndex: 20 | 21 | def embed_passages_for_retrieval(self, passages: Dict[str, str]) -> Dict[str, np.array]: 22 | 23 | pass 24 | 25 | def create_faiss(self, 26 | data_folder: str, 27 | index_folder: str) -> None: 28 | 29 | index_file_name = f"{index_folder}/kilt_dpr_data.faiss" 30 | cache_file_name = f"{index_folder}/data_kilt_embedded.arrow" 31 | 32 | docs_data = load_dataset(data_folder, 33 | split="train", 34 | # field="data" # To be removed for BH data template, which differs from others 35 | ) 36 | 37 | if os.path.isfile(index_file_name): 38 | logger.error(f"Deleting existing Faiss index: {index_file_name}") 39 | os.remove(index_file_name) 40 | if os.path.isfile(cache_file_name): 41 | logger.error(f"Deleting existing Faiss index cache {cache_file_name}") 42 | os.remove(cache_file_name) 43 | 44 | # TODO: asssert set(self.docs_data_columns.features.keys()) == set(self.docs_data_columns) 45 | 46 | paragraphs_embeddings = docs_data.map(self.embed_passages_for_retrieval, 47 | remove_columns=self.docs_data_columns, 48 | batched=True, 49 | batch_size=512, 50 | cache_file_name=cache_file_name, 51 | desc="Creating faiss index") 52 | 53 | # Faiss implementation of HNSW for fast approximate nearest neighbor search 54 | # custom_index = faiss.IndexHNSWFlat(dims, 128, faiss.METRIC_INNER_PRODUCT) 55 | # custom_index = faiss.IndexFlatIP(dims) 56 | # custom_index = faiss.index_cpu_to_all_gpus(custom_index) 57 | 58 | paragraphs_embeddings.add_faiss_index( 59 | column="embeddings", 60 | custom_index=faiss.IndexFlatIP(self.dims)) 61 | paragraphs_embeddings.save_faiss_index( 62 | "embeddings", index_file_name) 63 | logger.error("Faiss index successfully created") 64 | 65 | 66 | class DocFaissIndex(FaissIndex): 67 | 68 | def __init__(self, 69 | ctx_encoder_name: str = "vblagoje/dpr-ctx_encoder-single-lfqa-base"): 70 | 71 | self.dims = 128 72 | self.device = ("cuda" if torch.cuda.is_available() else "cpu") 73 | self.ctx_tokenizer = AutoTokenizer.from_pretrained( 74 | ctx_encoder_name) 75 | self.ctx_model = DPRContextEncoder.from_pretrained( 76 | ctx_encoder_name).to(self.device) 77 | _ = self.ctx_model.eval() 78 | 79 | self.docs_data_columns = ['source', 80 | 'docs_id', 81 | 'title', 82 | 'section', 83 | 'text', 84 | 'start_character', 85 | 'end_character', 86 | 'date'] 87 | 88 | def embed_passages_for_retrieval(self, 89 | passages: Dict[str, str]): 90 | p = self.ctx_tokenizer(passages["text"], 91 | max_length=128, 92 | padding="max_length", 93 | truncation=True, 94 | return_tensors="pt") 95 | with torch.no_grad(): 96 | a_reps = self.ctx_model(p["input_ids"].to("cuda:0"), 97 | p["attention_mask"].to("cuda:0")).pooler_output 98 | 99 | return {"embeddings": a_reps.cpu().numpy()} 100 | 101 | 102 | class EmbedFaissIndex(FaissIndex): 103 | 104 | def __init__(self, 105 | model_name: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"): 106 | 107 | self.dims = 384 108 | self.device = ("cuda" if torch.cuda.is_available() else "cpu") 109 | self.model = SentenceTransformer(model_name) 110 | 111 | self.docs_data_columns = ['source', 112 | 'docs_id', 113 | 'title', 114 | 'section', 115 | 'text', 116 | 'start_character', 117 | 'end_character', 118 | 'date'] 119 | 120 | def embed_passages_for_retrieval(self, examples): 121 | return {"embeddings": self.model.encode(examples['text'])} 122 | 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser( 126 | description="Creates Faiss Docs index file") 127 | 128 | parser.add_argument( 129 | "--task-domain", 130 | help="Task Domain, among DevOps, StackExchange...", 131 | ) 132 | 133 | main_args, _ = parser.parse_known_args() 134 | 135 | faiss_index = DocFaissIndex() 136 | faiss_index.create_faiss(data_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/KnowledgeCorpus/main", 137 | index_folder=f"{ROOTPATH}/Data/{main_args.task_domain}/RetrievalIndex") 138 | -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/dpr_context_aggregator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from os.path import abspath, dirname 3 | from typing import Dict, List, Union 4 | 5 | import numpy as np 6 | from RetrievalSystems.context_utils import ( 7 | ContextPassage, 8 | ContextProvider, 9 | SearchConstraint, 10 | filter_args, 11 | ) 12 | from sentence_transformers import CrossEncoder 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | ROOTPATH = dirname(dirname(abspath(__file__))) 17 | 18 | 19 | class DPRContextGenerator(ContextProvider): 20 | 21 | # Here the SearchAggregator Works over a single ContextProvider class 22 | TOPK_CROSSENCODER = 3 23 | 24 | def __init__(self, 25 | context_sources: Dict[str, ContextProvider]): 26 | 27 | self.context_sources = context_sources 28 | self.crossencoder = CrossEncoder( 29 | "cross-encoder/ms-marco-MiniLM-L-6-v2") 30 | self.search_constraints = SearchConstraint() 31 | 32 | def get_matching_context(self, 33 | query: str) -> List[ContextPassage]: 34 | 35 | context_passages = [] 36 | 37 | # TODO: Run in parallel 38 | for context_provider_id, context_provider in self.context_sources.items(): 39 | 40 | try: 41 | 42 | context_passages.extend( 43 | context_provider.get_context_from_query(query=query)) 44 | # logger.info( 45 | # f'{context_provider_id} Context successfully extracted for query "{query}"') 46 | 47 | except Exception as e: 48 | logger.error( 49 | f'Failure to extract {context_provider_id} context for query "{query}": {e}') 50 | 51 | return context_passages 52 | 53 | def get_ranked_context(self, 54 | query: str, 55 | context_passages: List[ContextPassage], 56 | topk_crossencoder: int = TOPK_CROSSENCODER) -> List[ContextPassage]: 57 | 58 | question_passage_combinations = [ 59 | [query, p.text] for p in context_passages] 60 | 61 | # Compute the similarity scores for these combinations 62 | similarity_scores = self.crossencoder.predict( 63 | question_passage_combinations) 64 | 65 | # Sort the scores in decreasing order 66 | sim_ranking_idx = np.flip(np.argsort(similarity_scores)) 67 | 68 | return [context_passages[rank_idx] 69 | for rank_idx in sim_ranking_idx[:topk_crossencoder]] 70 | 71 | def get_context_from_query(self, 72 | query: str, 73 | params: Dict[str, Union[int, str]] = {}) -> List[ContextPassage]: 74 | 75 | preprocessed_query = query.replace('\n', ' ') 76 | context_passages = self.get_matching_context( 77 | query=preprocessed_query) 78 | 79 | ranked_context_passages = self.get_ranked_context( 80 | query=preprocessed_query, 81 | context_passages=context_passages, 82 | **filter_args(func=self.get_ranked_context, 83 | args_dict=params)) 84 | 85 | return ranked_context_passages 86 | -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/embedding_retriever.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import List 3 | 4 | import numpy as np 5 | from datasets import load_dataset 6 | from RetrievalSystems.context_utils import ContextPassage, ContextProvider 7 | from RetrievalSystems.docs_faiss_index import EmbedFaissIndex 8 | from sentence_transformers import SentenceTransformer 9 | 10 | 11 | class EmbeddingContextProvider(ContextProvider): 12 | def __init__(self, 13 | index_folder: str, 14 | data_folder: str, 15 | regenerate_index: bool = True): 16 | """ 17 | index_folder := f"{ROOTPATH}/Data/DevOps/RetrievalIndex/multi_qa_emb" 18 | data_folder := f"{ROOTPATH}/Data/DevOps/KnowledgeCorpus/main" 19 | """ 20 | self.model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1') 21 | 22 | self.topk_embeddings = 20 23 | self.min_snippet_length = 20 24 | 25 | self.docs_data = load_dataset(data_folder, 26 | split="train", 27 | # field="data" Old artifact from BH Template 28 | ) 29 | # Generate a new index each time to avoid using an incorrect one 30 | if regenerate_index or not os.path.isfile(f"{index_folder}/kilt_dpr_data.faiss"): 31 | faiss_index = EmbedFaissIndex() 32 | faiss_index.create_faiss(data_folder=data_folder, 33 | index_folder=index_folder) 34 | self.docs_data.load_faiss_index("embeddings", 35 | f"{index_folder}/kilt_dpr_data.faiss") 36 | self.columns = ['source', 'docs_id', 'title', 'section', 37 | 'text', 'start_character', 'end_character', 'date'] 38 | 39 | def embed_questions_for_retrieval(self, 40 | question: str) -> np.array: 41 | return self.model.encode(question) 42 | 43 | def query_index(self, 44 | query: str) -> List[ContextPassage]: 45 | question_embedding = self.embed_questions_for_retrieval([query]) 46 | a, docs_passages = self.docs_data.get_nearest_examples( 47 | "embeddings", question_embedding, k=self.topk_embeddings) 48 | retrieved_examples = [] 49 | r = list(zip(docs_passages[k] for k in self.columns)) 50 | for i in range(self.topk_embeddings): 51 | retrieved_examples.append(ContextPassage(**{k: v for k, v in zip( 52 | self.columns, [r[j][0][i] for j in range(len(self.columns))])})) 53 | return retrieved_examples 54 | 55 | def get_context_from_query(self, 56 | query: str) -> List[ContextPassage]: 57 | 58 | context_passages = [res for res in self.query_index(query=query) 59 | if len(res.text.split()) > self.min_snippet_length][:int(self.topk_embeddings / 3)] 60 | 61 | return context_passages 62 | 63 | def get_id(self) -> str: 64 | 65 | return "MultiQAEmbContextProvider" 66 | -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/siamese_retriever.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path 3 | from typing import List 4 | 5 | import numpy as np 6 | import torch 7 | from datasets import load_dataset 8 | from RetrievalSystems.context_utils import ContextPassage, ContextProvider 9 | from RetrievalSystems.docs_faiss_index import DocFaissIndex 10 | from transformers import AutoTokenizer, DPRQuestionEncoder 11 | 12 | 13 | class SiameseContextProvider(ContextProvider): 14 | 15 | def __init__(self, 16 | index_folder: str, 17 | data_folder: str, 18 | regenerate_index: bool = True): 19 | """ 20 | index_folder := f"{ROOTPATH}/Data/DevOps/RetrievalIndex" 21 | data_folder := f"{ROOTPATH}/Data/DevOps/KnowledgeCorpus/main" 22 | """ 23 | self.device = ("cuda" if torch.cuda.is_available() else "cpu") 24 | self.model = DPRQuestionEncoder.from_pretrained( 25 | "vblagoje/dpr-question_encoder-single-lfqa-base").to(self.device) 26 | self.tokenizer = AutoTokenizer.from_pretrained( 27 | "vblagoje/dpr-question_encoder-single-lfqa-base") 28 | _ = self.model.eval() 29 | 30 | self.topk_embeddings = 20 31 | self.min_snippet_length = 20 32 | 33 | self.docs_data = load_dataset(data_folder, 34 | split="train", 35 | # field="data" Old artifact from BH Template 36 | ) 37 | # Generate a new index each time to avoid using an incorrect one 38 | if regenerate_index or not os.path.isfile(f"{index_folder}/kilt_dpr_data.faiss"): 39 | faiss_index = DocFaissIndex() 40 | faiss_index.create_faiss(data_folder=data_folder, 41 | index_folder=index_folder) 42 | self.docs_data.load_faiss_index("embeddings", 43 | f"{index_folder}/kilt_dpr_data.faiss") 44 | self.columns = ['source', 'docs_id', 'title', 'section', 45 | 'text', 'start_character', 'end_character', 'date'] 46 | 47 | def embed_questions_for_retrieval(self, 48 | questions: List[str]) -> np.array: 49 | query = self.tokenizer(questions, max_length=128, padding=True, 50 | truncation=True, return_tensors="pt") 51 | with torch.no_grad(): 52 | q_reps = self.model(query["input_ids"].to(self.device), 53 | query["attention_mask"].to(self.device)).pooler_output 54 | return q_reps.cpu().numpy() 55 | 56 | def query_index(self, 57 | query: str) -> List[ContextPassage]: 58 | question_embedding = self.embed_questions_for_retrieval([query]) 59 | a, docs_passages = self.docs_data.get_nearest_examples( 60 | "embeddings", question_embedding, k=self.topk_embeddings) 61 | retrieved_examples = [] 62 | r = list(zip(docs_passages[k] for k in self.columns)) 63 | for i in range(self.topk_embeddings): 64 | retrieved_examples.append(ContextPassage(**{k: v for k, v in zip( 65 | self.columns, [r[j][0][i] for j in range(len(self.columns))])})) 66 | return retrieved_examples 67 | 68 | def get_context_from_query(self, 69 | query: str) -> List[ContextPassage]: 70 | 71 | context_passages = [res for res in self.query_index(query=query) 72 | if len(res.text.split()) > self.min_snippet_length][:int(self.topk_embeddings / 3)] 73 | 74 | return context_passages 75 | 76 | def get_id(self) -> str: 77 | 78 | return "SiameseContextProvider" 79 | -------------------------------------------------------------------------------- /auto-rag-eval/RetrievalSystems/test_retrieval_models.py: -------------------------------------------------------------------------------- 1 | 2 | from os.path import abspath, dirname 3 | 4 | from RetrievalSystems.bm25 import BM25ContextProvider, BM25Okapi 5 | from RetrievalSystems.dpr_context_aggregator import DPRContextGenerator 6 | from RetrievalSystems.embedding_retriever import EmbeddingContextProvider 7 | from RetrievalSystems.siamese_retriever import SiameseContextProvider 8 | 9 | ROOTPATH = dirname(dirname(abspath(__file__))) 10 | 11 | if __name__ == "__main__": 12 | 13 | bm25_context_provider = BM25ContextProvider( 14 | data_folder=f"{ROOTPATH}/Data/DevOps/KnowledgeCorpus/main", 15 | bm25algo=BM25Okapi) 16 | 17 | # query = "How to connect an EC2 instance to an s3 bucket ?" 18 | query = "Which of the following is a valid method for verifying the Availability Zone mapping on an AWS account?" 19 | 20 | print(bm25_context_provider.get_context_from_query(query)) 21 | 22 | emb_context_generator = EmbeddingContextProvider( 23 | index_folder=f"{ROOTPATH}/Data/DevOps/RetrievalIndex/multi_qa_emb", 24 | data_folder=f"{ROOTPATH}/Data/DevOps/KnowledgeCorpus/main", 25 | regenerate_index=True) 26 | 27 | print(emb_context_generator.get_context_from_query("How to terminate an EC2 instance ?")) 28 | 29 | # Testing the Siamese Context Generator 30 | # --- 31 | siamese_context_generator = SiameseContextProvider( 32 | index_folder=f"{ROOTPATH}/Data/DevOps/RetrievalIndex/siamese_emb", 33 | data_folder=f"{ROOTPATH}/Data/DevOps/KnowledgeCorpus/main", 34 | regenerate_index=False) 35 | 36 | print(siamese_context_generator.get_context_from_query("How to terminate an EC2 instance ?")) 37 | 38 | # Testing the DPR Context Generator 39 | # --- 40 | dpr_context_generator = DPRContextGenerator( 41 | index_folder=f"{ROOTPATH}/Data/DevOps/RetrievalIndex/siamese_emb", 42 | data_folder=f"{ROOTPATH}/Data/DevOps/KnowledgeCorpus/main") 43 | 44 | print(dpr_context_generator.get_context_from_query("How to terminate an EC2 instance ?")) 45 | -------------------------------------------------------------------------------- /auto-rag-eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Implement your code here. 2 | -------------------------------------------------------------------------------- /auto-rag-eval/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file that indicates this package supports typing 2 | -------------------------------------------------------------------------------- /images/generation_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/auto-rag-eval/a25bc50e78790044ddc45874e0c9085a73f0262e/images/generation_summary.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | 4 | [tool.isort] 5 | known_first_party = ["llm_automated_exam_evaluation"] 6 | 7 | # required for compatibility with black: 8 | profile = "black" 9 | 10 | # To maintain consistency with other settings 11 | line_length = 100 12 | 13 | [tool.mypy] 14 | # See https://mypy.readthedocs.io/en/latest/config_file.html for more mypy options. 15 | 16 | # Enables the type-checker on the interior of functions without type annotations. 17 | check_untyped_defs = true 18 | 19 | # Displaying specific error codes makes it easier to silence specific errors 20 | # See also https://mypy.readthedocs.io/en/latest/error_codes.html 21 | show_error_codes = true 22 | 23 | # Show source code snippets and location markers in error messages 24 | pretty = true 25 | 26 | # Suppresses errors about packages which do not implement type-hint sharing. 27 | # See also https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports. 28 | ignore_missing_imports = true 29 | 30 | dependencies = [ 31 | "boto3", 32 | "datasets", 33 | "matplotlib", 34 | "nltk", 35 | "numpy", 36 | "pandas", 37 | "plotly", 38 | "requests", 39 | "scipy", 40 | "sentence_transformers", 41 | "scikit-learn", 42 | "spacy", 43 | "torch", 44 | "tqdm", 45 | "lm-harness", 46 | ] --------------------------------------------------------------------------------