├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── images ├── system_overview.png └── task_definition.png ├── irgr.yml ├── public_reposiroty_irgr_final_v1.zip ├── requirements.txt ├── setup.py └── src ├── __init__.py ├── base_utils.py ├── entailment_bank ├── NOTICE ├── README.md ├── __init__.py ├── eval │ ├── __init__.py │ └── run_scorer.py └── utils │ ├── __init__.py │ ├── angle_utils.py │ ├── entail_trees_utils.py │ ├── eval_utils.py │ └── proof_utils.py ├── entailment_baseline.py ├── entailment_iterative.ipynb ├── entailment_retrieval.ipynb └── retrieval_utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Iterative Retrieval-Generation Reasoner (NAACL 2022) 2 | 3 | This repository contains the code and data for the paper: 4 | 5 | [Entailment Tree Explanations via Iterative Retrieval-Generation Reasoner](https://assets.amazon.science/6e/5d/b055e2644da985a210e15b825422/entailment-tree-explanations-via-iterative-retrieval-generation-reasoner.pdf) 6 | 7 | **Entailment Trees** represents a chain of reasoning that shows how a hypothesis (or an answer to a question) can be explained from simpler textual evidence. 8 | 9 |

10 | Task definition 11 |

12 | 13 | **Iterative Retrieval-Generation Reasoner** our proposed architecture that iteratively searches for suitable premises, constructing a single entailment step at a time. At every generation step, the model searches for a distinct set of premises that will support the generation of a single step, therefore mitigating the language model’s input size limit and improving generation correctness. 14 | 15 |

16 | System overview 17 |

18 | 19 | ## Setting Up Environemnt 20 | 21 | First you need to install the dependencies of the project: 22 | 23 | ```bash 24 | conda env create --file irgr.yml 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | Then activate your conda environment: 29 | 30 | ```bash 31 | conda activate irgr 32 | ``` 33 | 34 | ### Setting up Jupyter 35 | 36 | Most of the code is wrapped inside Jupyter Notebooks. 37 | 38 | You can either start a Jupyter server locally or follow the AWS instructions on how to setup the jupyter notebook on EC2 instances and access it through your browser: 39 | 40 | https://docs.aws.amazon.com/dlami/latest/devguide/setup-jupyter.html 41 | 42 | ### Data Folder Structure 43 | 44 | You can download the [Entailment Bank](https://allenai.org/data/entailmentbank) data and evaluation code by running: 45 | 46 | ``` 47 | python setup.py 48 | ``` 49 | 50 | ### Running Experiments 51 | 52 | You can re-start the kernel and run the whole notebook to execute data-loading / training / evaluation 53 | 54 | dada loading has to be done before executing training and evaluation. 55 | 56 | #### Entailment Tree Generation 57 | 58 | The main model's code is in `src/entailment_iterative.ipynb`. 59 | 60 | The model can generate explanations and proofs for EntailmentBank dataset. 61 | 62 | #### Premise Retrieval 63 | 64 | The main model's code is in `src/entailment_retrieval.ipynb`. 65 | 66 | This model retrieves a set of premises from the corpus. Training uses the EntailmentBank + World Tree V2 corpus. 67 | 68 | ## Citation 69 | 70 | ``` 71 | @inproceedings{neves-ribeiro-etal-2022-entailment, 72 | title = "Entailment Tree Explanations via Iterative Retrieval-Generation Reasoner", 73 | author = "Neves Ribeiro, Danilo and 74 | Wang, Shen and 75 | Ma, Xiaofei and 76 | Dong, Rui and 77 | Wei, Xiaokai and 78 | Zhu, Henghui and 79 | Chen, Xinchi and 80 | Xu, Peng and 81 | Huang, Zhiheng and 82 | Arnold, Andrew and 83 | Roth, Dan", 84 | booktitle = "Findings of the Association for Computational Linguistics: NAACL 2022", 85 | month = jul, 86 | year = "2022", 87 | address = "Seattle, United States", 88 | publisher = "Association for Computational Linguistics", 89 | url = "https://aclanthology.org/2022.findings-naacl.35", 90 | doi = "10.18653/v1/2022.findings-naacl.35", 91 | pages = "465--475", 92 | abstract = "Large language models have achieved high performance on various question answering (QA) benchmarks, but the explainability of their output remains elusive. Structured explanations, called entailment trees, were recently suggested as a way to explain the reasoning behind a QA system{'}s answer. In order to better generate such entailment trees, we propose an architecture called Iterative Retrieval-Generation Reasoner (IRGR). Our model is able to explain a given hypothesis by systematically generating a step-by-step explanation from textual premises. The IRGR model iteratively searches for suitable premises, constructing a single entailment step at a time. Contrary to previous approaches, our method combines generation steps and retrieval of premises, allowing the model to leverage intermediate conclusions, and mitigating the input size limit of baseline encoder-decoder models. We conduct experiments using the EntailmentBank dataset, where we outperform existing benchmarks on premise retrieval and entailment tree generation, with around 300{\%} gain in overall correctness.", 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /images/system_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/images/system_overview.png -------------------------------------------------------------------------------- /images/task_definition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/images/task_definition.png -------------------------------------------------------------------------------- /irgr.yml: -------------------------------------------------------------------------------- 1 | name: irgr 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - anaconda 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=1_gnu 11 | - anyio=3.3.0=py38h578d9bd_0 12 | - argon2-cffi=20.1.0=py38h497a2fe_2 13 | - async_generator=1.10=py_0 14 | - attrs=21.2.0=pyhd8ed1ab_0 15 | - babel=2.9.1=pyh44b312d_0 16 | - backcall=0.2.0=pyh9f0ad1d_0 17 | - backports=1.0=py_2 18 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 19 | - blas=1.0=mkl 20 | - bleach=4.0.0=pyhd8ed1ab_0 21 | - brotlipy=0.7.0=py38h497a2fe_1001 22 | - ca-certificates=2020.10.14=0 23 | - certifi=2020.6.20=py38_0 24 | - cffi=1.14.6=py38ha65f79e_0 25 | - chardet=4.0.0=py38h578d9bd_1 26 | - charset-normalizer=2.0.0=pyhd8ed1ab_0 27 | - click=8.0.1=py38h578d9bd_0 28 | - colorama=0.4.4=pyh9f0ad1d_0 29 | - cryptography=3.4.7=py38ha5dfef3_0 30 | - cudatoolkit=11.1.74=h6bb024c_0 31 | - dataclasses=0.8=pyhc8e2a94_1 32 | - debugpy=1.4.1=py38h709712a_0 33 | - decorator=5.0.9=pyhd8ed1ab_0 34 | - defusedxml=0.7.1=pyhd8ed1ab_0 35 | - entrypoints=0.3=pyhd8ed1ab_1003 36 | - filelock=3.0.12=pyh9f0ad1d_0 37 | - huggingface_hub=0.0.15=pyhd8ed1ab_0 38 | - idna=3.1=pyhd3deb0d_0 39 | - importlib-metadata=4.6.3=py38h578d9bd_0 40 | - importlib_metadata=4.6.3=hd8ed1ab_0 41 | - intel-openmp=2021.3.0=h06a4308_3350 42 | - ipykernel=6.0.3=py38hd0cf306_0 43 | - ipython=7.26.0=py38he5a9106_0 44 | - ipython_genutils=0.2.0=py_1 45 | - ipywidgets=7.6.3=pyhd3deb0d_0 46 | - jedi=0.18.0=py38h578d9bd_2 47 | - jinja2=3.0.1=pyhd8ed1ab_0 48 | - joblib=1.0.1=pyhd8ed1ab_0 49 | - json5=0.9.5=pyh9f0ad1d_0 50 | - jsonschema=3.2.0=pyhd8ed1ab_3 51 | - jupyter_client=6.1.12=pyhd8ed1ab_0 52 | - jupyter_core=4.7.1=py38h578d9bd_0 53 | - jupyter_server=1.10.2=pyhd8ed1ab_0 54 | - jupyterlab=3.1.4=pyhd8ed1ab_0 55 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 56 | - jupyterlab_server=2.7.0=pyhd8ed1ab_0 57 | - jupyterlab_widgets=1.0.0=pyhd8ed1ab_1 58 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 59 | - libblas=3.9.0=11_linux64_mkl 60 | - libcblas=3.9.0=11_linux64_mkl 61 | - libffi=3.3=h58526e2_2 62 | - libgcc-ng=11.1.0=hc902ee8_8 63 | - libgomp=11.1.0=hc902ee8_8 64 | - liblapack=3.9.0=11_linux64_mkl 65 | - libprotobuf=3.13.0.1=hd408876_0 66 | - libsodium=1.0.18=h36c2ea0_1 67 | - libstdcxx-ng=11.1.0=h56837e0_8 68 | - libuv=1.42.0=h7f98852_0 69 | - markupsafe=2.0.1=py38h497a2fe_0 70 | - matplotlib-inline=0.1.2=pyhd8ed1ab_2 71 | - mistune=0.8.4=py38h497a2fe_1004 72 | - mkl=2021.3.0=h06a4308_520 73 | - nbclassic=0.3.1=pyhd8ed1ab_1 74 | - nbclient=0.5.3=pyhd8ed1ab_0 75 | - nbconvert=6.1.0=py38h578d9bd_0 76 | - nbformat=5.1.3=pyhd8ed1ab_0 77 | - ncurses=6.2=h58526e2_4 78 | - nest-asyncio=1.5.1=pyhd8ed1ab_0 79 | - ninja=1.10.2=h4bd325d_0 80 | - notebook=6.4.3=pyha770c72_0 81 | - numpy=1.21.1=py38h9894fe3_0 82 | - openssl=1.1.1k=h7f98852_1 83 | - packaging=21.0=pyhd8ed1ab_0 84 | - pandas=1.1.3=py38he6710b0_0 85 | - pandoc=2.14.1=h7f98852_0 86 | - pandocfilters=1.4.2=py_1 87 | - parso=0.8.2=pyhd8ed1ab_0 88 | - pexpect=4.8.0=pyh9f0ad1d_2 89 | - pickleshare=0.7.5=py_1003 90 | - pip=21.2.3=pyhd8ed1ab_0 91 | - prometheus_client=0.11.0=pyhd8ed1ab_0 92 | - prompt-toolkit=3.0.19=pyha770c72_0 93 | - protobuf=3.13.0.1=py38he6710b0_1 94 | - ptyprocess=0.7.0=pyhd3deb0d_0 95 | - pycparser=2.20=pyh9f0ad1d_2 96 | - pygments=2.9.0=pyhd8ed1ab_0 97 | - pyopenssl=20.0.1=pyhd8ed1ab_0 98 | - pyparsing=2.4.7=pyh9f0ad1d_0 99 | - pyrsistent=0.17.3=py38h497a2fe_2 100 | - pysocks=1.7.1=py38h578d9bd_3 101 | - python=3.8.10=h49503c6_1_cpython 102 | - python-dateutil=2.8.2=pyhd8ed1ab_0 103 | - python_abi=3.8=2_cp38 104 | - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0 105 | - pytz=2021.1=pyhd8ed1ab_0 106 | - pyyaml=5.4.1=py38h497a2fe_0 107 | - pyzmq=22.2.1=py38h2035c66_0 108 | - readline=8.1=h46c0cb4_0 109 | - regex=2021.8.3=py38h497a2fe_0 110 | - requests=2.26.0=pyhd8ed1ab_0 111 | - requests-unixsocket=0.2.0=py_0 112 | - sacremoses=0.0.43=pyh9f0ad1d_0 113 | - send2trash=1.8.0=pyhd8ed1ab_0 114 | - sentencepiece=0.1.95=py38h1fd1430_0 115 | - setuptools=49.6.0=py38h578d9bd_3 116 | - six=1.16.0=pyh6c4a22f_0 117 | - sniffio=1.2.0=py38h578d9bd_1 118 | - sqlite=3.36.0=h9cd32fc_0 119 | - terminado=0.10.1=py38h578d9bd_0 120 | - testpath=0.5.0=pyhd8ed1ab_0 121 | - tk=8.6.10=h21135ba_1 122 | - tokenizers=0.10.1=py38hb63a372_0 123 | - tornado=6.1=py38h497a2fe_1 124 | - tqdm=4.62.0=pyhd8ed1ab_0 125 | - traitlets=5.0.5=py_0 126 | - transformers=4.9.2=pyhd8ed1ab_0 127 | - typing-extensions=3.10.0.0=hd8ed1ab_0 128 | - typing_extensions=3.10.0.0=pyha770c72_0 129 | - urllib3=1.26.6=pyhd8ed1ab_0 130 | - wcwidth=0.2.5=pyh9f0ad1d_2 131 | - webencodings=0.5.1=py_1 132 | - websocket-client=0.57.0=py38h578d9bd_4 133 | - wheel=0.37.0=pyhd8ed1ab_0 134 | - widgetsnbextension=3.5.1=py38h578d9bd_4 135 | - xz=5.2.5=h516909a_1 136 | - yaml=0.2.5=h516909a_0 137 | - zeromq=4.3.4=h9c3ff4c_0 138 | - zipp=3.5.0=pyhd8ed1ab_0 139 | - zlib=1.2.11=h516909a_1010 140 | - pip: 141 | - environment-kernels==1.1.1 142 | prefix: /home/ec2-user/anaconda3/envs/csr 143 | -------------------------------------------------------------------------------- /public_reposiroty_irgr_final_v1.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/public_reposiroty_irgr_final_v1.zip -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | BLEURT @ git+https://github.com/google-research/bleurt.git@b610120347ef22b494b6d69b4316e303f5932516 2 | datasets==1.6.2 3 | deepspeed==0.5.0 4 | huggingface-hub==0.0.12 5 | multiprocess==0.70.12.2 6 | sagemaker==2.59.4 7 | scikit-learn==0.24.2 8 | scipy==1.7.1 9 | sentence-transformers==2.0.0 10 | sentencepiece==0.1.95 11 | tensorboardX 12 | tensorflow 13 | tensorflow-estimator 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | import shutil 5 | from pathlib import Path 6 | 7 | ENTAILMENT_BANK_REPO = 'https://github.com/allenai/entailment_bank.git' 8 | BLEURT_MODEL_PATH = 'bleurt-large-512.zip' 9 | BLEURT_MODEL_URL = 'https://storage.googleapis.com/bleurt-oss/bleurt-large-512.zip' 10 | 11 | ENT_BANK_DATASET_PATH = './entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/dataset/' 12 | DATASET_PATH = './data/arc_entail/dataset/' 13 | ENT_BANK_WT_CORPUS_PATH = './entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/supporting_data/' 14 | WT_CORPUS_PATH = './data/arc_entail/supporting_data/' 15 | ENT_BANK_SRC_DATA_PATH = './entailment_bank/data/' 16 | SRC_DATA_PATH = './src/entailment_bank/data/' 17 | 18 | BLEURT_FOLDER = './bleurt-large-512' 19 | SRC_BLEURT_FOLDER = './src/entailment_bank/bleurt-large-512' 20 | 21 | def prepare_path(path): 22 | path = Path(path) 23 | os.makedirs(path.parents[0], exist_ok=True) 24 | return path 25 | 26 | def doanlod_and_save_to_path(url, path): 27 | path = prepare_path(path) 28 | print(f'Downloading:\n{url}') 29 | response = requests.get(url) 30 | print(f'Saving to:\n{path}') 31 | open(path, 'wb').write(response.content) 32 | 33 | def copy_folders(from_path, to_path): 34 | print(f'{from_path} => {to_path}') 35 | from_path = prepare_path(from_path) 36 | to_path = prepare_path(to_path) 37 | shutil.copytree(from_path, to_path, dirs_exist_ok=True) 38 | 39 | def move_folders(from_path, to_path): 40 | print(f'{from_path} => {to_path}') 41 | from_path = prepare_path(from_path) 42 | to_path = prepare_path(to_path) 43 | shutil.move(from_path, to_path) 44 | 45 | def clone_git_repo(git_repo): 46 | # Cloning 47 | os.system(f'git clone {git_repo}') 48 | 49 | def unzip_file(path): 50 | path = Path(path) 51 | # Unzipping 52 | os.system(f'unzip {path}') 53 | 54 | def setup_entailment_bank_eval(): 55 | print('\nCloning EntailmentBank evaluation repository...') 56 | clone_git_repo(ENTAILMENT_BANK_REPO) 57 | print('\nCopying files...') 58 | copy_folders(ENT_BANK_DATASET_PATH, DATASET_PATH) 59 | copy_folders(ENT_BANK_WT_CORPUS_PATH,WT_CORPUS_PATH) 60 | copy_folders(ENT_BANK_SRC_DATA_PATH, SRC_DATA_PATH) 61 | 62 | print('\nDownloading BLEURT model...') 63 | # downlaad bleurt model 64 | doanlod_and_save_to_path(BLEURT_MODEL_URL, BLEURT_MODEL_PATH) 65 | # unzip bleurt model 66 | unzip_file(BLEURT_MODEL_PATH) 67 | # copy files to src folder 68 | print('\nMoving model files...') 69 | move_folders(BLEURT_FOLDER, SRC_BLEURT_FOLDER) 70 | 71 | def main(): 72 | setup_entailment_bank_eval() 73 | print('\nSetup finished!') 74 | 75 | if __name__ == '__main__': 76 | main() -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/src/__init__.py -------------------------------------------------------------------------------- /src/base_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import re 5 | 6 | #################################################################### 7 | # File Manipulation 8 | #################################################################### 9 | 10 | def save_to_jsonl_file(output_list, file_path): 11 | print('saving data to file:', file_path) 12 | with open(file_path, 'w') as file: 13 | for obj in output_list: 14 | file.write(json.dumps(obj) + '\n') 15 | 16 | def run_funtion_redirect_stdout(fun, args=[], kargs={}, filename='logs.txt'): 17 | ''' 18 | Runs functions while redirecting output to file 19 | ''' 20 | # redirect stdout to file 21 | orig_stdout = sys.stdout 22 | f = open(filename, 'w') 23 | sys.stdout = f 24 | 25 | # execute function 26 | output = fun(*args, **kargs) 27 | 28 | # restore stdout 29 | sys.stdout = orig_stdout 30 | f.close() 31 | return output 32 | 33 | #################################################################### 34 | # Strings 35 | #################################################################### 36 | 37 | def str_replace_single_pass(string, substitutions): 38 | ''' 39 | A Python function that does multiple string replace ops in a single pass. 40 | E.g. substitutions = {"foo": "FOO", "bar": "BAR"} 41 | ''' 42 | substrings = sorted(substitutions, key=len, reverse=True) 43 | regex = re.compile('|'.join(map(re.escape, substrings))) 44 | return regex.sub(lambda match: substitutions[match.group(0)], string) 45 | 46 | 47 | #################################################################### 48 | # Resource Locator 49 | #################################################################### 50 | 51 | def get_results_file_path(params, test_split=False, result_only=False, 52 | temp=False, uuid=False): 53 | suffix = '' 54 | if test_split: 55 | suffix += '_test' 56 | if result_only: 57 | suffix += '_result_only' 58 | if temp: 59 | suffix += '_temp' 60 | if uuid: 61 | suffix += '_uuid' 62 | results_file_path = params.results_file_path.format( 63 | model_name = params.model_name, 64 | task_name = params.task_name, 65 | dataset_name = params.dataset_name, 66 | approach_name = params.approach_name, 67 | suffix = suffix, 68 | extension = 'tsv') 69 | return results_file_path 70 | 71 | def get_logs_file_path(params, test_split=False, temp=False, epoch_num = None): 72 | suffix = '' 73 | prefix = '' 74 | if test_split: 75 | suffix += '_test' 76 | if temp: 77 | suffix += '_temp' 78 | prefix += 'temp/' 79 | if epoch_num is not None: 80 | suffix += f'_{epoch_num}' 81 | logs_file_path = params.logs_file_path.format( 82 | model_name = params.model_name, 83 | task_name = params.task_name, 84 | dataset_name = params.dataset_name, 85 | approach_name = params.approach_name, 86 | prefix = prefix, 87 | suffix = suffix 88 | ) 89 | return logs_file_path 90 | 91 | def get_proof_ranking_data_path(params, split = 'dev'): 92 | logs_file_path = params.proof_ranking_data_filepath.format( 93 | model_name = params.model_name, 94 | task_name = params.task_name, 95 | dataset_name = params.dataset_name, 96 | approach_name = params.approach_name, 97 | split = split 98 | ) 99 | return logs_file_path -------------------------------------------------------------------------------- /src/entailment_bank/NOTICE: -------------------------------------------------------------------------------- 1 | ** entailment_bank; version 1 -- https://github.com/allenai/entailment_bank 2 | Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 3 | License. 4 | 5 | Apache License 6 | Version 2.0, January 2004 7 | http://www.apache.org/licenses/ 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, 14 | and distribution as defined by Sections 1 through 9 of this document. 15 | 16 | "Licensor" shall mean the copyright owner or entity authorized by 17 | the copyright owner that is granting the License. 18 | 19 | "Legal Entity" shall mean the union of the acting entity and all 20 | other entities that control, are controlled by, or are under common 21 | control with that entity. For the purposes of this definition, 22 | "control" means (i) the power, direct or indirect, to cause the 23 | direction or management of such entity, whether by contract or 24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 25 | outstanding shares, or (iii) beneficial ownership of such entity. 26 | 27 | "You" (or "Your") shall mean an individual or Legal Entity 28 | exercising permissions granted by this License. 29 | 30 | "Source" form shall mean the preferred form for making modifications, 31 | including but not limited to software source code, documentation 32 | source, and configuration files. 33 | 34 | "Object" form shall mean any form resulting from mechanical 35 | transformation or translation of a Source form, including but 36 | not limited to compiled object code, generated documentation, 37 | and conversions to other media types. 38 | 39 | "Work" shall mean the work of authorship, whether in Source or 40 | Object form, made available under the License, as indicated by a 41 | copyright notice that is included in or attached to the work 42 | (an example is provided in the Appendix below). 43 | 44 | "Derivative Works" shall mean any work, whether in Source or Object 45 | form, that is based on (or derived from) the Work and for which the 46 | editorial revisions, annotations, elaborations, or other modifications 47 | represent, as a whole, an original work of authorship. For the purposes 48 | of this License, Derivative Works shall not include works that remain 49 | separable from, or merely link (or bind by name) to the interfaces of, 50 | the Work and Derivative Works thereof. 51 | 52 | "Contribution" shall mean any work of authorship, including 53 | the original version of the Work and any modifications or additions 54 | to that Work or Derivative Works thereof, that is intentionally 55 | submitted to Licensor for inclusion in the Work by the copyright owner 56 | or by an individual or Legal Entity authorized to submit on behalf of 57 | the copyright owner. For the purposes of this definition, "submitted" 58 | means any form of electronic, verbal, or written communication sent 59 | to the Licensor or its representatives, including but not limited to 60 | communication on electronic mailing lists, source code control systems, 61 | and issue tracking systems that are managed by, or on behalf of, the 62 | Licensor for the purpose of discussing and improving the Work, but 63 | excluding communication that is conspicuously marked or otherwise 64 | designated in writing by the copyright owner as "Not a Contribution." 65 | 66 | "Contributor" shall mean Licensor and any individual or Legal Entity 67 | on behalf of whom a Contribution has been received by Licensor and 68 | subsequently incorporated within the Work. 69 | 70 | 2. Grant of Copyright License. Subject to the terms and conditions of 71 | this License, each Contributor hereby grants to You a perpetual, 72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 73 | copyright license to reproduce, prepare Derivative Works of, 74 | publicly display, publicly perform, sublicense, and distribute the 75 | Work and such Derivative Works in Source or Object form. 76 | 77 | 3. Grant of Patent License. Subject to the terms and conditions of 78 | this License, each Contributor hereby grants to You a perpetual, 79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | (except as stated in this section) patent license to make, have made, 81 | use, offer to sell, sell, import, and otherwise transfer the Work, 82 | where such license applies only to those patent claims licensable 83 | by such Contributor that are necessarily infringed by their 84 | Contribution(s) alone or by combination of their Contribution(s) 85 | with the Work to which such Contribution(s) was submitted. If You 86 | institute patent litigation against any entity (including a 87 | cross-claim or counterclaim in a lawsuit) alleging that the Work 88 | or a Contribution incorporated within the Work constitutes direct 89 | or contributory patent infringement, then any patent licenses 90 | granted to You under this License for that Work shall terminate 91 | as of the date such litigation is filed. 92 | 93 | 4. Redistribution. You may reproduce and distribute copies of the 94 | Work or Derivative Works thereof in any medium, with or without 95 | modifications, and in Source or Object form, provided that You 96 | meet the following conditions: 97 | 98 | (a) You must give any other recipients of the Work or 99 | Derivative Works a copy of this License; and 100 | 101 | (b) You must cause any modified files to carry prominent notices 102 | stating that You changed the files; and 103 | 104 | (c) You must retain, in the Source form of any Derivative Works 105 | that You distribute, all copyright, patent, trademark, and 106 | attribution notices from the Source form of the Work, 107 | excluding those notices that do not pertain to any part of 108 | the Derivative Works; and 109 | 110 | (d) If the Work includes a "NOTICE" text file as part of its 111 | distribution, then any Derivative Works that You distribute must 112 | include a readable copy of the attribution notices contained 113 | within such NOTICE file, excluding those notices that do not 114 | pertain to any part of the Derivative Works, in at least one 115 | of the following places: within a NOTICE text file distributed 116 | as part of the Derivative Works; within the Source form or 117 | documentation, if provided along with the Derivative Works; or, 118 | within a display generated by the Derivative Works, if and 119 | wherever such third-party notices normally appear. The contents 120 | of the NOTICE file are for informational purposes only and 121 | do not modify the License. You may add Your own attribution 122 | notices within Derivative Works that You distribute, alongside 123 | or as an addendum to the NOTICE text from the Work, provided 124 | that such additional attribution notices cannot be construed 125 | as modifying the License. 126 | 127 | You may add Your own copyright statement to Your modifications and 128 | may provide additional or different license terms and conditions 129 | for use, reproduction, or distribution of Your modifications, or 130 | for any such Derivative Works as a whole, provided Your use, 131 | reproduction, and distribution of the Work otherwise complies with 132 | the conditions stated in this License. 133 | 134 | 5. Submission of Contributions. Unless You explicitly state otherwise, 135 | any Contribution intentionally submitted for inclusion in the Work 136 | by You to the Licensor shall be under the terms and conditions of 137 | this License, without any additional terms or conditions. 138 | Notwithstanding the above, nothing herein shall supersede or modify 139 | the terms of any separate license agreement you may have executed 140 | with Licensor regarding such Contributions. 141 | 142 | 6. Trademarks. This License does not grant permission to use the trade 143 | names, trademarks, service marks, or product names of the Licensor, 144 | except as required for reasonable and customary use in describing the 145 | origin of the Work and reproducing the content of the NOTICE file. 146 | 147 | 7. Disclaimer of Warranty. Unless required by applicable law or 148 | agreed to in writing, Licensor provides the Work (and each 149 | Contributor provides its Contributions) on an "AS IS" BASIS, 150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 151 | implied, including, without limitation, any warranties or conditions 152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 153 | PARTICULAR PURPOSE. You are solely responsible for determining the 154 | appropriateness of using or redistributing the Work and assume any 155 | risks associated with Your exercise of permissions under this License. 156 | 157 | 8. Limitation of Liability. In no event and under no legal theory, 158 | whether in tort (including negligence), contract, or otherwise, 159 | unless required by applicable law (such as deliberate and grossly 160 | negligent acts) or agreed to in writing, shall any Contributor be 161 | liable to You for damages, including any direct, indirect, special, 162 | incidental, or consequential damages of any character arising as a 163 | result of this License or out of the use or inability to use the 164 | Work (including but not limited to damages for loss of goodwill, 165 | work stoppage, computer failure or malfunction, or any and all 166 | other commercial damages or losses), even if such Contributor 167 | has been advised of the possibility of such damages. 168 | 169 | 9. Accepting Warranty or Additional Liability. While redistributing 170 | the Work or Derivative Works thereof, You may choose to offer, 171 | and charge a fee for, acceptance of support, warranty, indemnity, 172 | or other liability obligations and/or rights consistent with this 173 | License. However, in accepting such obligations, You may act only 174 | on Your own behalf and on Your sole responsibility, not on behalf 175 | of any other Contributor, and only if You agree to indemnify, 176 | defend, and hold each Contributor harmless for any liability 177 | incurred by, or claims asserted against, such Contributor by reason 178 | of your accepting any such warranty or additional liability. 179 | 180 | END OF TERMS AND CONDITIONS 181 | 182 | APPENDIX: How to apply the Apache License to your work. 183 | 184 | To apply the Apache License to your work, attach the following 185 | boilerplate notice, with the fields enclosed by brackets "[]" 186 | replaced with your own identifying information. (Don't include 187 | the brackets!) The text should be enclosed in the appropriate 188 | comment syntax for the file format. We also recommend that a 189 | file or class name and description of purpose be included on the 190 | same "printed page" as the copyright notice for easier 191 | identification within third-party archives. 192 | 193 | Copyright [yyyy] [name of copyright owner] 194 | 195 | Licensed under the Apache License, Version 2.0 (the "License"); 196 | you may not use this file except in compliance with the License. 197 | You may obtain a copy of the License at 198 | 199 | http://www.apache.org/licenses/LICENSE-2.0 200 | 201 | Unless required by applicable law or agreed to in writing, software 202 | distributed under the License is distributed on an "AS IS" BASIS, 203 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 204 | See the License for the specific language governing permissions and 205 | limitations under the License. -------------------------------------------------------------------------------- /src/entailment_bank/README.md: -------------------------------------------------------------------------------- 1 | # Entailment Bank Evaluation 2 | 3 | This evaluation code is part of the entailment bank repository available at: 4 | 5 | [https://github.com/allenai/entailment_bank](https://github.com/allenai/entailment_bank) 6 | 7 | Small changes made to support further metrics and small improvements in "task-3" evaluation. -------------------------------------------------------------------------------- /src/entailment_bank/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/src/entailment_bank/__init__.py -------------------------------------------------------------------------------- /src/entailment_bank/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/src/entailment_bank/eval/__init__.py -------------------------------------------------------------------------------- /src/entailment_bank/eval/run_scorer.py: -------------------------------------------------------------------------------- 1 | # Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 License. 2 | # Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | """ Evaluation script for EntailmentBank models. """ 5 | 6 | import argparse 7 | import glob 8 | import logging 9 | import json 10 | import os 11 | import re 12 | import sys 13 | from bleurt import score 14 | 15 | from tqdm import tqdm 16 | 17 | sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip 18 | 19 | # 2022-05-26: Amazon addition. 20 | from entailment_bank.utils.angle_utils import decompose_slots, load_jsonl, save_json, shortform_angle, formatting 21 | from entailment_bank.utils.eval_utils import collate_scores, score_prediction_whole_proof 22 | from entailment_bank.utils.retrieval_utils import convert_datapoint_sent_to_uuid 23 | # End of Amazon addition. 24 | 25 | logger = logging.getLogger(__name__) 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | 29 | def get_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--task", default=None, required=True, type=str, 32 | help="Task name: task_1, task_2, task_3") 33 | parser.add_argument("--output_dir", default=None, required=True, type=str, 34 | help="Directory to store scores.") 35 | parser.add_argument("--split", default=None, required=True, type=str, help="Which split (train/dev/test) to evaluate.") 36 | parser.add_argument("--prediction_file", default=None, required=True, type=str, 37 | help="Prediction file(s) to score.") 38 | parser.add_argument("--bleurt_checkpoint", default="true", type=str, 39 | help="Path to the BLEURT model checkpoint (Download from https://github.com/google-research/bleurt#checkpoints) " 40 | "We use bleurt-large-512 model for EntailmentBank evaluation") 41 | 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def split_info_sentences(context): 47 | words_list = context.split(" ") 48 | sentence_ids = re.findall(r'[\w\.]+[0-9]+:', context) 49 | sentence_dict = dict() 50 | prev_sid = "" 51 | prev_sentence_parts = [] 52 | for word in words_list: 53 | if word in sentence_ids: 54 | if prev_sid: 55 | sentence_dict[prev_sid] = ' '.join(prev_sentence_parts) 56 | prev_sentence_parts = [] 57 | prev_sid = word 58 | else: 59 | prev_sentence_parts.append(word) 60 | 61 | if prev_sid: 62 | sentence_dict[prev_sid] = ' '.join(prev_sentence_parts) 63 | return sentence_dict 64 | # 2022-05-26: Amazon addition. 65 | def score_predictions(args, predictions_file, score_file, gold_file, angle_file=None, dataset=None, bleurt_checkpoint="", task_name = None): 66 | # End of Amazon addition. 67 | if args.bleurt_checkpoint: 68 | bleurt_scorer = score.BleurtScorer(bleurt_checkpoint) 69 | else: 70 | bleurt_scorer = None 71 | 72 | gold_data = load_jsonl(gold_file) 73 | 74 | # 2022-05-26: Amazon addition. 75 | ###### to fix task_3 evaluation 76 | print('\n\task_name =', task_name) 77 | if task_name == 'task_3': 78 | print('UPDATING DATAPOINTS') 79 | gold_data = [convert_datapoint_sent_to_uuid(g) for g in gold_data] 80 | ###### 81 | # End of Amazon addition. 82 | 83 | gold_by_id = {g['id']: g for g in gold_data} 84 | 85 | gold_train_file = gold_file.replace("dev", "train") 86 | gold_train_data = load_jsonl(gold_train_file) 87 | # gold_train_by_id = {g['id']: g for g in gold_train_data} 88 | train_context_dict = dict() 89 | train_answers_dict = dict() 90 | for g in gold_train_data: 91 | # print(f"{g['meta']['triples']}") 92 | context_dict = dict(g['meta']['triples']) 93 | for context in context_dict.values(): 94 | train_context_dict[context] = 1 95 | train_answers_dict[g['answer']] = 1 96 | 97 | is_jsonl = predictions_file.endswith(".jsonl") 98 | if not is_jsonl: 99 | angle_data = load_jsonl(angle_file) 100 | scores = [] 101 | sort_angle = False 102 | 103 | num_dev_answers = 0 104 | num_dev_answers_seen_in_train_context = 0 105 | num_dev_answers_seen_in_train_answers = 0 106 | diagnostics_tsv = open(score_file+".diagnostics.tsv", "w") 107 | 108 | # 2022-05-26: Amazon addition. 109 | diagnostics_tsv.write(f"Q-ID" 110 | f"\tI:Context" 111 | f"\tI:Question" 112 | f"\tI:Answer" 113 | f"\tI:Hypothesis" 114 | f"\tO:Gold Proof" 115 | f"\tO:Predicted Proof" 116 | f"\tPredicted to gold int alignment" 117 | f"\tRewritten predicted proof after alignment" 118 | f"\t% Leaves-P" 119 | f"\t% Leaves-R" 120 | f"\t% Leaves-F1" 121 | f"\t% Leaves-Correct" 122 | f"\t% Steps-F1" 123 | f"\t% Steps-Correct" 124 | f"\t% Interm-BLEURT-P" 125 | f"\t% Interm-BLEURT-R" 126 | f"\t% Interm-BLEURT-F1" 127 | f"\tInterm-BLEURT-score" 128 | f"\t% Interm-BLEURT-Acc" 129 | f"\t% perfect alignment" 130 | f"\t% Overall Correct" 131 | f"\tNum Distractors" 132 | f"\tContext Length" 133 | f"\tFraction of distractors" 134 | f"\tDistractor Ids" 135 | f"\tDepth of Proof" 136 | f"\tLength of Proof" 137 | "\n") 138 | # End of Amazon addition. 139 | 140 | with open(f"{score_file}.json", "w") as score_file, open(predictions_file, "r") as preds_file: 141 | for line_idx, line in tqdm(enumerate(preds_file)): 142 | if is_jsonl: 143 | pred = json.loads(line.strip()) 144 | else: 145 | pred = {'id': angle_data[line_idx]['id'], 146 | 'angle': angle_data[line_idx]['angle'], 147 | 'prediction': line.strip()} 148 | 149 | angle = pred['angle'] 150 | angle_canonical = shortform_angle(angle, sort_angle=sort_angle) 151 | pred['angle_str'] = angle_canonical 152 | item_id = pred['id'] 153 | 154 | # if item_id not in ['CSZ20680']: 155 | # continue 156 | 157 | if item_id not in gold_by_id: 158 | continue 159 | raise ValueError(f"Missing id in gold data: {item_id}") 160 | slots = decompose_slots(pred['prediction']) 161 | 162 | pred['slots'] = slots 163 | 164 | num_dev_answers += 1 165 | # print(f"======= pred: {pred}") 166 | # print(f">>>>>>>>>>>> id:{item_id}") 167 | metrics = score_prediction_whole_proof(pred, gold_by_id[item_id], dataset, 168 | scoring_spec={ 169 | "hypothesis_eval": "nlg", 170 | "proof_eval": "entail_whole_proof_align_eval", 171 | # "proof_eval": "entail_whole_polish_proof_align_eval", 172 | }, 173 | bleurt_scorer=bleurt_scorer) 174 | 175 | pred['metrics'] = metrics 176 | score_file.write(json.dumps(pred) + "\n") 177 | id = angle_data[line_idx]['id'] 178 | goldslot_record = gold_by_id[id] 179 | # print(f"goldslot_record:{goldslot_record}") 180 | question_before_json = "" 181 | if 'meta' in goldslot_record and 'question' in goldslot_record['meta']: 182 | question_before_json = goldslot_record['meta']['question'] 183 | 184 | question_json = {} 185 | if 'meta' in goldslot_record and 'question' in goldslot_record['meta']: 186 | question_json = goldslot_record['meta']['question'] 187 | 188 | question_json['gold_proofs'] = question_json.get('proofs', "") 189 | question_json['proofs'] = "" 190 | 191 | hypothesis_f1 = metrics.get('hypothesis', dict()).get('ROUGE_L_F', -1) 192 | question_json['ROUGE_L_F'] = hypothesis_f1 193 | 194 | sentences_dict = split_info_sentences(goldslot_record['context']) 195 | sentence_set = [] 196 | for sid, sent in sentences_dict.items(): 197 | sentence_set.append(f"{sid}: {sent}") 198 | sent_str = formatting(sentence_set) 199 | 200 | gold_triples = goldslot_record['meta']['triples'] 201 | gold_ints = goldslot_record.get('meta', dict()).get('intermediate_conclusions', dict()) 202 | gold_ints['hypothesis'] = goldslot_record['hypothesis'] 203 | gold_triples.update(gold_ints) 204 | gold_proof_str = goldslot_record['proof'] 205 | # if '; ' in gold_proof_str: 206 | if True: 207 | gold_proof_steps = gold_proof_str.split(';') 208 | 209 | gold_proof_str_list = [] 210 | for step in gold_proof_steps: 211 | step = step.strip() 212 | if step.strip() and len(step.split(' -> '))==2: 213 | print(f"step:{step}") 214 | 215 | parts = step.split(' -> ') 216 | lhs_ids = parts[0].split('&') 217 | rhs = parts[1] 218 | if rhs == "hypothesis": 219 | rhs = f"hypothesis: {gold_triples['hypothesis']}" 220 | for lid in lhs_ids: 221 | lhs_id = lid.strip() 222 | print(f"QID:{item_id}") 223 | print(f"gold_triples:{gold_triples}") 224 | print(f"step:{step}") 225 | # gold_proof_str_list.append(f"{lhs_id}: {gold_triples[lhs_id]} &") 226 | gold_proof_str_list.append(f"{lhs_id}: {gold_triples[lhs_id] if lhs_id in gold_triples else ''} &") 227 | gold_proof_str_list.append(f"-> {rhs}") 228 | gold_proof_str_list.append(f"-----------------") 229 | gold_proof_str_to_output = formatting(gold_proof_str_list) 230 | 231 | pred_triples = goldslot_record['meta']['triples'] 232 | pred_triples['hypothesis'] = goldslot_record['hypothesis'] 233 | pred_proof_str = pred['slots'].get('proof', "") 234 | print(f"^^^^^^^^^^^^^^^^^^pred_proof_str:{pred_proof_str}") 235 | # if '; ' in pred_proof_str: 236 | if True: 237 | pred_proof_steps = pred_proof_str.split(';') 238 | pred_proof_str_list = [] 239 | # print(f"\n\n=================") 240 | # print(f"pred_proof_str:{pred_proof_str}") 241 | for step in pred_proof_steps: 242 | step = step.strip() 243 | if step.strip() and len(step.split(' -> '))==2: 244 | print(f"step:{step}") 245 | parts = step.split(' -> ') 246 | lhs_ids = parts[0].split('&') 247 | if ',' in parts[0]: 248 | lhs_ids = parts[0].split(',') 249 | rhs = parts[1] 250 | # 2022-05-26: Amazon addition. 251 | if rhs == "hypothesis" or "hypothesis" in rhs: 252 | rhs = f"hypothesis: {pred_triples['hypothesis']}" 253 | else: 254 | if rhs.count(":") == 1: 255 | rhs_parts = rhs.split(":") 256 | print(rhs_parts) 257 | int_id = rhs_parts[0] 258 | int_str = rhs_parts[1].strip() 259 | pred_triples[int_id] = int_str 260 | # else: 261 | # pred_triples[int_id] = rhs.strip() 262 | for lid in lhs_ids: 263 | lhs_id = lid.strip() 264 | pred_proof_str_list.append(f"{lhs_id}: {pred_triples.get(lhs_id, 'NULL')} &") 265 | pred_proof_str_list.append(f"-> {rhs}") 266 | pred_proof_str_list.append(f"-----------------") 267 | # End of Amazon addition. 268 | pred_proof_str_to_output = formatting(pred_proof_str_list) 269 | 270 | if '; ' in pred_proof_str: 271 | pred_proof_steps = pred_proof_str.split('; ') 272 | pred_step_list = [] 273 | 274 | for step in pred_proof_steps: 275 | if step.strip(): 276 | pred_step_list.append(f"{step}; ") 277 | pred_proof_str = formatting(pred_step_list) 278 | 279 | relevance_f1 = "-" 280 | relevance_accuracy = "-" 281 | if 'relevance' in metrics: 282 | relevance_f1 = metrics['relevance']['F1'] 283 | relevance_accuracy = metrics['relevance']['acc'] 284 | 285 | proof_acc = "-" 286 | proof_f1 = "-" 287 | proof_alignements = "-" 288 | if 'aligned_proof' in metrics: 289 | proof_acc = metrics['aligned_proof']['acc'] 290 | proof_f1 = metrics['aligned_proof']['F1'] 291 | proof_alignements = metrics['aligned_proof']['pred_to_gold_mapping'] 292 | 293 | inference_type = "none" 294 | if "abduction" in id: 295 | inference_type = "abduction" 296 | elif "deduction" in id: 297 | inference_type = "deduction" 298 | 299 | # 2022-05-26: Amazon addition. 300 | depth_of_proof = goldslot_record['depth_of_proof'] 301 | length_of_proof = goldslot_record['length_of_proof'] 302 | # End of Amazon addition. 303 | 304 | num_distractors = 0 305 | fraction_distractors = 0.0 306 | num_context_sent = len(goldslot_record['meta']['triples']) 307 | if 'distractors' in goldslot_record['meta']: 308 | num_distractors = len(goldslot_record['meta']['distractors']) 309 | fraction_distractors = 1.0 * num_distractors / num_context_sent 310 | 311 | distractor_ids = goldslot_record['meta'].get('distractors', []) 312 | pred_to_gold_mapping = metrics['proof-steps']['pred_to_gold_mapping'] 313 | pred_to_gold_mapping_str = "" 314 | for pred_int, gold_int in pred_to_gold_mapping.items(): 315 | pred_to_gold_mapping_str += f"p_{pred_int} -> g_{gold_int} ;; " 316 | diagnostics_tsv.write(f"{id}" 317 | f"\t{sent_str}" 318 | f"\t{goldslot_record['question']}" 319 | f"\t{goldslot_record['answer']}" 320 | f"\t{goldslot_record['hypothesis']}" 321 | f"\t{gold_proof_str_to_output}" 322 | f"\t{pred_proof_str_to_output}" 323 | f"\t{' ;; '.join(metrics['proof-steps']['sentences_pred_aligned'])}" 324 | f"\t{pred_to_gold_mapping_str}" 325 | f"\t{metrics['proof-leaves']['P']*100}" 326 | f"\t{metrics['proof-leaves']['R']*100}" 327 | f"\t{metrics['proof-leaves']['F1']*100}" 328 | f"\t{metrics['proof-leaves']['acc']*100}" 329 | f"\t{metrics['proof-steps']['F1']*100}" 330 | f"\t{metrics['proof-steps']['acc']*100}" 331 | f"\t{metrics['proof-intermediates']['BLEURT_P']*100}" 332 | f"\t{metrics['proof-intermediates']['BLEURT_R']*100}" 333 | f"\t{metrics['proof-intermediates']['BLEURT_F1']*100}" 334 | f"\t{metrics['proof-intermediates']['BLEURT']}" 335 | f"\t{metrics['proof-intermediates']['BLEURT_acc']*100}" 336 | f"\t{metrics['proof-intermediates']['fraction_perfect_align']*100}" 337 | f"\t{metrics['proof-overall']['acc']*100}" 338 | f"\t{num_distractors}" 339 | f"\t{num_context_sent}" 340 | f"\t{fraction_distractors}" 341 | f"\t{', '.join(distractor_ids)}" 342 | f"\t{depth_of_proof}" 343 | f"\t{length_of_proof}" 344 | "\n") 345 | 346 | scores.append(pred) 347 | 348 | print("\n=================\n" 349 | "Percentage recall per gold proof depth\n" 350 | "Gold_proof_depth\t#Gold answers\t#Correct predictions\t%accuracy (recall)\t%Gold answers\t%Correct Predictions") 351 | 352 | print(f"=========================") 353 | print(f"num_dev_answers:{num_dev_answers}") 354 | print(f"num_dev_answers_seen_in_train_context:{num_dev_answers_seen_in_train_context}") 355 | print(f"num_dev_answers_seen_in_train_answers:{num_dev_answers_seen_in_train_answers}") 356 | 357 | return scores 358 | 359 | 360 | # Sample command 361 | # python multi_angle/run_scorer_mf_all_at_once.py --angle_data_dir /Users/bhavanad/research_data/ruletaker/missing_facts/data/angles/OWA_d3_run3 --output_dir /Users/bhavanad/research_data/ruletaker/missing_facts/data/scorings/OWA_d3_run2/OWA_d3_run1_on_d3 --slot_root_dir /Users/bhavanad/research_data/ruletaker/missing_facts/data/slots/ --slot_data_dir OWA_d3_run3-slots --split test --prediction_file /Users/bhavanad/research_data/ruletaker/missing_facts/data/predictions/OWA_d3_T5large/OWA_d3_run1.on_d3.15k.pred.test.tsv 362 | def main(args): 363 | prediction_files = args.prediction_file 364 | if "," in prediction_files: 365 | prediction_files = prediction_files.split(",") 366 | elif os.path.isdir(prediction_files): 367 | dir_path = prediction_files 368 | prediction_files = [f"{dir_path}/{f}" for f in os.listdir(prediction_files) if re.match(r'.*_predictions', f)] 369 | else: 370 | prediction_files = glob.glob(prediction_files) 371 | 372 | 373 | prediction_files.sort() 374 | # 2022-05-26: Amazon addition. 375 | root_dir = "entailment_bank/data/processed_data" 376 | # End of Amazon addition. 377 | angle_data_dir = f"{root_dir}/angles/{args.task}/" 378 | slot_data_dir = f"{root_dir}/slots/{args.task}-slots/" 379 | angle_base_name = os.path.basename(angle_data_dir) 380 | slot_file = os.path.join(slot_data_dir, args.split + '.jsonl') 381 | if not os.path.exists(slot_file): 382 | if args.split == 'val' and os.path.exists(os.path.join(slot_data_dir, "dev.jsonl")): 383 | slot_file = os.path.join(slot_data_dir, "dev.jsonl") 384 | else: 385 | raise ValueError(f"Slot data file {slot_file} does not exist!") 386 | predictions_jsonl_format = True 387 | for prediction_file in prediction_files: 388 | if not prediction_file.endswith(".jsonl"): 389 | predictions_jsonl_format = False 390 | angle_file = None 391 | # If predictions not in jsonl format, need angle data to get correct ids and angles 392 | if not predictions_jsonl_format: 393 | angle_file = os.path.join(angle_data_dir, args.split + '.jsonl') 394 | if not os.path.exists(angle_file): 395 | if args.split == 'val' and os.path.exists(os.path.join(angle_data_dir, "dev.jsonl")): 396 | slot_file = os.path.join(angle_data_dir, "dev.jsonl") 397 | else: 398 | raise ValueError(f"Angle data file {angle_file} does not exist!") 399 | if not os.path.exists(args.output_dir): 400 | os.makedirs(args.output_dir) 401 | 402 | logger.info("Scoring the following files: %s", prediction_files) 403 | all_metrics_aggregated = {} 404 | 405 | split = args.split 406 | output_dir = args.output_dir 407 | bleurt_checkpoint = args.bleurt_checkpoint 408 | 409 | sys.argv = sys.argv[:1] 410 | 411 | for prediction_file in prediction_files: 412 | if not os.path.exists(prediction_file): 413 | logger.warning(f" File not found: {prediction_file}") 414 | continue 415 | score_file_base = f"scores-{split}" 416 | score_file = os.path.join(output_dir, score_file_base) 417 | logger.info(f"***** Scoring predictions in {prediction_file} *****") 418 | logger.info(f" Gold data from: {slot_file}") 419 | logger.info(f" Full output in: {score_file}") 420 | 421 | # 2022-05-26: Amazon addition. 422 | scores = score_predictions( 423 | args=args, 424 | predictions_file=prediction_file, 425 | score_file=score_file, 426 | gold_file=slot_file, 427 | angle_file=angle_file, 428 | dataset=angle_base_name, 429 | bleurt_checkpoint=bleurt_checkpoint, 430 | task_name=args.task) 431 | # End of Amazon addition. 432 | collated = collate_scores(scores) 433 | all_metrics_aggregated[score_file] = collated['metrics_aggregated'] 434 | logger.info(f" Aggregated metrics:") 435 | for key, val in collated['metrics_aggregated'].items(): 436 | logger.info(f" {key}: {val}") 437 | print(f"\n======================") 438 | colmns_str = '\t'.join([ 439 | # 'leave-P', 'leave-R', 440 | 'leave-F1', 'leaves-Acc', 441 | 'steps-F1', 'steps-Acc', 442 | 'int-BLEURT-F1', 'int-BLEURT-Acc', 443 | #'int-BLEURT_align', 'int-BLEURT-Acc_align', 444 | 'overall-Acc', 445 | # 'overall-Acc_align','int-fraction-align' 446 | ]) 447 | print(f"collated:{collated['metrics_aggregated']}") 448 | aggr_metrics = collated['metrics_aggregated']['QAHC->P'] 449 | metrics_str = '\t'.join([ 450 | # prediction_file, 451 | # str(round(aggr_metrics['proof-leaves']['P']*100.0, 2)), 452 | # str(round(aggr_metrics['proof-leaves']['R']*100.0, 2)), 453 | str(round(aggr_metrics['proof-leaves']['F1']*100.0, 2)), 454 | str(round(aggr_metrics['proof-leaves']['acc']*100.0, 2)), 455 | str(round(aggr_metrics['proof-steps']['F1']*100.0, 2)), 456 | str(round(aggr_metrics['proof-steps']['acc']*100.0, 2)), 457 | str(round(aggr_metrics['proof-intermediates']['BLEURT_F1'] * 100.0, 2)), 458 | str(round(aggr_metrics['proof-intermediates']['BLEURT_acc'] * 100.0, 2)), 459 | #str(round(aggr_metrics['proof-intermediates']['BLEURT_perfect_align'], 2)), 460 | #str(round(aggr_metrics['proof-intermediates']['BLEURT_acc_perfect_align']*100.0,2)), 461 | str(round(aggr_metrics['proof-overall']['acc']*100.0, 2)), 462 | #str(round(aggr_metrics['proof-overall']['acc_perfect_align']*100.0, 2)), 463 | #str(round(aggr_metrics['proof-intermediates']['fraction_perfect_align'] * 100.0, 2)), 464 | ]) 465 | print(f"{colmns_str}") 466 | print(f"{metrics_str}") 467 | save_json(f"{score_file}.metrics.json", all_metrics_aggregated) 468 | # 2022-05-26: Amazon addition. 469 | return all_metrics_aggregated 470 | # End of Amazon addition. 471 | 472 | if __name__ == "__main__": 473 | args = get_args() 474 | main(args) -------------------------------------------------------------------------------- /src/entailment_bank/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/src/entailment_bank/utils/__init__.py -------------------------------------------------------------------------------- /src/entailment_bank/utils/angle_utils.py: -------------------------------------------------------------------------------- 1 | # Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 License. 2 | # Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | from typing import Dict 5 | import json 6 | import logging 7 | import pickle 8 | import os 9 | import random 10 | import re 11 | 12 | from transformers import BartTokenizer, T5Tokenizer 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | DEFAULT_SLOT_FORMAT = { 17 | "slot": "$SLOT$", 18 | "assign": " = ", 19 | "separator": " ; ", 20 | "missing_value": "N/A" 21 | } 22 | 23 | SLOT_SHORTFORMS = {"Q": "question", "C": "context", "A": "answer", "E": "explanation", 24 | "M": "mcoptions", "R": "rationale", "P": "proof", 25 | "O": "original_question", 26 | "H": "hypothesis", 27 | "F": "full_text_proof", 28 | "V": "valid" 29 | } 30 | 31 | 32 | def save_jsonl(file_name, data): 33 | with open(file_name, 'w') as file: 34 | for d in data: 35 | file.write(json.dumps(d)) 36 | file.write("\n") 37 | 38 | 39 | def load_jsonl(file_name): 40 | with open(file_name, 'r') as file: 41 | return [json.loads(line.strip()) for line in file] 42 | 43 | 44 | def save_json(file_name, data): 45 | with open(file_name, 'w') as file: 46 | file.write(json.dumps(data)) 47 | 48 | 49 | ### From https://github.com/huggingface/transformers/blob/master/examples/rag/utils.py 50 | 51 | def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"): 52 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {} 53 | tokenizer.padding_side = padding_side 54 | return tokenizer( 55 | [line], 56 | max_length=max_length, 57 | padding="max_length" if pad_to_max_length else None, 58 | truncation=True, 59 | return_tensors=return_tensors, 60 | add_special_tokens=True, 61 | **extra_kw, 62 | ) 63 | 64 | 65 | def trim_batch( 66 | input_ids, 67 | pad_token_id, 68 | attention_mask=None, 69 | ): 70 | """Remove columns that are populated exclusively by pad_token_id""" 71 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 72 | if attention_mask is None: 73 | return input_ids[:, keep_column_mask] 74 | else: 75 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 76 | 77 | 78 | def pickle_save(obj, path): 79 | """pickle.dump(obj, path)""" 80 | with open(path, "wb") as f: 81 | return pickle.dump(obj, f) 82 | 83 | 84 | def scramble_order(data, keep_last=None): 85 | keep_last = keep_last or [] 86 | last = [] 87 | other = [] 88 | for d in data: 89 | if d in keep_last: 90 | last.append(d) 91 | else: 92 | other.append(d) 93 | random.shuffle(other) 94 | return other + last 95 | 96 | 97 | def scramble_context_sentences(sentences, random_seed=137): 98 | scrambled_sentence_dict = dict() 99 | scrambled_sentence_list = [] 100 | old_to_new_id_map = dict() 101 | sent_ids_orig = list(sentences.keys()) 102 | random.shuffle(sent_ids_orig) 103 | for idx, sent_id_orig in enumerate(sent_ids_orig, start=1): 104 | new_sent_id = "sent" + str(idx) 105 | old_to_new_id_map[sent_id_orig] = new_sent_id 106 | scrambled_sentence_dict[new_sent_id] = sentences[sent_id_orig] 107 | scrambled_sentence_list.append(sentences[sent_id_orig]) 108 | return scrambled_sentence_dict, scrambled_sentence_list, old_to_new_id_map 109 | 110 | 111 | # Turns [['context', 'question','mcoptions'],['explanation', 'answer]] into 'CMQ->AE' 112 | def shortform_angle(angle_full, sort_angle=True, overrides=None): 113 | if angle_full is None: 114 | return "" 115 | if sort_angle: 116 | return "->".join(["".join(sorted([angle[0].upper() for angle in angles])) for angles in angle_full]) 117 | return "->".join(["".join([angle[0].upper() for angle in angles]) for angles in angle_full]) 118 | 119 | def decompose_slots(string, fmt=None): 120 | fmt = fmt or DEFAULT_SLOT_FORMAT 121 | string = string.strip() 122 | no_slot = "PREFIX" 123 | slot_re = re.compile('(?i)'+re.escape(fmt['slot']).replace("SLOT", "(\\w*?)")) 124 | assign_re = re.escape(fmt['assign']).replace('\\ ','\\s*') 125 | separator_re = re.escape(fmt['separator']).replace('\\ ','\\s*') 126 | strip_re = re.compile(f"^({assign_re})?(.*?)({separator_re})?$") 127 | slot_pos = [] 128 | for m in slot_re.finditer(string): 129 | slot_pos.append((m.span(), m.group(1))) 130 | if len(slot_pos) == 0: 131 | return {no_slot: string} 132 | if slot_pos[0][0][0] > 0: 133 | slot_pos = [((0,-1), no_slot)] + slot_pos 134 | res = {} 135 | for idx, (pos, slot_name) in enumerate(slot_pos): 136 | if idx == len(slot_pos) - 1: 137 | value = string[pos[1]+1:] 138 | else: 139 | value = string[pos[1]+1:slot_pos[idx+1][0][0]-1] 140 | m = strip_re.match(value) 141 | if m is not None: 142 | value = m.group(2) 143 | value = value.strip() 144 | if slot_name in res: 145 | value = res[slot_name] + " ~AND~ " + value 146 | res[slot_name] = value 147 | return res 148 | 149 | 150 | def slot_file_to_angles(slot_file, slot_shortforms, 151 | angle_distribution, 152 | split, 153 | full_train_first_angle=False, 154 | meta_fields=None, 155 | id_filter_regex=None, 156 | train_replicas=1, 157 | random_seed=137, **kwparams): 158 | 159 | res = [] 160 | random.seed(random_seed) 161 | if split == 'train': 162 | if full_train_first_angle: 163 | angle_distributions = [angle_distribution[0][0]] + [angle_distribution] * train_replicas 164 | else: 165 | angle_distributions = [angle_distribution] * train_replicas 166 | else: 167 | angle_distributions = angle_distribution[0] 168 | with open(slot_file, 'r') as file: 169 | for line in file: 170 | fields = json.loads(line.strip()) 171 | if id_filter_regex is not None and "id" in fields and not re.match(id_filter_regex, fields['id']): 172 | continue 173 | slot_data = SlotDataInstance(fields) 174 | for ad in angle_distributions: 175 | instance = slot_data.sample_angle_instance(ad, slot_shortforms, **kwparams) 176 | instance.update({"id": fields.get('id', 'NA')}) 177 | res.append(instance) 178 | if meta_fields: 179 | instance['meta'] = {x:fields['meta'][x] for x in meta_fields} 180 | return res 181 | 182 | 183 | ANGLE_SPEC_DEFAULT = {'angle_distribution': None, 184 | 'full_train_first_angle': False, 185 | 'id_filter_regex': None, 186 | 'train_replicas': 1, 187 | 'meta_fields': [], 188 | 'random_seed': 137, 189 | 'keep_last': ['context'], 190 | 'scramble_slots': True, 191 | 'multi_value_sampling': None 192 | } 193 | 194 | 195 | def build_angle_dir(slot_dir, angle_dir, angle_spec, debug_print=2): 196 | if os.path.exists(angle_dir): 197 | raise ValueError(f"Angle data directory {angle_dir} already exist!") 198 | os.makedirs(angle_dir) 199 | angle_spec = {**ANGLE_SPEC_DEFAULT, **angle_spec} 200 | angle_spec_file = os.path.join(angle_dir, "angle_spec.json") 201 | 202 | save_json(angle_spec_file, angle_spec) 203 | made_splits = [] 204 | for split in ['train', 'dev', 'val', 'test']: 205 | slot_file = os.path.join(slot_dir, split+".jsonl") 206 | if os.path.exists(slot_file): 207 | logger.info(f"Creating angle data for {slot_file}") 208 | angle_data = slot_file_to_angles(slot_file, SLOT_SHORTFORMS, split=split, **angle_spec) 209 | if debug_print > 0: 210 | logger.info(f"Sample angle data: {angle_data[:debug_print]}") 211 | angle_file = os.path.join(angle_dir, split+".jsonl") 212 | save_jsonl(angle_file, angle_data) 213 | made_splits.append((split, len(angle_data))) 214 | logger.info(f"Created angle data for splits {made_splits}.") 215 | return made_splits 216 | 217 | 218 | def save_tsv_file(file_name, data): 219 | with open(file_name, "w") as f: 220 | for d in data: 221 | out = "\t".join([s.replace('\n', ' ').replace('\t', ' ') for s in d]) 222 | f.write(out + '\n') 223 | 224 | # Use small_dev = 2000 to save a smaller size-2000 dev set 225 | def convert_angle_dir_tsv(angle_dir, tsv_dir, small_dev=False): 226 | if os.path.exists(tsv_dir): 227 | raise ValueError(f"TSV data directory {tsv_dir} already exist!") 228 | os.makedirs(tsv_dir) 229 | counts = {} 230 | for split in ['train', 'dev', 'val', 'test']: 231 | angle_file = os.path.join(angle_dir, split+".jsonl") 232 | if os.path.exists(angle_file): 233 | logger.info(f"Creating tsv data for {angle_file}") 234 | angle_data = load_jsonl(angle_file) 235 | tsv_data = [[x['input'], x['output']] for x in angle_data] 236 | meta_data =[[x['id'], shortform_angle(x['angle'], sort_angle=False)] for x in angle_data] 237 | if small_dev and split in ['dev', 'val']: 238 | num_dev = small_dev if isinstance(small_dev, int) else 1000 239 | counts[split+"-full"] = len(tsv_data) 240 | save_tsv_file(os.path.join(tsv_dir, split+"-full.tsv"), tsv_data) 241 | save_tsv_file(os.path.join(tsv_dir, "meta-"+split+"-full.tsv"), meta_data) 242 | tsv_data = tsv_data[:num_dev] 243 | meta_data = meta_data[:num_dev] 244 | counts[split] = len(tsv_data) 245 | save_tsv_file(os.path.join(tsv_dir, split + ".tsv"), tsv_data) 246 | save_tsv_file(os.path.join(tsv_dir, "meta-" + split + ".tsv"), meta_data) 247 | save_json(os.path.join(tsv_dir, "counts.json"), counts) 248 | logger.info(f"Created angle data for splits {counts}.") 249 | return counts 250 | 251 | 252 | class SlotDataInstance(): 253 | 254 | def __init__(self, fields: Dict): 255 | self.fields = fields 256 | self.slot_value_sampling = {} 257 | 258 | def get_slot_value(self, slot, default=None, multi_value_sampling=None): 259 | res = self.fields.get(slot, default) 260 | if isinstance(res, list): 261 | if multi_value_sampling is not None and slot in multi_value_sampling: 262 | fn = multi_value_sampling[slot] 263 | if fn == "random": 264 | res = random.choice(res) 265 | elif "random-with" in fn: 266 | other_slots = fn.split("-")[2:] 267 | value_index = -1 268 | for other_slot in other_slots: 269 | if other_slot in self.slot_value_sampling: 270 | value_index = self.slot_value_sampling[other_slot] 271 | if value_index == -1: 272 | value_index = random.choice(range(len(res))) 273 | self.slot_value_sampling[slot] = value_index 274 | res = res[value_index] 275 | else: 276 | raise ValueError(f"Unknown multi_value_sampling function {fn}") 277 | else: 278 | res = res[0] 279 | return res 280 | 281 | def convert_shortform_angle(self, angle, slot_shortforms): 282 | if isinstance(angle, str): 283 | arrowpos = angle.index("->") 284 | lhs = angle[:arrowpos].strip() 285 | rhs = angle[arrowpos + 2:].strip() 286 | lhs = [slot_shortforms[c] for c in lhs] 287 | rhs = [slot_shortforms[c] for c in rhs] 288 | else: 289 | lhs = angle[0] 290 | rhs = angle[1] 291 | missing = [slot for slot in lhs + rhs if slot not in self.fields] 292 | return ((lhs, rhs), missing) 293 | 294 | def make_angle_instance(self, angle, fmt=None, multi_value_sampling=None): 295 | fmt = fmt or DEFAULT_SLOT_FORMAT 296 | lhs = [] 297 | rhs = [] 298 | for slot in angle[1]: 299 | slot_name = fmt['slot'].replace("SLOT", slot) 300 | slot_value = self.get_slot_value(slot, fmt['missing_value'], multi_value_sampling) 301 | lhs.append(slot_name) 302 | rhs.append(f"{slot_name}{fmt['assign']}{slot_value}") 303 | for slot in angle[0]: 304 | slot_name = fmt['slot'].replace("SLOT", slot) 305 | slot_value = self.get_slot_value(slot, fmt['missing_value'], multi_value_sampling) 306 | lhs.append(f"{slot_name}{fmt['assign']}{slot_value}") 307 | return {"input": fmt['separator'].join(lhs), 308 | "output": fmt['separator'].join(rhs), 309 | "angle": angle} 310 | 311 | def sample_angle_instance(self, angle_distribution, slot_shortforms, 312 | scramble_slots=True, 313 | keep_last=None, 314 | missing_retries=100, 315 | fmt=None, 316 | multi_value_sampling=None): 317 | keep_last = keep_last or ["context"] 318 | fmt = fmt or DEFAULT_SLOT_FORMAT 319 | if isinstance(angle_distribution, str): 320 | angle, missing = self.convert_shortform_angle(angle_distribution, slot_shortforms) 321 | else: 322 | angle_distribution = tuple(x.copy() for x in angle_distribution) 323 | retries = missing_retries 324 | missing = [1] 325 | while retries >= 0 and len(missing) > 0 and len(angle_distribution[0]) > 0: 326 | retries -= 1 327 | angle_shortform = random.choices(*angle_distribution)[0] 328 | angle, missing = self.convert_shortform_angle(angle_shortform, slot_shortforms) 329 | if len(missing) > 0: 330 | ind = angle_distribution[0].index(angle_shortform) 331 | angle_distribution[0].pop(ind) 332 | angle_distribution[1].pop(ind) 333 | if scramble_slots: 334 | angle = [scramble_order(a, keep_last) for a in angle] 335 | res = self.make_angle_instance(angle, fmt, multi_value_sampling) 336 | return res 337 | 338 | 339 | def formatting(a_set): 340 | return '=CONCATENATE("' + ('" ; CHAR(10); "').join(a_set) + '"' + ')' if len(a_set) > 0 else "" 341 | 342 | 343 | def get_selected_str(data_dict, selected_keys, format=False): 344 | selected=[] 345 | for key in selected_keys: 346 | selected.append(f"{key}: {data_dict[key]['text']} ") 347 | if format: 348 | selected_str = formatting(selected) 349 | else: 350 | selected_str = ''.join(selected) 351 | return selected_str 352 | 353 | 354 | def get_selected_keys(data_dict, selected_keys, format=False): 355 | selected=[] 356 | for key in selected_keys: 357 | selected.append(f"{key}: {data_dict[key]} ") 358 | if format: 359 | selected_str = formatting(selected) 360 | else: 361 | selected_str = ''.join(selected) 362 | return selected_str -------------------------------------------------------------------------------- /src/entailment_bank/utils/entail_trees_utils.py: -------------------------------------------------------------------------------- 1 | # Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 License. 2 | # Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | from collections import defaultdict 5 | from copy import deepcopy 6 | from nltk.stem import PorterStemmer 7 | from nltk.tokenize import word_tokenize 8 | import random 9 | import re 10 | 11 | # 2022-05-26: Amazon addition. 12 | from entailment_bank.utils.proof_utils import parse_lisp, decouple_proof_struct_ints, polish_notation_to_proof 13 | # End of Amazon addition. 14 | 15 | # Count phrase appearing in a reference string, making sure word boundaries are respected 16 | def count_phrase_matches(phrase, reference): 17 | regex = "(\\b|(?!\\w))" + re.escape(phrase) + "((?": 63 | new_dependency = proof[2] 64 | for k, v in dependencies.items(): 65 | dependencies[k] = v + [new_dependency] 66 | if new_dependency not in dependencies: 67 | dependencies[new_dependency] = [] 68 | return get_parents_recursive(proof[0], dependencies) 69 | elif len(proof) == 0: 70 | return dependencies 71 | elif len(proof) == 1: 72 | return get_parents_recursive(proof[0], dependencies) 73 | else: 74 | all_dependencies = [] 75 | for sub_proof in proof: 76 | dep1 = deepcopy(dependencies) 77 | get_parents_recursive(sub_proof, dep1) 78 | all_dependencies.append(dep1) 79 | dd = defaultdict(list) 80 | for d in all_dependencies: 81 | for key, value in d.items(): 82 | dd[key] += value 83 | return dd 84 | 85 | 86 | def get_intermediate_dependencies(proof): 87 | list_proof = parse_lisp(proof) 88 | dependencies = {} 89 | res = get_parents_recursive(list_proof, dependencies) 90 | return res 91 | 92 | 93 | def get_stripped_recursive(proof, stripped): 94 | if isinstance(proof, str): 95 | return proof 96 | if len(proof) == 3 and proof[1] == "->": 97 | stripped[proof[2]] = [get_stripped_recursive(x, stripped) for x in proof[0]] 98 | return proof[2] 99 | elif len(proof) == 1: 100 | return get_stripped_recursive(proof[0], stripped) 101 | else: 102 | raise ValueError(f"Nonsense found: {proof}") 103 | 104 | 105 | def get_core_proofs(proof): 106 | list_proof = parse_lisp(proof) 107 | stripped = {} 108 | get_stripped_recursive(list_proof, stripped) 109 | return stripped 110 | 111 | 112 | def remove_distractors(qdata, num_removed_distractors=0): 113 | if num_removed_distractors < 1: 114 | return qdata 115 | else: 116 | new_q = deepcopy(qdata) 117 | distractors = new_q['meta']['distractors'] 118 | sentences_removed = list(reversed(distractors))[:num_removed_distractors] 119 | sentences_remaining = {k: v for k, v in new_q['meta']['triples'].items() if k not in sentences_removed} 120 | new_distractors = [k for k in new_q['meta']['distractors'] if k not in sentences_removed] 121 | sentence_map = {k: f"sent{i + 1}" for i, k in enumerate(sentences_remaining.keys())} 122 | new_q['meta']['triples'] = sentences_remaining 123 | new_q['meta']['distractors'] = new_distractors 124 | new_q = remap_sentences(new_q, sentence_map) 125 | return new_q 126 | 127 | 128 | # Break down an entailment tree data instance into one-step inference steps 129 | def make_inference_steps(qdata, rescramble_sentences=False, num_removed_distractors=0): 130 | proof = qdata['meta']['lisp_proof'] 131 | core_proofs = list(get_core_proofs(proof).items()) 132 | random.shuffle(core_proofs) 133 | sentences = qdata['meta']['triples'].copy() 134 | intermediates = qdata['meta']['intermediate_conclusions'] 135 | q_id = qdata['id'] 136 | hypothesis_id = qdata['meta']['hypothesis_id'] 137 | res = [] 138 | while len(core_proofs) > 0: 139 | selected = None 140 | for proof in core_proofs: 141 | selected = proof 142 | for dep in proof[1]: 143 | if 'int' in dep: 144 | selected = None 145 | if selected is not None: 146 | break 147 | if selected is None: 148 | raise ValueError(f"No resolved proofs in {core_proofs}") 149 | new_res = selected[0] 150 | if new_res == hypothesis_id: 151 | new_res_text = "hypothesis" 152 | assert len(core_proofs) == 1 153 | else: 154 | new_res_text = "int1: " + intermediates[new_res] 155 | 156 | new_proof = selected[1] 157 | new_proof_text = " & ".join(new_proof) + " -> " + new_res_text 158 | new_context = " ".join([f"{k}: {v}" for k, v in sentences.items()]) 159 | new_q = deepcopy(qdata) 160 | new_q['id'] = f"{q_id}-add{len(res)}" 161 | new_q['meta'] = {'triples': sentences.copy(), 162 | 'distractors': new_q['meta'].get('distractors', [])} 163 | new_q['proof'] = new_proof_text 164 | new_q['meta']['hypothesis_id'] = "int1" 165 | new_q['depth_of_proof'] = 1 166 | new_q['length_of_proof'] = 1 167 | new_q['context'] = new_context 168 | if rescramble_sentences: 169 | new_q = scramble_sentences_in_entail_tree_q(new_q) 170 | if num_removed_distractors > 0: 171 | new_q = remove_distractors(new_q, num_removed_distractors) 172 | 173 | res.append(new_q) 174 | new_sentence = "sent" + str(len(sentences) + 1) 175 | sentences[new_sentence] = intermediates[selected[0]] 176 | new_core_proofs = [] 177 | for proof in core_proofs: 178 | if proof[0] == new_res: 179 | continue 180 | new_parents = [] 181 | for parent in proof[1]: 182 | new_parents.append(new_sentence if parent == new_res else parent) 183 | new_core_proofs.append((proof[0], new_parents)) 184 | core_proofs = new_core_proofs 185 | return res 186 | 187 | 188 | def normalize_sentence(sent): 189 | return sent.replace(" ", " ").replace(".", "").replace('\n', '').replace("( ", "").replace(" )", "").lower().strip() 190 | 191 | 192 | def get_entailment_steps_from_polish_proof(polish_proof): 193 | print(f"POLISH_PROOF:{polish_proof}") 194 | pn_without_ints, int_dict = decouple_proof_struct_ints(polish_proof) 195 | print(f"POLISH_PROOF without INTS:{pn_without_ints}") 196 | try: 197 | recursive_proof = polish_notation_to_proof(pn_without_ints)[0] 198 | except: 199 | return [] 200 | print(f"recursive_proof:{recursive_proof}") 201 | return get_entailment_steps_from_recursive_proof(recursive_proof, int_dict) 202 | 203 | 204 | def append_list(list_obj, to_be_added): 205 | if isinstance(to_be_added, list): 206 | for add_item in to_be_added: 207 | if add_item != '->': 208 | print(f"\t**********adding {add_item} to lhs") 209 | list_obj += append_list(list_obj, add_item) 210 | else: 211 | if to_be_added != '->': 212 | print(f"\t**********adding {to_be_added} to lhs") 213 | list_obj.append(to_be_added) 214 | return list_obj 215 | 216 | 217 | def get_entailment_steps_from_recursive_proof(recursive_proof, int_dict): 218 | entailment_steps = [] 219 | print(f"======Calling recursion: recursive_proof:{recursive_proof}") 220 | 221 | lhs = recursive_proof[0] 222 | rhs = recursive_proof[2] 223 | rhs_str = int_dict.get(rhs, "") 224 | print(f"======lhs:{lhs}") 225 | print(f"======rhs:{rhs}") 226 | lhs_ids = [] 227 | if isinstance(lhs, str): 228 | append_list(lhs_ids, lhs) 229 | else: 230 | for l in lhs: 231 | print(f"\tl:{l}") 232 | if isinstance(l, str): 233 | append_list(lhs_ids, l) 234 | else: 235 | if '->' in l: 236 | entailment_steps += get_entailment_steps_from_recursive_proof(l, int_dict) 237 | append_list(lhs_ids, l[2]) 238 | else: 239 | print(f"\t^^^^lhs:{lhs}") 240 | print(f"\t^^^^rhs:{rhs}") 241 | print(f"\t^^^^l{l}") 242 | print(f"\t^^^^^lhs_ids{lhs_ids}") 243 | for l_part in l: 244 | if isinstance(l_part, list): 245 | if '->' in l_part: 246 | entailment_steps += get_entailment_steps_from_recursive_proof(l_part, int_dict) 247 | append_list(lhs_ids, l_part[2]) 248 | else: 249 | append_list(lhs_ids, l_part) 250 | else: 251 | append_list(lhs_ids, l_part) # for cases like ['sent20', 'sent8'] in [['sent14', ['sent20', 'sent8']], '->', 'int1'] 252 | print(f"\tlhs_ids{lhs_ids}") 253 | 254 | print(f"\tlhs_ids:{lhs_ids}") 255 | print(f"\trhs:{rhs}") 256 | lhs_str = ' & '.join(lhs_ids) 257 | print(f"++++Adding step: {lhs_str} -> {rhs}: {rhs_str}") 258 | entailment_steps.append(f"{lhs_str} -> {rhs}: {rhs_str}") 259 | return entailment_steps -------------------------------------------------------------------------------- /src/entailment_bank/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 License. 2 | # Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | import collections 5 | import copy 6 | import json 7 | import re 8 | import rouge 9 | import string 10 | 11 | # 2022-05-26: Amazon addition. 12 | from entailment_bank.utils.angle_utils import decompose_slots 13 | from entailment_bank.utils.proof_utils import parse_entailment_step_proof, \ 14 | align_conclusions_across_proofs, rewrite_aligned_proof, score_sentence_overlaps, ruletaker_inferences_scores, \ 15 | parse_entailment_step_proof_remove_ids, rewrite_aligned_proof_noids 16 | # End of Amazon addition. 17 | 18 | INCLUDE_NLG_EVAL = False 19 | 20 | if INCLUDE_NLG_EVAL: 21 | from nlgeval import NLGEval 22 | # Initialize: 23 | nlgeval = NLGEval(no_skipthoughts=True, no_glove=True) 24 | 25 | def load_slot_data_by_id(slot_file): 26 | res = {} 27 | with open(slot_file) as file: 28 | for line in file: 29 | data = json.loads(line.strip()) 30 | res[data['id']] = data 31 | return res 32 | 33 | USE_GOOGLE_ROUGE_CODE = False 34 | 35 | if USE_GOOGLE_ROUGE_CODE: 36 | # The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt 37 | from rouge_score import rouge_scorer 38 | rouge_scoring_fun = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) 39 | 40 | def rouge_google_metric_max_over_ground_truths(prediction, ground_truths): 41 | if len(ground_truths) == 0: 42 | return 0 43 | scores_for_ground_truths = [] 44 | for ground_truth in ground_truths: 45 | score = rouge_scoring_fun.score(prediction, ground_truth) 46 | scores_for_ground_truths.append(score['rougeL'].fmeasure) 47 | return max(scores_for_ground_truths) 48 | 49 | 50 | def score_string_similarity(str1, str2): 51 | if str1 == str2: 52 | return 3.0 # Better than perfect token match 53 | str1 = fix_t5_unk_characters(str1) 54 | str2 = fix_t5_unk_characters(str2) 55 | if str1 == str2: 56 | return 2.0 57 | str1 = str1.lower() 58 | str2 = str2.lower() 59 | if str1 == str2: 60 | return 1.5 61 | if " " in str1 or " " in str2: 62 | str1_split = str1.split(" ") 63 | str2_split = str2.split(" ") 64 | overlap = list(set(str1_split) & set(str2_split)) 65 | return len(overlap) / max(len(str1_split), len(str2_split)) 66 | else: 67 | return 0.0 68 | 69 | 70 | def replace_punctuation(str): 71 | return str.replace("\"", "").replace("'", "") 72 | 73 | # remove characters tokenized as unknown (\u2047) character 74 | T5_GOOD_CHARS=[32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 75 | 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 76 | 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 77 | 84, 85, 86, 87, 88, 89, 90, 91, 93, 95, 97, 98, 99, 100, 101, 102, 78 | 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 79 | 117, 118, 119, 120, 121, 122, 124, 163, 171, 173, 174, 176, 187, 201, 80 | 206, 220, 223, 224, 225, 226, 228, 231, 232, 233, 234, 238, 243, 244, 81 | 246, 249, 251, 252, 259, 351, 355, 537, 539, 1072, 1074, 1076, 1077, 82 | 1080, 1082, 1083, 1084, 1085, 1086, 1088, 1089, 1090, 1091, 8211, 83 | 8212, 8216, 8217, 8220, 8221, 8222, 8226, 8242, 8364, 9601] 84 | T5_BAD_REGEX = re.compile("[^"+re.escape(".".join([chr(x) for x in T5_GOOD_CHARS]))+"]") 85 | 86 | 87 | def fix_t5_unk_characters(str): 88 | return re.sub(" {2,}", " ", re.sub(T5_BAD_REGEX, " ", str)) 89 | 90 | # Rouge evaluator copied from UnifiedQA 91 | rouge_l_evaluator = rouge.Rouge( 92 | metrics=["rouge-l"], 93 | max_n=4, 94 | limit_length=True, 95 | length_limit=100, 96 | length_limit_type="words", 97 | apply_avg=True, 98 | apply_best=True, 99 | alpha=0.5, 100 | weight_factor=1.2, 101 | stemming=True, 102 | ) 103 | 104 | def rouge_l(p, g): 105 | return rouge_l_evaluator.get_scores(p, g) 106 | 107 | def rouge_metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 108 | scores_for_ground_truths = [] 109 | for ground_truth in ground_truths: 110 | score = metric_fn(prediction, [ground_truth]) 111 | scores_for_ground_truths.append(score) 112 | if isinstance(score, dict) and "rouge-l" in score: 113 | max_score = copy.deepcopy(score) 114 | max_score["rouge-l"]["f"] = round( 115 | max([score["rouge-l"]["f"] for score in scores_for_ground_truths]), 2 116 | ) 117 | max_score["rouge-l"]["p"] = round( 118 | max([score["rouge-l"]["p"] for score in scores_for_ground_truths]), 2 119 | ) 120 | max_score["rouge-l"]["r"] = round( 121 | max([score["rouge-l"]["r"] for score in scores_for_ground_truths]), 2 122 | ) 123 | return max_score 124 | else: 125 | return round(max(scores_for_ground_truths), 2) 126 | 127 | def nlg_string_similarities(prediction, gold_strings, normalize=True): 128 | if gold_strings is None: 129 | res = {"missing_gold": 1} 130 | if prediction is None: 131 | res['missing'] = 1 132 | return res 133 | if prediction is None: 134 | return {"missing": 1} 135 | if normalize: 136 | gold_strings = [fix_t5_unk_characters(x.lower()) for x in gold_strings] 137 | prediction = fix_t5_unk_characters(prediction.lower()) 138 | # gold_strings = gold_strings[:1] 139 | res = {} 140 | if INCLUDE_NLG_EVAL: 141 | res = nlgeval.compute_individual_metrics(gold_strings, prediction) 142 | if 'CIDEr' in res: 143 | del res['CIDEr'] 144 | rouge_l_score = rouge_metric_max_over_ground_truths(rouge_l, prediction, gold_strings) 145 | res['ROUGE_L_F'] = rouge_l_score["rouge-l"]["f"] 146 | if USE_GOOGLE_ROUGE_CODE: 147 | res['ROUGE_L_G'] = rouge_google_metric_max_over_ground_truths(prediction, gold_strings) 148 | res['pred'] = prediction 149 | res['gold'] = gold_strings 150 | if not gold_strings[0] and not prediction: 151 | res['ROUGE_L_F'] = 1.0 152 | return res 153 | 154 | 155 | def nlg_string_similarities_intermediates_with_F1(prediction_to_aligned_gold: dict, 156 | id_to_int_gold: dict(), 157 | id_to_int_pred: dict(), 158 | prediction_to_perfect_match: dict, 159 | normalize=True, 160 | bleurt_scorer=None, 161 | bleurt_threshold=0.28 #for original BLEURT 162 | ): 163 | num_perfect_aligns = 0 164 | 165 | sum_rouge_l_score = 0.0 166 | sum_perfect_align_rouge_l_score = 0.0 167 | sum_bleurt_score = 0.0 168 | sum_perfect_align_bleurt_score = 0.0 169 | num_bleurt_correct = 0.0 170 | num_perfect_align_bleurt_correct = 0.0 171 | 172 | preds = [] 173 | golds = [] 174 | res = {} 175 | # print(f"prediction_to_aligned_gold:{prediction_to_aligned_gold}") 176 | pred_precise = set() 177 | gold_covered = set() 178 | for prediction, gold in prediction_to_aligned_gold.items(): 179 | preds.append(prediction) 180 | golds.append(gold) 181 | gold_strings = [gold] 182 | if normalize: 183 | gold_strings = [fix_t5_unk_characters(x.lower()) for x in gold_strings] 184 | prediction_norm = fix_t5_unk_characters(prediction.lower()) 185 | #res = nlgeval.compute_individual_metrics(gold_strings, prediction) 186 | #if 'CIDEr' in res: 187 | # del res['CIDEr'] 188 | rouge_l_score = rouge_metric_max_over_ground_truths(rouge_l, prediction_norm, gold_strings) 189 | if bleurt_scorer: 190 | bleurt_score = bleurt_scorer.score(references=gold_strings, candidates=[prediction_norm], batch_size=1)[0] 191 | else: 192 | bleurt_score = -1 193 | # bleurt_score = max(0.0, min(1.0, unnorm_bleurt_score)) 194 | # bleurt_score = [0.0] 195 | sum_rouge_l_score += rouge_l_score["rouge-l"]["f"] 196 | sum_bleurt_score += bleurt_score 197 | # if gold == "": 198 | # print(f"@@@@@@@@@@@@@@@@@@@@@@@PREDICTION:{prediction}\tGOLD:{gold}\tBLEURT:{bleurt_score}") 199 | # print(f"bleurt_score:{unnorm_bleurt_score}\t{bleurt_score}") 200 | 201 | if bleurt_score >= bleurt_threshold: 202 | if gold != "": 203 | num_bleurt_correct += 1 204 | pred_precise.add(prediction) 205 | gold_covered.add(gold) 206 | 207 | # print(f"prediction_to_perfect_match:{prediction_to_perfect_match}") 208 | # print(f"prediction:{prediction}") 209 | if prediction_to_perfect_match.get(prediction, False): 210 | sum_perfect_align_rouge_l_score += rouge_l_score["rouge-l"]["f"] 211 | sum_perfect_align_bleurt_score += bleurt_score 212 | num_perfect_aligns += 1 213 | if bleurt_score >= bleurt_threshold: 214 | num_perfect_align_bleurt_correct += 1 215 | 216 | bleurt_P = len(pred_precise)/max(1, len(id_to_int_pred)) 217 | bleurt_R = len(gold_covered)/max(1, len(id_to_int_gold)) 218 | if (bleurt_P + bleurt_R) == 0.0: 219 | bleurt_F1 = 0.0 220 | else: 221 | bleurt_F1 = (2 * bleurt_P * bleurt_R) / (bleurt_P + bleurt_R) 222 | # print(f"@@@@@@@@@@@@@@@@@@@@@@@pred_precise:{len(pred_precise)}\tpred:{len(id_to_int_pred)}\tgold_covered:{len(gold_covered)}\tgold:{len(id_to_int_gold)}") 223 | # print(f"@@@@@@@@@@@@@@@@@@@@@@@bleurt_P:{bleurt_P}\tbleurt_R:{bleurt_R}\tbleurt_F1:{bleurt_F1}") 224 | res['ROUGE_L_F'] = sum_rouge_l_score / max(1, len(prediction_to_aligned_gold.keys())) 225 | res['ROUGE_L_F_perfect_align'] = sum_perfect_align_rouge_l_score / max(1, num_perfect_aligns) 226 | res['BLEURT'] = sum_bleurt_score / max(1, len(prediction_to_aligned_gold.keys())) 227 | res['BLEURT_perfect_align'] = sum_perfect_align_bleurt_score / max(1, num_perfect_aligns) 228 | res['BLEURT_P'] = bleurt_P 229 | res['BLEURT_R'] = bleurt_R 230 | res['BLEURT_F1'] = bleurt_F1 231 | # res['BLEURT_acc'] = int(num_bleurt_correct == len(prediction_to_aligned_gold.keys())) 232 | res['BLEURT_acc'] = int(bleurt_F1==1) 233 | res['BLEURT_acc_perfect_align'] = int(num_perfect_align_bleurt_correct == num_perfect_aligns) 234 | res['fraction_perfect_align'] = num_perfect_aligns/max(1, len(prediction_to_aligned_gold.keys())) 235 | res['pred'] = preds 236 | res['gold'] = golds 237 | # print(f"res:{res}") 238 | return res 239 | 240 | 241 | def nlg_string_similarities_intermediates(prediction_to_aligned_gold: dict, 242 | prediction_to_perfect_match: dict, 243 | normalize=True, 244 | bleurt_scorer=None, 245 | bleurt_threshold=0.28 #for original BLEURT 246 | ): 247 | num_perfect_aligns = 0 248 | 249 | sum_rouge_l_score = 0.0 250 | sum_perfect_align_rouge_l_score = 0.0 251 | sum_bleurt_score = 0.0 252 | sum_perfect_align_bleurt_score = 0.0 253 | num_bleurt_correct = 0.0 254 | num_perfect_align_bleurt_correct = 0.0 255 | 256 | preds = [] 257 | golds = [] 258 | res = {} 259 | # print(f"prediction_to_aligned_gold:{prediction_to_aligned_gold}") 260 | for prediction, gold in prediction_to_aligned_gold.items(): 261 | preds.append(prediction) 262 | golds.append(gold) 263 | gold_strings = [gold] 264 | if normalize: 265 | gold_strings = [fix_t5_unk_characters(x.lower()) for x in gold_strings] 266 | prediction = fix_t5_unk_characters(prediction.lower()) 267 | #res = nlgeval.compute_individual_metrics(gold_strings, prediction) 268 | #if 'CIDEr' in res: 269 | # del res['CIDEr'] 270 | rouge_l_score = rouge_metric_max_over_ground_truths(rouge_l, prediction, gold_strings) 271 | if bleurt_scorer: 272 | bleurt_score = bleurt_scorer.score(gold_strings, [prediction], batch_size=1)[0] 273 | else: 274 | bleurt_score = -1 275 | # bleurt_score = max(0.0, min(1.0, unnorm_bleurt_score)) 276 | # bleurt_score = [0.0] 277 | sum_rouge_l_score += rouge_l_score["rouge-l"]["f"] 278 | sum_bleurt_score += bleurt_score 279 | # print(f"bleurt_score:{unnorm_bleurt_score}\t{bleurt_score}") 280 | 281 | if bleurt_score >= bleurt_threshold: 282 | num_bleurt_correct += 1 283 | # print(f"prediction_to_perfect_match:{prediction_to_perfect_match}") 284 | # print(f"prediction:{prediction}") 285 | if prediction_to_perfect_match[prediction]: 286 | sum_perfect_align_rouge_l_score += rouge_l_score["rouge-l"]["f"] 287 | sum_perfect_align_bleurt_score += bleurt_score 288 | num_perfect_aligns += 1 289 | if bleurt_score >= bleurt_threshold: 290 | num_perfect_align_bleurt_correct += 1 291 | 292 | res['ROUGE_L_F'] = sum_rouge_l_score / max(1, len(prediction_to_aligned_gold.keys())) 293 | res['ROUGE_L_F_perfect_align'] = sum_perfect_align_rouge_l_score / max(1, num_perfect_aligns) 294 | res['BLEURT'] = sum_bleurt_score / max(1, len(prediction_to_aligned_gold.keys())) 295 | res['BLEURT_perfect_align'] = sum_perfect_align_bleurt_score / max(1, num_perfect_aligns) 296 | res['BLEURT_acc'] = int(num_bleurt_correct == len(prediction_to_aligned_gold.keys())) 297 | res['BLEURT_acc_perfect_align'] = int(num_perfect_align_bleurt_correct == num_perfect_aligns) 298 | res['fraction_perfect_align'] = num_perfect_aligns/max(1, len(prediction_to_aligned_gold.keys())) 299 | res['pred'] = preds 300 | res['gold'] = golds 301 | # print(f"res:{res}") 302 | return res 303 | 304 | 305 | def squad_normalize_answer(s): 306 | """Lower text and remove punctuation, articles and extra whitespace.""" 307 | def remove_articles(text): 308 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 309 | return re.sub(regex, ' ', text) 310 | def white_space_fix(text): 311 | return ' '.join(text.split()) 312 | def remove_punc(text): 313 | exclude = set(string.punctuation) 314 | return ''.join(ch for ch in text if ch not in exclude) 315 | def lower(text): 316 | return text.lower() 317 | return fix_t5_unk_characters(white_space_fix(remove_articles(remove_punc(lower(s))))) 318 | 319 | def get_tokens(s): 320 | if not s: return [] 321 | return squad_normalize_answer(s).split() 322 | 323 | def compute_exact(a_gold, a_pred): 324 | return int(squad_normalize_answer(a_gold) == squad_normalize_answer(a_pred)) 325 | 326 | def compute_f1(a_gold, a_pred): 327 | gold_toks = get_tokens(a_gold) 328 | pred_toks = get_tokens(a_pred) 329 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 330 | num_same = sum(common.values()) 331 | if len(gold_toks) == 0 or len(pred_toks) == 0: 332 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 333 | return int(gold_toks == pred_toks) 334 | if num_same == 0: 335 | return 0 336 | precision = 1.0 * num_same / len(pred_toks) 337 | recall = 1.0 * num_same / len(gold_toks) 338 | f1 = (2 * precision * recall) / (precision + recall) 339 | return f1 340 | 341 | 342 | def split_mcoptions(mcoptions): 343 | first_option = ord(mcoptions.strip()[1]) 344 | labels = "".join([chr(x) for x in range(first_option, first_option+10)]) 345 | choices = re.split("\\s*\\(["+labels+"]\\)\\s*", mcoptions)[1:] 346 | return (choices, chr(first_option)) 347 | 348 | 349 | def mcq_answer_accuracy(slots, gold): 350 | answer = slots.get('answer') 351 | if answer is None: 352 | return {"acc": 0, "missing": 1} 353 | mcoptions, first_label = split_mcoptions(gold['mcoptions']) 354 | best = -1 355 | selected = None 356 | selected_key = None 357 | for idx, option in enumerate(mcoptions): 358 | score = score_string_similarity(answer, option) 359 | if score > best: 360 | best = score 361 | selected = option 362 | selected_key = chr(ord(first_label) + idx) 363 | acc = 1 if selected == gold['answer'] else 0 364 | res = {"acc": acc, "answerkey": selected_key, "align_score": best} 365 | return res 366 | 367 | 368 | def bool_accuracy(slots, gold): 369 | pred_answer = str(slots.get('answer')) 370 | gold_answer = str(gold['answer']) 371 | 372 | res = {"acc": float(pred_answer==gold_answer), 373 | "ROUGE_L_F": float(pred_answer==gold_answer)} 374 | #print(f"---- {type(pred_answer)}\t{type(gold_answer)}\t{res}") 375 | return res 376 | 377 | 378 | def squad_em_f1(answer, gold_answers): 379 | best_em = -1 380 | best_f1 = -1 381 | best_match = "" 382 | for gold in gold_answers: 383 | if gold.lower() == "noanswer": 384 | if answer.strip().lower() == "noanswer": 385 | em = f1 = 1.0 386 | else: 387 | em = f1 = 0.0 388 | else: 389 | em = compute_exact(gold, answer) 390 | f1 = compute_f1(gold, answer) 391 | if em > best_em: 392 | best_em = em 393 | best_match = gold 394 | if f1 > best_f1: 395 | best_f1 = f1 396 | best_match = gold 397 | res = {"EM": best_em, "F1": best_f1, "matched_gold": best_match} 398 | return res 399 | 400 | 401 | def rc_answer_accuracy(slots, gold, slot_name='answer'): 402 | answer = str(slots.get(slot_name)) 403 | if answer is None: 404 | return {"EM": 0, "F1": 0, "missing": 1} 405 | gold_answers = str(gold[slot_name]) 406 | if isinstance(gold_answers, str): 407 | gold_answers = [gold_answers] 408 | return squad_em_f1(answer, gold_answers) 409 | 410 | 411 | def extact_string_match_accuracy(slots, gold, slot_name='answer'): 412 | answer = str(slots.get(slot_name)) 413 | if answer is None: 414 | return {"EM": 0, "F1": 0, "missing": 1} 415 | gold_answers = str(gold[slot_name]) 416 | if isinstance(gold_answers, str): 417 | gold_answers = [gold_answers] 418 | return exact_match(answer, gold_answers) 419 | 420 | 421 | def exact_match(answer, gold_answers): 422 | best_em = -1 423 | best_match = "" 424 | for gold in gold_answers: 425 | em = 1.0 if gold.lower() == answer.lower() else 0.0 426 | if em > best_em: 427 | best_em = em 428 | best_match = gold 429 | 430 | res = {"EM": best_em, "matched_gold": best_match} 431 | return res 432 | 433 | 434 | def basic_split_mcoptions(mcoptions): 435 | splits = re.split("\\s*\\(\\w\\)\\s*", mcoptions) 436 | res = [s for s in splits if s != ''] 437 | return res 438 | 439 | 440 | # Very basic matching algorithm to gold MC options, not very meaningful 441 | # (finds best matching gold option, only best score kept for each matched olgd) 442 | def rough_mcoptions_f1(pred_mcoptions, gold_mcoptions): 443 | if pred_mcoptions is None: 444 | return {"F1": 0, "missing": 1} 445 | gold_split = basic_split_mcoptions(gold_mcoptions) 446 | pred_split = basic_split_mcoptions(pred_mcoptions) 447 | scores = {} 448 | for pred in pred_split: 449 | score = squad_em_f1(pred, gold_split) 450 | matched = score['matched_gold'] 451 | f1 = score['F1'] 452 | old_f1 = scores.get(matched, 0) 453 | if f1 > old_f1: 454 | scores[matched] = f1 455 | tot = sum(scores.values()) 456 | return {"F1": tot / max(len(pred_split), 2)} 457 | 458 | 459 | SCORING_SPECS = { 460 | "ruletaker_inferences": ruletaker_inferences_scores 461 | } 462 | 463 | 464 | def collate_scores(predictions): 465 | res_by_angle = {} 466 | metrics_by_angle = {} 467 | averaged_metrics = ['acc', 'EM', "F1", 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 468 | 'CIDEr', 'ROUGE_L_F', 'ROUGE_L_G', 'bad_parse', 'P', 'R', 469 | 'ROUGE_L_F_perfect_align', 'BLEURT', 'BLEURT_P', 'BLEURT_R', 'BLEURT_F1', 'BLEURT_perfect_align', 470 | 'BLEURT_acc', 'BLEURT_acc_perfect_align', 'acc_perfect_align', 'fraction_perfect_align', 471 | 'edited_or_not_acc', 472 | 'num_fever_queries', 'num_fever_queries_no_answer', 'num_fever_queries_no_rationale', 473 | 'num_premises', 'num_premises_valid', 'num_premises_valid_support', 'percent_valid_premises'] 474 | aggregated_metrics = ['missing'] 475 | for pred in predictions: 476 | angle = pred['angle_str'] 477 | metrics = pred.get('metrics',{}) 478 | if angle not in res_by_angle: 479 | res_by_angle[angle] = [] 480 | res_by_angle[angle].append(pred) 481 | if angle not in metrics_by_angle: 482 | metrics_by_angle[angle] = {} 483 | metrics_by_angle[angle]['counter'] = metrics_by_angle[angle].get('counter', 0) + 1 484 | for slot, slot_metrics in metrics.items(): 485 | if slot == "extra_slots": 486 | continue 487 | if slot not in metrics_by_angle[angle]: 488 | metrics_by_angle[angle][slot] = {} 489 | for metric in averaged_metrics + aggregated_metrics: 490 | if metric in slot_metrics: 491 | metrics_by_angle[angle][slot][metric] = \ 492 | metrics_by_angle[angle][slot].get(metric, 0) + slot_metrics[metric] 493 | for angle, metrics in metrics_by_angle.items(): 494 | counter = metrics['counter'] 495 | for slot, slot_metrics in metrics.items(): 496 | if not isinstance(slot_metrics, dict): 497 | continue 498 | for slot_metric, value in slot_metrics.items(): 499 | if slot_metric in averaged_metrics: 500 | slot_metrics[slot_metric] = value/counter 501 | return {"metrics_aggregated": metrics_by_angle, "by_angle": res_by_angle} 502 | 503 | 504 | def score_aligned_entail_tree_proof(prediction, gold_list, angle, gold_json_record:dict, bleurt_scorer=None): 505 | res = {} 506 | if gold_list is None: 507 | res[angle] = {"missing_gold": 1} 508 | if prediction is None: 509 | res[angle]['missing'] = 1 510 | return res 511 | if prediction is None: 512 | res[angle] = {"missing": 1} 513 | return res 514 | print(f"\n\n\n======================\n") 515 | print(f"pred:{prediction}") 516 | print(f"gold:{gold_list[0]}") 517 | print(f"\n\n\n======================\n") 518 | print(f"Reading predicted proof") 519 | sentences_pred, inferences_pred, int_to_all_ancestors_pred, relevant_sentences_pred, id_to_int_pred = \ 520 | parse_entailment_step_proof(prediction, gold_json_record=gold_json_record) 521 | 522 | print(f"\n\n\n||||||||||||||||||||||\n") 523 | print(f"Reading gold proof") 524 | sentences_gold, inferences_gold, int_to_all_ancestors_gold, relevant_sentences_gold, id_to_int_gold = \ 525 | parse_entailment_step_proof(gold_list[0], gold_json_record=gold_json_record) 526 | 527 | pred_int_to_gold_int_mapping, prediction_to_aligned_gold, prediction_to_perfect_match = \ 528 | align_conclusions_across_proofs(int_to_all_ancestors_pred, int_to_all_ancestors_gold, 529 | id_to_int_pred, id_to_int_gold) 530 | 531 | # res[angle+'-steps-unaligned'] = score_sentence_overlaps(sentences=sentences_pred, sentences_gold=sentences_gold) 532 | 533 | sentences_pred_aligned = rewrite_aligned_proof(prediction, pred_int_to_gold_int_mapping) 534 | print(f"\n\n\n++++++++++++++++++++++++++++++++++++") 535 | print(f"pred_int_to_gold_int_mapping:{pred_int_to_gold_int_mapping}") 536 | print(f"relevant_sentences_pred:{relevant_sentences_pred}") 537 | print(f"relevant_sentences_gold:{relevant_sentences_gold}") 538 | res[angle+'-leaves'] = score_sentence_overlaps(sentences=sorted(list(relevant_sentences_pred)), 539 | sentences_gold=sorted(list(relevant_sentences_gold))) 540 | 541 | res[angle + '-steps'] = score_sentence_overlaps(sentences=sorted(list(sentences_pred_aligned)), 542 | sentences_gold=sorted(list(sentences_gold))) 543 | 544 | res[angle + '-steps']['pred_to_gold_mapping'] = pred_int_to_gold_int_mapping 545 | res[angle + '-steps']['sentences_pred_aligned'] = sentences_pred_aligned 546 | 547 | res[angle+'-intermediates'] = nlg_string_similarities_intermediates_with_F1(prediction_to_aligned_gold=prediction_to_aligned_gold, 548 | id_to_int_gold=id_to_int_gold, 549 | id_to_int_pred=id_to_int_pred, 550 | prediction_to_perfect_match=prediction_to_perfect_match, 551 | bleurt_scorer=bleurt_scorer) 552 | res[angle+'-overall'] = overall_proof_score(leaves=res[angle+'-leaves'], 553 | edges=res[angle+'-steps'], 554 | intermediates=res[angle+'-intermediates']) 555 | return res 556 | 557 | 558 | def score_aligned_entail_tree_proof_onlyIR(prediction, gold_list, angle, gold_json_record:dict, pred_json_record: dict, bleurt_scorer=None): 559 | res = {} 560 | if gold_list is None: 561 | res[angle] = {"missing_gold": 1} 562 | if prediction is None: 563 | res[angle]['missing'] = 1 564 | return res 565 | if prediction is None: 566 | res[angle] = {"missing": 1} 567 | return res 568 | print(f"\n\n++++++++++++++++++\nprediction:{prediction}") 569 | # print(f"pred_json_record:{pred_json_record}") 570 | sentences_pred, inferences_pred, int_to_all_ancestors_pred, relevant_sentences_pred, id_to_int_pred = \ 571 | parse_entailment_step_proof_remove_ids(prediction, slot_json_record=pred_json_record) 572 | 573 | print(f"gold_json_record:{gold_json_record}") 574 | # print(f"gold_json_record:{gold_json_record}") 575 | sentences_gold, inferences_gold, int_to_all_ancestors_gold, relevant_sentences_gold, id_to_int_gold = \ 576 | parse_entailment_step_proof_remove_ids(gold_list[0], slot_json_record=gold_json_record) 577 | 578 | print(f"^^^^^^^pred:{prediction}") 579 | print(f"========sentences_pred:{sentences_pred}") 580 | print(f"^^^^^^^gold:{gold_list[0]}") 581 | print(f"========sentences_gold:{sentences_gold}") 582 | 583 | print(f"Q: {pred_json_record['id']}") 584 | pred_int_to_gold_int_mapping, prediction_to_aligned_gold, prediction_to_perfect_match = \ 585 | align_conclusions_across_proofs(int_to_all_ancestors_pred, int_to_all_ancestors_gold, 586 | id_to_int_pred, id_to_int_gold) 587 | 588 | # res[angle+'-steps-unaligned'] = score_sentence_overlaps(sentences=sentences_pred, sentences_gold=sentences_gold) 589 | 590 | print(f"\n\n+++++++++++++++++++++++++\n") 591 | print(f"prediction:{prediction}") 592 | print(f"pred_int_to_gold_int_mapping:{pred_int_to_gold_int_mapping}") 593 | pred_sentences = pred_json_record['meta']['triples'] 594 | pred_sentences['hypothesis'] = gold_json_record['hypothesis'] 595 | sentences_pred_aligned, sentences_pred_aligned_strings = rewrite_aligned_proof_noids(prediction, 596 | pred_int_to_gold_int_mapping, 597 | pred_sentences=pred_sentences, 598 | gold_ints=gold_json_record['meta']['intermediate_conclusions'] 599 | ) 600 | res[angle+'-leaves'] = score_sentence_overlaps(sentences=sorted(list(relevant_sentences_pred)), 601 | sentences_gold=sorted(list(relevant_sentences_gold))) 602 | 603 | print(f"*********ID:{gold_json_record['id']}") 604 | print(f"*********sentences_pred_aligned:{sentences_pred_aligned_strings}") 605 | print(f"*********sentences_gold:{sentences_gold}") 606 | res[angle + '-steps'] = score_sentence_overlaps(sentences=sorted(list(sentences_pred_aligned_strings)), 607 | sentences_gold=sorted(list(sentences_gold))) 608 | res[angle + '-steps']['pred_to_gold_mapping'] = pred_int_to_gold_int_mapping 609 | res[angle + '-steps']['sentences_pred_aligned'] = sentences_pred_aligned 610 | 611 | res[angle+'-intermediates'] = nlg_string_similarities_intermediates_with_F1(prediction_to_aligned_gold=prediction_to_aligned_gold, 612 | id_to_int_gold=id_to_int_gold, 613 | id_to_int_pred=id_to_int_pred, 614 | prediction_to_perfect_match=prediction_to_perfect_match, 615 | bleurt_scorer=bleurt_scorer) 616 | res[angle+'-overall'] = overall_proof_score(leaves=res[angle+'-leaves'], 617 | edges=res[angle+'-steps'], 618 | intermediates=res[angle+'-intermediates']) 619 | return res 620 | 621 | 622 | def overall_proof_score(leaves, edges, intermediates): 623 | res = {} 624 | leaves_acc = leaves['acc'] 625 | edges_acc = edges['acc'] 626 | accuracy = leaves_acc * edges_acc * intermediates['BLEURT_acc'] 627 | # accuracy = leaves_acc * edges_acc 628 | accuracy_align = leaves_acc * edges_acc * intermediates['BLEURT_acc_perfect_align'] 629 | 630 | res['acc'] = accuracy 631 | res['acc_perfect_align'] = accuracy_align 632 | return res 633 | 634 | 635 | def score_prediction_whole_proof(prediction, gold, prediction_json=None, dataset=None, scoring_spec=None, bleurt_scorer=None): 636 | angle = prediction.get('angle') 637 | if 'slots' in prediction: 638 | slots = prediction['slots'] 639 | else: 640 | slots = decompose_slots(prediction['prediction']) 641 | answer_eval = "emf1" 642 | if scoring_spec is not None and "answer_eval" in scoring_spec: 643 | answer_eval = scoring_spec['answer_eval'] 644 | elif "mcoptions" in gold: 645 | answer_eval = "mcq" 646 | elif dataset is not None and "narrative" in dataset: 647 | answer_eval = "nlg" 648 | 649 | hypothesis_eval = "nlg" 650 | if scoring_spec is not None and "hypothesis_eval" in scoring_spec: 651 | hypothesis_eval = scoring_spec['hypothesis_eval'] 652 | 653 | proof_eval = "pn_eval" 654 | if scoring_spec is not None and "proof_eval" in scoring_spec: 655 | proof_eval = scoring_spec['proof_eval'] 656 | 657 | res = {} 658 | angles_out = angle[1] if angle is not None else list(slots.keys()) 659 | for angle in angles_out: 660 | gold_str = gold.get(angle) 661 | gold_list = [gold_str] if isinstance(gold_str, str) else gold_str 662 | slot = slots.get(angle) 663 | if angle == 'hypothesis': 664 | if hypothesis_eval == 'old_emf1': 665 | res[angle] = rc_answer_accuracy(slots, gold) 666 | elif hypothesis_eval == 'emf1': 667 | sentences_pred = slots.get('hypothesis').split(' , ') 668 | sentences_gold = gold['hypothesis'].split(' , ') 669 | res[angle] = score_sentence_overlaps(sentences=sentences_pred, 670 | sentences_gold=sentences_gold) 671 | elif hypothesis_eval == 'mcq': 672 | res[angle] = mcq_answer_accuracy(slots, gold) 673 | elif hypothesis_eval == 'nlg': 674 | res[angle] = nlg_string_similarities(slot, gold_list) 675 | else: 676 | raise ValueError(f"Unknown answer_eval setting: {answer_eval}") 677 | elif angle in ['question', 'explanation']: 678 | res[angle] = nlg_string_similarities(slot, gold_list) 679 | elif angle in ['mcoptions']: 680 | res[angle] = rough_mcoptions_f1(slot, gold_str) 681 | elif angle in ['proof']: 682 | if proof_eval == "entail_whole_proof_align_eval": 683 | # This writes in multiple metrics 684 | res.update(score_aligned_entail_tree_proof(slot, gold_list, angle, gold_json_record=gold,bleurt_scorer=bleurt_scorer)) 685 | elif proof_eval == "entail_whole_proof_align_eval_onlyIR": 686 | # This writes in multiple metrics 687 | res.update(score_aligned_entail_tree_proof_onlyIR(slot, gold_list, angle, 688 | gold_json_record=gold, 689 | pred_json_record=prediction_json, 690 | bleurt_scorer=bleurt_scorer)) 691 | 692 | else: 693 | gold_proofs_list = [] 694 | for gold_proof in gold_str.split(' OR '): 695 | gold_parts = gold_proof.replace(', ', ',').split(',') 696 | if len(gold_parts) == 2: 697 | gold_proofs_list.append(gold_proof) 698 | gold_proofs_list.append(f"{gold_parts[1]}, {gold_parts[0]}") 699 | res[angle] = squad_em_f1(slot, gold_proofs_list) 700 | elif angle in ['rationale']: 701 | pass # Not implemented yet 702 | else: 703 | res[angle] = {} 704 | extras = [] 705 | for slot in slots: 706 | if slot not in res: 707 | extras.append(slot) 708 | if len(extras) > 0: 709 | res['extra_slots'] = extras 710 | return res -------------------------------------------------------------------------------- /src/entailment_baseline.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import os 3 | import sys 4 | import json 5 | import torch 6 | import base_utils 7 | import random 8 | import numpy as np 9 | import pandas as pd 10 | from types import SimpleNamespace 11 | from tqdm.notebook import tqdm 12 | 13 | # Importing DL libraries 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import cuda 17 | from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler 18 | from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration 19 | from sentence_transformers import SentenceTransformer 20 | from sentence_transformers import util as st_util 21 | from entailment_bank.eval.run_scorer import main 22 | from retrieval_utils import sent_text_as_counter, convert_datapoint_to_sent_to_text 23 | 24 | class CustomDataset(Dataset): 25 | 26 | def __init__(self, source_text, target_text, tokenizer, source_len, summ_len): 27 | self.tokenizer = tokenizer 28 | self.source_len = source_len 29 | self.summ_len = summ_len 30 | self.source_text = source_text 31 | self.target_text = target_text 32 | 33 | def __len__(self): 34 | return len(self.source_text) 35 | 36 | def __getitem__(self, index): 37 | source_text = str(self.source_text[index]) 38 | source_text = ' '.join(source_text.split()) 39 | 40 | target_text = str(self.target_text[index]) 41 | target_text = ' '.join(target_text.split()) 42 | 43 | source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, 44 | padding='max_length',return_tensors='pt', truncation=True) 45 | target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, 46 | padding='max_length',return_tensors='pt', truncation=True) 47 | 48 | source_ids = source['input_ids'].squeeze() 49 | source_mask = source['attention_mask'].squeeze() 50 | target_ids = target['input_ids'].squeeze() 51 | target_mask = target['attention_mask'].squeeze() 52 | 53 | return { 54 | 'source_ids': source_ids.to(dtype=torch.long), 55 | 'source_mask': source_mask.to(dtype=torch.long), 56 | 'target_ids': target_ids.to(dtype=torch.long), 57 | 'target_ids_y': target_ids.to(dtype=torch.long) 58 | } 59 | 60 | class Trainer(): 61 | ''' 62 | Based on: 63 | https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb 64 | ''' 65 | def __init__(self, tokenizer=None, model=None, optimizer=None, params = None, config = None): 66 | self.params = params 67 | self.config = config 68 | 69 | # tokenzier for encoding the text 70 | self.tokenizer = tokenizer 71 | if tokenizer is None: 72 | self.tokenizer = T5Tokenizer.from_pretrained(self.params.model_name) 73 | 74 | # Defining the model. 75 | self.model = model 76 | if model is None: 77 | self.model = T5ForConditionalGeneration.from_pretrained( 78 | self.params.model_name, 79 | ) 80 | 81 | # Defining the optimizer that will be used to tune the weights of 82 | # the network in the training session. 83 | self.optimizer = optimizer 84 | if optimizer is None: 85 | self.optimizer = torch.optim.Adam(params = self.model.parameters(), 86 | lr=self.config.LEARNING_RATE) 87 | 88 | self.set_random_seed() 89 | 90 | def train(self, epoch, train_loader, val_loader, model = None, 91 | prefix_constrained_generator = None): 92 | if model is None: 93 | model = self.model 94 | model.train() 95 | running_loss = 0.0 96 | tqdm_loader = tqdm(train_loader) 97 | for step_num, data in enumerate(tqdm_loader, 0): 98 | if prefix_constrained_generator is not None: 99 | prefix_constrained_generator.set_batch_number(step_num) 100 | y = data['target_ids'].to(self.params.device, dtype = torch.long) 101 | y_ids = y[:, :-1].contiguous() 102 | labels = y[:, 1:].clone().detach() 103 | labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100 104 | ids = data['source_ids'].to(self.params.device, dtype = torch.long) 105 | mask = data['source_mask'].to(self.params.device, dtype = torch.long) 106 | 107 | outputs = model(input_ids = ids, attention_mask = mask, 108 | decoder_input_ids=y_ids, labels=labels) 109 | loss = outputs[0] 110 | 111 | running_loss += loss.item() 112 | if step_num % 100==0: 113 | avg_loss = running_loss / ((step_num + 1) * self.config.TRAIN_BATCH_SIZE) 114 | tqdm_loader.set_description("Loss %.4f" % avg_loss) 115 | 116 | self.optimizer.zero_grad() 117 | loss.backward() 118 | self.optimizer.step() 119 | train_avg_loss = running_loss / ((step_num + 1) * self.config.TRAIN_BATCH_SIZE) 120 | val_avg_loss = self.validate(val_loader, verbose = False) 121 | print('Epoch: %d, Train Loss: %.4f, Eval Loss %.4f' % (epoch, train_avg_loss, val_avg_loss)) 122 | return train_avg_loss, val_avg_loss 123 | 124 | def validate(self, val_loader, verbose = True, model = None, 125 | prefix_constrained_generator = None): 126 | if model is None: 127 | model = self.model 128 | model.eval() 129 | running_loss = 0.0 130 | tqdm_loader = tqdm(val_loader) if verbose else val_loader 131 | for step_num, data in enumerate(tqdm_loader, 0): 132 | if prefix_constrained_generator is not None: 133 | prefix_constrained_generator.set_batch_number(step_num) 134 | y = data['target_ids'].to(self.params.device, dtype = torch.long) 135 | y_ids = y[:, :-1].contiguous() 136 | labels = y[:, 1:].clone().detach() 137 | labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100 138 | ids = data['source_ids'].to(self.params.device, dtype = torch.long) 139 | mask = data['source_mask'].to(self.params.device, dtype = torch.long) 140 | 141 | outputs = model(input_ids = ids, attention_mask = mask, 142 | decoder_input_ids=y_ids, labels=labels) 143 | loss = outputs[0] 144 | 145 | running_loss += loss.item() 146 | if verbose and step_num % 100==0: 147 | avg_loss = running_loss / ((step_num + 1) * self.config.TRAIN_BATCH_SIZE) 148 | tqdm_loader.set_description("Loss %.4f" % avg_loss) 149 | 150 | avg_loss = running_loss / ((step_num + 1) * self.config.TRAIN_BATCH_SIZE) 151 | if verbose: 152 | print('Loss: %.4f' % (avg_loss,)) 153 | return avg_loss 154 | 155 | 156 | def predict(self, loader, generation_args = None, model = None, 157 | prefix_constrained_generator = None): 158 | if model is None: 159 | model = self.model 160 | model.eval() 161 | context = [] 162 | predictions = [] 163 | actuals = [] 164 | 165 | if generation_args is None: 166 | generation_args = { 167 | 'max_length': self.config.SUMMARY_LEN, 168 | 'num_beams': 3, 169 | # repetition_penalty': 2.5, 170 | 'length_penalty': 1.0, 171 | 'early_stopping': True 172 | } 173 | with torch.no_grad(): 174 | for step_num, data in enumerate(tqdm(loader), 0): 175 | if prefix_constrained_generator is not None: 176 | prefix_constrained_generator.set_batch_number(step_num) 177 | y = data['target_ids'].to(self.params.device, dtype = torch.long) 178 | ids = data['source_ids'].to(self.params.device, dtype = torch.long) 179 | mask = data['source_mask'].to(self.params.device, dtype = torch.long) 180 | generation_args.update({ 181 | 'input_ids': ids, 182 | 'attention_mask': mask, 183 | }) 184 | generated_ids = model.generate(**generation_args) 185 | inputs = [self.tokenizer.decode( 186 | i, skip_special_tokens=True, 187 | clean_up_tokenization_spaces=True) for i in ids] 188 | preds = [self.tokenizer.decode( 189 | g, skip_special_tokens=True, 190 | clean_up_tokenization_spaces=True) for g in generated_ids] 191 | target = [self.tokenizer.decode( 192 | t, skip_special_tokens=True, 193 | clean_up_tokenization_spaces=True) for t in y] 194 | context.extend(inputs) 195 | predictions.extend(preds) 196 | actuals.extend(target) 197 | return predictions, actuals, context 198 | 199 | def set_random_seed(self): 200 | # Set random seeds and deterministic pytorch for reproducibility 201 | torch.manual_seed(self.config.SEED) # pytorch random seed 202 | np.random.seed(self.config.SEED) # numpy random seed 203 | torch.backends.cudnn.deterministic = True 204 | 205 | def save_model(self, file_path_suffix = ''): 206 | model_path = self.params.model_file_path.format( 207 | model_name = self.params.model_name, 208 | task_name = self.params.task_name, 209 | dataset_name = self.params.dataset_name, 210 | approach_name = self.params.approach_name, 211 | suffix = file_path_suffix) 212 | torch.save(self.model.state_dict(), model_path) 213 | print('state dict saved to: %s' % model_path) 214 | 215 | def load_model(self, file_path_suffix = ''): 216 | model_path = self.params.model_file_path.format( 217 | model_name = self.params.model_name, 218 | task_name = self.params.task_name, 219 | dataset_name = self.params.dataset_name, 220 | approach_name = self.params.approach_name, 221 | suffix = file_path_suffix) 222 | print('Loading state dict from: %s' % model_path) 223 | self.model.load_state_dict(torch.load(model_path)) 224 | 225 | class EntailmentARCDataset(): 226 | 227 | ROOT_PATH = "../data/arc_entail" 228 | DATASET_PATH = os.path.join(ROOT_PATH, "dataset") 229 | TASK_PATH = os.path.join(DATASET_PATH, "task_{task_num}") 230 | PARTITION_DATA_PATH = os.path.join(TASK_PATH, "{partition}.jsonl") 231 | 232 | def __init__(self, semantic_search = None, params = None, config = None): 233 | self.params = params 234 | self.config = config 235 | self.data = {self.get_task_name(task_num): 236 | {partition: [] for partition in ['train', 'dev', 'test']} 237 | for task_num in range(1, 4)} 238 | self.load_dataset() 239 | self.semantic_search = semantic_search 240 | 241 | def get_task_name(self, task_num): 242 | return "task_" + str(task_num) 243 | 244 | def get_task_number(self, task_name): 245 | return int(task_name[-1]) 246 | 247 | def get_dataset_path(self, task_num = 1, partition = 'train'): 248 | path = self.PARTITION_DATA_PATH.format(task_num = task_num, partition = partition) 249 | return path 250 | 251 | def load_dataset(self): 252 | for task_name in self.data: 253 | for partition in self.data[task_name]: 254 | path = self.get_dataset_path(self.get_task_number(task_name), partition) 255 | with open(path, 'r', encoding='utf8') as f: 256 | for line in f: 257 | datapoint = json.loads(line) 258 | self.data[task_name][partition].append(datapoint) 259 | 260 | def combine_existing_and_search_context(self, datapoint, retrieved): 261 | # break down existing context to get sent texts 262 | sent_to_text = convert_datapoint_to_sent_to_text(datapoint) 263 | # merge retrieved and existing sent texts 264 | original_ret_size = len(retrieved) 265 | new_sents = [s['text'] for s in sent_to_text.values()] 266 | new_sents_lowered = [s['text'].lower() for s in sent_to_text.values()] 267 | # add sents retrieved by search to new_sents 268 | for ret in retrieved: 269 | if ret.lower() in new_sents_lowered: 270 | continue 271 | new_sents.append(ret) 272 | if len(new_sents) >= original_ret_size: 273 | break 274 | assert len(new_sents) <= original_ret_size 275 | 276 | # shuffles order, saving original index 277 | # new_sent_order[i] == j means original j-th sentence 278 | # now in i-th position 279 | new_sent_order = list(range(len(new_sents))) 280 | new_context = [] 281 | random.shuffle(new_sent_order) 282 | for i in range(len(new_sents)): 283 | new_context.append(new_sents[new_sent_order[i]]) 284 | # create new contex 285 | new_context = ' '.join(['sent%d: %s' % (i+1, r) for i, r in enumerate(new_context)]) 286 | # create new proof, modify index according to new context 287 | old_to_new_sent_map = {} 288 | for i in range(len(new_sents)): 289 | old_to_new_sent_map['sent%d ' % (new_sent_order[i] + 1,)] = 'sent%d ' % (i+1,) 290 | old_proof = datapoint['proof'] 291 | new_proof = base_utils.str_replace_single_pass(old_proof, old_to_new_sent_map) 292 | return new_context, new_proof 293 | 294 | def update_dataset_with_search(self, dataset, include_existing_context = False): 295 | # use retrieved context instead of goden context 296 | new_dataset = [dict(dp) for dp in dataset] 297 | retrieved_lst = self.semantic_search.search( 298 | dataset, top_k = self.params.max_retrieved_sentences) 299 | for retrived_it, retrieved in enumerate(retrieved_lst): 300 | if include_existing_context: 301 | new_context, new_proof = self.combine_existing_and_search_context( 302 | dataset[retrived_it], retrieved) 303 | if retrived_it < 20: 304 | print('OLD CONTEXT = ', dataset[retrived_it]['context']) 305 | print('NEW CONTEXT = ', new_context) 306 | print('OLD PROOF = ', dataset[retrived_it]['proof']) 307 | print('NEW PROOF = ', new_proof) 308 | print() 309 | new_dataset[retrived_it]['context'] = new_context 310 | new_dataset[retrived_it]['proof'] = new_proof 311 | else: 312 | sents = ['sent%d: %s' % (i+1, r) for i, r in enumerate(retrieved)] 313 | new_dataset[retrived_it]['context'] = ' '.join(sents) 314 | # makes sure proof is empty since original proof is unrelated to context 315 | new_dataset[retrived_it]['proof'] = '' 316 | return new_dataset 317 | 318 | def get_source_text(self, task_name, partition): 319 | source_text = [] 320 | if self.semantic_search is not None: 321 | new_contexts = self.get_contexts_from_search( 322 | self.data[task_name][partition]) 323 | for dp_it, data_point in enumerate(self.data[task_name][partition]): 324 | if self.semantic_search is not None: 325 | context = new_contexts[dp_it] 326 | else: 327 | context = data_point['context'] 328 | hypothesis = data_point['hypothesis'] 329 | source_text.append( 330 | 'hypothesis: %s, %s' % (hypothesis, context)) 331 | return source_text 332 | 333 | def get_target_text(self, task_name, partition): 334 | source_text = [] 335 | for data_point in self.data[task_name][partition]: 336 | source_text.append('$proof$ = %s' % (data_point['proof'],)) 337 | return source_text 338 | 339 | def get_torch_dataloaders(self, task_name, tokenizer): 340 | ''' 341 | Creation of Dataset and Dataloader for a certain entailment task. 342 | ''' 343 | # Creating the Training and Validation dataset for further creation of Dataloader 344 | 345 | train_source_text = self.get_source_text(task_name, 'train') 346 | train_target_text = self.get_target_text(task_name, 'train') 347 | training_set = CustomDataset(train_source_text, train_target_text, tokenizer, 348 | self.config.MAX_LEN, self.config.SUMMARY_LEN) 349 | 350 | dev_source_text = self.get_source_text(task_name, 'dev') 351 | dev_target_text = self.get_target_text(task_name, 'dev') 352 | val_set = CustomDataset(dev_source_text, dev_target_text, tokenizer, 353 | self.config.MAX_LEN, self.config.SUMMARY_LEN) 354 | 355 | # Defining the parameters for creation of dataloaders 356 | train_params = { 357 | 'batch_size': self.config.TRAIN_BATCH_SIZE, 358 | 'shuffle': True, 359 | 'num_workers': 0 360 | } 361 | 362 | val_params = { 363 | 'batch_size': self.config.VALID_BATCH_SIZE, 364 | 'shuffle': False, 365 | 'num_workers': 0 366 | } 367 | 368 | # Creation of Dataloaders for testing and validation. 369 | # This will be used down for training and validation stage for the model. 370 | training_loader = DataLoader(training_set, **train_params) 371 | val_loader = DataLoader(val_set, **val_params) 372 | return training_loader, val_loader 373 | 374 | class SemanticSearch(): 375 | 376 | def __init__(self, corpus = None, encoder_model = None, params = None, config = None): 377 | self.params = params 378 | self.config = config 379 | self.encoder_model = encoder_model 380 | if encoder_model is None: 381 | self.encoder_model = SentenceTransformer(self.params.sent_trans_name) 382 | if corpus is not None: 383 | self.update_corpus(corpus) 384 | 385 | def load_wt_corpus_file(self): 386 | wt_corpus = {} 387 | with open(self.params.wt_corpus_file_path, 'r', encoding='utf8') as f: 388 | wt_corpus = json.loads(f.readline()) 389 | return wt_corpus 390 | 391 | def load_wt_corpus(self, extra_facts = None): 392 | wt_corpus = self.load_wt_corpus_file() 393 | corpus = list(wt_corpus.values()) 394 | if extra_facts is not None: 395 | corpus.extend(extra_facts) 396 | corpus = list(set(corpus)) 397 | self.update_corpus_embeddings(corpus) 398 | 399 | def update_corpus_embeddings(self, corpus): 400 | self.corpus = corpus 401 | #Encode all sentences in corpus 402 | self.corpus_embeddings = self.encoder_model.encode( 403 | corpus, convert_to_tensor=True, show_progress_bar = True) 404 | self.corpus_embeddings = self.corpus_embeddings.to(self.params.device) 405 | self.corpus_embeddings = st_util.normalize_embeddings(self.corpus_embeddings) 406 | 407 | def search_with_id_and_scores(self, queries, top_k = 1): 408 | ''' 409 | Search for best semantically similar sentences in corpus. 410 | 411 | returns corpus ids (index in input corpus) and scoress 412 | ''' 413 | if type(queries) != list: 414 | queries = [queries] 415 | 416 | #Encode all queries 417 | query_embeddings = self.encoder_model.encode( 418 | queries, convert_to_tensor=True, show_progress_bar = False) 419 | query_embeddings = query_embeddings.to(self.params.device) 420 | query_embeddings = st_util.normalize_embeddings(query_embeddings) 421 | hits = st_util.semantic_search(query_embeddings, self.corpus_embeddings, 422 | top_k=top_k, score_function=st_util.dot_score) 423 | return hits 424 | 425 | def search(self, *args, **kwargs): 426 | ''' 427 | Search for best semantically similar sentences in corpus. 428 | 429 | Only returns elements from corpus (no score or id) 430 | ''' 431 | hits = self.search_with_id_and_scores(*args, **kwargs) 432 | elements = [[self.corpus[ret['corpus_id']] for ret in hit] for hit in hits] 433 | return elements 434 | 435 | def run_test(self): 436 | corpus = [ 437 | 'A man is eating food.', 'A man is eating a piece of bread.', 438 | 'The girl is carrying a baby.', 'A man is riding a horse.', 'A woman is playing violin.', 439 | 'Two men pushed carts through the woods.', 'A man is riding a white horse on an enclosed ground.', 440 | 'A monkey is playing drums.', 'Someone in a gorilla costume is playing a set of drums.' 441 | ] 442 | self.update_corpus(corpus) 443 | queries = ['A woman enjoys her meal', 'A primate is performing at a concert'] 444 | results = self.search(queries, top_k = 2) 445 | for i in range(len(queries)): 446 | print('Query:', queries[i]) 447 | print('Best results:', results[i]) 448 | print() 449 | 450 | class PrefixConstrainedGenerator: 451 | ''' 452 | Constraints the beam search to allowed tokens only at each step. 453 | Enforces entailmnet dataset expected format (important for evaluation code) 454 | ''' 455 | 456 | def __init__(self, tokenizer, source_text, batch_size): 457 | # tokenzier for encoding the text 458 | self.tokenizer = tokenizer 459 | self.source_text = source_text 460 | self.batch_size = batch_size 461 | self.batch_num = 0 462 | 463 | def get_first_token_id(self, text): 464 | toks = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) 465 | return toks[0] 466 | 467 | def get_last_token_id(self, text): 468 | toks = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) 469 | return toks[-1] 470 | 471 | def set_batch_number(self, batch_num): 472 | self.batch_num = batch_num 473 | 474 | def set_source_text(self, source_text): 475 | self.source_text = source_text 476 | 477 | def fixed_prefix_allowed_tokens_fn(self, batch_id, inputs_ids): 478 | ''' 479 | Constrain the next token for beam search depending on currently generated prefix (input_ids) 480 | The output is loosely formated according to dataset specification. 481 | ''' 482 | # print(inputs_ids, batch_id) 483 | prefix = self.tokenizer.decode(inputs_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) 484 | # print(prefix) 485 | if prefix.strip() == '': 486 | return [self.get_first_token_id('sent')] 487 | if prefix.endswith(' & ') or prefix.endswith(' ; '): 488 | return [self.get_first_token_id('sent'), self.get_first_token_id('int')] 489 | if prefix.endswith('sent') or prefix.endswith('int'): 490 | return [self.get_last_token_id('sent' + str(num)) for num in range(10)] 491 | if prefix.endswith(' -> '): 492 | return [self.get_first_token_id('hypothesis'), self.get_first_token_id('int')] 493 | return list(range(self.tokenizer.vocab_size)) 494 | 495 | def iterative_prefix_allowed_tokens_fn(self, batch_id, inputs_ids): 496 | ''' 497 | Constrain the next token for beam search depending on currently generated prefix (input_ids) 498 | The output is loosely formated according to dataset specification. 499 | ''' 500 | prefix = self.tokenizer.decode(inputs_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) 501 | source_idx = self.batch_size * self.batch_num + batch_id 502 | source_text = self.source_text[source_idx] 503 | available_sent_nums = [source_text[match.span()[0] + len('sent'): match.span()[1]-1] 504 | for match in re.finditer("(sent)[0-9]+:", source_text)] 505 | avaliable_int_nums = [source_text[match.span()[0] + len('int'): match.span()[1]-1] 506 | for match in re.finditer("(int)[0-9]+:", source_text)] 507 | 508 | if prefix.strip() == 'in': 509 | return [self.get_last_token_id('int')] 510 | if prefix.strip() == '' or prefix.endswith(' & ') or prefix.endswith(' ; '): 511 | return [self.get_first_token_id('sent'), self.get_first_token_id('int')] 512 | if prefix.endswith('sent'): 513 | return list(set([self.get_last_token_id('sent' + num) for num in available_sent_nums])) 514 | # return list(set([self.get_last_token_id('sent' + str(num)) for num in range(10)])) 515 | if prefix.endswith('int') and not prefix.endswith('-> int'): 516 | return list(set([self.get_last_token_id('int' + num) for num in avaliable_int_nums])) 517 | # return list(set([self.get_last_token_id('int' + str(num)) for num in range(10)])) 518 | if prefix.endswith(' -> '): 519 | return [self.get_first_token_id('hypothesis'), self.get_first_token_id('int')] 520 | if not ' -> ' in prefix: 521 | all_toks = list(range(self.tokenizer.vocab_size)) 522 | all_toks.remove(self.get_last_token_id('int1:')) 523 | return all_toks 524 | return list(range(self.tokenizer.vocab_size)) 525 | 526 | # Training loop 527 | def run_training_loop(config): 528 | print('Initiating Fine-Tuning for the model on our dataset') 529 | 530 | min_val_avg_loss = 1e10 531 | 532 | for epoch in range(config.TRAIN_EPOCHS): 533 | _, val_avg_loss = trainer.train(epoch, training_loader, val_loader) 534 | if params.save_min_val_loss and val_avg_loss < min_val_avg_loss: 535 | min_val_avg_loss = val_avg_loss 536 | # Saving trained model with lowest validation loss 537 | trainer.save_model(file_path_suffix = '_min_val_loss') 538 | 539 | # Saving trained model 540 | trainer.save_model() 541 | 542 | print('min_val_avg_loss', min_val_avg_loss) 543 | -------------------------------------------------------------------------------- /src/entailment_retrieval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Imports\n", 10 | "import os\n", 11 | "import re\n", 12 | "import sys\n", 13 | "import json\n", 14 | "import torch\n", 15 | "import base_utils\n", 16 | "import random\n", 17 | "import string\n", 18 | "import numpy as np\n", 19 | "import pandas as pd\n", 20 | "from types import SimpleNamespace\n", 21 | "from tqdm.notebook import tqdm\n", 22 | "from collections import Counter, defaultdict\n", 23 | "\n", 24 | "# Importing DL libraries\n", 25 | "import torch\n", 26 | "from torch import nn\n", 27 | "import torch.nn.functional as F\n", 28 | "from torch import cuda\n", 29 | "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n", 30 | "from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, Adafactor\n", 31 | "from sentence_transformers import SentenceTransformer\n", 32 | "from sentence_transformers import util as st_util\n", 33 | "from sentence_transformers import evaluation as st_evaluation\n", 34 | "from sentence_transformers import SentenceTransformer, InputExample, losses\n", 35 | "\n", 36 | "from retrieval_utils import convert_datapoint_to_sent_to_text, sent_text_as_counter" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "params = SimpleNamespace(\n", 46 | " # options: \"arc_entail\" (entailment bank), \"eqasc\"\n", 47 | " dataset_name = 'arc_entail', \n", 48 | " # task_1, task_2, task_3\n", 49 | " task_name = 'task_3',\n", 50 | " # use test instead of dev data to evaluate model\n", 51 | " use_test_data = True, \n", 52 | " encoder_model_path = '../data/arc_entail/models/%s_fine_tuned_v6/',\n", 53 | " encoder_checkpoint_path = '../data/arc_entail/models/%s_checkpoint_v6',\n", 54 | " device = 'cuda' if cuda.is_available() else 'cpu',\n", 55 | " # full list of sentence transformers: https://www.sbert.net/docs/pretrained_models.html\n", 56 | " sent_trans_name = 'all-mpnet-base-v2',\n", 57 | " wt_corpus_file_path = '../data/arc_entail/supporting_data/worldtree_corpus_sentences_extended.json', \n", 58 | " max_retrieved_sentences = 25\n", 59 | ")\n", 60 | "\n", 61 | "config = SimpleNamespace(\n", 62 | " TRAIN_EPOCHS = 5, # number of epochs to train\n", 63 | " LEARNING_RATE = 4e-5, # learning rate\n", 64 | " SEED = 39, # random seed\n", 65 | ")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "def set_random_seed():\n", 75 | " # Set random seeds and deterministic pytorch for reproducibility\n", 76 | " torch.manual_seed(config.SEED) # pytorch random seed\n", 77 | " np.random.seed(config.SEED) # numpy random seed\n", 78 | " torch.backends.cudnn.deterministic = True \n", 79 | "\n", 80 | "set_random_seed()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "class CustomDataset(Dataset):\n", 90 | "\n", 91 | " def __init__(self, source_text, target_text, tokenizer, source_len, summ_len):\n", 92 | " self.tokenizer = tokenizer\n", 93 | " self.source_len = source_len\n", 94 | " self.summ_len = summ_len\n", 95 | " self.source_text = source_text\n", 96 | " self.target_text = target_text\n", 97 | "\n", 98 | " def __len__(self):\n", 99 | " return len(self.source_text)\n", 100 | "\n", 101 | " def __getitem__(self, index):\n", 102 | " source_text = str(self.source_text[index])\n", 103 | " source_text = ' '.join(source_text.split())\n", 104 | "\n", 105 | " target_text = str(self.target_text[index])\n", 106 | " target_text = ' '.join(target_text.split())\n", 107 | "\n", 108 | " source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, \n", 109 | " padding='max_length',return_tensors='pt', truncation=True)\n", 110 | " target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, \n", 111 | " padding='max_length',return_tensors='pt', truncation=True)\n", 112 | "\n", 113 | " source_ids = source['input_ids'].squeeze()\n", 114 | " source_mask = source['attention_mask'].squeeze()\n", 115 | " target_ids = target['input_ids'].squeeze()\n", 116 | " target_mask = target['attention_mask'].squeeze()\n", 117 | "\n", 118 | " return {\n", 119 | " 'source_ids': source_ids.to(dtype=torch.long),\n", 120 | " 'source_mask': source_mask.to(dtype=torch.long), \n", 121 | " 'target_ids': target_ids.to(dtype=torch.long),\n", 122 | " 'target_ids_y': target_ids.to(dtype=torch.long)\n", 123 | " }" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "class SemanticSearch():\n", 133 | " \n", 134 | " def __init__(self, corpus = None, encoder_model = None, params = None):\n", 135 | " self.params = params\n", 136 | " self.encoder_model = encoder_model\n", 137 | " if encoder_model is None:\n", 138 | " self.encoder_model = SentenceTransformer(self.params.sent_trans_name)\n", 139 | " if corpus is not None:\n", 140 | " self.update_corpus_embeddings(corpus)\n", 141 | "\n", 142 | " def load_wt_corpus_file(self):\n", 143 | " wt_corpus = {}\n", 144 | " with open(self.params.wt_corpus_file_path, 'r', encoding='utf8') as f:\n", 145 | " wt_corpus = json.loads(f.readline())\n", 146 | " return wt_corpus\n", 147 | " \n", 148 | " def load_wt_corpus(self, extra_facts = None):\n", 149 | " wt_corpus = self.load_wt_corpus_file()\n", 150 | " corpus = list(wt_corpus.values())\n", 151 | " if extra_facts is not None:\n", 152 | " corpus.extend(extra_facts)\n", 153 | " corpus = list(set(corpus))\n", 154 | " self.update_corpus_embeddings(corpus)\n", 155 | " \n", 156 | " def update_corpus_embeddings(self, corpus):\n", 157 | " self.corpus = corpus\n", 158 | " # Encode all sentences in corpus\n", 159 | " self.corpus_embeddings = self.encoder_model.encode(\n", 160 | " corpus, convert_to_tensor=True, show_progress_bar = False)\n", 161 | " self.corpus_embeddings = self.corpus_embeddings.to(self.params.device)\n", 162 | " self.corpus_embeddings = st_util.normalize_embeddings(self.corpus_embeddings) \n", 163 | " \n", 164 | " def search_with_id_and_scores(self, queries, top_k = 1):\n", 165 | " '''\n", 166 | " Search for best semantically similar sentences in corpus.\n", 167 | " \n", 168 | " returns corpus ids (index in input corpus) and scoress\n", 169 | " '''\n", 170 | " if type(queries) != list:\n", 171 | " queries = [queries]\n", 172 | "\n", 173 | " #Encode all queries\n", 174 | " query_embeddings = self.encoder_model.encode(queries, convert_to_tensor=True)\n", 175 | " query_embeddings = query_embeddings.to(self.params.device)\n", 176 | " query_embeddings = st_util.normalize_embeddings(query_embeddings)\n", 177 | " hits = st_util.semantic_search(query_embeddings, self.corpus_embeddings, \n", 178 | " top_k=top_k, score_function=st_util.dot_score)\n", 179 | " return hits\n", 180 | " \n", 181 | " def search(self, *args, **kwargs):\n", 182 | " '''\n", 183 | " Search for best semantically similar sentences in corpus.\n", 184 | " \n", 185 | " Only returns elements from corpus (no score or id)\n", 186 | " '''\n", 187 | " hits = self.search_with_id_and_scores(*args, **kwargs)\n", 188 | " elements = [[self.corpus[ret['corpus_id']] for ret in hit] for hit in hits]\n", 189 | " return elements\n", 190 | " \n", 191 | " def run_test(self):\n", 192 | " corpus = [\n", 193 | " 'A man is eating food.', 'A man is eating a piece of bread.',\n", 194 | " 'The girl is carrying a baby.', 'A man is riding a horse.', 'A woman is playing violin.',\n", 195 | " 'Two men pushed carts through the woods.', 'A man is riding a white horse on an enclosed ground.',\n", 196 | " 'A monkey is playing drums.', 'Someone in a gorilla costume is playing a set of drums.',\n", 197 | " 'matter in the gas phase has variable shape',\n", 198 | " ]\n", 199 | " self.update_corpus_embeddings(corpus)\n", 200 | " queries = ['A woman enjoys her meal', 'A primate is performing at a concert', 'matter in gas phase has no definite volume and no definite shape']\n", 201 | " results = self.search(queries, top_k = 2)\n", 202 | " for i in range(len(queries)):\n", 203 | " print('Query:', queries[i])\n", 204 | " print('Best results:', results[i])\n", 205 | " print()\n" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "class EntailmentARCDataset():\n", 215 | " \n", 216 | " ROOT_PATH = \"../data/arc_entail\"\n", 217 | " DATASET_PATH = os.path.join(ROOT_PATH, \"dataset\")\n", 218 | " TASK_PATH = os.path.join(DATASET_PATH, \"task_{task_num}\")\n", 219 | " PARTITION_DATA_PATH = os.path.join(TASK_PATH, \"{partition}.jsonl\")\n", 220 | " \n", 221 | " def __init__(self, semantic_search = None, params = None, config = None):\n", 222 | " self.params = params\n", 223 | " self.config = config\n", 224 | " self.data = {self.get_task_name(task_num): \n", 225 | " {partition: [] for partition in ['train', 'dev', 'test']} \n", 226 | " for task_num in range(1, 4)}\n", 227 | " self.load_dataset()\n", 228 | " self.semantic_search = semantic_search\n", 229 | " \n", 230 | " def get_task_name(self, task_num):\n", 231 | " return \"task_\" + str(task_num)\n", 232 | " \n", 233 | " def get_task_number(self, task_name):\n", 234 | " return int(task_name[-1])\n", 235 | " \n", 236 | " def get_dataset_path(self, task_num = 1, partition = 'train'):\n", 237 | " path = self.PARTITION_DATA_PATH.format(task_num = task_num, partition = partition)\n", 238 | " return path\n", 239 | " \n", 240 | " def load_dataset(self):\n", 241 | " for task_name in self.data:\n", 242 | " for partition in self.data[task_name]:\n", 243 | " path = self.get_dataset_path(self.get_task_number(task_name), partition)\n", 244 | " with open(path, 'r', encoding='utf8') as f:\n", 245 | " for line in f:\n", 246 | " datapoint = json.loads(line)\n", 247 | " self.data[task_name][partition].append(datapoint)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "# Loading Data\n", 255 | "\n", 256 | "### **NOTE**: if not fine-tunning encoder, set ``load_model = True`` to load an existing trained encoder model" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "load_model = False\n", 266 | "encoder_model = None\n", 267 | "\n", 268 | "if load_model:\n", 269 | " encoder_model_path = params.encoder_model_path % params.sent_trans_name\n", 270 | " print(f'loading model from: {encoder_model_path}')\n", 271 | " encoder_model = SentenceTransformer(encoder_model_path)\n", 272 | "\n", 273 | "semantic_search = SemanticSearch(encoder_model = encoder_model, params = params)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "entail_dataset = EntailmentARCDataset(\n", 283 | " semantic_search = semantic_search, params = params, config = config)\n", 284 | "print(entail_dataset.data[params.task_name]['train'][0])" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "# Retrieval Utils" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "def counter_jaccard_similarity(c1, c2):\n", 301 | " inter = c1 & c2\n", 302 | " union = c1 | c2\n", 303 | " return sum(inter.values()) / float(sum(union.values()))\n", 304 | "\n", 305 | "def construct_sent_context_mapping(matching_texts, sent_to_text, new_mapping_key, matching_uuids = None):\n", 306 | " alignment = []\n", 307 | " for match_it, matching_text in enumerate(matching_texts):\n", 308 | " text_counter = sent_text_as_counter(matching_text)\n", 309 | " matching_uuid = None\n", 310 | " if matching_uuids is not None:\n", 311 | " matching_uuid = matching_uuids[match_it]\n", 312 | " for sent_k, sent_v in sent_to_text.items():\n", 313 | " alignment.append((sent_k, sent_v['text'], matching_text, match_it, matching_uuid,\n", 314 | " counter_jaccard_similarity(sent_v['text_counter'], text_counter)))\n", 315 | " sorted_alignment = sorted(alignment, key= lambda x: x[-1], reverse=True)\n", 316 | " matches_it_used = []\n", 317 | " for align_item in sorted_alignment:\n", 318 | " sent_key = align_item[0]\n", 319 | " matching_text = align_item[2]\n", 320 | " match_it = align_item[3]\n", 321 | " matching_uuid = align_item[4]\n", 322 | " if new_mapping_key not in sent_to_text[sent_key].keys():\n", 323 | " if match_it not in matches_it_used: \n", 324 | " sent_to_text[sent_key][new_mapping_key] = matching_text\n", 325 | " if matching_uuid is not None:\n", 326 | " sent_to_text[sent_key][new_mapping_key + '_uuid'] = matching_uuid\n", 327 | " matches_it_used.append(match_it)\n", 328 | " \n", 329 | " assert all([new_mapping_key in v for v in sent_to_text.values()])\n", 330 | " return sent_to_text\n", 331 | "\n", 332 | "\n", 333 | "def construct_datapoint_context_mapping(datapoint):\n", 334 | " sent_to_text = convert_datapoint_to_sent_to_text(datapoint)\n", 335 | " triples = datapoint['meta']['triples'].values()\n", 336 | " assert len(sent_to_text.keys()) == len(triples) \n", 337 | " sent_to_text = construct_sent_context_mapping(\n", 338 | " triples, sent_to_text, new_mapping_key = 'triple_text')\n", 339 | " \n", 340 | " wt_p_items = [wt_p_item['original_text'] \n", 341 | " for wt_p_item in datapoint['meta']['worldtree_provenance'].values()]\n", 342 | " wt_p_uuids = [wt_p_item['uuid'] \n", 343 | " for wt_p_item in datapoint['meta']['worldtree_provenance'].values()]\n", 344 | " assert len(sent_to_text.keys()) == len(wt_p_items)\n", 345 | " sent_to_text = construct_sent_context_mapping(\n", 346 | " wt_p_items, sent_to_text, new_mapping_key = 'wt_p_text', matching_uuids = wt_p_uuids)\n", 347 | " \n", 348 | " for sent_to_text_v in sent_to_text.values():\n", 349 | " del sent_to_text_v['text_counter']\n", 350 | " return sent_to_text\n", 351 | "\n", 352 | "def create_context_mapping(dataset = entail_dataset.data['task_1']['test'], verbose = False):\n", 353 | " context_mapping = []\n", 354 | " wt_corpus = {}\n", 355 | " with open(params.wt_corpus_file_path, 'r', encoding='utf8') as f:\n", 356 | " wt_corpus = json.loads(f.readline())\n", 357 | " \n", 358 | " for datapoint in dataset:\n", 359 | " datapoint_context_mapping = construct_datapoint_context_mapping(datapoint)\n", 360 | " context_mapping.append(datapoint_context_mapping)\n", 361 | " \n", 362 | " for k, v in datapoint_context_mapping.items(): \n", 363 | " if 'wt_p_text_uuid' in v and v['wt_p_text_uuid'] in wt_corpus.keys():\n", 364 | " datapoint_context_mapping[k]['wt_corpus_text'] = wt_corpus[v['wt_p_text_uuid']]\n", 365 | " \n", 366 | " if verbose:\n", 367 | " for k, v in datapoint_context_mapping.items(): \n", 368 | " for item_k, item_v in v.items():\n", 369 | " print(item_k, '=', item_v)\n", 370 | " print()\n", 371 | " print('======')\n", 372 | " return context_mapping" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "def fix_wt_corpus_with_task_1_data(split = 'test'):\n", 382 | " dataset = entail_dataset.data['task_1'][split] \n", 383 | " context_mapping = create_context_mapping(dataset)\n", 384 | " removal_sents = [v for cm in context_mapping for p in cm.values() \n", 385 | " for k,v in p.items() if k != 'text']\n", 386 | " include_sents = [v for cm in context_mapping for p in cm.values() \n", 387 | " for k,v in p.items() if k == 'text']\n", 388 | " print(removal_sents[:20])\n", 389 | " wt_corpus = semantic_search.load_wt_corpus_file()\n", 390 | " corpus = list(set(list(wt_corpus.values())) - set(removal_sents))\n", 391 | " corpus.extend(include_sents)\n", 392 | " semantic_search.update_corpus_embeddings(corpus)\n", 393 | " \n", 394 | " \n", 395 | "fix_wt_corpus_with_task_1_data()\n", 396 | "print('corpus size = ', len(semantic_search.corpus))" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "def get_sents_height_on_tree(sent_text_lst, data_point):\n", 406 | " proof = data_point['proof']\n", 407 | " context = data_point['context']\n", 408 | " conclusion_to_antecedent_int_map = {}\n", 409 | " sents_to_int_map = {}\n", 410 | " int_to_height_map = {}\n", 411 | " \n", 412 | " steps = proof.split(';')[:-1]\n", 413 | " for step in steps:\n", 414 | " antecedent, conclusion = step.split(' -> ')\n", 415 | " if conclusion.strip() == 'hypothesis':\n", 416 | " conclusion_int = 'hypothesis'\n", 417 | " else:\n", 418 | " conclusion_int = re.findall(r'int[0-9]+', conclusion)[0]\n", 419 | " antecedent_ints = re.findall(r'int[0-9]+', antecedent)\n", 420 | " antecedent_sents = re.findall(r'sent[0-9]+', antecedent)\n", 421 | " for ant_sent in antecedent_sents:\n", 422 | " sents_to_int_map[ant_sent] = conclusion_int\n", 423 | " conclusion_to_antecedent_int_map[conclusion_int] = antecedent_ints\n", 424 | " \n", 425 | " cur_height = 0\n", 426 | " current_ints = ['hypothesis'] \n", 427 | " while len(current_ints) > 0:\n", 428 | " next_ints = []\n", 429 | " for cur_int in current_ints:\n", 430 | " int_to_height_map[cur_int] = cur_height\n", 431 | " if cur_int in conclusion_to_antecedent_int_map.keys():\n", 432 | " for next_int in conclusion_to_antecedent_int_map[cur_int]:\n", 433 | " next_ints.append(next_int)\n", 434 | " current_ints = next_ints\n", 435 | " cur_height += 1\n", 436 | " \n", 437 | " heights = []\n", 438 | " for sent_text in sent_text_lst:\n", 439 | " sent_match = re.findall(\n", 440 | " '(sent[0-9]+): (%s)' % re.escape(sent_text.strip()), context)\n", 441 | " if len(sent_match) == 0:\n", 442 | " print('MISSING!!!')\n", 443 | " print('sent_text =', sent_text)\n", 444 | " print('context =', context)\n", 445 | " continue\n", 446 | " sent_symb = sent_match[0][0]\n", 447 | " int_symb = sents_to_int_map[sent_symb]\n", 448 | " if not int_symb in int_to_height_map:\n", 449 | " # this might happen when proof has antecedent missing \"int\"\n", 450 | " continue\n", 451 | " heights.append({\n", 452 | " 'sent':sent_symb, 'text': sent_text,\n", 453 | " 'height': int_to_height_map[int_symb] + 1,\n", 454 | " })\n", 455 | " \n", 456 | " return heights, conclusion_to_antecedent_int_map, sents_to_int_map, int_to_height_map\n", 457 | "\n", 458 | "def compute_retrieval_metrics(retrieved_sentences_lst, split = 'test', verbose = False):\n", 459 | " dataset = entail_dataset.data['task_1'][split]\n", 460 | " context_mapping = create_context_mapping(dataset)\n", 461 | " \n", 462 | " assert len(retrieved_sentences_lst) == len(dataset)\n", 463 | " \n", 464 | " tot_sent = 0\n", 465 | " tot_sent_correct = 0\n", 466 | " tot_sent_missing = 0\n", 467 | " tot_no_missing = 0\n", 468 | " tot_sent_not_in_wt = 0\n", 469 | " tot_missing_sent_height = 0\n", 470 | " tot_correct_sent_height = 0\n", 471 | " \n", 472 | " correct_retrieved_lst = []\n", 473 | " errors_lst = [] # in retreived but not in gold\n", 474 | " missing_lst = [] # in gold but not in retrieved\n", 475 | " \n", 476 | " for ret_sentences, dp_context_mapping, datapoint in zip(retrieved_sentences_lst, context_mapping, dataset):\n", 477 | " correct_retrieved = []\n", 478 | " errors = []\n", 479 | " for ret_sentence in ret_sentences:\n", 480 | " is_correct = False\n", 481 | " for mapping_texts in dp_context_mapping.values():\n", 482 | " if ret_sentence in mapping_texts.values():\n", 483 | " is_correct = True\n", 484 | " if mapping_texts['text'] not in correct_retrieved:\n", 485 | " correct_retrieved.append(mapping_texts['text'])\n", 486 | " if len(mapping_texts['wt_p_text_uuid']) < 2:\n", 487 | " tot_sent_not_in_wt += 1\n", 488 | " break\n", 489 | " if not is_correct:\n", 490 | " errors.append(ret_sentence)\n", 491 | " all_sents = [v['text'] for v in dp_context_mapping.values()]\n", 492 | " missing = list(set(all_sents) - set(correct_retrieved))\n", 493 | " \n", 494 | " correct_retrieved_lst.append(correct_retrieved)\n", 495 | " errors_lst.append(errors)\n", 496 | " missing_lst.append(missing)\n", 497 | " \n", 498 | " tot_sent += len(dp_context_mapping.keys())\n", 499 | " tot_sent_correct += len(correct_retrieved)\n", 500 | " tot_sent_missing += len(missing)\n", 501 | " tot_no_missing += 0 if len(missing) > 0 else 1\n", 502 | " \n", 503 | " missing_heights, _, _, _ = get_sents_height_on_tree(missing, datapoint)\n", 504 | " tot_missing_sent_height += sum([mh['height'] for mh in missing_heights])\n", 505 | " correct_heights, _, _, _ = get_sents_height_on_tree(correct_retrieved, datapoint)\n", 506 | " tot_correct_sent_height += sum([ch['height'] for ch in correct_heights])\n", 507 | " \n", 508 | " # if verbose and len(missing) > 0:\n", 509 | " if verbose and len(missing) > 0:\n", 510 | " hypothesis = datapoint['hypothesis']\n", 511 | " question = datapoint['question']\n", 512 | " answer = datapoint['answer']\n", 513 | " \n", 514 | " print('hypothesis', hypothesis)\n", 515 | " print('Q + A', question + ' -> ' + answer)\n", 516 | " print('=====')\n", 517 | " print('retrieved:', correct_retrieved)\n", 518 | " print('missing:', missing)\n", 519 | " print()\n", 520 | " \n", 521 | " \n", 522 | " recall = tot_sent_correct / float(tot_sent)\n", 523 | " all_correct = tot_no_missing / float(len(dataset))\n", 524 | " avg_correct_sent_height = tot_correct_sent_height / (float(tot_sent_correct) + 1e-9)\n", 525 | " avg_missing_sent_height = tot_missing_sent_height / (float(tot_sent_missing) + 1e-9)\n", 526 | " print('recall:', recall)\n", 527 | " print('all correct:', all_correct)\n", 528 | " print('number of retrieved not in corpus:', tot_sent_not_in_wt)\n", 529 | " print('avg height of correct sentences:', avg_correct_sent_height)\n", 530 | " print('avg height of missing sentences:', avg_missing_sent_height)\n", 531 | " \n", 532 | " return recall, correct_retrieved_lst, errors_lst, missing_lst" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": null, 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "def test_task_3_paper_recall():\n", 542 | " split = 'test'\n", 543 | " t3_data = entail_dataset.data['task_3'][split]\n", 544 | " retrieved_sentences_lst = [t3['meta']['triples'].values() for t3 in t3_data]\n", 545 | " compute_retrieval_metrics(retrieved_sentences_lst, split = split, verbose=False)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": null, 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "def test_sent_transformer_recall(split = 'test', verbose = True, top_k = 25):\n", 555 | " print('top_k', top_k)\n", 556 | " \n", 557 | " t1_data = entail_dataset.data['task_1'][split]\n", 558 | " ret_data = semantic_search.search([t1['hypothesis'] for t1 in t1_data], top_k = top_k)\n", 559 | " assert len(t1_data) == len(ret_data)\n", 560 | " \n", 561 | " return compute_retrieval_metrics(ret_data, split = split)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "markdown", 566 | "metadata": {}, 567 | "source": [ 568 | "# Retrieval Training" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": null, 574 | "metadata": {}, 575 | "outputs": [], 576 | "source": [ 577 | "def get_dataset_examples(split = 'train', use_hard_negative = True, hn_top_k = 25):\n", 578 | " # Define training / dev examples for sentence transformer. \n", 579 | " examples = []\n", 580 | " if use_hard_negative:\n", 581 | " # retrieved by current model but that is not part of the dataset\n", 582 | " _, _, errors, _ = test_sent_transformer_recall(\n", 583 | " split = split, verbose = False, top_k = hn_top_k)\n", 584 | " \n", 585 | " for t1_it, t1 in enumerate(tqdm(entail_dataset.data['task_1'][split])):\n", 586 | " hyp = t1['hypothesis']\n", 587 | " sents = t1['meta']['triples'].values()\n", 588 | " hard_negs = []\n", 589 | " for sent in sents:\n", 590 | " examples.append(InputExample(texts=[hyp, sent], label=1.0))\n", 591 | " \n", 592 | " if use_hard_negative:\n", 593 | " for error in errors[t1_it]:\n", 594 | " if error not in sents:\n", 595 | " hard_negs.append(error)\n", 596 | " examples.append(InputExample(texts=[hyp, error], label=0.75))\n", 597 | " \n", 598 | " neg_sents = random.sample(semantic_search.corpus, len(sents) * 4)\n", 599 | " for neg_sent in neg_sents:\n", 600 | " if neg_sent not in sents and neg_sent not in hard_negs:\n", 601 | " examples.append(InputExample(texts=[hyp, neg_sent], label=0.0)) \n", 602 | " return examples\n", 603 | "\n", 604 | "def train_sent_trans(get_examples_fn = get_dataset_examples):\n", 605 | "\n", 606 | " #Define the model. Either from scratch of by loading a pre-trained model\n", 607 | " model = semantic_search.encoder_model\n", 608 | " \n", 609 | " train_examples = get_examples_fn(split = 'train')\n", 610 | " dev_examples = get_examples_fn(split = 'dev') \n", 611 | " \n", 612 | " print('train_examples sz =', len(train_examples))\n", 613 | " print('dev_examples sz =', len(dev_examples))\n", 614 | " \n", 615 | " for example in train_examples[:10]:\n", 616 | " print(example)\n", 617 | " print()\n", 618 | " \n", 619 | " #Define your train dataset, the dataloader and the train loss\n", 620 | " train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)\n", 621 | " dev_dataloader = DataLoader(dev_examples, shuffle=False, batch_size=32)\n", 622 | " train_loss = losses.CosineSimilarityLoss(model)\n", 623 | " \n", 624 | " eval_sent_1 = [e.texts[0] for e in dev_examples]\n", 625 | " eval_sent_2 = [e.texts[1] for e in dev_examples]\n", 626 | " eval_scores = [e.label for e in dev_examples]\n", 627 | " \n", 628 | " evaluator = st_evaluation.EmbeddingSimilarityEvaluator(eval_sent_1, eval_sent_2, eval_scores)\n", 629 | " \n", 630 | " print('evaluating pre-trained model') \n", 631 | " results = model.evaluate(evaluator)\n", 632 | " print('results = ', results)\n", 633 | " \n", 634 | " callback = lambda score, epoch, steps: print('callback =', score, epoch, steps)\n", 635 | " \n", 636 | " print('fine-tunning model')\n", 637 | " #Tune the model\n", 638 | "\n", 639 | " model_trained_path = params.encoder_model_path % params.sent_trans_name\n", 640 | " checkpoint_path = params.encoder_checkpoint_path % params.sent_trans_name\n", 641 | " \n", 642 | " model.fit(train_objectives=[(train_dataloader, train_loss)], \n", 643 | " epochs=config.TRAIN_EPOCHS, warmup_steps=1000, \n", 644 | " save_best_model=True, output_path=model_trained_path, \n", 645 | " optimizer_params = {'lr': config.LEARNING_RATE},\n", 646 | " evaluator=evaluator, evaluation_steps=500,\n", 647 | " checkpoint_path=checkpoint_path,\n", 648 | " checkpoint_save_total_limit = 2,\n", 649 | " #callback = callback\n", 650 | " )\n", 651 | " model_trained_path_final = '../data/arc_entail/models/%s_fine_tuned_all_steps_v6' % params.sent_trans_name\n", 652 | " model.save(model_trained_path_final)\n", 653 | " \n", 654 | " print('evaluating fine-tuned model')\n", 655 | " results = model.evaluate(evaluator)\n", 656 | " print('results = ', results) " 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": null, 662 | "metadata": {}, 663 | "outputs": [], 664 | "source": [ 665 | "train_sent_trans(get_examples_fn = get_dataset_examples)" 666 | ] 667 | }, 668 | { 669 | "cell_type": "markdown", 670 | "metadata": {}, 671 | "source": [ 672 | "# Retrieval Evaluation" 673 | ] 674 | }, 675 | { 676 | "cell_type": "markdown", 677 | "metadata": {}, 678 | "source": [ 679 | "## EntailmentWriter evaluation" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": null, 685 | "metadata": {}, 686 | "outputs": [], 687 | "source": [ 688 | "test_task_3_paper_recall()" 689 | ] 690 | }, 691 | { 692 | "cell_type": "markdown", 693 | "metadata": {}, 694 | "source": [ 695 | "## Retrieval evaluation (single)" 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": null, 701 | "metadata": {}, 702 | "outputs": [], 703 | "source": [ 704 | "_, _, _, _ = test_sent_transformer_recall(split = 'test', verbose = False)" 705 | ] 706 | }, 707 | { 708 | "cell_type": "markdown", 709 | "metadata": {}, 710 | "source": [ 711 | "## Multi-step retrieval evaluation (conditional)" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": null, 717 | "metadata": {}, 718 | "outputs": [], 719 | "source": [ 720 | "def test_sent_transformer_recall(split = 'test', verbose = True, top_k = 25):\n", 721 | " t1_data = entail_dataset.data['task_1'][split]\n", 722 | " \n", 723 | " ret_data = []\n", 724 | " for _ in t1_data:\n", 725 | " ret_data.append([])\n", 726 | "# probes = [t1['hypothesis'] for t1 in t1_data]\n", 727 | " probes = [t1['question'] + ' ' + t1['answer'] for t1 in t1_data]\n", 728 | " keep_top_from_hyp = 15\n", 729 | " \n", 730 | " for k_step in range(1, top_k - keep_top_from_hyp + 1):\n", 731 | " temp_ret_data = semantic_search.search(probes, top_k = k_step)\n", 732 | " for ret_it, rets in enumerate(temp_ret_data):\n", 733 | " for ret in rets:\n", 734 | " if ret not in ret_data[ret_it]:\n", 735 | " ret_data[ret_it].append(ret)\n", 736 | " probes[ret_it] += ' ' + ret\n", 737 | " break\n", 738 | " \n", 739 | " # now gather \"keep_top_from_hyp\" by using only hypothesis as probe\n", 740 | " probes = [t1['hypothesis'] for t1 in t1_data]\n", 741 | "# probes = [t1['question'] + ' ' + t1['answer'] for t1 in t1_data]\n", 742 | " temp_ret_data = semantic_search.search(probes, top_k = top_k * 3)\n", 743 | " for ret_it, rets in enumerate(temp_ret_data):\n", 744 | " for ret in rets:\n", 745 | " if ret not in ret_data[ret_it]:\n", 746 | " ret_data[ret_it].append(ret)\n", 747 | " if len(ret_data[ret_it]) == top_k:\n", 748 | " break\n", 749 | " \n", 750 | " assert all([len(x) == top_k for x in ret_data])\n", 751 | " \n", 752 | " compute_retrieval_metrics(ret_data, split = split, verbose = verbose)\n", 753 | "\n", 754 | "# _, _, _, _ = test_sent_transformer_recall(split = 'test', verbose = False)\n", 755 | "test_sent_transformer_recall(split = 'test', verbose = False, top_k = 25)" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "metadata": {}, 762 | "outputs": [], 763 | "source": [] 764 | } 765 | ], 766 | "metadata": { 767 | "kernelspec": { 768 | "display_name": "Environment (conda_csr)", 769 | "language": "python", 770 | "name": "conda_csr" 771 | }, 772 | "language_info": { 773 | "codemirror_mode": { 774 | "name": "ipython", 775 | "version": 3 776 | }, 777 | "file_extension": ".py", 778 | "mimetype": "text/x-python", 779 | "name": "python", 780 | "nbconvert_exporter": "python", 781 | "pygments_lexer": "ipython3", 782 | "version": "3.8.10" 783 | } 784 | }, 785 | "nbformat": 4, 786 | "nbformat_minor": 4 787 | } 788 | -------------------------------------------------------------------------------- /src/retrieval_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import json 4 | import re 5 | import rouge 6 | import string 7 | from collections import Counter 8 | 9 | def sent_text_as_counter(text): 10 | text = text.lower().replace(' \'', '\'') 11 | text = re.sub(r"[.,;/\\?]", "", text) 12 | tokens = text.split() 13 | return Counter(tokens) 14 | 15 | def create_sentence_uuid_map_from_corpus(world_tree_file): 16 | uuid_to_sent = {} 17 | with open(world_tree_file, 'r') as file: 18 | line = file.readline() 19 | uuid_to_sent = json.loads(line) 20 | sent_txt_to_uuid = [] 21 | num_uuid = 1 22 | for uuid, wt_text in uuid_to_sent.items(): 23 | sent_txt_to_uuid.append({ 24 | 'text': wt_text, 25 | 'text_counter': sent_text_as_counter(wt_text), 26 | 'uuid': uuid, 27 | 'num_uuid': str(num_uuid) 28 | }) 29 | num_uuid += 1 30 | return sent_txt_to_uuid 31 | 32 | def convert_datapoint_to_sent_to_text(datapoint): 33 | ''' 34 | creates a mapping from sentences in context to their text 35 | (e.g. {'sent1': {'text': 'leo is a kind of constellation', 'text_counter': [...]}}) 36 | ''' 37 | context = datapoint['context'] 38 | matches = list(re.finditer("(sent)[0-9]+:", context)) 39 | sent_to_text = {} 40 | for match_idx, match in enumerate(matches): 41 | sent_match = match.group() 42 | sent_symb = sent_match[:-1] # remove the ':' in "sentX:' 43 | sent_span = match.span() 44 | start_pos = sent_span[0] + len(sent_match) 45 | end_pos = None 46 | if match_idx + 1 < len(matches): 47 | end_pos = matches[match_idx + 1].span()[0] 48 | sent_text = context[start_pos: end_pos].strip() 49 | sent_to_text[sent_symb] = { 50 | 'text': sent_text, 51 | 'text_counter': sent_text_as_counter(sent_text), 52 | } 53 | return sent_to_text 54 | 55 | def search_for_sent_uuid(probe, sent_txt_to_uuid): 56 | ''' 57 | Returns the uuid from worldtree corpus that best match text represented by 58 | probe_counter input 59 | ''' 60 | probe_counter = probe['text_counter'] 61 | best_uuid = None 62 | best_match = None 63 | best_match_score = 0 64 | for wt_item in sent_txt_to_uuid: 65 | wt_counter = wt_item['text_counter']; wt_uuid = wt_item['num_uuid'] 66 | match_counter = wt_counter & probe_counter 67 | match_score = sum(match_counter.values()) 68 | if match_score > best_match_score: 69 | best_uuid = wt_uuid 70 | best_match = wt_item 71 | best_match_score = match_score 72 | ratio = float(best_match_score) / float(sum(probe_counter.values())) 73 | if ratio < 0.8: 74 | print('ratio',ratio) 75 | print('probe', probe) 76 | print('best_match_counter', best_match) 77 | print() 78 | 79 | return best_uuid 80 | 81 | def convert_sentences_num_to_uuid(expression, sent_to_text, sent_txt_to_uuid): 82 | ''' 83 | Inputs: 84 | - expression: text containing sentence symbol (e.g. 'sent1 & sen2 -> hypothesis') 85 | - sent_to_text: dictionary mapping sentence symbols to sentence text 86 | - sent_txt_to_uuid: dictionary mapping sentence text to worldtree uuid 87 | ''' 88 | new_expression = expression 89 | 90 | # print('sent_to_text', sent_to_text) 91 | matches = list(re.finditer("(sent)[0-9]+ ", expression)) 92 | for match_idx, match in enumerate(matches): 93 | sent_symb = match.group() 94 | if sent_symb[:-1] in sent_to_text: 95 | sent_text_item = sent_to_text[sent_symb[:-1]] 96 | sent_uuid = search_for_sent_uuid(sent_text_item, sent_txt_to_uuid) 97 | else: 98 | sent_uuid = 1 99 | new_expression = new_expression.replace(sent_symb, f"sent{sent_uuid} " ) 100 | return new_expression 101 | 102 | def convert_datapoint_sent_to_uuid(datapoint, world_tree_file='data/arc_entail/supporting_data/worldtree_corpus_sentences_extended.json'): 103 | ''' 104 | converts sentence symbols (e.g. 'sent1') in datapoint text to uuid (e.g. 'sent0239-6af2-d042-caf6') 105 | ''' 106 | sent_txt_to_uuid = create_sentence_uuid_map_from_corpus(world_tree_file) 107 | sent_to_text = convert_datapoint_to_sent_to_text(datapoint) 108 | new_datapoint = dict(datapoint) 109 | new_datapoint['proof'] = convert_sentences_num_to_uuid( 110 | new_datapoint['proof'], sent_to_text, sent_txt_to_uuid) 111 | return new_datapoint --------------------------------------------------------------------------------