├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── images
├── system_overview.png
└── task_definition.png
├── irgr.yml
├── public_reposiroty_irgr_final_v1.zip
├── requirements.txt
├── setup.py
└── src
├── __init__.py
├── base_utils.py
├── entailment_bank
├── NOTICE
├── README.md
├── __init__.py
├── eval
│ ├── __init__.py
│ └── run_scorer.py
└── utils
│ ├── __init__.py
│ ├── angle_utils.py
│ ├── entail_trees_utils.py
│ ├── eval_utils.py
│ └── proof_utils.py
├── entailment_baseline.py
├── entailment_iterative.ipynb
├── entailment_retrieval.ipynb
└── retrieval_utils.py
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Iterative Retrieval-Generation Reasoner (NAACL 2022)
2 |
3 | This repository contains the code and data for the paper:
4 |
5 | [Entailment Tree Explanations via Iterative Retrieval-Generation Reasoner](https://assets.amazon.science/6e/5d/b055e2644da985a210e15b825422/entailment-tree-explanations-via-iterative-retrieval-generation-reasoner.pdf)
6 |
7 | **Entailment Trees** represents a chain of reasoning that shows how a hypothesis (or an answer to a question) can be explained from simpler textual evidence.
8 |
9 |
10 |
11 |
12 |
13 | **Iterative Retrieval-Generation Reasoner** our proposed architecture that iteratively searches for suitable premises, constructing a single entailment step at a time. At every generation step, the model searches for a distinct set of premises that will support the generation of a single step, therefore mitigating the language model’s input size limit and improving generation correctness.
14 |
15 |
16 |
17 |
18 |
19 | ## Setting Up Environemnt
20 |
21 | First you need to install the dependencies of the project:
22 |
23 | ```bash
24 | conda env create --file irgr.yml
25 | pip install -r requirements.txt
26 | ```
27 |
28 | Then activate your conda environment:
29 |
30 | ```bash
31 | conda activate irgr
32 | ```
33 |
34 | ### Setting up Jupyter
35 |
36 | Most of the code is wrapped inside Jupyter Notebooks.
37 |
38 | You can either start a Jupyter server locally or follow the AWS instructions on how to setup the jupyter notebook on EC2 instances and access it through your browser:
39 |
40 | https://docs.aws.amazon.com/dlami/latest/devguide/setup-jupyter.html
41 |
42 | ### Data Folder Structure
43 |
44 | You can download the [Entailment Bank](https://allenai.org/data/entailmentbank) data and evaluation code by running:
45 |
46 | ```
47 | python setup.py
48 | ```
49 |
50 | ### Running Experiments
51 |
52 | You can re-start the kernel and run the whole notebook to execute data-loading / training / evaluation
53 |
54 | dada loading has to be done before executing training and evaluation.
55 |
56 | #### Entailment Tree Generation
57 |
58 | The main model's code is in `src/entailment_iterative.ipynb`.
59 |
60 | The model can generate explanations and proofs for EntailmentBank dataset.
61 |
62 | #### Premise Retrieval
63 |
64 | The main model's code is in `src/entailment_retrieval.ipynb`.
65 |
66 | This model retrieves a set of premises from the corpus. Training uses the EntailmentBank + World Tree V2 corpus.
67 |
68 | ## Citation
69 |
70 | ```
71 | @inproceedings{neves-ribeiro-etal-2022-entailment,
72 | title = "Entailment Tree Explanations via Iterative Retrieval-Generation Reasoner",
73 | author = "Neves Ribeiro, Danilo and
74 | Wang, Shen and
75 | Ma, Xiaofei and
76 | Dong, Rui and
77 | Wei, Xiaokai and
78 | Zhu, Henghui and
79 | Chen, Xinchi and
80 | Xu, Peng and
81 | Huang, Zhiheng and
82 | Arnold, Andrew and
83 | Roth, Dan",
84 | booktitle = "Findings of the Association for Computational Linguistics: NAACL 2022",
85 | month = jul,
86 | year = "2022",
87 | address = "Seattle, United States",
88 | publisher = "Association for Computational Linguistics",
89 | url = "https://aclanthology.org/2022.findings-naacl.35",
90 | doi = "10.18653/v1/2022.findings-naacl.35",
91 | pages = "465--475",
92 | abstract = "Large language models have achieved high performance on various question answering (QA) benchmarks, but the explainability of their output remains elusive. Structured explanations, called entailment trees, were recently suggested as a way to explain the reasoning behind a QA system{'}s answer. In order to better generate such entailment trees, we propose an architecture called Iterative Retrieval-Generation Reasoner (IRGR). Our model is able to explain a given hypothesis by systematically generating a step-by-step explanation from textual premises. The IRGR model iteratively searches for suitable premises, constructing a single entailment step at a time. Contrary to previous approaches, our method combines generation steps and retrieval of premises, allowing the model to leverage intermediate conclusions, and mitigating the input size limit of baseline encoder-decoder models. We conduct experiments using the EntailmentBank dataset, where we outperform existing benchmarks on premise retrieval and entailment tree generation, with around 300{\%} gain in overall correctness.",
93 | }
94 | ```
95 |
--------------------------------------------------------------------------------
/images/system_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/images/system_overview.png
--------------------------------------------------------------------------------
/images/task_definition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/images/task_definition.png
--------------------------------------------------------------------------------
/irgr.yml:
--------------------------------------------------------------------------------
1 | name: irgr
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - anaconda
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=conda_forge
10 | - _openmp_mutex=4.5=1_gnu
11 | - anyio=3.3.0=py38h578d9bd_0
12 | - argon2-cffi=20.1.0=py38h497a2fe_2
13 | - async_generator=1.10=py_0
14 | - attrs=21.2.0=pyhd8ed1ab_0
15 | - babel=2.9.1=pyh44b312d_0
16 | - backcall=0.2.0=pyh9f0ad1d_0
17 | - backports=1.0=py_2
18 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
19 | - blas=1.0=mkl
20 | - bleach=4.0.0=pyhd8ed1ab_0
21 | - brotlipy=0.7.0=py38h497a2fe_1001
22 | - ca-certificates=2020.10.14=0
23 | - certifi=2020.6.20=py38_0
24 | - cffi=1.14.6=py38ha65f79e_0
25 | - chardet=4.0.0=py38h578d9bd_1
26 | - charset-normalizer=2.0.0=pyhd8ed1ab_0
27 | - click=8.0.1=py38h578d9bd_0
28 | - colorama=0.4.4=pyh9f0ad1d_0
29 | - cryptography=3.4.7=py38ha5dfef3_0
30 | - cudatoolkit=11.1.74=h6bb024c_0
31 | - dataclasses=0.8=pyhc8e2a94_1
32 | - debugpy=1.4.1=py38h709712a_0
33 | - decorator=5.0.9=pyhd8ed1ab_0
34 | - defusedxml=0.7.1=pyhd8ed1ab_0
35 | - entrypoints=0.3=pyhd8ed1ab_1003
36 | - filelock=3.0.12=pyh9f0ad1d_0
37 | - huggingface_hub=0.0.15=pyhd8ed1ab_0
38 | - idna=3.1=pyhd3deb0d_0
39 | - importlib-metadata=4.6.3=py38h578d9bd_0
40 | - importlib_metadata=4.6.3=hd8ed1ab_0
41 | - intel-openmp=2021.3.0=h06a4308_3350
42 | - ipykernel=6.0.3=py38hd0cf306_0
43 | - ipython=7.26.0=py38he5a9106_0
44 | - ipython_genutils=0.2.0=py_1
45 | - ipywidgets=7.6.3=pyhd3deb0d_0
46 | - jedi=0.18.0=py38h578d9bd_2
47 | - jinja2=3.0.1=pyhd8ed1ab_0
48 | - joblib=1.0.1=pyhd8ed1ab_0
49 | - json5=0.9.5=pyh9f0ad1d_0
50 | - jsonschema=3.2.0=pyhd8ed1ab_3
51 | - jupyter_client=6.1.12=pyhd8ed1ab_0
52 | - jupyter_core=4.7.1=py38h578d9bd_0
53 | - jupyter_server=1.10.2=pyhd8ed1ab_0
54 | - jupyterlab=3.1.4=pyhd8ed1ab_0
55 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0
56 | - jupyterlab_server=2.7.0=pyhd8ed1ab_0
57 | - jupyterlab_widgets=1.0.0=pyhd8ed1ab_1
58 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
59 | - libblas=3.9.0=11_linux64_mkl
60 | - libcblas=3.9.0=11_linux64_mkl
61 | - libffi=3.3=h58526e2_2
62 | - libgcc-ng=11.1.0=hc902ee8_8
63 | - libgomp=11.1.0=hc902ee8_8
64 | - liblapack=3.9.0=11_linux64_mkl
65 | - libprotobuf=3.13.0.1=hd408876_0
66 | - libsodium=1.0.18=h36c2ea0_1
67 | - libstdcxx-ng=11.1.0=h56837e0_8
68 | - libuv=1.42.0=h7f98852_0
69 | - markupsafe=2.0.1=py38h497a2fe_0
70 | - matplotlib-inline=0.1.2=pyhd8ed1ab_2
71 | - mistune=0.8.4=py38h497a2fe_1004
72 | - mkl=2021.3.0=h06a4308_520
73 | - nbclassic=0.3.1=pyhd8ed1ab_1
74 | - nbclient=0.5.3=pyhd8ed1ab_0
75 | - nbconvert=6.1.0=py38h578d9bd_0
76 | - nbformat=5.1.3=pyhd8ed1ab_0
77 | - ncurses=6.2=h58526e2_4
78 | - nest-asyncio=1.5.1=pyhd8ed1ab_0
79 | - ninja=1.10.2=h4bd325d_0
80 | - notebook=6.4.3=pyha770c72_0
81 | - numpy=1.21.1=py38h9894fe3_0
82 | - openssl=1.1.1k=h7f98852_1
83 | - packaging=21.0=pyhd8ed1ab_0
84 | - pandas=1.1.3=py38he6710b0_0
85 | - pandoc=2.14.1=h7f98852_0
86 | - pandocfilters=1.4.2=py_1
87 | - parso=0.8.2=pyhd8ed1ab_0
88 | - pexpect=4.8.0=pyh9f0ad1d_2
89 | - pickleshare=0.7.5=py_1003
90 | - pip=21.2.3=pyhd8ed1ab_0
91 | - prometheus_client=0.11.0=pyhd8ed1ab_0
92 | - prompt-toolkit=3.0.19=pyha770c72_0
93 | - protobuf=3.13.0.1=py38he6710b0_1
94 | - ptyprocess=0.7.0=pyhd3deb0d_0
95 | - pycparser=2.20=pyh9f0ad1d_2
96 | - pygments=2.9.0=pyhd8ed1ab_0
97 | - pyopenssl=20.0.1=pyhd8ed1ab_0
98 | - pyparsing=2.4.7=pyh9f0ad1d_0
99 | - pyrsistent=0.17.3=py38h497a2fe_2
100 | - pysocks=1.7.1=py38h578d9bd_3
101 | - python=3.8.10=h49503c6_1_cpython
102 | - python-dateutil=2.8.2=pyhd8ed1ab_0
103 | - python_abi=3.8=2_cp38
104 | - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0
105 | - pytz=2021.1=pyhd8ed1ab_0
106 | - pyyaml=5.4.1=py38h497a2fe_0
107 | - pyzmq=22.2.1=py38h2035c66_0
108 | - readline=8.1=h46c0cb4_0
109 | - regex=2021.8.3=py38h497a2fe_0
110 | - requests=2.26.0=pyhd8ed1ab_0
111 | - requests-unixsocket=0.2.0=py_0
112 | - sacremoses=0.0.43=pyh9f0ad1d_0
113 | - send2trash=1.8.0=pyhd8ed1ab_0
114 | - sentencepiece=0.1.95=py38h1fd1430_0
115 | - setuptools=49.6.0=py38h578d9bd_3
116 | - six=1.16.0=pyh6c4a22f_0
117 | - sniffio=1.2.0=py38h578d9bd_1
118 | - sqlite=3.36.0=h9cd32fc_0
119 | - terminado=0.10.1=py38h578d9bd_0
120 | - testpath=0.5.0=pyhd8ed1ab_0
121 | - tk=8.6.10=h21135ba_1
122 | - tokenizers=0.10.1=py38hb63a372_0
123 | - tornado=6.1=py38h497a2fe_1
124 | - tqdm=4.62.0=pyhd8ed1ab_0
125 | - traitlets=5.0.5=py_0
126 | - transformers=4.9.2=pyhd8ed1ab_0
127 | - typing-extensions=3.10.0.0=hd8ed1ab_0
128 | - typing_extensions=3.10.0.0=pyha770c72_0
129 | - urllib3=1.26.6=pyhd8ed1ab_0
130 | - wcwidth=0.2.5=pyh9f0ad1d_2
131 | - webencodings=0.5.1=py_1
132 | - websocket-client=0.57.0=py38h578d9bd_4
133 | - wheel=0.37.0=pyhd8ed1ab_0
134 | - widgetsnbextension=3.5.1=py38h578d9bd_4
135 | - xz=5.2.5=h516909a_1
136 | - yaml=0.2.5=h516909a_0
137 | - zeromq=4.3.4=h9c3ff4c_0
138 | - zipp=3.5.0=pyhd8ed1ab_0
139 | - zlib=1.2.11=h516909a_1010
140 | - pip:
141 | - environment-kernels==1.1.1
142 | prefix: /home/ec2-user/anaconda3/envs/csr
143 |
--------------------------------------------------------------------------------
/public_reposiroty_irgr_final_v1.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/public_reposiroty_irgr_final_v1.zip
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | BLEURT @ git+https://github.com/google-research/bleurt.git@b610120347ef22b494b6d69b4316e303f5932516
2 | datasets==1.6.2
3 | deepspeed==0.5.0
4 | huggingface-hub==0.0.12
5 | multiprocess==0.70.12.2
6 | sagemaker==2.59.4
7 | scikit-learn==0.24.2
8 | scipy==1.7.1
9 | sentence-transformers==2.0.0
10 | sentencepiece==0.1.95
11 | tensorboardX
12 | tensorflow
13 | tensorflow-estimator
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import requests
4 | import shutil
5 | from pathlib import Path
6 |
7 | ENTAILMENT_BANK_REPO = 'https://github.com/allenai/entailment_bank.git'
8 | BLEURT_MODEL_PATH = 'bleurt-large-512.zip'
9 | BLEURT_MODEL_URL = 'https://storage.googleapis.com/bleurt-oss/bleurt-large-512.zip'
10 |
11 | ENT_BANK_DATASET_PATH = './entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/dataset/'
12 | DATASET_PATH = './data/arc_entail/dataset/'
13 | ENT_BANK_WT_CORPUS_PATH = './entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/supporting_data/'
14 | WT_CORPUS_PATH = './data/arc_entail/supporting_data/'
15 | ENT_BANK_SRC_DATA_PATH = './entailment_bank/data/'
16 | SRC_DATA_PATH = './src/entailment_bank/data/'
17 |
18 | BLEURT_FOLDER = './bleurt-large-512'
19 | SRC_BLEURT_FOLDER = './src/entailment_bank/bleurt-large-512'
20 |
21 | def prepare_path(path):
22 | path = Path(path)
23 | os.makedirs(path.parents[0], exist_ok=True)
24 | return path
25 |
26 | def doanlod_and_save_to_path(url, path):
27 | path = prepare_path(path)
28 | print(f'Downloading:\n{url}')
29 | response = requests.get(url)
30 | print(f'Saving to:\n{path}')
31 | open(path, 'wb').write(response.content)
32 |
33 | def copy_folders(from_path, to_path):
34 | print(f'{from_path} => {to_path}')
35 | from_path = prepare_path(from_path)
36 | to_path = prepare_path(to_path)
37 | shutil.copytree(from_path, to_path, dirs_exist_ok=True)
38 |
39 | def move_folders(from_path, to_path):
40 | print(f'{from_path} => {to_path}')
41 | from_path = prepare_path(from_path)
42 | to_path = prepare_path(to_path)
43 | shutil.move(from_path, to_path)
44 |
45 | def clone_git_repo(git_repo):
46 | # Cloning
47 | os.system(f'git clone {git_repo}')
48 |
49 | def unzip_file(path):
50 | path = Path(path)
51 | # Unzipping
52 | os.system(f'unzip {path}')
53 |
54 | def setup_entailment_bank_eval():
55 | print('\nCloning EntailmentBank evaluation repository...')
56 | clone_git_repo(ENTAILMENT_BANK_REPO)
57 | print('\nCopying files...')
58 | copy_folders(ENT_BANK_DATASET_PATH, DATASET_PATH)
59 | copy_folders(ENT_BANK_WT_CORPUS_PATH,WT_CORPUS_PATH)
60 | copy_folders(ENT_BANK_SRC_DATA_PATH, SRC_DATA_PATH)
61 |
62 | print('\nDownloading BLEURT model...')
63 | # downlaad bleurt model
64 | doanlod_and_save_to_path(BLEURT_MODEL_URL, BLEURT_MODEL_PATH)
65 | # unzip bleurt model
66 | unzip_file(BLEURT_MODEL_PATH)
67 | # copy files to src folder
68 | print('\nMoving model files...')
69 | move_folders(BLEURT_FOLDER, SRC_BLEURT_FOLDER)
70 |
71 | def main():
72 | setup_entailment_bank_eval()
73 | print('\nSetup finished!')
74 |
75 | if __name__ == '__main__':
76 | main()
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/src/__init__.py
--------------------------------------------------------------------------------
/src/base_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import sys
4 | import re
5 |
6 | ####################################################################
7 | # File Manipulation
8 | ####################################################################
9 |
10 | def save_to_jsonl_file(output_list, file_path):
11 | print('saving data to file:', file_path)
12 | with open(file_path, 'w') as file:
13 | for obj in output_list:
14 | file.write(json.dumps(obj) + '\n')
15 |
16 | def run_funtion_redirect_stdout(fun, args=[], kargs={}, filename='logs.txt'):
17 | '''
18 | Runs functions while redirecting output to file
19 | '''
20 | # redirect stdout to file
21 | orig_stdout = sys.stdout
22 | f = open(filename, 'w')
23 | sys.stdout = f
24 |
25 | # execute function
26 | output = fun(*args, **kargs)
27 |
28 | # restore stdout
29 | sys.stdout = orig_stdout
30 | f.close()
31 | return output
32 |
33 | ####################################################################
34 | # Strings
35 | ####################################################################
36 |
37 | def str_replace_single_pass(string, substitutions):
38 | '''
39 | A Python function that does multiple string replace ops in a single pass.
40 | E.g. substitutions = {"foo": "FOO", "bar": "BAR"}
41 | '''
42 | substrings = sorted(substitutions, key=len, reverse=True)
43 | regex = re.compile('|'.join(map(re.escape, substrings)))
44 | return regex.sub(lambda match: substitutions[match.group(0)], string)
45 |
46 |
47 | ####################################################################
48 | # Resource Locator
49 | ####################################################################
50 |
51 | def get_results_file_path(params, test_split=False, result_only=False,
52 | temp=False, uuid=False):
53 | suffix = ''
54 | if test_split:
55 | suffix += '_test'
56 | if result_only:
57 | suffix += '_result_only'
58 | if temp:
59 | suffix += '_temp'
60 | if uuid:
61 | suffix += '_uuid'
62 | results_file_path = params.results_file_path.format(
63 | model_name = params.model_name,
64 | task_name = params.task_name,
65 | dataset_name = params.dataset_name,
66 | approach_name = params.approach_name,
67 | suffix = suffix,
68 | extension = 'tsv')
69 | return results_file_path
70 |
71 | def get_logs_file_path(params, test_split=False, temp=False, epoch_num = None):
72 | suffix = ''
73 | prefix = ''
74 | if test_split:
75 | suffix += '_test'
76 | if temp:
77 | suffix += '_temp'
78 | prefix += 'temp/'
79 | if epoch_num is not None:
80 | suffix += f'_{epoch_num}'
81 | logs_file_path = params.logs_file_path.format(
82 | model_name = params.model_name,
83 | task_name = params.task_name,
84 | dataset_name = params.dataset_name,
85 | approach_name = params.approach_name,
86 | prefix = prefix,
87 | suffix = suffix
88 | )
89 | return logs_file_path
90 |
91 | def get_proof_ranking_data_path(params, split = 'dev'):
92 | logs_file_path = params.proof_ranking_data_filepath.format(
93 | model_name = params.model_name,
94 | task_name = params.task_name,
95 | dataset_name = params.dataset_name,
96 | approach_name = params.approach_name,
97 | split = split
98 | )
99 | return logs_file_path
--------------------------------------------------------------------------------
/src/entailment_bank/NOTICE:
--------------------------------------------------------------------------------
1 | ** entailment_bank; version 1 -- https://github.com/allenai/entailment_bank
2 | Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0
3 | License.
4 |
5 | Apache License
6 | Version 2.0, January 2004
7 | http://www.apache.org/licenses/
8 |
9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10 |
11 | 1. Definitions.
12 |
13 | "License" shall mean the terms and conditions for use, reproduction,
14 | and distribution as defined by Sections 1 through 9 of this document.
15 |
16 | "Licensor" shall mean the copyright owner or entity authorized by
17 | the copyright owner that is granting the License.
18 |
19 | "Legal Entity" shall mean the union of the acting entity and all
20 | other entities that control, are controlled by, or are under common
21 | control with that entity. For the purposes of this definition,
22 | "control" means (i) the power, direct or indirect, to cause the
23 | direction or management of such entity, whether by contract or
24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
25 | outstanding shares, or (iii) beneficial ownership of such entity.
26 |
27 | "You" (or "Your") shall mean an individual or Legal Entity
28 | exercising permissions granted by this License.
29 |
30 | "Source" form shall mean the preferred form for making modifications,
31 | including but not limited to software source code, documentation
32 | source, and configuration files.
33 |
34 | "Object" form shall mean any form resulting from mechanical
35 | transformation or translation of a Source form, including but
36 | not limited to compiled object code, generated documentation,
37 | and conversions to other media types.
38 |
39 | "Work" shall mean the work of authorship, whether in Source or
40 | Object form, made available under the License, as indicated by a
41 | copyright notice that is included in or attached to the work
42 | (an example is provided in the Appendix below).
43 |
44 | "Derivative Works" shall mean any work, whether in Source or Object
45 | form, that is based on (or derived from) the Work and for which the
46 | editorial revisions, annotations, elaborations, or other modifications
47 | represent, as a whole, an original work of authorship. For the purposes
48 | of this License, Derivative Works shall not include works that remain
49 | separable from, or merely link (or bind by name) to the interfaces of,
50 | the Work and Derivative Works thereof.
51 |
52 | "Contribution" shall mean any work of authorship, including
53 | the original version of the Work and any modifications or additions
54 | to that Work or Derivative Works thereof, that is intentionally
55 | submitted to Licensor for inclusion in the Work by the copyright owner
56 | or by an individual or Legal Entity authorized to submit on behalf of
57 | the copyright owner. For the purposes of this definition, "submitted"
58 | means any form of electronic, verbal, or written communication sent
59 | to the Licensor or its representatives, including but not limited to
60 | communication on electronic mailing lists, source code control systems,
61 | and issue tracking systems that are managed by, or on behalf of, the
62 | Licensor for the purpose of discussing and improving the Work, but
63 | excluding communication that is conspicuously marked or otherwise
64 | designated in writing by the copyright owner as "Not a Contribution."
65 |
66 | "Contributor" shall mean Licensor and any individual or Legal Entity
67 | on behalf of whom a Contribution has been received by Licensor and
68 | subsequently incorporated within the Work.
69 |
70 | 2. Grant of Copyright License. Subject to the terms and conditions of
71 | this License, each Contributor hereby grants to You a perpetual,
72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73 | copyright license to reproduce, prepare Derivative Works of,
74 | publicly display, publicly perform, sublicense, and distribute the
75 | Work and such Derivative Works in Source or Object form.
76 |
77 | 3. Grant of Patent License. Subject to the terms and conditions of
78 | this License, each Contributor hereby grants to You a perpetual,
79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80 | (except as stated in this section) patent license to make, have made,
81 | use, offer to sell, sell, import, and otherwise transfer the Work,
82 | where such license applies only to those patent claims licensable
83 | by such Contributor that are necessarily infringed by their
84 | Contribution(s) alone or by combination of their Contribution(s)
85 | with the Work to which such Contribution(s) was submitted. If You
86 | institute patent litigation against any entity (including a
87 | cross-claim or counterclaim in a lawsuit) alleging that the Work
88 | or a Contribution incorporated within the Work constitutes direct
89 | or contributory patent infringement, then any patent licenses
90 | granted to You under this License for that Work shall terminate
91 | as of the date such litigation is filed.
92 |
93 | 4. Redistribution. You may reproduce and distribute copies of the
94 | Work or Derivative Works thereof in any medium, with or without
95 | modifications, and in Source or Object form, provided that You
96 | meet the following conditions:
97 |
98 | (a) You must give any other recipients of the Work or
99 | Derivative Works a copy of this License; and
100 |
101 | (b) You must cause any modified files to carry prominent notices
102 | stating that You changed the files; and
103 |
104 | (c) You must retain, in the Source form of any Derivative Works
105 | that You distribute, all copyright, patent, trademark, and
106 | attribution notices from the Source form of the Work,
107 | excluding those notices that do not pertain to any part of
108 | the Derivative Works; and
109 |
110 | (d) If the Work includes a "NOTICE" text file as part of its
111 | distribution, then any Derivative Works that You distribute must
112 | include a readable copy of the attribution notices contained
113 | within such NOTICE file, excluding those notices that do not
114 | pertain to any part of the Derivative Works, in at least one
115 | of the following places: within a NOTICE text file distributed
116 | as part of the Derivative Works; within the Source form or
117 | documentation, if provided along with the Derivative Works; or,
118 | within a display generated by the Derivative Works, if and
119 | wherever such third-party notices normally appear. The contents
120 | of the NOTICE file are for informational purposes only and
121 | do not modify the License. You may add Your own attribution
122 | notices within Derivative Works that You distribute, alongside
123 | or as an addendum to the NOTICE text from the Work, provided
124 | that such additional attribution notices cannot be construed
125 | as modifying the License.
126 |
127 | You may add Your own copyright statement to Your modifications and
128 | may provide additional or different license terms and conditions
129 | for use, reproduction, or distribution of Your modifications, or
130 | for any such Derivative Works as a whole, provided Your use,
131 | reproduction, and distribution of the Work otherwise complies with
132 | the conditions stated in this License.
133 |
134 | 5. Submission of Contributions. Unless You explicitly state otherwise,
135 | any Contribution intentionally submitted for inclusion in the Work
136 | by You to the Licensor shall be under the terms and conditions of
137 | this License, without any additional terms or conditions.
138 | Notwithstanding the above, nothing herein shall supersede or modify
139 | the terms of any separate license agreement you may have executed
140 | with Licensor regarding such Contributions.
141 |
142 | 6. Trademarks. This License does not grant permission to use the trade
143 | names, trademarks, service marks, or product names of the Licensor,
144 | except as required for reasonable and customary use in describing the
145 | origin of the Work and reproducing the content of the NOTICE file.
146 |
147 | 7. Disclaimer of Warranty. Unless required by applicable law or
148 | agreed to in writing, Licensor provides the Work (and each
149 | Contributor provides its Contributions) on an "AS IS" BASIS,
150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151 | implied, including, without limitation, any warranties or conditions
152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153 | PARTICULAR PURPOSE. You are solely responsible for determining the
154 | appropriateness of using or redistributing the Work and assume any
155 | risks associated with Your exercise of permissions under this License.
156 |
157 | 8. Limitation of Liability. In no event and under no legal theory,
158 | whether in tort (including negligence), contract, or otherwise,
159 | unless required by applicable law (such as deliberate and grossly
160 | negligent acts) or agreed to in writing, shall any Contributor be
161 | liable to You for damages, including any direct, indirect, special,
162 | incidental, or consequential damages of any character arising as a
163 | result of this License or out of the use or inability to use the
164 | Work (including but not limited to damages for loss of goodwill,
165 | work stoppage, computer failure or malfunction, or any and all
166 | other commercial damages or losses), even if such Contributor
167 | has been advised of the possibility of such damages.
168 |
169 | 9. Accepting Warranty or Additional Liability. While redistributing
170 | the Work or Derivative Works thereof, You may choose to offer,
171 | and charge a fee for, acceptance of support, warranty, indemnity,
172 | or other liability obligations and/or rights consistent with this
173 | License. However, in accepting such obligations, You may act only
174 | on Your own behalf and on Your sole responsibility, not on behalf
175 | of any other Contributor, and only if You agree to indemnify,
176 | defend, and hold each Contributor harmless for any liability
177 | incurred by, or claims asserted against, such Contributor by reason
178 | of your accepting any such warranty or additional liability.
179 |
180 | END OF TERMS AND CONDITIONS
181 |
182 | APPENDIX: How to apply the Apache License to your work.
183 |
184 | To apply the Apache License to your work, attach the following
185 | boilerplate notice, with the fields enclosed by brackets "[]"
186 | replaced with your own identifying information. (Don't include
187 | the brackets!) The text should be enclosed in the appropriate
188 | comment syntax for the file format. We also recommend that a
189 | file or class name and description of purpose be included on the
190 | same "printed page" as the copyright notice for easier
191 | identification within third-party archives.
192 |
193 | Copyright [yyyy] [name of copyright owner]
194 |
195 | Licensed under the Apache License, Version 2.0 (the "License");
196 | you may not use this file except in compliance with the License.
197 | You may obtain a copy of the License at
198 |
199 | http://www.apache.org/licenses/LICENSE-2.0
200 |
201 | Unless required by applicable law or agreed to in writing, software
202 | distributed under the License is distributed on an "AS IS" BASIS,
203 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
204 | See the License for the specific language governing permissions and
205 | limitations under the License.
--------------------------------------------------------------------------------
/src/entailment_bank/README.md:
--------------------------------------------------------------------------------
1 | # Entailment Bank Evaluation
2 |
3 | This evaluation code is part of the entailment bank repository available at:
4 |
5 | [https://github.com/allenai/entailment_bank](https://github.com/allenai/entailment_bank)
6 |
7 | Small changes made to support further metrics and small improvements in "task-3" evaluation.
--------------------------------------------------------------------------------
/src/entailment_bank/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/src/entailment_bank/__init__.py
--------------------------------------------------------------------------------
/src/entailment_bank/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/src/entailment_bank/eval/__init__.py
--------------------------------------------------------------------------------
/src/entailment_bank/eval/run_scorer.py:
--------------------------------------------------------------------------------
1 | # Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 License.
2 | # Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 |
4 | """ Evaluation script for EntailmentBank models. """
5 |
6 | import argparse
7 | import glob
8 | import logging
9 | import json
10 | import os
11 | import re
12 | import sys
13 | from bleurt import score
14 |
15 | from tqdm import tqdm
16 |
17 | sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
18 |
19 | # 2022-05-26: Amazon addition.
20 | from entailment_bank.utils.angle_utils import decompose_slots, load_jsonl, save_json, shortform_angle, formatting
21 | from entailment_bank.utils.eval_utils import collate_scores, score_prediction_whole_proof
22 | from entailment_bank.utils.retrieval_utils import convert_datapoint_sent_to_uuid
23 | # End of Amazon addition.
24 |
25 | logger = logging.getLogger(__name__)
26 | logging.basicConfig(level=logging.INFO)
27 |
28 |
29 | def get_args():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--task", default=None, required=True, type=str,
32 | help="Task name: task_1, task_2, task_3")
33 | parser.add_argument("--output_dir", default=None, required=True, type=str,
34 | help="Directory to store scores.")
35 | parser.add_argument("--split", default=None, required=True, type=str, help="Which split (train/dev/test) to evaluate.")
36 | parser.add_argument("--prediction_file", default=None, required=True, type=str,
37 | help="Prediction file(s) to score.")
38 | parser.add_argument("--bleurt_checkpoint", default="true", type=str,
39 | help="Path to the BLEURT model checkpoint (Download from https://github.com/google-research/bleurt#checkpoints) "
40 | "We use bleurt-large-512 model for EntailmentBank evaluation")
41 |
42 | args = parser.parse_args()
43 | return args
44 |
45 |
46 | def split_info_sentences(context):
47 | words_list = context.split(" ")
48 | sentence_ids = re.findall(r'[\w\.]+[0-9]+:', context)
49 | sentence_dict = dict()
50 | prev_sid = ""
51 | prev_sentence_parts = []
52 | for word in words_list:
53 | if word in sentence_ids:
54 | if prev_sid:
55 | sentence_dict[prev_sid] = ' '.join(prev_sentence_parts)
56 | prev_sentence_parts = []
57 | prev_sid = word
58 | else:
59 | prev_sentence_parts.append(word)
60 |
61 | if prev_sid:
62 | sentence_dict[prev_sid] = ' '.join(prev_sentence_parts)
63 | return sentence_dict
64 | # 2022-05-26: Amazon addition.
65 | def score_predictions(args, predictions_file, score_file, gold_file, angle_file=None, dataset=None, bleurt_checkpoint="", task_name = None):
66 | # End of Amazon addition.
67 | if args.bleurt_checkpoint:
68 | bleurt_scorer = score.BleurtScorer(bleurt_checkpoint)
69 | else:
70 | bleurt_scorer = None
71 |
72 | gold_data = load_jsonl(gold_file)
73 |
74 | # 2022-05-26: Amazon addition.
75 | ###### to fix task_3 evaluation
76 | print('\n\task_name =', task_name)
77 | if task_name == 'task_3':
78 | print('UPDATING DATAPOINTS')
79 | gold_data = [convert_datapoint_sent_to_uuid(g) for g in gold_data]
80 | ######
81 | # End of Amazon addition.
82 |
83 | gold_by_id = {g['id']: g for g in gold_data}
84 |
85 | gold_train_file = gold_file.replace("dev", "train")
86 | gold_train_data = load_jsonl(gold_train_file)
87 | # gold_train_by_id = {g['id']: g for g in gold_train_data}
88 | train_context_dict = dict()
89 | train_answers_dict = dict()
90 | for g in gold_train_data:
91 | # print(f"{g['meta']['triples']}")
92 | context_dict = dict(g['meta']['triples'])
93 | for context in context_dict.values():
94 | train_context_dict[context] = 1
95 | train_answers_dict[g['answer']] = 1
96 |
97 | is_jsonl = predictions_file.endswith(".jsonl")
98 | if not is_jsonl:
99 | angle_data = load_jsonl(angle_file)
100 | scores = []
101 | sort_angle = False
102 |
103 | num_dev_answers = 0
104 | num_dev_answers_seen_in_train_context = 0
105 | num_dev_answers_seen_in_train_answers = 0
106 | diagnostics_tsv = open(score_file+".diagnostics.tsv", "w")
107 |
108 | # 2022-05-26: Amazon addition.
109 | diagnostics_tsv.write(f"Q-ID"
110 | f"\tI:Context"
111 | f"\tI:Question"
112 | f"\tI:Answer"
113 | f"\tI:Hypothesis"
114 | f"\tO:Gold Proof"
115 | f"\tO:Predicted Proof"
116 | f"\tPredicted to gold int alignment"
117 | f"\tRewritten predicted proof after alignment"
118 | f"\t% Leaves-P"
119 | f"\t% Leaves-R"
120 | f"\t% Leaves-F1"
121 | f"\t% Leaves-Correct"
122 | f"\t% Steps-F1"
123 | f"\t% Steps-Correct"
124 | f"\t% Interm-BLEURT-P"
125 | f"\t% Interm-BLEURT-R"
126 | f"\t% Interm-BLEURT-F1"
127 | f"\tInterm-BLEURT-score"
128 | f"\t% Interm-BLEURT-Acc"
129 | f"\t% perfect alignment"
130 | f"\t% Overall Correct"
131 | f"\tNum Distractors"
132 | f"\tContext Length"
133 | f"\tFraction of distractors"
134 | f"\tDistractor Ids"
135 | f"\tDepth of Proof"
136 | f"\tLength of Proof"
137 | "\n")
138 | # End of Amazon addition.
139 |
140 | with open(f"{score_file}.json", "w") as score_file, open(predictions_file, "r") as preds_file:
141 | for line_idx, line in tqdm(enumerate(preds_file)):
142 | if is_jsonl:
143 | pred = json.loads(line.strip())
144 | else:
145 | pred = {'id': angle_data[line_idx]['id'],
146 | 'angle': angle_data[line_idx]['angle'],
147 | 'prediction': line.strip()}
148 |
149 | angle = pred['angle']
150 | angle_canonical = shortform_angle(angle, sort_angle=sort_angle)
151 | pred['angle_str'] = angle_canonical
152 | item_id = pred['id']
153 |
154 | # if item_id not in ['CSZ20680']:
155 | # continue
156 |
157 | if item_id not in gold_by_id:
158 | continue
159 | raise ValueError(f"Missing id in gold data: {item_id}")
160 | slots = decompose_slots(pred['prediction'])
161 |
162 | pred['slots'] = slots
163 |
164 | num_dev_answers += 1
165 | # print(f"======= pred: {pred}")
166 | # print(f">>>>>>>>>>>> id:{item_id}")
167 | metrics = score_prediction_whole_proof(pred, gold_by_id[item_id], dataset,
168 | scoring_spec={
169 | "hypothesis_eval": "nlg",
170 | "proof_eval": "entail_whole_proof_align_eval",
171 | # "proof_eval": "entail_whole_polish_proof_align_eval",
172 | },
173 | bleurt_scorer=bleurt_scorer)
174 |
175 | pred['metrics'] = metrics
176 | score_file.write(json.dumps(pred) + "\n")
177 | id = angle_data[line_idx]['id']
178 | goldslot_record = gold_by_id[id]
179 | # print(f"goldslot_record:{goldslot_record}")
180 | question_before_json = ""
181 | if 'meta' in goldslot_record and 'question' in goldslot_record['meta']:
182 | question_before_json = goldslot_record['meta']['question']
183 |
184 | question_json = {}
185 | if 'meta' in goldslot_record and 'question' in goldslot_record['meta']:
186 | question_json = goldslot_record['meta']['question']
187 |
188 | question_json['gold_proofs'] = question_json.get('proofs', "")
189 | question_json['proofs'] = ""
190 |
191 | hypothesis_f1 = metrics.get('hypothesis', dict()).get('ROUGE_L_F', -1)
192 | question_json['ROUGE_L_F'] = hypothesis_f1
193 |
194 | sentences_dict = split_info_sentences(goldslot_record['context'])
195 | sentence_set = []
196 | for sid, sent in sentences_dict.items():
197 | sentence_set.append(f"{sid}: {sent}")
198 | sent_str = formatting(sentence_set)
199 |
200 | gold_triples = goldslot_record['meta']['triples']
201 | gold_ints = goldslot_record.get('meta', dict()).get('intermediate_conclusions', dict())
202 | gold_ints['hypothesis'] = goldslot_record['hypothesis']
203 | gold_triples.update(gold_ints)
204 | gold_proof_str = goldslot_record['proof']
205 | # if '; ' in gold_proof_str:
206 | if True:
207 | gold_proof_steps = gold_proof_str.split(';')
208 |
209 | gold_proof_str_list = []
210 | for step in gold_proof_steps:
211 | step = step.strip()
212 | if step.strip() and len(step.split(' -> '))==2:
213 | print(f"step:{step}")
214 |
215 | parts = step.split(' -> ')
216 | lhs_ids = parts[0].split('&')
217 | rhs = parts[1]
218 | if rhs == "hypothesis":
219 | rhs = f"hypothesis: {gold_triples['hypothesis']}"
220 | for lid in lhs_ids:
221 | lhs_id = lid.strip()
222 | print(f"QID:{item_id}")
223 | print(f"gold_triples:{gold_triples}")
224 | print(f"step:{step}")
225 | # gold_proof_str_list.append(f"{lhs_id}: {gold_triples[lhs_id]} &")
226 | gold_proof_str_list.append(f"{lhs_id}: {gold_triples[lhs_id] if lhs_id in gold_triples else ''} &")
227 | gold_proof_str_list.append(f"-> {rhs}")
228 | gold_proof_str_list.append(f"-----------------")
229 | gold_proof_str_to_output = formatting(gold_proof_str_list)
230 |
231 | pred_triples = goldslot_record['meta']['triples']
232 | pred_triples['hypothesis'] = goldslot_record['hypothesis']
233 | pred_proof_str = pred['slots'].get('proof', "")
234 | print(f"^^^^^^^^^^^^^^^^^^pred_proof_str:{pred_proof_str}")
235 | # if '; ' in pred_proof_str:
236 | if True:
237 | pred_proof_steps = pred_proof_str.split(';')
238 | pred_proof_str_list = []
239 | # print(f"\n\n=================")
240 | # print(f"pred_proof_str:{pred_proof_str}")
241 | for step in pred_proof_steps:
242 | step = step.strip()
243 | if step.strip() and len(step.split(' -> '))==2:
244 | print(f"step:{step}")
245 | parts = step.split(' -> ')
246 | lhs_ids = parts[0].split('&')
247 | if ',' in parts[0]:
248 | lhs_ids = parts[0].split(',')
249 | rhs = parts[1]
250 | # 2022-05-26: Amazon addition.
251 | if rhs == "hypothesis" or "hypothesis" in rhs:
252 | rhs = f"hypothesis: {pred_triples['hypothesis']}"
253 | else:
254 | if rhs.count(":") == 1:
255 | rhs_parts = rhs.split(":")
256 | print(rhs_parts)
257 | int_id = rhs_parts[0]
258 | int_str = rhs_parts[1].strip()
259 | pred_triples[int_id] = int_str
260 | # else:
261 | # pred_triples[int_id] = rhs.strip()
262 | for lid in lhs_ids:
263 | lhs_id = lid.strip()
264 | pred_proof_str_list.append(f"{lhs_id}: {pred_triples.get(lhs_id, 'NULL')} &")
265 | pred_proof_str_list.append(f"-> {rhs}")
266 | pred_proof_str_list.append(f"-----------------")
267 | # End of Amazon addition.
268 | pred_proof_str_to_output = formatting(pred_proof_str_list)
269 |
270 | if '; ' in pred_proof_str:
271 | pred_proof_steps = pred_proof_str.split('; ')
272 | pred_step_list = []
273 |
274 | for step in pred_proof_steps:
275 | if step.strip():
276 | pred_step_list.append(f"{step}; ")
277 | pred_proof_str = formatting(pred_step_list)
278 |
279 | relevance_f1 = "-"
280 | relevance_accuracy = "-"
281 | if 'relevance' in metrics:
282 | relevance_f1 = metrics['relevance']['F1']
283 | relevance_accuracy = metrics['relevance']['acc']
284 |
285 | proof_acc = "-"
286 | proof_f1 = "-"
287 | proof_alignements = "-"
288 | if 'aligned_proof' in metrics:
289 | proof_acc = metrics['aligned_proof']['acc']
290 | proof_f1 = metrics['aligned_proof']['F1']
291 | proof_alignements = metrics['aligned_proof']['pred_to_gold_mapping']
292 |
293 | inference_type = "none"
294 | if "abduction" in id:
295 | inference_type = "abduction"
296 | elif "deduction" in id:
297 | inference_type = "deduction"
298 |
299 | # 2022-05-26: Amazon addition.
300 | depth_of_proof = goldslot_record['depth_of_proof']
301 | length_of_proof = goldslot_record['length_of_proof']
302 | # End of Amazon addition.
303 |
304 | num_distractors = 0
305 | fraction_distractors = 0.0
306 | num_context_sent = len(goldslot_record['meta']['triples'])
307 | if 'distractors' in goldslot_record['meta']:
308 | num_distractors = len(goldslot_record['meta']['distractors'])
309 | fraction_distractors = 1.0 * num_distractors / num_context_sent
310 |
311 | distractor_ids = goldslot_record['meta'].get('distractors', [])
312 | pred_to_gold_mapping = metrics['proof-steps']['pred_to_gold_mapping']
313 | pred_to_gold_mapping_str = ""
314 | for pred_int, gold_int in pred_to_gold_mapping.items():
315 | pred_to_gold_mapping_str += f"p_{pred_int} -> g_{gold_int} ;; "
316 | diagnostics_tsv.write(f"{id}"
317 | f"\t{sent_str}"
318 | f"\t{goldslot_record['question']}"
319 | f"\t{goldslot_record['answer']}"
320 | f"\t{goldslot_record['hypothesis']}"
321 | f"\t{gold_proof_str_to_output}"
322 | f"\t{pred_proof_str_to_output}"
323 | f"\t{' ;; '.join(metrics['proof-steps']['sentences_pred_aligned'])}"
324 | f"\t{pred_to_gold_mapping_str}"
325 | f"\t{metrics['proof-leaves']['P']*100}"
326 | f"\t{metrics['proof-leaves']['R']*100}"
327 | f"\t{metrics['proof-leaves']['F1']*100}"
328 | f"\t{metrics['proof-leaves']['acc']*100}"
329 | f"\t{metrics['proof-steps']['F1']*100}"
330 | f"\t{metrics['proof-steps']['acc']*100}"
331 | f"\t{metrics['proof-intermediates']['BLEURT_P']*100}"
332 | f"\t{metrics['proof-intermediates']['BLEURT_R']*100}"
333 | f"\t{metrics['proof-intermediates']['BLEURT_F1']*100}"
334 | f"\t{metrics['proof-intermediates']['BLEURT']}"
335 | f"\t{metrics['proof-intermediates']['BLEURT_acc']*100}"
336 | f"\t{metrics['proof-intermediates']['fraction_perfect_align']*100}"
337 | f"\t{metrics['proof-overall']['acc']*100}"
338 | f"\t{num_distractors}"
339 | f"\t{num_context_sent}"
340 | f"\t{fraction_distractors}"
341 | f"\t{', '.join(distractor_ids)}"
342 | f"\t{depth_of_proof}"
343 | f"\t{length_of_proof}"
344 | "\n")
345 |
346 | scores.append(pred)
347 |
348 | print("\n=================\n"
349 | "Percentage recall per gold proof depth\n"
350 | "Gold_proof_depth\t#Gold answers\t#Correct predictions\t%accuracy (recall)\t%Gold answers\t%Correct Predictions")
351 |
352 | print(f"=========================")
353 | print(f"num_dev_answers:{num_dev_answers}")
354 | print(f"num_dev_answers_seen_in_train_context:{num_dev_answers_seen_in_train_context}")
355 | print(f"num_dev_answers_seen_in_train_answers:{num_dev_answers_seen_in_train_answers}")
356 |
357 | return scores
358 |
359 |
360 | # Sample command
361 | # python multi_angle/run_scorer_mf_all_at_once.py --angle_data_dir /Users/bhavanad/research_data/ruletaker/missing_facts/data/angles/OWA_d3_run3 --output_dir /Users/bhavanad/research_data/ruletaker/missing_facts/data/scorings/OWA_d3_run2/OWA_d3_run1_on_d3 --slot_root_dir /Users/bhavanad/research_data/ruletaker/missing_facts/data/slots/ --slot_data_dir OWA_d3_run3-slots --split test --prediction_file /Users/bhavanad/research_data/ruletaker/missing_facts/data/predictions/OWA_d3_T5large/OWA_d3_run1.on_d3.15k.pred.test.tsv
362 | def main(args):
363 | prediction_files = args.prediction_file
364 | if "," in prediction_files:
365 | prediction_files = prediction_files.split(",")
366 | elif os.path.isdir(prediction_files):
367 | dir_path = prediction_files
368 | prediction_files = [f"{dir_path}/{f}" for f in os.listdir(prediction_files) if re.match(r'.*_predictions', f)]
369 | else:
370 | prediction_files = glob.glob(prediction_files)
371 |
372 |
373 | prediction_files.sort()
374 | # 2022-05-26: Amazon addition.
375 | root_dir = "entailment_bank/data/processed_data"
376 | # End of Amazon addition.
377 | angle_data_dir = f"{root_dir}/angles/{args.task}/"
378 | slot_data_dir = f"{root_dir}/slots/{args.task}-slots/"
379 | angle_base_name = os.path.basename(angle_data_dir)
380 | slot_file = os.path.join(slot_data_dir, args.split + '.jsonl')
381 | if not os.path.exists(slot_file):
382 | if args.split == 'val' and os.path.exists(os.path.join(slot_data_dir, "dev.jsonl")):
383 | slot_file = os.path.join(slot_data_dir, "dev.jsonl")
384 | else:
385 | raise ValueError(f"Slot data file {slot_file} does not exist!")
386 | predictions_jsonl_format = True
387 | for prediction_file in prediction_files:
388 | if not prediction_file.endswith(".jsonl"):
389 | predictions_jsonl_format = False
390 | angle_file = None
391 | # If predictions not in jsonl format, need angle data to get correct ids and angles
392 | if not predictions_jsonl_format:
393 | angle_file = os.path.join(angle_data_dir, args.split + '.jsonl')
394 | if not os.path.exists(angle_file):
395 | if args.split == 'val' and os.path.exists(os.path.join(angle_data_dir, "dev.jsonl")):
396 | slot_file = os.path.join(angle_data_dir, "dev.jsonl")
397 | else:
398 | raise ValueError(f"Angle data file {angle_file} does not exist!")
399 | if not os.path.exists(args.output_dir):
400 | os.makedirs(args.output_dir)
401 |
402 | logger.info("Scoring the following files: %s", prediction_files)
403 | all_metrics_aggregated = {}
404 |
405 | split = args.split
406 | output_dir = args.output_dir
407 | bleurt_checkpoint = args.bleurt_checkpoint
408 |
409 | sys.argv = sys.argv[:1]
410 |
411 | for prediction_file in prediction_files:
412 | if not os.path.exists(prediction_file):
413 | logger.warning(f" File not found: {prediction_file}")
414 | continue
415 | score_file_base = f"scores-{split}"
416 | score_file = os.path.join(output_dir, score_file_base)
417 | logger.info(f"***** Scoring predictions in {prediction_file} *****")
418 | logger.info(f" Gold data from: {slot_file}")
419 | logger.info(f" Full output in: {score_file}")
420 |
421 | # 2022-05-26: Amazon addition.
422 | scores = score_predictions(
423 | args=args,
424 | predictions_file=prediction_file,
425 | score_file=score_file,
426 | gold_file=slot_file,
427 | angle_file=angle_file,
428 | dataset=angle_base_name,
429 | bleurt_checkpoint=bleurt_checkpoint,
430 | task_name=args.task)
431 | # End of Amazon addition.
432 | collated = collate_scores(scores)
433 | all_metrics_aggregated[score_file] = collated['metrics_aggregated']
434 | logger.info(f" Aggregated metrics:")
435 | for key, val in collated['metrics_aggregated'].items():
436 | logger.info(f" {key}: {val}")
437 | print(f"\n======================")
438 | colmns_str = '\t'.join([
439 | # 'leave-P', 'leave-R',
440 | 'leave-F1', 'leaves-Acc',
441 | 'steps-F1', 'steps-Acc',
442 | 'int-BLEURT-F1', 'int-BLEURT-Acc',
443 | #'int-BLEURT_align', 'int-BLEURT-Acc_align',
444 | 'overall-Acc',
445 | # 'overall-Acc_align','int-fraction-align'
446 | ])
447 | print(f"collated:{collated['metrics_aggregated']}")
448 | aggr_metrics = collated['metrics_aggregated']['QAHC->P']
449 | metrics_str = '\t'.join([
450 | # prediction_file,
451 | # str(round(aggr_metrics['proof-leaves']['P']*100.0, 2)),
452 | # str(round(aggr_metrics['proof-leaves']['R']*100.0, 2)),
453 | str(round(aggr_metrics['proof-leaves']['F1']*100.0, 2)),
454 | str(round(aggr_metrics['proof-leaves']['acc']*100.0, 2)),
455 | str(round(aggr_metrics['proof-steps']['F1']*100.0, 2)),
456 | str(round(aggr_metrics['proof-steps']['acc']*100.0, 2)),
457 | str(round(aggr_metrics['proof-intermediates']['BLEURT_F1'] * 100.0, 2)),
458 | str(round(aggr_metrics['proof-intermediates']['BLEURT_acc'] * 100.0, 2)),
459 | #str(round(aggr_metrics['proof-intermediates']['BLEURT_perfect_align'], 2)),
460 | #str(round(aggr_metrics['proof-intermediates']['BLEURT_acc_perfect_align']*100.0,2)),
461 | str(round(aggr_metrics['proof-overall']['acc']*100.0, 2)),
462 | #str(round(aggr_metrics['proof-overall']['acc_perfect_align']*100.0, 2)),
463 | #str(round(aggr_metrics['proof-intermediates']['fraction_perfect_align'] * 100.0, 2)),
464 | ])
465 | print(f"{colmns_str}")
466 | print(f"{metrics_str}")
467 | save_json(f"{score_file}.metrics.json", all_metrics_aggregated)
468 | # 2022-05-26: Amazon addition.
469 | return all_metrics_aggregated
470 | # End of Amazon addition.
471 |
472 | if __name__ == "__main__":
473 | args = get_args()
474 | main(args)
--------------------------------------------------------------------------------
/src/entailment_bank/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/irgr/fd385a77880c92e92167a5766859b2862af6ff26/src/entailment_bank/utils/__init__.py
--------------------------------------------------------------------------------
/src/entailment_bank/utils/angle_utils.py:
--------------------------------------------------------------------------------
1 | # Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 License.
2 | # Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 |
4 | from typing import Dict
5 | import json
6 | import logging
7 | import pickle
8 | import os
9 | import random
10 | import re
11 |
12 | from transformers import BartTokenizer, T5Tokenizer
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | DEFAULT_SLOT_FORMAT = {
17 | "slot": "$SLOT$",
18 | "assign": " = ",
19 | "separator": " ; ",
20 | "missing_value": "N/A"
21 | }
22 |
23 | SLOT_SHORTFORMS = {"Q": "question", "C": "context", "A": "answer", "E": "explanation",
24 | "M": "mcoptions", "R": "rationale", "P": "proof",
25 | "O": "original_question",
26 | "H": "hypothesis",
27 | "F": "full_text_proof",
28 | "V": "valid"
29 | }
30 |
31 |
32 | def save_jsonl(file_name, data):
33 | with open(file_name, 'w') as file:
34 | for d in data:
35 | file.write(json.dumps(d))
36 | file.write("\n")
37 |
38 |
39 | def load_jsonl(file_name):
40 | with open(file_name, 'r') as file:
41 | return [json.loads(line.strip()) for line in file]
42 |
43 |
44 | def save_json(file_name, data):
45 | with open(file_name, 'w') as file:
46 | file.write(json.dumps(data))
47 |
48 |
49 | ### From https://github.com/huggingface/transformers/blob/master/examples/rag/utils.py
50 |
51 | def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"):
52 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {}
53 | tokenizer.padding_side = padding_side
54 | return tokenizer(
55 | [line],
56 | max_length=max_length,
57 | padding="max_length" if pad_to_max_length else None,
58 | truncation=True,
59 | return_tensors=return_tensors,
60 | add_special_tokens=True,
61 | **extra_kw,
62 | )
63 |
64 |
65 | def trim_batch(
66 | input_ids,
67 | pad_token_id,
68 | attention_mask=None,
69 | ):
70 | """Remove columns that are populated exclusively by pad_token_id"""
71 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
72 | if attention_mask is None:
73 | return input_ids[:, keep_column_mask]
74 | else:
75 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
76 |
77 |
78 | def pickle_save(obj, path):
79 | """pickle.dump(obj, path)"""
80 | with open(path, "wb") as f:
81 | return pickle.dump(obj, f)
82 |
83 |
84 | def scramble_order(data, keep_last=None):
85 | keep_last = keep_last or []
86 | last = []
87 | other = []
88 | for d in data:
89 | if d in keep_last:
90 | last.append(d)
91 | else:
92 | other.append(d)
93 | random.shuffle(other)
94 | return other + last
95 |
96 |
97 | def scramble_context_sentences(sentences, random_seed=137):
98 | scrambled_sentence_dict = dict()
99 | scrambled_sentence_list = []
100 | old_to_new_id_map = dict()
101 | sent_ids_orig = list(sentences.keys())
102 | random.shuffle(sent_ids_orig)
103 | for idx, sent_id_orig in enumerate(sent_ids_orig, start=1):
104 | new_sent_id = "sent" + str(idx)
105 | old_to_new_id_map[sent_id_orig] = new_sent_id
106 | scrambled_sentence_dict[new_sent_id] = sentences[sent_id_orig]
107 | scrambled_sentence_list.append(sentences[sent_id_orig])
108 | return scrambled_sentence_dict, scrambled_sentence_list, old_to_new_id_map
109 |
110 |
111 | # Turns [['context', 'question','mcoptions'],['explanation', 'answer]] into 'CMQ->AE'
112 | def shortform_angle(angle_full, sort_angle=True, overrides=None):
113 | if angle_full is None:
114 | return ""
115 | if sort_angle:
116 | return "->".join(["".join(sorted([angle[0].upper() for angle in angles])) for angles in angle_full])
117 | return "->".join(["".join([angle[0].upper() for angle in angles]) for angles in angle_full])
118 |
119 | def decompose_slots(string, fmt=None):
120 | fmt = fmt or DEFAULT_SLOT_FORMAT
121 | string = string.strip()
122 | no_slot = "PREFIX"
123 | slot_re = re.compile('(?i)'+re.escape(fmt['slot']).replace("SLOT", "(\\w*?)"))
124 | assign_re = re.escape(fmt['assign']).replace('\\ ','\\s*')
125 | separator_re = re.escape(fmt['separator']).replace('\\ ','\\s*')
126 | strip_re = re.compile(f"^({assign_re})?(.*?)({separator_re})?$")
127 | slot_pos = []
128 | for m in slot_re.finditer(string):
129 | slot_pos.append((m.span(), m.group(1)))
130 | if len(slot_pos) == 0:
131 | return {no_slot: string}
132 | if slot_pos[0][0][0] > 0:
133 | slot_pos = [((0,-1), no_slot)] + slot_pos
134 | res = {}
135 | for idx, (pos, slot_name) in enumerate(slot_pos):
136 | if idx == len(slot_pos) - 1:
137 | value = string[pos[1]+1:]
138 | else:
139 | value = string[pos[1]+1:slot_pos[idx+1][0][0]-1]
140 | m = strip_re.match(value)
141 | if m is not None:
142 | value = m.group(2)
143 | value = value.strip()
144 | if slot_name in res:
145 | value = res[slot_name] + " ~AND~ " + value
146 | res[slot_name] = value
147 | return res
148 |
149 |
150 | def slot_file_to_angles(slot_file, slot_shortforms,
151 | angle_distribution,
152 | split,
153 | full_train_first_angle=False,
154 | meta_fields=None,
155 | id_filter_regex=None,
156 | train_replicas=1,
157 | random_seed=137, **kwparams):
158 |
159 | res = []
160 | random.seed(random_seed)
161 | if split == 'train':
162 | if full_train_first_angle:
163 | angle_distributions = [angle_distribution[0][0]] + [angle_distribution] * train_replicas
164 | else:
165 | angle_distributions = [angle_distribution] * train_replicas
166 | else:
167 | angle_distributions = angle_distribution[0]
168 | with open(slot_file, 'r') as file:
169 | for line in file:
170 | fields = json.loads(line.strip())
171 | if id_filter_regex is not None and "id" in fields and not re.match(id_filter_regex, fields['id']):
172 | continue
173 | slot_data = SlotDataInstance(fields)
174 | for ad in angle_distributions:
175 | instance = slot_data.sample_angle_instance(ad, slot_shortforms, **kwparams)
176 | instance.update({"id": fields.get('id', 'NA')})
177 | res.append(instance)
178 | if meta_fields:
179 | instance['meta'] = {x:fields['meta'][x] for x in meta_fields}
180 | return res
181 |
182 |
183 | ANGLE_SPEC_DEFAULT = {'angle_distribution': None,
184 | 'full_train_first_angle': False,
185 | 'id_filter_regex': None,
186 | 'train_replicas': 1,
187 | 'meta_fields': [],
188 | 'random_seed': 137,
189 | 'keep_last': ['context'],
190 | 'scramble_slots': True,
191 | 'multi_value_sampling': None
192 | }
193 |
194 |
195 | def build_angle_dir(slot_dir, angle_dir, angle_spec, debug_print=2):
196 | if os.path.exists(angle_dir):
197 | raise ValueError(f"Angle data directory {angle_dir} already exist!")
198 | os.makedirs(angle_dir)
199 | angle_spec = {**ANGLE_SPEC_DEFAULT, **angle_spec}
200 | angle_spec_file = os.path.join(angle_dir, "angle_spec.json")
201 |
202 | save_json(angle_spec_file, angle_spec)
203 | made_splits = []
204 | for split in ['train', 'dev', 'val', 'test']:
205 | slot_file = os.path.join(slot_dir, split+".jsonl")
206 | if os.path.exists(slot_file):
207 | logger.info(f"Creating angle data for {slot_file}")
208 | angle_data = slot_file_to_angles(slot_file, SLOT_SHORTFORMS, split=split, **angle_spec)
209 | if debug_print > 0:
210 | logger.info(f"Sample angle data: {angle_data[:debug_print]}")
211 | angle_file = os.path.join(angle_dir, split+".jsonl")
212 | save_jsonl(angle_file, angle_data)
213 | made_splits.append((split, len(angle_data)))
214 | logger.info(f"Created angle data for splits {made_splits}.")
215 | return made_splits
216 |
217 |
218 | def save_tsv_file(file_name, data):
219 | with open(file_name, "w") as f:
220 | for d in data:
221 | out = "\t".join([s.replace('\n', ' ').replace('\t', ' ') for s in d])
222 | f.write(out + '\n')
223 |
224 | # Use small_dev = 2000 to save a smaller size-2000 dev set
225 | def convert_angle_dir_tsv(angle_dir, tsv_dir, small_dev=False):
226 | if os.path.exists(tsv_dir):
227 | raise ValueError(f"TSV data directory {tsv_dir} already exist!")
228 | os.makedirs(tsv_dir)
229 | counts = {}
230 | for split in ['train', 'dev', 'val', 'test']:
231 | angle_file = os.path.join(angle_dir, split+".jsonl")
232 | if os.path.exists(angle_file):
233 | logger.info(f"Creating tsv data for {angle_file}")
234 | angle_data = load_jsonl(angle_file)
235 | tsv_data = [[x['input'], x['output']] for x in angle_data]
236 | meta_data =[[x['id'], shortform_angle(x['angle'], sort_angle=False)] for x in angle_data]
237 | if small_dev and split in ['dev', 'val']:
238 | num_dev = small_dev if isinstance(small_dev, int) else 1000
239 | counts[split+"-full"] = len(tsv_data)
240 | save_tsv_file(os.path.join(tsv_dir, split+"-full.tsv"), tsv_data)
241 | save_tsv_file(os.path.join(tsv_dir, "meta-"+split+"-full.tsv"), meta_data)
242 | tsv_data = tsv_data[:num_dev]
243 | meta_data = meta_data[:num_dev]
244 | counts[split] = len(tsv_data)
245 | save_tsv_file(os.path.join(tsv_dir, split + ".tsv"), tsv_data)
246 | save_tsv_file(os.path.join(tsv_dir, "meta-" + split + ".tsv"), meta_data)
247 | save_json(os.path.join(tsv_dir, "counts.json"), counts)
248 | logger.info(f"Created angle data for splits {counts}.")
249 | return counts
250 |
251 |
252 | class SlotDataInstance():
253 |
254 | def __init__(self, fields: Dict):
255 | self.fields = fields
256 | self.slot_value_sampling = {}
257 |
258 | def get_slot_value(self, slot, default=None, multi_value_sampling=None):
259 | res = self.fields.get(slot, default)
260 | if isinstance(res, list):
261 | if multi_value_sampling is not None and slot in multi_value_sampling:
262 | fn = multi_value_sampling[slot]
263 | if fn == "random":
264 | res = random.choice(res)
265 | elif "random-with" in fn:
266 | other_slots = fn.split("-")[2:]
267 | value_index = -1
268 | for other_slot in other_slots:
269 | if other_slot in self.slot_value_sampling:
270 | value_index = self.slot_value_sampling[other_slot]
271 | if value_index == -1:
272 | value_index = random.choice(range(len(res)))
273 | self.slot_value_sampling[slot] = value_index
274 | res = res[value_index]
275 | else:
276 | raise ValueError(f"Unknown multi_value_sampling function {fn}")
277 | else:
278 | res = res[0]
279 | return res
280 |
281 | def convert_shortform_angle(self, angle, slot_shortforms):
282 | if isinstance(angle, str):
283 | arrowpos = angle.index("->")
284 | lhs = angle[:arrowpos].strip()
285 | rhs = angle[arrowpos + 2:].strip()
286 | lhs = [slot_shortforms[c] for c in lhs]
287 | rhs = [slot_shortforms[c] for c in rhs]
288 | else:
289 | lhs = angle[0]
290 | rhs = angle[1]
291 | missing = [slot for slot in lhs + rhs if slot not in self.fields]
292 | return ((lhs, rhs), missing)
293 |
294 | def make_angle_instance(self, angle, fmt=None, multi_value_sampling=None):
295 | fmt = fmt or DEFAULT_SLOT_FORMAT
296 | lhs = []
297 | rhs = []
298 | for slot in angle[1]:
299 | slot_name = fmt['slot'].replace("SLOT", slot)
300 | slot_value = self.get_slot_value(slot, fmt['missing_value'], multi_value_sampling)
301 | lhs.append(slot_name)
302 | rhs.append(f"{slot_name}{fmt['assign']}{slot_value}")
303 | for slot in angle[0]:
304 | slot_name = fmt['slot'].replace("SLOT", slot)
305 | slot_value = self.get_slot_value(slot, fmt['missing_value'], multi_value_sampling)
306 | lhs.append(f"{slot_name}{fmt['assign']}{slot_value}")
307 | return {"input": fmt['separator'].join(lhs),
308 | "output": fmt['separator'].join(rhs),
309 | "angle": angle}
310 |
311 | def sample_angle_instance(self, angle_distribution, slot_shortforms,
312 | scramble_slots=True,
313 | keep_last=None,
314 | missing_retries=100,
315 | fmt=None,
316 | multi_value_sampling=None):
317 | keep_last = keep_last or ["context"]
318 | fmt = fmt or DEFAULT_SLOT_FORMAT
319 | if isinstance(angle_distribution, str):
320 | angle, missing = self.convert_shortform_angle(angle_distribution, slot_shortforms)
321 | else:
322 | angle_distribution = tuple(x.copy() for x in angle_distribution)
323 | retries = missing_retries
324 | missing = [1]
325 | while retries >= 0 and len(missing) > 0 and len(angle_distribution[0]) > 0:
326 | retries -= 1
327 | angle_shortform = random.choices(*angle_distribution)[0]
328 | angle, missing = self.convert_shortform_angle(angle_shortform, slot_shortforms)
329 | if len(missing) > 0:
330 | ind = angle_distribution[0].index(angle_shortform)
331 | angle_distribution[0].pop(ind)
332 | angle_distribution[1].pop(ind)
333 | if scramble_slots:
334 | angle = [scramble_order(a, keep_last) for a in angle]
335 | res = self.make_angle_instance(angle, fmt, multi_value_sampling)
336 | return res
337 |
338 |
339 | def formatting(a_set):
340 | return '=CONCATENATE("' + ('" ; CHAR(10); "').join(a_set) + '"' + ')' if len(a_set) > 0 else ""
341 |
342 |
343 | def get_selected_str(data_dict, selected_keys, format=False):
344 | selected=[]
345 | for key in selected_keys:
346 | selected.append(f"{key}: {data_dict[key]['text']} ")
347 | if format:
348 | selected_str = formatting(selected)
349 | else:
350 | selected_str = ''.join(selected)
351 | return selected_str
352 |
353 |
354 | def get_selected_keys(data_dict, selected_keys, format=False):
355 | selected=[]
356 | for key in selected_keys:
357 | selected.append(f"{key}: {data_dict[key]} ")
358 | if format:
359 | selected_str = formatting(selected)
360 | else:
361 | selected_str = ''.join(selected)
362 | return selected_str
--------------------------------------------------------------------------------
/src/entailment_bank/utils/entail_trees_utils.py:
--------------------------------------------------------------------------------
1 | # Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 License.
2 | # Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 |
4 | from collections import defaultdict
5 | from copy import deepcopy
6 | from nltk.stem import PorterStemmer
7 | from nltk.tokenize import word_tokenize
8 | import random
9 | import re
10 |
11 | # 2022-05-26: Amazon addition.
12 | from entailment_bank.utils.proof_utils import parse_lisp, decouple_proof_struct_ints, polish_notation_to_proof
13 | # End of Amazon addition.
14 |
15 | # Count phrase appearing in a reference string, making sure word boundaries are respected
16 | def count_phrase_matches(phrase, reference):
17 | regex = "(\\b|(?!\\w))" + re.escape(phrase) + "((?":
63 | new_dependency = proof[2]
64 | for k, v in dependencies.items():
65 | dependencies[k] = v + [new_dependency]
66 | if new_dependency not in dependencies:
67 | dependencies[new_dependency] = []
68 | return get_parents_recursive(proof[0], dependencies)
69 | elif len(proof) == 0:
70 | return dependencies
71 | elif len(proof) == 1:
72 | return get_parents_recursive(proof[0], dependencies)
73 | else:
74 | all_dependencies = []
75 | for sub_proof in proof:
76 | dep1 = deepcopy(dependencies)
77 | get_parents_recursive(sub_proof, dep1)
78 | all_dependencies.append(dep1)
79 | dd = defaultdict(list)
80 | for d in all_dependencies:
81 | for key, value in d.items():
82 | dd[key] += value
83 | return dd
84 |
85 |
86 | def get_intermediate_dependencies(proof):
87 | list_proof = parse_lisp(proof)
88 | dependencies = {}
89 | res = get_parents_recursive(list_proof, dependencies)
90 | return res
91 |
92 |
93 | def get_stripped_recursive(proof, stripped):
94 | if isinstance(proof, str):
95 | return proof
96 | if len(proof) == 3 and proof[1] == "->":
97 | stripped[proof[2]] = [get_stripped_recursive(x, stripped) for x in proof[0]]
98 | return proof[2]
99 | elif len(proof) == 1:
100 | return get_stripped_recursive(proof[0], stripped)
101 | else:
102 | raise ValueError(f"Nonsense found: {proof}")
103 |
104 |
105 | def get_core_proofs(proof):
106 | list_proof = parse_lisp(proof)
107 | stripped = {}
108 | get_stripped_recursive(list_proof, stripped)
109 | return stripped
110 |
111 |
112 | def remove_distractors(qdata, num_removed_distractors=0):
113 | if num_removed_distractors < 1:
114 | return qdata
115 | else:
116 | new_q = deepcopy(qdata)
117 | distractors = new_q['meta']['distractors']
118 | sentences_removed = list(reversed(distractors))[:num_removed_distractors]
119 | sentences_remaining = {k: v for k, v in new_q['meta']['triples'].items() if k not in sentences_removed}
120 | new_distractors = [k for k in new_q['meta']['distractors'] if k not in sentences_removed]
121 | sentence_map = {k: f"sent{i + 1}" for i, k in enumerate(sentences_remaining.keys())}
122 | new_q['meta']['triples'] = sentences_remaining
123 | new_q['meta']['distractors'] = new_distractors
124 | new_q = remap_sentences(new_q, sentence_map)
125 | return new_q
126 |
127 |
128 | # Break down an entailment tree data instance into one-step inference steps
129 | def make_inference_steps(qdata, rescramble_sentences=False, num_removed_distractors=0):
130 | proof = qdata['meta']['lisp_proof']
131 | core_proofs = list(get_core_proofs(proof).items())
132 | random.shuffle(core_proofs)
133 | sentences = qdata['meta']['triples'].copy()
134 | intermediates = qdata['meta']['intermediate_conclusions']
135 | q_id = qdata['id']
136 | hypothesis_id = qdata['meta']['hypothesis_id']
137 | res = []
138 | while len(core_proofs) > 0:
139 | selected = None
140 | for proof in core_proofs:
141 | selected = proof
142 | for dep in proof[1]:
143 | if 'int' in dep:
144 | selected = None
145 | if selected is not None:
146 | break
147 | if selected is None:
148 | raise ValueError(f"No resolved proofs in {core_proofs}")
149 | new_res = selected[0]
150 | if new_res == hypothesis_id:
151 | new_res_text = "hypothesis"
152 | assert len(core_proofs) == 1
153 | else:
154 | new_res_text = "int1: " + intermediates[new_res]
155 |
156 | new_proof = selected[1]
157 | new_proof_text = " & ".join(new_proof) + " -> " + new_res_text
158 | new_context = " ".join([f"{k}: {v}" for k, v in sentences.items()])
159 | new_q = deepcopy(qdata)
160 | new_q['id'] = f"{q_id}-add{len(res)}"
161 | new_q['meta'] = {'triples': sentences.copy(),
162 | 'distractors': new_q['meta'].get('distractors', [])}
163 | new_q['proof'] = new_proof_text
164 | new_q['meta']['hypothesis_id'] = "int1"
165 | new_q['depth_of_proof'] = 1
166 | new_q['length_of_proof'] = 1
167 | new_q['context'] = new_context
168 | if rescramble_sentences:
169 | new_q = scramble_sentences_in_entail_tree_q(new_q)
170 | if num_removed_distractors > 0:
171 | new_q = remove_distractors(new_q, num_removed_distractors)
172 |
173 | res.append(new_q)
174 | new_sentence = "sent" + str(len(sentences) + 1)
175 | sentences[new_sentence] = intermediates[selected[0]]
176 | new_core_proofs = []
177 | for proof in core_proofs:
178 | if proof[0] == new_res:
179 | continue
180 | new_parents = []
181 | for parent in proof[1]:
182 | new_parents.append(new_sentence if parent == new_res else parent)
183 | new_core_proofs.append((proof[0], new_parents))
184 | core_proofs = new_core_proofs
185 | return res
186 |
187 |
188 | def normalize_sentence(sent):
189 | return sent.replace(" ", " ").replace(".", "").replace('\n', '').replace("( ", "").replace(" )", "").lower().strip()
190 |
191 |
192 | def get_entailment_steps_from_polish_proof(polish_proof):
193 | print(f"POLISH_PROOF:{polish_proof}")
194 | pn_without_ints, int_dict = decouple_proof_struct_ints(polish_proof)
195 | print(f"POLISH_PROOF without INTS:{pn_without_ints}")
196 | try:
197 | recursive_proof = polish_notation_to_proof(pn_without_ints)[0]
198 | except:
199 | return []
200 | print(f"recursive_proof:{recursive_proof}")
201 | return get_entailment_steps_from_recursive_proof(recursive_proof, int_dict)
202 |
203 |
204 | def append_list(list_obj, to_be_added):
205 | if isinstance(to_be_added, list):
206 | for add_item in to_be_added:
207 | if add_item != '->':
208 | print(f"\t**********adding {add_item} to lhs")
209 | list_obj += append_list(list_obj, add_item)
210 | else:
211 | if to_be_added != '->':
212 | print(f"\t**********adding {to_be_added} to lhs")
213 | list_obj.append(to_be_added)
214 | return list_obj
215 |
216 |
217 | def get_entailment_steps_from_recursive_proof(recursive_proof, int_dict):
218 | entailment_steps = []
219 | print(f"======Calling recursion: recursive_proof:{recursive_proof}")
220 |
221 | lhs = recursive_proof[0]
222 | rhs = recursive_proof[2]
223 | rhs_str = int_dict.get(rhs, "")
224 | print(f"======lhs:{lhs}")
225 | print(f"======rhs:{rhs}")
226 | lhs_ids = []
227 | if isinstance(lhs, str):
228 | append_list(lhs_ids, lhs)
229 | else:
230 | for l in lhs:
231 | print(f"\tl:{l}")
232 | if isinstance(l, str):
233 | append_list(lhs_ids, l)
234 | else:
235 | if '->' in l:
236 | entailment_steps += get_entailment_steps_from_recursive_proof(l, int_dict)
237 | append_list(lhs_ids, l[2])
238 | else:
239 | print(f"\t^^^^lhs:{lhs}")
240 | print(f"\t^^^^rhs:{rhs}")
241 | print(f"\t^^^^l{l}")
242 | print(f"\t^^^^^lhs_ids{lhs_ids}")
243 | for l_part in l:
244 | if isinstance(l_part, list):
245 | if '->' in l_part:
246 | entailment_steps += get_entailment_steps_from_recursive_proof(l_part, int_dict)
247 | append_list(lhs_ids, l_part[2])
248 | else:
249 | append_list(lhs_ids, l_part)
250 | else:
251 | append_list(lhs_ids, l_part) # for cases like ['sent20', 'sent8'] in [['sent14', ['sent20', 'sent8']], '->', 'int1']
252 | print(f"\tlhs_ids{lhs_ids}")
253 |
254 | print(f"\tlhs_ids:{lhs_ids}")
255 | print(f"\trhs:{rhs}")
256 | lhs_str = ' & '.join(lhs_ids)
257 | print(f"++++Adding step: {lhs_str} -> {rhs}: {rhs_str}")
258 | entailment_steps.append(f"{lhs_str} -> {rhs}: {rhs_str}")
259 | return entailment_steps
--------------------------------------------------------------------------------
/src/entailment_bank/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | # Original Copyright 2021 Allen Institute for AI. Licensed under the Apache-2.0 License.
2 | # Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 |
4 | import collections
5 | import copy
6 | import json
7 | import re
8 | import rouge
9 | import string
10 |
11 | # 2022-05-26: Amazon addition.
12 | from entailment_bank.utils.angle_utils import decompose_slots
13 | from entailment_bank.utils.proof_utils import parse_entailment_step_proof, \
14 | align_conclusions_across_proofs, rewrite_aligned_proof, score_sentence_overlaps, ruletaker_inferences_scores, \
15 | parse_entailment_step_proof_remove_ids, rewrite_aligned_proof_noids
16 | # End of Amazon addition.
17 |
18 | INCLUDE_NLG_EVAL = False
19 |
20 | if INCLUDE_NLG_EVAL:
21 | from nlgeval import NLGEval
22 | # Initialize:
23 | nlgeval = NLGEval(no_skipthoughts=True, no_glove=True)
24 |
25 | def load_slot_data_by_id(slot_file):
26 | res = {}
27 | with open(slot_file) as file:
28 | for line in file:
29 | data = json.loads(line.strip())
30 | res[data['id']] = data
31 | return res
32 |
33 | USE_GOOGLE_ROUGE_CODE = False
34 |
35 | if USE_GOOGLE_ROUGE_CODE:
36 | # The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt
37 | from rouge_score import rouge_scorer
38 | rouge_scoring_fun = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
39 |
40 | def rouge_google_metric_max_over_ground_truths(prediction, ground_truths):
41 | if len(ground_truths) == 0:
42 | return 0
43 | scores_for_ground_truths = []
44 | for ground_truth in ground_truths:
45 | score = rouge_scoring_fun.score(prediction, ground_truth)
46 | scores_for_ground_truths.append(score['rougeL'].fmeasure)
47 | return max(scores_for_ground_truths)
48 |
49 |
50 | def score_string_similarity(str1, str2):
51 | if str1 == str2:
52 | return 3.0 # Better than perfect token match
53 | str1 = fix_t5_unk_characters(str1)
54 | str2 = fix_t5_unk_characters(str2)
55 | if str1 == str2:
56 | return 2.0
57 | str1 = str1.lower()
58 | str2 = str2.lower()
59 | if str1 == str2:
60 | return 1.5
61 | if " " in str1 or " " in str2:
62 | str1_split = str1.split(" ")
63 | str2_split = str2.split(" ")
64 | overlap = list(set(str1_split) & set(str2_split))
65 | return len(overlap) / max(len(str1_split), len(str2_split))
66 | else:
67 | return 0.0
68 |
69 |
70 | def replace_punctuation(str):
71 | return str.replace("\"", "").replace("'", "")
72 |
73 | # remove characters tokenized as unknown (\u2047) character
74 | T5_GOOD_CHARS=[32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
75 | 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66,
76 | 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
77 | 84, 85, 86, 87, 88, 89, 90, 91, 93, 95, 97, 98, 99, 100, 101, 102,
78 | 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
79 | 117, 118, 119, 120, 121, 122, 124, 163, 171, 173, 174, 176, 187, 201,
80 | 206, 220, 223, 224, 225, 226, 228, 231, 232, 233, 234, 238, 243, 244,
81 | 246, 249, 251, 252, 259, 351, 355, 537, 539, 1072, 1074, 1076, 1077,
82 | 1080, 1082, 1083, 1084, 1085, 1086, 1088, 1089, 1090, 1091, 8211,
83 | 8212, 8216, 8217, 8220, 8221, 8222, 8226, 8242, 8364, 9601]
84 | T5_BAD_REGEX = re.compile("[^"+re.escape(".".join([chr(x) for x in T5_GOOD_CHARS]))+"]")
85 |
86 |
87 | def fix_t5_unk_characters(str):
88 | return re.sub(" {2,}", " ", re.sub(T5_BAD_REGEX, " ", str))
89 |
90 | # Rouge evaluator copied from UnifiedQA
91 | rouge_l_evaluator = rouge.Rouge(
92 | metrics=["rouge-l"],
93 | max_n=4,
94 | limit_length=True,
95 | length_limit=100,
96 | length_limit_type="words",
97 | apply_avg=True,
98 | apply_best=True,
99 | alpha=0.5,
100 | weight_factor=1.2,
101 | stemming=True,
102 | )
103 |
104 | def rouge_l(p, g):
105 | return rouge_l_evaluator.get_scores(p, g)
106 |
107 | def rouge_metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
108 | scores_for_ground_truths = []
109 | for ground_truth in ground_truths:
110 | score = metric_fn(prediction, [ground_truth])
111 | scores_for_ground_truths.append(score)
112 | if isinstance(score, dict) and "rouge-l" in score:
113 | max_score = copy.deepcopy(score)
114 | max_score["rouge-l"]["f"] = round(
115 | max([score["rouge-l"]["f"] for score in scores_for_ground_truths]), 2
116 | )
117 | max_score["rouge-l"]["p"] = round(
118 | max([score["rouge-l"]["p"] for score in scores_for_ground_truths]), 2
119 | )
120 | max_score["rouge-l"]["r"] = round(
121 | max([score["rouge-l"]["r"] for score in scores_for_ground_truths]), 2
122 | )
123 | return max_score
124 | else:
125 | return round(max(scores_for_ground_truths), 2)
126 |
127 | def nlg_string_similarities(prediction, gold_strings, normalize=True):
128 | if gold_strings is None:
129 | res = {"missing_gold": 1}
130 | if prediction is None:
131 | res['missing'] = 1
132 | return res
133 | if prediction is None:
134 | return {"missing": 1}
135 | if normalize:
136 | gold_strings = [fix_t5_unk_characters(x.lower()) for x in gold_strings]
137 | prediction = fix_t5_unk_characters(prediction.lower())
138 | # gold_strings = gold_strings[:1]
139 | res = {}
140 | if INCLUDE_NLG_EVAL:
141 | res = nlgeval.compute_individual_metrics(gold_strings, prediction)
142 | if 'CIDEr' in res:
143 | del res['CIDEr']
144 | rouge_l_score = rouge_metric_max_over_ground_truths(rouge_l, prediction, gold_strings)
145 | res['ROUGE_L_F'] = rouge_l_score["rouge-l"]["f"]
146 | if USE_GOOGLE_ROUGE_CODE:
147 | res['ROUGE_L_G'] = rouge_google_metric_max_over_ground_truths(prediction, gold_strings)
148 | res['pred'] = prediction
149 | res['gold'] = gold_strings
150 | if not gold_strings[0] and not prediction:
151 | res['ROUGE_L_F'] = 1.0
152 | return res
153 |
154 |
155 | def nlg_string_similarities_intermediates_with_F1(prediction_to_aligned_gold: dict,
156 | id_to_int_gold: dict(),
157 | id_to_int_pred: dict(),
158 | prediction_to_perfect_match: dict,
159 | normalize=True,
160 | bleurt_scorer=None,
161 | bleurt_threshold=0.28 #for original BLEURT
162 | ):
163 | num_perfect_aligns = 0
164 |
165 | sum_rouge_l_score = 0.0
166 | sum_perfect_align_rouge_l_score = 0.0
167 | sum_bleurt_score = 0.0
168 | sum_perfect_align_bleurt_score = 0.0
169 | num_bleurt_correct = 0.0
170 | num_perfect_align_bleurt_correct = 0.0
171 |
172 | preds = []
173 | golds = []
174 | res = {}
175 | # print(f"prediction_to_aligned_gold:{prediction_to_aligned_gold}")
176 | pred_precise = set()
177 | gold_covered = set()
178 | for prediction, gold in prediction_to_aligned_gold.items():
179 | preds.append(prediction)
180 | golds.append(gold)
181 | gold_strings = [gold]
182 | if normalize:
183 | gold_strings = [fix_t5_unk_characters(x.lower()) for x in gold_strings]
184 | prediction_norm = fix_t5_unk_characters(prediction.lower())
185 | #res = nlgeval.compute_individual_metrics(gold_strings, prediction)
186 | #if 'CIDEr' in res:
187 | # del res['CIDEr']
188 | rouge_l_score = rouge_metric_max_over_ground_truths(rouge_l, prediction_norm, gold_strings)
189 | if bleurt_scorer:
190 | bleurt_score = bleurt_scorer.score(references=gold_strings, candidates=[prediction_norm], batch_size=1)[0]
191 | else:
192 | bleurt_score = -1
193 | # bleurt_score = max(0.0, min(1.0, unnorm_bleurt_score))
194 | # bleurt_score = [0.0]
195 | sum_rouge_l_score += rouge_l_score["rouge-l"]["f"]
196 | sum_bleurt_score += bleurt_score
197 | # if gold == "":
198 | # print(f"@@@@@@@@@@@@@@@@@@@@@@@PREDICTION:{prediction}\tGOLD:{gold}\tBLEURT:{bleurt_score}")
199 | # print(f"bleurt_score:{unnorm_bleurt_score}\t{bleurt_score}")
200 |
201 | if bleurt_score >= bleurt_threshold:
202 | if gold != "":
203 | num_bleurt_correct += 1
204 | pred_precise.add(prediction)
205 | gold_covered.add(gold)
206 |
207 | # print(f"prediction_to_perfect_match:{prediction_to_perfect_match}")
208 | # print(f"prediction:{prediction}")
209 | if prediction_to_perfect_match.get(prediction, False):
210 | sum_perfect_align_rouge_l_score += rouge_l_score["rouge-l"]["f"]
211 | sum_perfect_align_bleurt_score += bleurt_score
212 | num_perfect_aligns += 1
213 | if bleurt_score >= bleurt_threshold:
214 | num_perfect_align_bleurt_correct += 1
215 |
216 | bleurt_P = len(pred_precise)/max(1, len(id_to_int_pred))
217 | bleurt_R = len(gold_covered)/max(1, len(id_to_int_gold))
218 | if (bleurt_P + bleurt_R) == 0.0:
219 | bleurt_F1 = 0.0
220 | else:
221 | bleurt_F1 = (2 * bleurt_P * bleurt_R) / (bleurt_P + bleurt_R)
222 | # print(f"@@@@@@@@@@@@@@@@@@@@@@@pred_precise:{len(pred_precise)}\tpred:{len(id_to_int_pred)}\tgold_covered:{len(gold_covered)}\tgold:{len(id_to_int_gold)}")
223 | # print(f"@@@@@@@@@@@@@@@@@@@@@@@bleurt_P:{bleurt_P}\tbleurt_R:{bleurt_R}\tbleurt_F1:{bleurt_F1}")
224 | res['ROUGE_L_F'] = sum_rouge_l_score / max(1, len(prediction_to_aligned_gold.keys()))
225 | res['ROUGE_L_F_perfect_align'] = sum_perfect_align_rouge_l_score / max(1, num_perfect_aligns)
226 | res['BLEURT'] = sum_bleurt_score / max(1, len(prediction_to_aligned_gold.keys()))
227 | res['BLEURT_perfect_align'] = sum_perfect_align_bleurt_score / max(1, num_perfect_aligns)
228 | res['BLEURT_P'] = bleurt_P
229 | res['BLEURT_R'] = bleurt_R
230 | res['BLEURT_F1'] = bleurt_F1
231 | # res['BLEURT_acc'] = int(num_bleurt_correct == len(prediction_to_aligned_gold.keys()))
232 | res['BLEURT_acc'] = int(bleurt_F1==1)
233 | res['BLEURT_acc_perfect_align'] = int(num_perfect_align_bleurt_correct == num_perfect_aligns)
234 | res['fraction_perfect_align'] = num_perfect_aligns/max(1, len(prediction_to_aligned_gold.keys()))
235 | res['pred'] = preds
236 | res['gold'] = golds
237 | # print(f"res:{res}")
238 | return res
239 |
240 |
241 | def nlg_string_similarities_intermediates(prediction_to_aligned_gold: dict,
242 | prediction_to_perfect_match: dict,
243 | normalize=True,
244 | bleurt_scorer=None,
245 | bleurt_threshold=0.28 #for original BLEURT
246 | ):
247 | num_perfect_aligns = 0
248 |
249 | sum_rouge_l_score = 0.0
250 | sum_perfect_align_rouge_l_score = 0.0
251 | sum_bleurt_score = 0.0
252 | sum_perfect_align_bleurt_score = 0.0
253 | num_bleurt_correct = 0.0
254 | num_perfect_align_bleurt_correct = 0.0
255 |
256 | preds = []
257 | golds = []
258 | res = {}
259 | # print(f"prediction_to_aligned_gold:{prediction_to_aligned_gold}")
260 | for prediction, gold in prediction_to_aligned_gold.items():
261 | preds.append(prediction)
262 | golds.append(gold)
263 | gold_strings = [gold]
264 | if normalize:
265 | gold_strings = [fix_t5_unk_characters(x.lower()) for x in gold_strings]
266 | prediction = fix_t5_unk_characters(prediction.lower())
267 | #res = nlgeval.compute_individual_metrics(gold_strings, prediction)
268 | #if 'CIDEr' in res:
269 | # del res['CIDEr']
270 | rouge_l_score = rouge_metric_max_over_ground_truths(rouge_l, prediction, gold_strings)
271 | if bleurt_scorer:
272 | bleurt_score = bleurt_scorer.score(gold_strings, [prediction], batch_size=1)[0]
273 | else:
274 | bleurt_score = -1
275 | # bleurt_score = max(0.0, min(1.0, unnorm_bleurt_score))
276 | # bleurt_score = [0.0]
277 | sum_rouge_l_score += rouge_l_score["rouge-l"]["f"]
278 | sum_bleurt_score += bleurt_score
279 | # print(f"bleurt_score:{unnorm_bleurt_score}\t{bleurt_score}")
280 |
281 | if bleurt_score >= bleurt_threshold:
282 | num_bleurt_correct += 1
283 | # print(f"prediction_to_perfect_match:{prediction_to_perfect_match}")
284 | # print(f"prediction:{prediction}")
285 | if prediction_to_perfect_match[prediction]:
286 | sum_perfect_align_rouge_l_score += rouge_l_score["rouge-l"]["f"]
287 | sum_perfect_align_bleurt_score += bleurt_score
288 | num_perfect_aligns += 1
289 | if bleurt_score >= bleurt_threshold:
290 | num_perfect_align_bleurt_correct += 1
291 |
292 | res['ROUGE_L_F'] = sum_rouge_l_score / max(1, len(prediction_to_aligned_gold.keys()))
293 | res['ROUGE_L_F_perfect_align'] = sum_perfect_align_rouge_l_score / max(1, num_perfect_aligns)
294 | res['BLEURT'] = sum_bleurt_score / max(1, len(prediction_to_aligned_gold.keys()))
295 | res['BLEURT_perfect_align'] = sum_perfect_align_bleurt_score / max(1, num_perfect_aligns)
296 | res['BLEURT_acc'] = int(num_bleurt_correct == len(prediction_to_aligned_gold.keys()))
297 | res['BLEURT_acc_perfect_align'] = int(num_perfect_align_bleurt_correct == num_perfect_aligns)
298 | res['fraction_perfect_align'] = num_perfect_aligns/max(1, len(prediction_to_aligned_gold.keys()))
299 | res['pred'] = preds
300 | res['gold'] = golds
301 | # print(f"res:{res}")
302 | return res
303 |
304 |
305 | def squad_normalize_answer(s):
306 | """Lower text and remove punctuation, articles and extra whitespace."""
307 | def remove_articles(text):
308 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
309 | return re.sub(regex, ' ', text)
310 | def white_space_fix(text):
311 | return ' '.join(text.split())
312 | def remove_punc(text):
313 | exclude = set(string.punctuation)
314 | return ''.join(ch for ch in text if ch not in exclude)
315 | def lower(text):
316 | return text.lower()
317 | return fix_t5_unk_characters(white_space_fix(remove_articles(remove_punc(lower(s)))))
318 |
319 | def get_tokens(s):
320 | if not s: return []
321 | return squad_normalize_answer(s).split()
322 |
323 | def compute_exact(a_gold, a_pred):
324 | return int(squad_normalize_answer(a_gold) == squad_normalize_answer(a_pred))
325 |
326 | def compute_f1(a_gold, a_pred):
327 | gold_toks = get_tokens(a_gold)
328 | pred_toks = get_tokens(a_pred)
329 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
330 | num_same = sum(common.values())
331 | if len(gold_toks) == 0 or len(pred_toks) == 0:
332 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
333 | return int(gold_toks == pred_toks)
334 | if num_same == 0:
335 | return 0
336 | precision = 1.0 * num_same / len(pred_toks)
337 | recall = 1.0 * num_same / len(gold_toks)
338 | f1 = (2 * precision * recall) / (precision + recall)
339 | return f1
340 |
341 |
342 | def split_mcoptions(mcoptions):
343 | first_option = ord(mcoptions.strip()[1])
344 | labels = "".join([chr(x) for x in range(first_option, first_option+10)])
345 | choices = re.split("\\s*\\(["+labels+"]\\)\\s*", mcoptions)[1:]
346 | return (choices, chr(first_option))
347 |
348 |
349 | def mcq_answer_accuracy(slots, gold):
350 | answer = slots.get('answer')
351 | if answer is None:
352 | return {"acc": 0, "missing": 1}
353 | mcoptions, first_label = split_mcoptions(gold['mcoptions'])
354 | best = -1
355 | selected = None
356 | selected_key = None
357 | for idx, option in enumerate(mcoptions):
358 | score = score_string_similarity(answer, option)
359 | if score > best:
360 | best = score
361 | selected = option
362 | selected_key = chr(ord(first_label) + idx)
363 | acc = 1 if selected == gold['answer'] else 0
364 | res = {"acc": acc, "answerkey": selected_key, "align_score": best}
365 | return res
366 |
367 |
368 | def bool_accuracy(slots, gold):
369 | pred_answer = str(slots.get('answer'))
370 | gold_answer = str(gold['answer'])
371 |
372 | res = {"acc": float(pred_answer==gold_answer),
373 | "ROUGE_L_F": float(pred_answer==gold_answer)}
374 | #print(f"---- {type(pred_answer)}\t{type(gold_answer)}\t{res}")
375 | return res
376 |
377 |
378 | def squad_em_f1(answer, gold_answers):
379 | best_em = -1
380 | best_f1 = -1
381 | best_match = ""
382 | for gold in gold_answers:
383 | if gold.lower() == "noanswer":
384 | if answer.strip().lower() == "noanswer":
385 | em = f1 = 1.0
386 | else:
387 | em = f1 = 0.0
388 | else:
389 | em = compute_exact(gold, answer)
390 | f1 = compute_f1(gold, answer)
391 | if em > best_em:
392 | best_em = em
393 | best_match = gold
394 | if f1 > best_f1:
395 | best_f1 = f1
396 | best_match = gold
397 | res = {"EM": best_em, "F1": best_f1, "matched_gold": best_match}
398 | return res
399 |
400 |
401 | def rc_answer_accuracy(slots, gold, slot_name='answer'):
402 | answer = str(slots.get(slot_name))
403 | if answer is None:
404 | return {"EM": 0, "F1": 0, "missing": 1}
405 | gold_answers = str(gold[slot_name])
406 | if isinstance(gold_answers, str):
407 | gold_answers = [gold_answers]
408 | return squad_em_f1(answer, gold_answers)
409 |
410 |
411 | def extact_string_match_accuracy(slots, gold, slot_name='answer'):
412 | answer = str(slots.get(slot_name))
413 | if answer is None:
414 | return {"EM": 0, "F1": 0, "missing": 1}
415 | gold_answers = str(gold[slot_name])
416 | if isinstance(gold_answers, str):
417 | gold_answers = [gold_answers]
418 | return exact_match(answer, gold_answers)
419 |
420 |
421 | def exact_match(answer, gold_answers):
422 | best_em = -1
423 | best_match = ""
424 | for gold in gold_answers:
425 | em = 1.0 if gold.lower() == answer.lower() else 0.0
426 | if em > best_em:
427 | best_em = em
428 | best_match = gold
429 |
430 | res = {"EM": best_em, "matched_gold": best_match}
431 | return res
432 |
433 |
434 | def basic_split_mcoptions(mcoptions):
435 | splits = re.split("\\s*\\(\\w\\)\\s*", mcoptions)
436 | res = [s for s in splits if s != '']
437 | return res
438 |
439 |
440 | # Very basic matching algorithm to gold MC options, not very meaningful
441 | # (finds best matching gold option, only best score kept for each matched olgd)
442 | def rough_mcoptions_f1(pred_mcoptions, gold_mcoptions):
443 | if pred_mcoptions is None:
444 | return {"F1": 0, "missing": 1}
445 | gold_split = basic_split_mcoptions(gold_mcoptions)
446 | pred_split = basic_split_mcoptions(pred_mcoptions)
447 | scores = {}
448 | for pred in pred_split:
449 | score = squad_em_f1(pred, gold_split)
450 | matched = score['matched_gold']
451 | f1 = score['F1']
452 | old_f1 = scores.get(matched, 0)
453 | if f1 > old_f1:
454 | scores[matched] = f1
455 | tot = sum(scores.values())
456 | return {"F1": tot / max(len(pred_split), 2)}
457 |
458 |
459 | SCORING_SPECS = {
460 | "ruletaker_inferences": ruletaker_inferences_scores
461 | }
462 |
463 |
464 | def collate_scores(predictions):
465 | res_by_angle = {}
466 | metrics_by_angle = {}
467 | averaged_metrics = ['acc', 'EM', "F1", 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L',
468 | 'CIDEr', 'ROUGE_L_F', 'ROUGE_L_G', 'bad_parse', 'P', 'R',
469 | 'ROUGE_L_F_perfect_align', 'BLEURT', 'BLEURT_P', 'BLEURT_R', 'BLEURT_F1', 'BLEURT_perfect_align',
470 | 'BLEURT_acc', 'BLEURT_acc_perfect_align', 'acc_perfect_align', 'fraction_perfect_align',
471 | 'edited_or_not_acc',
472 | 'num_fever_queries', 'num_fever_queries_no_answer', 'num_fever_queries_no_rationale',
473 | 'num_premises', 'num_premises_valid', 'num_premises_valid_support', 'percent_valid_premises']
474 | aggregated_metrics = ['missing']
475 | for pred in predictions:
476 | angle = pred['angle_str']
477 | metrics = pred.get('metrics',{})
478 | if angle not in res_by_angle:
479 | res_by_angle[angle] = []
480 | res_by_angle[angle].append(pred)
481 | if angle not in metrics_by_angle:
482 | metrics_by_angle[angle] = {}
483 | metrics_by_angle[angle]['counter'] = metrics_by_angle[angle].get('counter', 0) + 1
484 | for slot, slot_metrics in metrics.items():
485 | if slot == "extra_slots":
486 | continue
487 | if slot not in metrics_by_angle[angle]:
488 | metrics_by_angle[angle][slot] = {}
489 | for metric in averaged_metrics + aggregated_metrics:
490 | if metric in slot_metrics:
491 | metrics_by_angle[angle][slot][metric] = \
492 | metrics_by_angle[angle][slot].get(metric, 0) + slot_metrics[metric]
493 | for angle, metrics in metrics_by_angle.items():
494 | counter = metrics['counter']
495 | for slot, slot_metrics in metrics.items():
496 | if not isinstance(slot_metrics, dict):
497 | continue
498 | for slot_metric, value in slot_metrics.items():
499 | if slot_metric in averaged_metrics:
500 | slot_metrics[slot_metric] = value/counter
501 | return {"metrics_aggregated": metrics_by_angle, "by_angle": res_by_angle}
502 |
503 |
504 | def score_aligned_entail_tree_proof(prediction, gold_list, angle, gold_json_record:dict, bleurt_scorer=None):
505 | res = {}
506 | if gold_list is None:
507 | res[angle] = {"missing_gold": 1}
508 | if prediction is None:
509 | res[angle]['missing'] = 1
510 | return res
511 | if prediction is None:
512 | res[angle] = {"missing": 1}
513 | return res
514 | print(f"\n\n\n======================\n")
515 | print(f"pred:{prediction}")
516 | print(f"gold:{gold_list[0]}")
517 | print(f"\n\n\n======================\n")
518 | print(f"Reading predicted proof")
519 | sentences_pred, inferences_pred, int_to_all_ancestors_pred, relevant_sentences_pred, id_to_int_pred = \
520 | parse_entailment_step_proof(prediction, gold_json_record=gold_json_record)
521 |
522 | print(f"\n\n\n||||||||||||||||||||||\n")
523 | print(f"Reading gold proof")
524 | sentences_gold, inferences_gold, int_to_all_ancestors_gold, relevant_sentences_gold, id_to_int_gold = \
525 | parse_entailment_step_proof(gold_list[0], gold_json_record=gold_json_record)
526 |
527 | pred_int_to_gold_int_mapping, prediction_to_aligned_gold, prediction_to_perfect_match = \
528 | align_conclusions_across_proofs(int_to_all_ancestors_pred, int_to_all_ancestors_gold,
529 | id_to_int_pred, id_to_int_gold)
530 |
531 | # res[angle+'-steps-unaligned'] = score_sentence_overlaps(sentences=sentences_pred, sentences_gold=sentences_gold)
532 |
533 | sentences_pred_aligned = rewrite_aligned_proof(prediction, pred_int_to_gold_int_mapping)
534 | print(f"\n\n\n++++++++++++++++++++++++++++++++++++")
535 | print(f"pred_int_to_gold_int_mapping:{pred_int_to_gold_int_mapping}")
536 | print(f"relevant_sentences_pred:{relevant_sentences_pred}")
537 | print(f"relevant_sentences_gold:{relevant_sentences_gold}")
538 | res[angle+'-leaves'] = score_sentence_overlaps(sentences=sorted(list(relevant_sentences_pred)),
539 | sentences_gold=sorted(list(relevant_sentences_gold)))
540 |
541 | res[angle + '-steps'] = score_sentence_overlaps(sentences=sorted(list(sentences_pred_aligned)),
542 | sentences_gold=sorted(list(sentences_gold)))
543 |
544 | res[angle + '-steps']['pred_to_gold_mapping'] = pred_int_to_gold_int_mapping
545 | res[angle + '-steps']['sentences_pred_aligned'] = sentences_pred_aligned
546 |
547 | res[angle+'-intermediates'] = nlg_string_similarities_intermediates_with_F1(prediction_to_aligned_gold=prediction_to_aligned_gold,
548 | id_to_int_gold=id_to_int_gold,
549 | id_to_int_pred=id_to_int_pred,
550 | prediction_to_perfect_match=prediction_to_perfect_match,
551 | bleurt_scorer=bleurt_scorer)
552 | res[angle+'-overall'] = overall_proof_score(leaves=res[angle+'-leaves'],
553 | edges=res[angle+'-steps'],
554 | intermediates=res[angle+'-intermediates'])
555 | return res
556 |
557 |
558 | def score_aligned_entail_tree_proof_onlyIR(prediction, gold_list, angle, gold_json_record:dict, pred_json_record: dict, bleurt_scorer=None):
559 | res = {}
560 | if gold_list is None:
561 | res[angle] = {"missing_gold": 1}
562 | if prediction is None:
563 | res[angle]['missing'] = 1
564 | return res
565 | if prediction is None:
566 | res[angle] = {"missing": 1}
567 | return res
568 | print(f"\n\n++++++++++++++++++\nprediction:{prediction}")
569 | # print(f"pred_json_record:{pred_json_record}")
570 | sentences_pred, inferences_pred, int_to_all_ancestors_pred, relevant_sentences_pred, id_to_int_pred = \
571 | parse_entailment_step_proof_remove_ids(prediction, slot_json_record=pred_json_record)
572 |
573 | print(f"gold_json_record:{gold_json_record}")
574 | # print(f"gold_json_record:{gold_json_record}")
575 | sentences_gold, inferences_gold, int_to_all_ancestors_gold, relevant_sentences_gold, id_to_int_gold = \
576 | parse_entailment_step_proof_remove_ids(gold_list[0], slot_json_record=gold_json_record)
577 |
578 | print(f"^^^^^^^pred:{prediction}")
579 | print(f"========sentences_pred:{sentences_pred}")
580 | print(f"^^^^^^^gold:{gold_list[0]}")
581 | print(f"========sentences_gold:{sentences_gold}")
582 |
583 | print(f"Q: {pred_json_record['id']}")
584 | pred_int_to_gold_int_mapping, prediction_to_aligned_gold, prediction_to_perfect_match = \
585 | align_conclusions_across_proofs(int_to_all_ancestors_pred, int_to_all_ancestors_gold,
586 | id_to_int_pred, id_to_int_gold)
587 |
588 | # res[angle+'-steps-unaligned'] = score_sentence_overlaps(sentences=sentences_pred, sentences_gold=sentences_gold)
589 |
590 | print(f"\n\n+++++++++++++++++++++++++\n")
591 | print(f"prediction:{prediction}")
592 | print(f"pred_int_to_gold_int_mapping:{pred_int_to_gold_int_mapping}")
593 | pred_sentences = pred_json_record['meta']['triples']
594 | pred_sentences['hypothesis'] = gold_json_record['hypothesis']
595 | sentences_pred_aligned, sentences_pred_aligned_strings = rewrite_aligned_proof_noids(prediction,
596 | pred_int_to_gold_int_mapping,
597 | pred_sentences=pred_sentences,
598 | gold_ints=gold_json_record['meta']['intermediate_conclusions']
599 | )
600 | res[angle+'-leaves'] = score_sentence_overlaps(sentences=sorted(list(relevant_sentences_pred)),
601 | sentences_gold=sorted(list(relevant_sentences_gold)))
602 |
603 | print(f"*********ID:{gold_json_record['id']}")
604 | print(f"*********sentences_pred_aligned:{sentences_pred_aligned_strings}")
605 | print(f"*********sentences_gold:{sentences_gold}")
606 | res[angle + '-steps'] = score_sentence_overlaps(sentences=sorted(list(sentences_pred_aligned_strings)),
607 | sentences_gold=sorted(list(sentences_gold)))
608 | res[angle + '-steps']['pred_to_gold_mapping'] = pred_int_to_gold_int_mapping
609 | res[angle + '-steps']['sentences_pred_aligned'] = sentences_pred_aligned
610 |
611 | res[angle+'-intermediates'] = nlg_string_similarities_intermediates_with_F1(prediction_to_aligned_gold=prediction_to_aligned_gold,
612 | id_to_int_gold=id_to_int_gold,
613 | id_to_int_pred=id_to_int_pred,
614 | prediction_to_perfect_match=prediction_to_perfect_match,
615 | bleurt_scorer=bleurt_scorer)
616 | res[angle+'-overall'] = overall_proof_score(leaves=res[angle+'-leaves'],
617 | edges=res[angle+'-steps'],
618 | intermediates=res[angle+'-intermediates'])
619 | return res
620 |
621 |
622 | def overall_proof_score(leaves, edges, intermediates):
623 | res = {}
624 | leaves_acc = leaves['acc']
625 | edges_acc = edges['acc']
626 | accuracy = leaves_acc * edges_acc * intermediates['BLEURT_acc']
627 | # accuracy = leaves_acc * edges_acc
628 | accuracy_align = leaves_acc * edges_acc * intermediates['BLEURT_acc_perfect_align']
629 |
630 | res['acc'] = accuracy
631 | res['acc_perfect_align'] = accuracy_align
632 | return res
633 |
634 |
635 | def score_prediction_whole_proof(prediction, gold, prediction_json=None, dataset=None, scoring_spec=None, bleurt_scorer=None):
636 | angle = prediction.get('angle')
637 | if 'slots' in prediction:
638 | slots = prediction['slots']
639 | else:
640 | slots = decompose_slots(prediction['prediction'])
641 | answer_eval = "emf1"
642 | if scoring_spec is not None and "answer_eval" in scoring_spec:
643 | answer_eval = scoring_spec['answer_eval']
644 | elif "mcoptions" in gold:
645 | answer_eval = "mcq"
646 | elif dataset is not None and "narrative" in dataset:
647 | answer_eval = "nlg"
648 |
649 | hypothesis_eval = "nlg"
650 | if scoring_spec is not None and "hypothesis_eval" in scoring_spec:
651 | hypothesis_eval = scoring_spec['hypothesis_eval']
652 |
653 | proof_eval = "pn_eval"
654 | if scoring_spec is not None and "proof_eval" in scoring_spec:
655 | proof_eval = scoring_spec['proof_eval']
656 |
657 | res = {}
658 | angles_out = angle[1] if angle is not None else list(slots.keys())
659 | for angle in angles_out:
660 | gold_str = gold.get(angle)
661 | gold_list = [gold_str] if isinstance(gold_str, str) else gold_str
662 | slot = slots.get(angle)
663 | if angle == 'hypothesis':
664 | if hypothesis_eval == 'old_emf1':
665 | res[angle] = rc_answer_accuracy(slots, gold)
666 | elif hypothesis_eval == 'emf1':
667 | sentences_pred = slots.get('hypothesis').split(' , ')
668 | sentences_gold = gold['hypothesis'].split(' , ')
669 | res[angle] = score_sentence_overlaps(sentences=sentences_pred,
670 | sentences_gold=sentences_gold)
671 | elif hypothesis_eval == 'mcq':
672 | res[angle] = mcq_answer_accuracy(slots, gold)
673 | elif hypothesis_eval == 'nlg':
674 | res[angle] = nlg_string_similarities(slot, gold_list)
675 | else:
676 | raise ValueError(f"Unknown answer_eval setting: {answer_eval}")
677 | elif angle in ['question', 'explanation']:
678 | res[angle] = nlg_string_similarities(slot, gold_list)
679 | elif angle in ['mcoptions']:
680 | res[angle] = rough_mcoptions_f1(slot, gold_str)
681 | elif angle in ['proof']:
682 | if proof_eval == "entail_whole_proof_align_eval":
683 | # This writes in multiple metrics
684 | res.update(score_aligned_entail_tree_proof(slot, gold_list, angle, gold_json_record=gold,bleurt_scorer=bleurt_scorer))
685 | elif proof_eval == "entail_whole_proof_align_eval_onlyIR":
686 | # This writes in multiple metrics
687 | res.update(score_aligned_entail_tree_proof_onlyIR(slot, gold_list, angle,
688 | gold_json_record=gold,
689 | pred_json_record=prediction_json,
690 | bleurt_scorer=bleurt_scorer))
691 |
692 | else:
693 | gold_proofs_list = []
694 | for gold_proof in gold_str.split(' OR '):
695 | gold_parts = gold_proof.replace(', ', ',').split(',')
696 | if len(gold_parts) == 2:
697 | gold_proofs_list.append(gold_proof)
698 | gold_proofs_list.append(f"{gold_parts[1]}, {gold_parts[0]}")
699 | res[angle] = squad_em_f1(slot, gold_proofs_list)
700 | elif angle in ['rationale']:
701 | pass # Not implemented yet
702 | else:
703 | res[angle] = {}
704 | extras = []
705 | for slot in slots:
706 | if slot not in res:
707 | extras.append(slot)
708 | if len(extras) > 0:
709 | res['extra_slots'] = extras
710 | return res
--------------------------------------------------------------------------------
/src/entailment_baseline.py:
--------------------------------------------------------------------------------
1 | # Imports
2 | import os
3 | import sys
4 | import json
5 | import torch
6 | import base_utils
7 | import random
8 | import numpy as np
9 | import pandas as pd
10 | from types import SimpleNamespace
11 | from tqdm.notebook import tqdm
12 |
13 | # Importing DL libraries
14 | import torch
15 | import torch.nn.functional as F
16 | from torch import cuda
17 | from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
18 | from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration
19 | from sentence_transformers import SentenceTransformer
20 | from sentence_transformers import util as st_util
21 | from entailment_bank.eval.run_scorer import main
22 | from retrieval_utils import sent_text_as_counter, convert_datapoint_to_sent_to_text
23 |
24 | class CustomDataset(Dataset):
25 |
26 | def __init__(self, source_text, target_text, tokenizer, source_len, summ_len):
27 | self.tokenizer = tokenizer
28 | self.source_len = source_len
29 | self.summ_len = summ_len
30 | self.source_text = source_text
31 | self.target_text = target_text
32 |
33 | def __len__(self):
34 | return len(self.source_text)
35 |
36 | def __getitem__(self, index):
37 | source_text = str(self.source_text[index])
38 | source_text = ' '.join(source_text.split())
39 |
40 | target_text = str(self.target_text[index])
41 | target_text = ' '.join(target_text.split())
42 |
43 | source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len,
44 | padding='max_length',return_tensors='pt', truncation=True)
45 | target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len,
46 | padding='max_length',return_tensors='pt', truncation=True)
47 |
48 | source_ids = source['input_ids'].squeeze()
49 | source_mask = source['attention_mask'].squeeze()
50 | target_ids = target['input_ids'].squeeze()
51 | target_mask = target['attention_mask'].squeeze()
52 |
53 | return {
54 | 'source_ids': source_ids.to(dtype=torch.long),
55 | 'source_mask': source_mask.to(dtype=torch.long),
56 | 'target_ids': target_ids.to(dtype=torch.long),
57 | 'target_ids_y': target_ids.to(dtype=torch.long)
58 | }
59 |
60 | class Trainer():
61 | '''
62 | Based on:
63 | https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb
64 | '''
65 | def __init__(self, tokenizer=None, model=None, optimizer=None, params = None, config = None):
66 | self.params = params
67 | self.config = config
68 |
69 | # tokenzier for encoding the text
70 | self.tokenizer = tokenizer
71 | if tokenizer is None:
72 | self.tokenizer = T5Tokenizer.from_pretrained(self.params.model_name)
73 |
74 | # Defining the model.
75 | self.model = model
76 | if model is None:
77 | self.model = T5ForConditionalGeneration.from_pretrained(
78 | self.params.model_name,
79 | )
80 |
81 | # Defining the optimizer that will be used to tune the weights of
82 | # the network in the training session.
83 | self.optimizer = optimizer
84 | if optimizer is None:
85 | self.optimizer = torch.optim.Adam(params = self.model.parameters(),
86 | lr=self.config.LEARNING_RATE)
87 |
88 | self.set_random_seed()
89 |
90 | def train(self, epoch, train_loader, val_loader, model = None,
91 | prefix_constrained_generator = None):
92 | if model is None:
93 | model = self.model
94 | model.train()
95 | running_loss = 0.0
96 | tqdm_loader = tqdm(train_loader)
97 | for step_num, data in enumerate(tqdm_loader, 0):
98 | if prefix_constrained_generator is not None:
99 | prefix_constrained_generator.set_batch_number(step_num)
100 | y = data['target_ids'].to(self.params.device, dtype = torch.long)
101 | y_ids = y[:, :-1].contiguous()
102 | labels = y[:, 1:].clone().detach()
103 | labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100
104 | ids = data['source_ids'].to(self.params.device, dtype = torch.long)
105 | mask = data['source_mask'].to(self.params.device, dtype = torch.long)
106 |
107 | outputs = model(input_ids = ids, attention_mask = mask,
108 | decoder_input_ids=y_ids, labels=labels)
109 | loss = outputs[0]
110 |
111 | running_loss += loss.item()
112 | if step_num % 100==0:
113 | avg_loss = running_loss / ((step_num + 1) * self.config.TRAIN_BATCH_SIZE)
114 | tqdm_loader.set_description("Loss %.4f" % avg_loss)
115 |
116 | self.optimizer.zero_grad()
117 | loss.backward()
118 | self.optimizer.step()
119 | train_avg_loss = running_loss / ((step_num + 1) * self.config.TRAIN_BATCH_SIZE)
120 | val_avg_loss = self.validate(val_loader, verbose = False)
121 | print('Epoch: %d, Train Loss: %.4f, Eval Loss %.4f' % (epoch, train_avg_loss, val_avg_loss))
122 | return train_avg_loss, val_avg_loss
123 |
124 | def validate(self, val_loader, verbose = True, model = None,
125 | prefix_constrained_generator = None):
126 | if model is None:
127 | model = self.model
128 | model.eval()
129 | running_loss = 0.0
130 | tqdm_loader = tqdm(val_loader) if verbose else val_loader
131 | for step_num, data in enumerate(tqdm_loader, 0):
132 | if prefix_constrained_generator is not None:
133 | prefix_constrained_generator.set_batch_number(step_num)
134 | y = data['target_ids'].to(self.params.device, dtype = torch.long)
135 | y_ids = y[:, :-1].contiguous()
136 | labels = y[:, 1:].clone().detach()
137 | labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100
138 | ids = data['source_ids'].to(self.params.device, dtype = torch.long)
139 | mask = data['source_mask'].to(self.params.device, dtype = torch.long)
140 |
141 | outputs = model(input_ids = ids, attention_mask = mask,
142 | decoder_input_ids=y_ids, labels=labels)
143 | loss = outputs[0]
144 |
145 | running_loss += loss.item()
146 | if verbose and step_num % 100==0:
147 | avg_loss = running_loss / ((step_num + 1) * self.config.TRAIN_BATCH_SIZE)
148 | tqdm_loader.set_description("Loss %.4f" % avg_loss)
149 |
150 | avg_loss = running_loss / ((step_num + 1) * self.config.TRAIN_BATCH_SIZE)
151 | if verbose:
152 | print('Loss: %.4f' % (avg_loss,))
153 | return avg_loss
154 |
155 |
156 | def predict(self, loader, generation_args = None, model = None,
157 | prefix_constrained_generator = None):
158 | if model is None:
159 | model = self.model
160 | model.eval()
161 | context = []
162 | predictions = []
163 | actuals = []
164 |
165 | if generation_args is None:
166 | generation_args = {
167 | 'max_length': self.config.SUMMARY_LEN,
168 | 'num_beams': 3,
169 | # repetition_penalty': 2.5,
170 | 'length_penalty': 1.0,
171 | 'early_stopping': True
172 | }
173 | with torch.no_grad():
174 | for step_num, data in enumerate(tqdm(loader), 0):
175 | if prefix_constrained_generator is not None:
176 | prefix_constrained_generator.set_batch_number(step_num)
177 | y = data['target_ids'].to(self.params.device, dtype = torch.long)
178 | ids = data['source_ids'].to(self.params.device, dtype = torch.long)
179 | mask = data['source_mask'].to(self.params.device, dtype = torch.long)
180 | generation_args.update({
181 | 'input_ids': ids,
182 | 'attention_mask': mask,
183 | })
184 | generated_ids = model.generate(**generation_args)
185 | inputs = [self.tokenizer.decode(
186 | i, skip_special_tokens=True,
187 | clean_up_tokenization_spaces=True) for i in ids]
188 | preds = [self.tokenizer.decode(
189 | g, skip_special_tokens=True,
190 | clean_up_tokenization_spaces=True) for g in generated_ids]
191 | target = [self.tokenizer.decode(
192 | t, skip_special_tokens=True,
193 | clean_up_tokenization_spaces=True) for t in y]
194 | context.extend(inputs)
195 | predictions.extend(preds)
196 | actuals.extend(target)
197 | return predictions, actuals, context
198 |
199 | def set_random_seed(self):
200 | # Set random seeds and deterministic pytorch for reproducibility
201 | torch.manual_seed(self.config.SEED) # pytorch random seed
202 | np.random.seed(self.config.SEED) # numpy random seed
203 | torch.backends.cudnn.deterministic = True
204 |
205 | def save_model(self, file_path_suffix = ''):
206 | model_path = self.params.model_file_path.format(
207 | model_name = self.params.model_name,
208 | task_name = self.params.task_name,
209 | dataset_name = self.params.dataset_name,
210 | approach_name = self.params.approach_name,
211 | suffix = file_path_suffix)
212 | torch.save(self.model.state_dict(), model_path)
213 | print('state dict saved to: %s' % model_path)
214 |
215 | def load_model(self, file_path_suffix = ''):
216 | model_path = self.params.model_file_path.format(
217 | model_name = self.params.model_name,
218 | task_name = self.params.task_name,
219 | dataset_name = self.params.dataset_name,
220 | approach_name = self.params.approach_name,
221 | suffix = file_path_suffix)
222 | print('Loading state dict from: %s' % model_path)
223 | self.model.load_state_dict(torch.load(model_path))
224 |
225 | class EntailmentARCDataset():
226 |
227 | ROOT_PATH = "../data/arc_entail"
228 | DATASET_PATH = os.path.join(ROOT_PATH, "dataset")
229 | TASK_PATH = os.path.join(DATASET_PATH, "task_{task_num}")
230 | PARTITION_DATA_PATH = os.path.join(TASK_PATH, "{partition}.jsonl")
231 |
232 | def __init__(self, semantic_search = None, params = None, config = None):
233 | self.params = params
234 | self.config = config
235 | self.data = {self.get_task_name(task_num):
236 | {partition: [] for partition in ['train', 'dev', 'test']}
237 | for task_num in range(1, 4)}
238 | self.load_dataset()
239 | self.semantic_search = semantic_search
240 |
241 | def get_task_name(self, task_num):
242 | return "task_" + str(task_num)
243 |
244 | def get_task_number(self, task_name):
245 | return int(task_name[-1])
246 |
247 | def get_dataset_path(self, task_num = 1, partition = 'train'):
248 | path = self.PARTITION_DATA_PATH.format(task_num = task_num, partition = partition)
249 | return path
250 |
251 | def load_dataset(self):
252 | for task_name in self.data:
253 | for partition in self.data[task_name]:
254 | path = self.get_dataset_path(self.get_task_number(task_name), partition)
255 | with open(path, 'r', encoding='utf8') as f:
256 | for line in f:
257 | datapoint = json.loads(line)
258 | self.data[task_name][partition].append(datapoint)
259 |
260 | def combine_existing_and_search_context(self, datapoint, retrieved):
261 | # break down existing context to get sent texts
262 | sent_to_text = convert_datapoint_to_sent_to_text(datapoint)
263 | # merge retrieved and existing sent texts
264 | original_ret_size = len(retrieved)
265 | new_sents = [s['text'] for s in sent_to_text.values()]
266 | new_sents_lowered = [s['text'].lower() for s in sent_to_text.values()]
267 | # add sents retrieved by search to new_sents
268 | for ret in retrieved:
269 | if ret.lower() in new_sents_lowered:
270 | continue
271 | new_sents.append(ret)
272 | if len(new_sents) >= original_ret_size:
273 | break
274 | assert len(new_sents) <= original_ret_size
275 |
276 | # shuffles order, saving original index
277 | # new_sent_order[i] == j means original j-th sentence
278 | # now in i-th position
279 | new_sent_order = list(range(len(new_sents)))
280 | new_context = []
281 | random.shuffle(new_sent_order)
282 | for i in range(len(new_sents)):
283 | new_context.append(new_sents[new_sent_order[i]])
284 | # create new contex
285 | new_context = ' '.join(['sent%d: %s' % (i+1, r) for i, r in enumerate(new_context)])
286 | # create new proof, modify index according to new context
287 | old_to_new_sent_map = {}
288 | for i in range(len(new_sents)):
289 | old_to_new_sent_map['sent%d ' % (new_sent_order[i] + 1,)] = 'sent%d ' % (i+1,)
290 | old_proof = datapoint['proof']
291 | new_proof = base_utils.str_replace_single_pass(old_proof, old_to_new_sent_map)
292 | return new_context, new_proof
293 |
294 | def update_dataset_with_search(self, dataset, include_existing_context = False):
295 | # use retrieved context instead of goden context
296 | new_dataset = [dict(dp) for dp in dataset]
297 | retrieved_lst = self.semantic_search.search(
298 | dataset, top_k = self.params.max_retrieved_sentences)
299 | for retrived_it, retrieved in enumerate(retrieved_lst):
300 | if include_existing_context:
301 | new_context, new_proof = self.combine_existing_and_search_context(
302 | dataset[retrived_it], retrieved)
303 | if retrived_it < 20:
304 | print('OLD CONTEXT = ', dataset[retrived_it]['context'])
305 | print('NEW CONTEXT = ', new_context)
306 | print('OLD PROOF = ', dataset[retrived_it]['proof'])
307 | print('NEW PROOF = ', new_proof)
308 | print()
309 | new_dataset[retrived_it]['context'] = new_context
310 | new_dataset[retrived_it]['proof'] = new_proof
311 | else:
312 | sents = ['sent%d: %s' % (i+1, r) for i, r in enumerate(retrieved)]
313 | new_dataset[retrived_it]['context'] = ' '.join(sents)
314 | # makes sure proof is empty since original proof is unrelated to context
315 | new_dataset[retrived_it]['proof'] = ''
316 | return new_dataset
317 |
318 | def get_source_text(self, task_name, partition):
319 | source_text = []
320 | if self.semantic_search is not None:
321 | new_contexts = self.get_contexts_from_search(
322 | self.data[task_name][partition])
323 | for dp_it, data_point in enumerate(self.data[task_name][partition]):
324 | if self.semantic_search is not None:
325 | context = new_contexts[dp_it]
326 | else:
327 | context = data_point['context']
328 | hypothesis = data_point['hypothesis']
329 | source_text.append(
330 | 'hypothesis: %s, %s' % (hypothesis, context))
331 | return source_text
332 |
333 | def get_target_text(self, task_name, partition):
334 | source_text = []
335 | for data_point in self.data[task_name][partition]:
336 | source_text.append('$proof$ = %s' % (data_point['proof'],))
337 | return source_text
338 |
339 | def get_torch_dataloaders(self, task_name, tokenizer):
340 | '''
341 | Creation of Dataset and Dataloader for a certain entailment task.
342 | '''
343 | # Creating the Training and Validation dataset for further creation of Dataloader
344 |
345 | train_source_text = self.get_source_text(task_name, 'train')
346 | train_target_text = self.get_target_text(task_name, 'train')
347 | training_set = CustomDataset(train_source_text, train_target_text, tokenizer,
348 | self.config.MAX_LEN, self.config.SUMMARY_LEN)
349 |
350 | dev_source_text = self.get_source_text(task_name, 'dev')
351 | dev_target_text = self.get_target_text(task_name, 'dev')
352 | val_set = CustomDataset(dev_source_text, dev_target_text, tokenizer,
353 | self.config.MAX_LEN, self.config.SUMMARY_LEN)
354 |
355 | # Defining the parameters for creation of dataloaders
356 | train_params = {
357 | 'batch_size': self.config.TRAIN_BATCH_SIZE,
358 | 'shuffle': True,
359 | 'num_workers': 0
360 | }
361 |
362 | val_params = {
363 | 'batch_size': self.config.VALID_BATCH_SIZE,
364 | 'shuffle': False,
365 | 'num_workers': 0
366 | }
367 |
368 | # Creation of Dataloaders for testing and validation.
369 | # This will be used down for training and validation stage for the model.
370 | training_loader = DataLoader(training_set, **train_params)
371 | val_loader = DataLoader(val_set, **val_params)
372 | return training_loader, val_loader
373 |
374 | class SemanticSearch():
375 |
376 | def __init__(self, corpus = None, encoder_model = None, params = None, config = None):
377 | self.params = params
378 | self.config = config
379 | self.encoder_model = encoder_model
380 | if encoder_model is None:
381 | self.encoder_model = SentenceTransformer(self.params.sent_trans_name)
382 | if corpus is not None:
383 | self.update_corpus(corpus)
384 |
385 | def load_wt_corpus_file(self):
386 | wt_corpus = {}
387 | with open(self.params.wt_corpus_file_path, 'r', encoding='utf8') as f:
388 | wt_corpus = json.loads(f.readline())
389 | return wt_corpus
390 |
391 | def load_wt_corpus(self, extra_facts = None):
392 | wt_corpus = self.load_wt_corpus_file()
393 | corpus = list(wt_corpus.values())
394 | if extra_facts is not None:
395 | corpus.extend(extra_facts)
396 | corpus = list(set(corpus))
397 | self.update_corpus_embeddings(corpus)
398 |
399 | def update_corpus_embeddings(self, corpus):
400 | self.corpus = corpus
401 | #Encode all sentences in corpus
402 | self.corpus_embeddings = self.encoder_model.encode(
403 | corpus, convert_to_tensor=True, show_progress_bar = True)
404 | self.corpus_embeddings = self.corpus_embeddings.to(self.params.device)
405 | self.corpus_embeddings = st_util.normalize_embeddings(self.corpus_embeddings)
406 |
407 | def search_with_id_and_scores(self, queries, top_k = 1):
408 | '''
409 | Search for best semantically similar sentences in corpus.
410 |
411 | returns corpus ids (index in input corpus) and scoress
412 | '''
413 | if type(queries) != list:
414 | queries = [queries]
415 |
416 | #Encode all queries
417 | query_embeddings = self.encoder_model.encode(
418 | queries, convert_to_tensor=True, show_progress_bar = False)
419 | query_embeddings = query_embeddings.to(self.params.device)
420 | query_embeddings = st_util.normalize_embeddings(query_embeddings)
421 | hits = st_util.semantic_search(query_embeddings, self.corpus_embeddings,
422 | top_k=top_k, score_function=st_util.dot_score)
423 | return hits
424 |
425 | def search(self, *args, **kwargs):
426 | '''
427 | Search for best semantically similar sentences in corpus.
428 |
429 | Only returns elements from corpus (no score or id)
430 | '''
431 | hits = self.search_with_id_and_scores(*args, **kwargs)
432 | elements = [[self.corpus[ret['corpus_id']] for ret in hit] for hit in hits]
433 | return elements
434 |
435 | def run_test(self):
436 | corpus = [
437 | 'A man is eating food.', 'A man is eating a piece of bread.',
438 | 'The girl is carrying a baby.', 'A man is riding a horse.', 'A woman is playing violin.',
439 | 'Two men pushed carts through the woods.', 'A man is riding a white horse on an enclosed ground.',
440 | 'A monkey is playing drums.', 'Someone in a gorilla costume is playing a set of drums.'
441 | ]
442 | self.update_corpus(corpus)
443 | queries = ['A woman enjoys her meal', 'A primate is performing at a concert']
444 | results = self.search(queries, top_k = 2)
445 | for i in range(len(queries)):
446 | print('Query:', queries[i])
447 | print('Best results:', results[i])
448 | print()
449 |
450 | class PrefixConstrainedGenerator:
451 | '''
452 | Constraints the beam search to allowed tokens only at each step.
453 | Enforces entailmnet dataset expected format (important for evaluation code)
454 | '''
455 |
456 | def __init__(self, tokenizer, source_text, batch_size):
457 | # tokenzier for encoding the text
458 | self.tokenizer = tokenizer
459 | self.source_text = source_text
460 | self.batch_size = batch_size
461 | self.batch_num = 0
462 |
463 | def get_first_token_id(self, text):
464 | toks = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
465 | return toks[0]
466 |
467 | def get_last_token_id(self, text):
468 | toks = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
469 | return toks[-1]
470 |
471 | def set_batch_number(self, batch_num):
472 | self.batch_num = batch_num
473 |
474 | def set_source_text(self, source_text):
475 | self.source_text = source_text
476 |
477 | def fixed_prefix_allowed_tokens_fn(self, batch_id, inputs_ids):
478 | '''
479 | Constrain the next token for beam search depending on currently generated prefix (input_ids)
480 | The output is loosely formated according to dataset specification.
481 | '''
482 | # print(inputs_ids, batch_id)
483 | prefix = self.tokenizer.decode(inputs_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
484 | # print(prefix)
485 | if prefix.strip() == '':
486 | return [self.get_first_token_id('sent')]
487 | if prefix.endswith(' & ') or prefix.endswith(' ; '):
488 | return [self.get_first_token_id('sent'), self.get_first_token_id('int')]
489 | if prefix.endswith('sent') or prefix.endswith('int'):
490 | return [self.get_last_token_id('sent' + str(num)) for num in range(10)]
491 | if prefix.endswith(' -> '):
492 | return [self.get_first_token_id('hypothesis'), self.get_first_token_id('int')]
493 | return list(range(self.tokenizer.vocab_size))
494 |
495 | def iterative_prefix_allowed_tokens_fn(self, batch_id, inputs_ids):
496 | '''
497 | Constrain the next token for beam search depending on currently generated prefix (input_ids)
498 | The output is loosely formated according to dataset specification.
499 | '''
500 | prefix = self.tokenizer.decode(inputs_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
501 | source_idx = self.batch_size * self.batch_num + batch_id
502 | source_text = self.source_text[source_idx]
503 | available_sent_nums = [source_text[match.span()[0] + len('sent'): match.span()[1]-1]
504 | for match in re.finditer("(sent)[0-9]+:", source_text)]
505 | avaliable_int_nums = [source_text[match.span()[0] + len('int'): match.span()[1]-1]
506 | for match in re.finditer("(int)[0-9]+:", source_text)]
507 |
508 | if prefix.strip() == 'in':
509 | return [self.get_last_token_id('int')]
510 | if prefix.strip() == '' or prefix.endswith(' & ') or prefix.endswith(' ; '):
511 | return [self.get_first_token_id('sent'), self.get_first_token_id('int')]
512 | if prefix.endswith('sent'):
513 | return list(set([self.get_last_token_id('sent' + num) for num in available_sent_nums]))
514 | # return list(set([self.get_last_token_id('sent' + str(num)) for num in range(10)]))
515 | if prefix.endswith('int') and not prefix.endswith('-> int'):
516 | return list(set([self.get_last_token_id('int' + num) for num in avaliable_int_nums]))
517 | # return list(set([self.get_last_token_id('int' + str(num)) for num in range(10)]))
518 | if prefix.endswith(' -> '):
519 | return [self.get_first_token_id('hypothesis'), self.get_first_token_id('int')]
520 | if not ' -> ' in prefix:
521 | all_toks = list(range(self.tokenizer.vocab_size))
522 | all_toks.remove(self.get_last_token_id('int1:'))
523 | return all_toks
524 | return list(range(self.tokenizer.vocab_size))
525 |
526 | # Training loop
527 | def run_training_loop(config):
528 | print('Initiating Fine-Tuning for the model on our dataset')
529 |
530 | min_val_avg_loss = 1e10
531 |
532 | for epoch in range(config.TRAIN_EPOCHS):
533 | _, val_avg_loss = trainer.train(epoch, training_loader, val_loader)
534 | if params.save_min_val_loss and val_avg_loss < min_val_avg_loss:
535 | min_val_avg_loss = val_avg_loss
536 | # Saving trained model with lowest validation loss
537 | trainer.save_model(file_path_suffix = '_min_val_loss')
538 |
539 | # Saving trained model
540 | trainer.save_model()
541 |
542 | print('min_val_avg_loss', min_val_avg_loss)
543 |
--------------------------------------------------------------------------------
/src/entailment_retrieval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Imports\n",
10 | "import os\n",
11 | "import re\n",
12 | "import sys\n",
13 | "import json\n",
14 | "import torch\n",
15 | "import base_utils\n",
16 | "import random\n",
17 | "import string\n",
18 | "import numpy as np\n",
19 | "import pandas as pd\n",
20 | "from types import SimpleNamespace\n",
21 | "from tqdm.notebook import tqdm\n",
22 | "from collections import Counter, defaultdict\n",
23 | "\n",
24 | "# Importing DL libraries\n",
25 | "import torch\n",
26 | "from torch import nn\n",
27 | "import torch.nn.functional as F\n",
28 | "from torch import cuda\n",
29 | "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n",
30 | "from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, Adafactor\n",
31 | "from sentence_transformers import SentenceTransformer\n",
32 | "from sentence_transformers import util as st_util\n",
33 | "from sentence_transformers import evaluation as st_evaluation\n",
34 | "from sentence_transformers import SentenceTransformer, InputExample, losses\n",
35 | "\n",
36 | "from retrieval_utils import convert_datapoint_to_sent_to_text, sent_text_as_counter"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "params = SimpleNamespace(\n",
46 | " # options: \"arc_entail\" (entailment bank), \"eqasc\"\n",
47 | " dataset_name = 'arc_entail', \n",
48 | " # task_1, task_2, task_3\n",
49 | " task_name = 'task_3',\n",
50 | " # use test instead of dev data to evaluate model\n",
51 | " use_test_data = True, \n",
52 | " encoder_model_path = '../data/arc_entail/models/%s_fine_tuned_v6/',\n",
53 | " encoder_checkpoint_path = '../data/arc_entail/models/%s_checkpoint_v6',\n",
54 | " device = 'cuda' if cuda.is_available() else 'cpu',\n",
55 | " # full list of sentence transformers: https://www.sbert.net/docs/pretrained_models.html\n",
56 | " sent_trans_name = 'all-mpnet-base-v2',\n",
57 | " wt_corpus_file_path = '../data/arc_entail/supporting_data/worldtree_corpus_sentences_extended.json', \n",
58 | " max_retrieved_sentences = 25\n",
59 | ")\n",
60 | "\n",
61 | "config = SimpleNamespace(\n",
62 | " TRAIN_EPOCHS = 5, # number of epochs to train\n",
63 | " LEARNING_RATE = 4e-5, # learning rate\n",
64 | " SEED = 39, # random seed\n",
65 | ")"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "def set_random_seed():\n",
75 | " # Set random seeds and deterministic pytorch for reproducibility\n",
76 | " torch.manual_seed(config.SEED) # pytorch random seed\n",
77 | " np.random.seed(config.SEED) # numpy random seed\n",
78 | " torch.backends.cudnn.deterministic = True \n",
79 | "\n",
80 | "set_random_seed()"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "class CustomDataset(Dataset):\n",
90 | "\n",
91 | " def __init__(self, source_text, target_text, tokenizer, source_len, summ_len):\n",
92 | " self.tokenizer = tokenizer\n",
93 | " self.source_len = source_len\n",
94 | " self.summ_len = summ_len\n",
95 | " self.source_text = source_text\n",
96 | " self.target_text = target_text\n",
97 | "\n",
98 | " def __len__(self):\n",
99 | " return len(self.source_text)\n",
100 | "\n",
101 | " def __getitem__(self, index):\n",
102 | " source_text = str(self.source_text[index])\n",
103 | " source_text = ' '.join(source_text.split())\n",
104 | "\n",
105 | " target_text = str(self.target_text[index])\n",
106 | " target_text = ' '.join(target_text.split())\n",
107 | "\n",
108 | " source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, \n",
109 | " padding='max_length',return_tensors='pt', truncation=True)\n",
110 | " target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, \n",
111 | " padding='max_length',return_tensors='pt', truncation=True)\n",
112 | "\n",
113 | " source_ids = source['input_ids'].squeeze()\n",
114 | " source_mask = source['attention_mask'].squeeze()\n",
115 | " target_ids = target['input_ids'].squeeze()\n",
116 | " target_mask = target['attention_mask'].squeeze()\n",
117 | "\n",
118 | " return {\n",
119 | " 'source_ids': source_ids.to(dtype=torch.long),\n",
120 | " 'source_mask': source_mask.to(dtype=torch.long), \n",
121 | " 'target_ids': target_ids.to(dtype=torch.long),\n",
122 | " 'target_ids_y': target_ids.to(dtype=torch.long)\n",
123 | " }"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "class SemanticSearch():\n",
133 | " \n",
134 | " def __init__(self, corpus = None, encoder_model = None, params = None):\n",
135 | " self.params = params\n",
136 | " self.encoder_model = encoder_model\n",
137 | " if encoder_model is None:\n",
138 | " self.encoder_model = SentenceTransformer(self.params.sent_trans_name)\n",
139 | " if corpus is not None:\n",
140 | " self.update_corpus_embeddings(corpus)\n",
141 | "\n",
142 | " def load_wt_corpus_file(self):\n",
143 | " wt_corpus = {}\n",
144 | " with open(self.params.wt_corpus_file_path, 'r', encoding='utf8') as f:\n",
145 | " wt_corpus = json.loads(f.readline())\n",
146 | " return wt_corpus\n",
147 | " \n",
148 | " def load_wt_corpus(self, extra_facts = None):\n",
149 | " wt_corpus = self.load_wt_corpus_file()\n",
150 | " corpus = list(wt_corpus.values())\n",
151 | " if extra_facts is not None:\n",
152 | " corpus.extend(extra_facts)\n",
153 | " corpus = list(set(corpus))\n",
154 | " self.update_corpus_embeddings(corpus)\n",
155 | " \n",
156 | " def update_corpus_embeddings(self, corpus):\n",
157 | " self.corpus = corpus\n",
158 | " # Encode all sentences in corpus\n",
159 | " self.corpus_embeddings = self.encoder_model.encode(\n",
160 | " corpus, convert_to_tensor=True, show_progress_bar = False)\n",
161 | " self.corpus_embeddings = self.corpus_embeddings.to(self.params.device)\n",
162 | " self.corpus_embeddings = st_util.normalize_embeddings(self.corpus_embeddings) \n",
163 | " \n",
164 | " def search_with_id_and_scores(self, queries, top_k = 1):\n",
165 | " '''\n",
166 | " Search for best semantically similar sentences in corpus.\n",
167 | " \n",
168 | " returns corpus ids (index in input corpus) and scoress\n",
169 | " '''\n",
170 | " if type(queries) != list:\n",
171 | " queries = [queries]\n",
172 | "\n",
173 | " #Encode all queries\n",
174 | " query_embeddings = self.encoder_model.encode(queries, convert_to_tensor=True)\n",
175 | " query_embeddings = query_embeddings.to(self.params.device)\n",
176 | " query_embeddings = st_util.normalize_embeddings(query_embeddings)\n",
177 | " hits = st_util.semantic_search(query_embeddings, self.corpus_embeddings, \n",
178 | " top_k=top_k, score_function=st_util.dot_score)\n",
179 | " return hits\n",
180 | " \n",
181 | " def search(self, *args, **kwargs):\n",
182 | " '''\n",
183 | " Search for best semantically similar sentences in corpus.\n",
184 | " \n",
185 | " Only returns elements from corpus (no score or id)\n",
186 | " '''\n",
187 | " hits = self.search_with_id_and_scores(*args, **kwargs)\n",
188 | " elements = [[self.corpus[ret['corpus_id']] for ret in hit] for hit in hits]\n",
189 | " return elements\n",
190 | " \n",
191 | " def run_test(self):\n",
192 | " corpus = [\n",
193 | " 'A man is eating food.', 'A man is eating a piece of bread.',\n",
194 | " 'The girl is carrying a baby.', 'A man is riding a horse.', 'A woman is playing violin.',\n",
195 | " 'Two men pushed carts through the woods.', 'A man is riding a white horse on an enclosed ground.',\n",
196 | " 'A monkey is playing drums.', 'Someone in a gorilla costume is playing a set of drums.',\n",
197 | " 'matter in the gas phase has variable shape',\n",
198 | " ]\n",
199 | " self.update_corpus_embeddings(corpus)\n",
200 | " queries = ['A woman enjoys her meal', 'A primate is performing at a concert', 'matter in gas phase has no definite volume and no definite shape']\n",
201 | " results = self.search(queries, top_k = 2)\n",
202 | " for i in range(len(queries)):\n",
203 | " print('Query:', queries[i])\n",
204 | " print('Best results:', results[i])\n",
205 | " print()\n"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "class EntailmentARCDataset():\n",
215 | " \n",
216 | " ROOT_PATH = \"../data/arc_entail\"\n",
217 | " DATASET_PATH = os.path.join(ROOT_PATH, \"dataset\")\n",
218 | " TASK_PATH = os.path.join(DATASET_PATH, \"task_{task_num}\")\n",
219 | " PARTITION_DATA_PATH = os.path.join(TASK_PATH, \"{partition}.jsonl\")\n",
220 | " \n",
221 | " def __init__(self, semantic_search = None, params = None, config = None):\n",
222 | " self.params = params\n",
223 | " self.config = config\n",
224 | " self.data = {self.get_task_name(task_num): \n",
225 | " {partition: [] for partition in ['train', 'dev', 'test']} \n",
226 | " for task_num in range(1, 4)}\n",
227 | " self.load_dataset()\n",
228 | " self.semantic_search = semantic_search\n",
229 | " \n",
230 | " def get_task_name(self, task_num):\n",
231 | " return \"task_\" + str(task_num)\n",
232 | " \n",
233 | " def get_task_number(self, task_name):\n",
234 | " return int(task_name[-1])\n",
235 | " \n",
236 | " def get_dataset_path(self, task_num = 1, partition = 'train'):\n",
237 | " path = self.PARTITION_DATA_PATH.format(task_num = task_num, partition = partition)\n",
238 | " return path\n",
239 | " \n",
240 | " def load_dataset(self):\n",
241 | " for task_name in self.data:\n",
242 | " for partition in self.data[task_name]:\n",
243 | " path = self.get_dataset_path(self.get_task_number(task_name), partition)\n",
244 | " with open(path, 'r', encoding='utf8') as f:\n",
245 | " for line in f:\n",
246 | " datapoint = json.loads(line)\n",
247 | " self.data[task_name][partition].append(datapoint)"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "# Loading Data\n",
255 | "\n",
256 | "### **NOTE**: if not fine-tunning encoder, set ``load_model = True`` to load an existing trained encoder model"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "load_model = False\n",
266 | "encoder_model = None\n",
267 | "\n",
268 | "if load_model:\n",
269 | " encoder_model_path = params.encoder_model_path % params.sent_trans_name\n",
270 | " print(f'loading model from: {encoder_model_path}')\n",
271 | " encoder_model = SentenceTransformer(encoder_model_path)\n",
272 | "\n",
273 | "semantic_search = SemanticSearch(encoder_model = encoder_model, params = params)"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": null,
279 | "metadata": {},
280 | "outputs": [],
281 | "source": [
282 | "entail_dataset = EntailmentARCDataset(\n",
283 | " semantic_search = semantic_search, params = params, config = config)\n",
284 | "print(entail_dataset.data[params.task_name]['train'][0])"
285 | ]
286 | },
287 | {
288 | "cell_type": "markdown",
289 | "metadata": {},
290 | "source": [
291 | "# Retrieval Utils"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": null,
297 | "metadata": {},
298 | "outputs": [],
299 | "source": [
300 | "def counter_jaccard_similarity(c1, c2):\n",
301 | " inter = c1 & c2\n",
302 | " union = c1 | c2\n",
303 | " return sum(inter.values()) / float(sum(union.values()))\n",
304 | "\n",
305 | "def construct_sent_context_mapping(matching_texts, sent_to_text, new_mapping_key, matching_uuids = None):\n",
306 | " alignment = []\n",
307 | " for match_it, matching_text in enumerate(matching_texts):\n",
308 | " text_counter = sent_text_as_counter(matching_text)\n",
309 | " matching_uuid = None\n",
310 | " if matching_uuids is not None:\n",
311 | " matching_uuid = matching_uuids[match_it]\n",
312 | " for sent_k, sent_v in sent_to_text.items():\n",
313 | " alignment.append((sent_k, sent_v['text'], matching_text, match_it, matching_uuid,\n",
314 | " counter_jaccard_similarity(sent_v['text_counter'], text_counter)))\n",
315 | " sorted_alignment = sorted(alignment, key= lambda x: x[-1], reverse=True)\n",
316 | " matches_it_used = []\n",
317 | " for align_item in sorted_alignment:\n",
318 | " sent_key = align_item[0]\n",
319 | " matching_text = align_item[2]\n",
320 | " match_it = align_item[3]\n",
321 | " matching_uuid = align_item[4]\n",
322 | " if new_mapping_key not in sent_to_text[sent_key].keys():\n",
323 | " if match_it not in matches_it_used: \n",
324 | " sent_to_text[sent_key][new_mapping_key] = matching_text\n",
325 | " if matching_uuid is not None:\n",
326 | " sent_to_text[sent_key][new_mapping_key + '_uuid'] = matching_uuid\n",
327 | " matches_it_used.append(match_it)\n",
328 | " \n",
329 | " assert all([new_mapping_key in v for v in sent_to_text.values()])\n",
330 | " return sent_to_text\n",
331 | "\n",
332 | "\n",
333 | "def construct_datapoint_context_mapping(datapoint):\n",
334 | " sent_to_text = convert_datapoint_to_sent_to_text(datapoint)\n",
335 | " triples = datapoint['meta']['triples'].values()\n",
336 | " assert len(sent_to_text.keys()) == len(triples) \n",
337 | " sent_to_text = construct_sent_context_mapping(\n",
338 | " triples, sent_to_text, new_mapping_key = 'triple_text')\n",
339 | " \n",
340 | " wt_p_items = [wt_p_item['original_text'] \n",
341 | " for wt_p_item in datapoint['meta']['worldtree_provenance'].values()]\n",
342 | " wt_p_uuids = [wt_p_item['uuid'] \n",
343 | " for wt_p_item in datapoint['meta']['worldtree_provenance'].values()]\n",
344 | " assert len(sent_to_text.keys()) == len(wt_p_items)\n",
345 | " sent_to_text = construct_sent_context_mapping(\n",
346 | " wt_p_items, sent_to_text, new_mapping_key = 'wt_p_text', matching_uuids = wt_p_uuids)\n",
347 | " \n",
348 | " for sent_to_text_v in sent_to_text.values():\n",
349 | " del sent_to_text_v['text_counter']\n",
350 | " return sent_to_text\n",
351 | "\n",
352 | "def create_context_mapping(dataset = entail_dataset.data['task_1']['test'], verbose = False):\n",
353 | " context_mapping = []\n",
354 | " wt_corpus = {}\n",
355 | " with open(params.wt_corpus_file_path, 'r', encoding='utf8') as f:\n",
356 | " wt_corpus = json.loads(f.readline())\n",
357 | " \n",
358 | " for datapoint in dataset:\n",
359 | " datapoint_context_mapping = construct_datapoint_context_mapping(datapoint)\n",
360 | " context_mapping.append(datapoint_context_mapping)\n",
361 | " \n",
362 | " for k, v in datapoint_context_mapping.items(): \n",
363 | " if 'wt_p_text_uuid' in v and v['wt_p_text_uuid'] in wt_corpus.keys():\n",
364 | " datapoint_context_mapping[k]['wt_corpus_text'] = wt_corpus[v['wt_p_text_uuid']]\n",
365 | " \n",
366 | " if verbose:\n",
367 | " for k, v in datapoint_context_mapping.items(): \n",
368 | " for item_k, item_v in v.items():\n",
369 | " print(item_k, '=', item_v)\n",
370 | " print()\n",
371 | " print('======')\n",
372 | " return context_mapping"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": null,
378 | "metadata": {},
379 | "outputs": [],
380 | "source": [
381 | "def fix_wt_corpus_with_task_1_data(split = 'test'):\n",
382 | " dataset = entail_dataset.data['task_1'][split] \n",
383 | " context_mapping = create_context_mapping(dataset)\n",
384 | " removal_sents = [v for cm in context_mapping for p in cm.values() \n",
385 | " for k,v in p.items() if k != 'text']\n",
386 | " include_sents = [v for cm in context_mapping for p in cm.values() \n",
387 | " for k,v in p.items() if k == 'text']\n",
388 | " print(removal_sents[:20])\n",
389 | " wt_corpus = semantic_search.load_wt_corpus_file()\n",
390 | " corpus = list(set(list(wt_corpus.values())) - set(removal_sents))\n",
391 | " corpus.extend(include_sents)\n",
392 | " semantic_search.update_corpus_embeddings(corpus)\n",
393 | " \n",
394 | " \n",
395 | "fix_wt_corpus_with_task_1_data()\n",
396 | "print('corpus size = ', len(semantic_search.corpus))"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": null,
402 | "metadata": {},
403 | "outputs": [],
404 | "source": [
405 | "def get_sents_height_on_tree(sent_text_lst, data_point):\n",
406 | " proof = data_point['proof']\n",
407 | " context = data_point['context']\n",
408 | " conclusion_to_antecedent_int_map = {}\n",
409 | " sents_to_int_map = {}\n",
410 | " int_to_height_map = {}\n",
411 | " \n",
412 | " steps = proof.split(';')[:-1]\n",
413 | " for step in steps:\n",
414 | " antecedent, conclusion = step.split(' -> ')\n",
415 | " if conclusion.strip() == 'hypothesis':\n",
416 | " conclusion_int = 'hypothesis'\n",
417 | " else:\n",
418 | " conclusion_int = re.findall(r'int[0-9]+', conclusion)[0]\n",
419 | " antecedent_ints = re.findall(r'int[0-9]+', antecedent)\n",
420 | " antecedent_sents = re.findall(r'sent[0-9]+', antecedent)\n",
421 | " for ant_sent in antecedent_sents:\n",
422 | " sents_to_int_map[ant_sent] = conclusion_int\n",
423 | " conclusion_to_antecedent_int_map[conclusion_int] = antecedent_ints\n",
424 | " \n",
425 | " cur_height = 0\n",
426 | " current_ints = ['hypothesis'] \n",
427 | " while len(current_ints) > 0:\n",
428 | " next_ints = []\n",
429 | " for cur_int in current_ints:\n",
430 | " int_to_height_map[cur_int] = cur_height\n",
431 | " if cur_int in conclusion_to_antecedent_int_map.keys():\n",
432 | " for next_int in conclusion_to_antecedent_int_map[cur_int]:\n",
433 | " next_ints.append(next_int)\n",
434 | " current_ints = next_ints\n",
435 | " cur_height += 1\n",
436 | " \n",
437 | " heights = []\n",
438 | " for sent_text in sent_text_lst:\n",
439 | " sent_match = re.findall(\n",
440 | " '(sent[0-9]+): (%s)' % re.escape(sent_text.strip()), context)\n",
441 | " if len(sent_match) == 0:\n",
442 | " print('MISSING!!!')\n",
443 | " print('sent_text =', sent_text)\n",
444 | " print('context =', context)\n",
445 | " continue\n",
446 | " sent_symb = sent_match[0][0]\n",
447 | " int_symb = sents_to_int_map[sent_symb]\n",
448 | " if not int_symb in int_to_height_map:\n",
449 | " # this might happen when proof has antecedent missing \"int\"\n",
450 | " continue\n",
451 | " heights.append({\n",
452 | " 'sent':sent_symb, 'text': sent_text,\n",
453 | " 'height': int_to_height_map[int_symb] + 1,\n",
454 | " })\n",
455 | " \n",
456 | " return heights, conclusion_to_antecedent_int_map, sents_to_int_map, int_to_height_map\n",
457 | "\n",
458 | "def compute_retrieval_metrics(retrieved_sentences_lst, split = 'test', verbose = False):\n",
459 | " dataset = entail_dataset.data['task_1'][split]\n",
460 | " context_mapping = create_context_mapping(dataset)\n",
461 | " \n",
462 | " assert len(retrieved_sentences_lst) == len(dataset)\n",
463 | " \n",
464 | " tot_sent = 0\n",
465 | " tot_sent_correct = 0\n",
466 | " tot_sent_missing = 0\n",
467 | " tot_no_missing = 0\n",
468 | " tot_sent_not_in_wt = 0\n",
469 | " tot_missing_sent_height = 0\n",
470 | " tot_correct_sent_height = 0\n",
471 | " \n",
472 | " correct_retrieved_lst = []\n",
473 | " errors_lst = [] # in retreived but not in gold\n",
474 | " missing_lst = [] # in gold but not in retrieved\n",
475 | " \n",
476 | " for ret_sentences, dp_context_mapping, datapoint in zip(retrieved_sentences_lst, context_mapping, dataset):\n",
477 | " correct_retrieved = []\n",
478 | " errors = []\n",
479 | " for ret_sentence in ret_sentences:\n",
480 | " is_correct = False\n",
481 | " for mapping_texts in dp_context_mapping.values():\n",
482 | " if ret_sentence in mapping_texts.values():\n",
483 | " is_correct = True\n",
484 | " if mapping_texts['text'] not in correct_retrieved:\n",
485 | " correct_retrieved.append(mapping_texts['text'])\n",
486 | " if len(mapping_texts['wt_p_text_uuid']) < 2:\n",
487 | " tot_sent_not_in_wt += 1\n",
488 | " break\n",
489 | " if not is_correct:\n",
490 | " errors.append(ret_sentence)\n",
491 | " all_sents = [v['text'] for v in dp_context_mapping.values()]\n",
492 | " missing = list(set(all_sents) - set(correct_retrieved))\n",
493 | " \n",
494 | " correct_retrieved_lst.append(correct_retrieved)\n",
495 | " errors_lst.append(errors)\n",
496 | " missing_lst.append(missing)\n",
497 | " \n",
498 | " tot_sent += len(dp_context_mapping.keys())\n",
499 | " tot_sent_correct += len(correct_retrieved)\n",
500 | " tot_sent_missing += len(missing)\n",
501 | " tot_no_missing += 0 if len(missing) > 0 else 1\n",
502 | " \n",
503 | " missing_heights, _, _, _ = get_sents_height_on_tree(missing, datapoint)\n",
504 | " tot_missing_sent_height += sum([mh['height'] for mh in missing_heights])\n",
505 | " correct_heights, _, _, _ = get_sents_height_on_tree(correct_retrieved, datapoint)\n",
506 | " tot_correct_sent_height += sum([ch['height'] for ch in correct_heights])\n",
507 | " \n",
508 | " # if verbose and len(missing) > 0:\n",
509 | " if verbose and len(missing) > 0:\n",
510 | " hypothesis = datapoint['hypothesis']\n",
511 | " question = datapoint['question']\n",
512 | " answer = datapoint['answer']\n",
513 | " \n",
514 | " print('hypothesis', hypothesis)\n",
515 | " print('Q + A', question + ' -> ' + answer)\n",
516 | " print('=====')\n",
517 | " print('retrieved:', correct_retrieved)\n",
518 | " print('missing:', missing)\n",
519 | " print()\n",
520 | " \n",
521 | " \n",
522 | " recall = tot_sent_correct / float(tot_sent)\n",
523 | " all_correct = tot_no_missing / float(len(dataset))\n",
524 | " avg_correct_sent_height = tot_correct_sent_height / (float(tot_sent_correct) + 1e-9)\n",
525 | " avg_missing_sent_height = tot_missing_sent_height / (float(tot_sent_missing) + 1e-9)\n",
526 | " print('recall:', recall)\n",
527 | " print('all correct:', all_correct)\n",
528 | " print('number of retrieved not in corpus:', tot_sent_not_in_wt)\n",
529 | " print('avg height of correct sentences:', avg_correct_sent_height)\n",
530 | " print('avg height of missing sentences:', avg_missing_sent_height)\n",
531 | " \n",
532 | " return recall, correct_retrieved_lst, errors_lst, missing_lst"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": null,
538 | "metadata": {},
539 | "outputs": [],
540 | "source": [
541 | "def test_task_3_paper_recall():\n",
542 | " split = 'test'\n",
543 | " t3_data = entail_dataset.data['task_3'][split]\n",
544 | " retrieved_sentences_lst = [t3['meta']['triples'].values() for t3 in t3_data]\n",
545 | " compute_retrieval_metrics(retrieved_sentences_lst, split = split, verbose=False)"
546 | ]
547 | },
548 | {
549 | "cell_type": "code",
550 | "execution_count": null,
551 | "metadata": {},
552 | "outputs": [],
553 | "source": [
554 | "def test_sent_transformer_recall(split = 'test', verbose = True, top_k = 25):\n",
555 | " print('top_k', top_k)\n",
556 | " \n",
557 | " t1_data = entail_dataset.data['task_1'][split]\n",
558 | " ret_data = semantic_search.search([t1['hypothesis'] for t1 in t1_data], top_k = top_k)\n",
559 | " assert len(t1_data) == len(ret_data)\n",
560 | " \n",
561 | " return compute_retrieval_metrics(ret_data, split = split)"
562 | ]
563 | },
564 | {
565 | "cell_type": "markdown",
566 | "metadata": {},
567 | "source": [
568 | "# Retrieval Training"
569 | ]
570 | },
571 | {
572 | "cell_type": "code",
573 | "execution_count": null,
574 | "metadata": {},
575 | "outputs": [],
576 | "source": [
577 | "def get_dataset_examples(split = 'train', use_hard_negative = True, hn_top_k = 25):\n",
578 | " # Define training / dev examples for sentence transformer. \n",
579 | " examples = []\n",
580 | " if use_hard_negative:\n",
581 | " # retrieved by current model but that is not part of the dataset\n",
582 | " _, _, errors, _ = test_sent_transformer_recall(\n",
583 | " split = split, verbose = False, top_k = hn_top_k)\n",
584 | " \n",
585 | " for t1_it, t1 in enumerate(tqdm(entail_dataset.data['task_1'][split])):\n",
586 | " hyp = t1['hypothesis']\n",
587 | " sents = t1['meta']['triples'].values()\n",
588 | " hard_negs = []\n",
589 | " for sent in sents:\n",
590 | " examples.append(InputExample(texts=[hyp, sent], label=1.0))\n",
591 | " \n",
592 | " if use_hard_negative:\n",
593 | " for error in errors[t1_it]:\n",
594 | " if error not in sents:\n",
595 | " hard_negs.append(error)\n",
596 | " examples.append(InputExample(texts=[hyp, error], label=0.75))\n",
597 | " \n",
598 | " neg_sents = random.sample(semantic_search.corpus, len(sents) * 4)\n",
599 | " for neg_sent in neg_sents:\n",
600 | " if neg_sent not in sents and neg_sent not in hard_negs:\n",
601 | " examples.append(InputExample(texts=[hyp, neg_sent], label=0.0)) \n",
602 | " return examples\n",
603 | "\n",
604 | "def train_sent_trans(get_examples_fn = get_dataset_examples):\n",
605 | "\n",
606 | " #Define the model. Either from scratch of by loading a pre-trained model\n",
607 | " model = semantic_search.encoder_model\n",
608 | " \n",
609 | " train_examples = get_examples_fn(split = 'train')\n",
610 | " dev_examples = get_examples_fn(split = 'dev') \n",
611 | " \n",
612 | " print('train_examples sz =', len(train_examples))\n",
613 | " print('dev_examples sz =', len(dev_examples))\n",
614 | " \n",
615 | " for example in train_examples[:10]:\n",
616 | " print(example)\n",
617 | " print()\n",
618 | " \n",
619 | " #Define your train dataset, the dataloader and the train loss\n",
620 | " train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)\n",
621 | " dev_dataloader = DataLoader(dev_examples, shuffle=False, batch_size=32)\n",
622 | " train_loss = losses.CosineSimilarityLoss(model)\n",
623 | " \n",
624 | " eval_sent_1 = [e.texts[0] for e in dev_examples]\n",
625 | " eval_sent_2 = [e.texts[1] for e in dev_examples]\n",
626 | " eval_scores = [e.label for e in dev_examples]\n",
627 | " \n",
628 | " evaluator = st_evaluation.EmbeddingSimilarityEvaluator(eval_sent_1, eval_sent_2, eval_scores)\n",
629 | " \n",
630 | " print('evaluating pre-trained model') \n",
631 | " results = model.evaluate(evaluator)\n",
632 | " print('results = ', results)\n",
633 | " \n",
634 | " callback = lambda score, epoch, steps: print('callback =', score, epoch, steps)\n",
635 | " \n",
636 | " print('fine-tunning model')\n",
637 | " #Tune the model\n",
638 | "\n",
639 | " model_trained_path = params.encoder_model_path % params.sent_trans_name\n",
640 | " checkpoint_path = params.encoder_checkpoint_path % params.sent_trans_name\n",
641 | " \n",
642 | " model.fit(train_objectives=[(train_dataloader, train_loss)], \n",
643 | " epochs=config.TRAIN_EPOCHS, warmup_steps=1000, \n",
644 | " save_best_model=True, output_path=model_trained_path, \n",
645 | " optimizer_params = {'lr': config.LEARNING_RATE},\n",
646 | " evaluator=evaluator, evaluation_steps=500,\n",
647 | " checkpoint_path=checkpoint_path,\n",
648 | " checkpoint_save_total_limit = 2,\n",
649 | " #callback = callback\n",
650 | " )\n",
651 | " model_trained_path_final = '../data/arc_entail/models/%s_fine_tuned_all_steps_v6' % params.sent_trans_name\n",
652 | " model.save(model_trained_path_final)\n",
653 | " \n",
654 | " print('evaluating fine-tuned model')\n",
655 | " results = model.evaluate(evaluator)\n",
656 | " print('results = ', results) "
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "execution_count": null,
662 | "metadata": {},
663 | "outputs": [],
664 | "source": [
665 | "train_sent_trans(get_examples_fn = get_dataset_examples)"
666 | ]
667 | },
668 | {
669 | "cell_type": "markdown",
670 | "metadata": {},
671 | "source": [
672 | "# Retrieval Evaluation"
673 | ]
674 | },
675 | {
676 | "cell_type": "markdown",
677 | "metadata": {},
678 | "source": [
679 | "## EntailmentWriter evaluation"
680 | ]
681 | },
682 | {
683 | "cell_type": "code",
684 | "execution_count": null,
685 | "metadata": {},
686 | "outputs": [],
687 | "source": [
688 | "test_task_3_paper_recall()"
689 | ]
690 | },
691 | {
692 | "cell_type": "markdown",
693 | "metadata": {},
694 | "source": [
695 | "## Retrieval evaluation (single)"
696 | ]
697 | },
698 | {
699 | "cell_type": "code",
700 | "execution_count": null,
701 | "metadata": {},
702 | "outputs": [],
703 | "source": [
704 | "_, _, _, _ = test_sent_transformer_recall(split = 'test', verbose = False)"
705 | ]
706 | },
707 | {
708 | "cell_type": "markdown",
709 | "metadata": {},
710 | "source": [
711 | "## Multi-step retrieval evaluation (conditional)"
712 | ]
713 | },
714 | {
715 | "cell_type": "code",
716 | "execution_count": null,
717 | "metadata": {},
718 | "outputs": [],
719 | "source": [
720 | "def test_sent_transformer_recall(split = 'test', verbose = True, top_k = 25):\n",
721 | " t1_data = entail_dataset.data['task_1'][split]\n",
722 | " \n",
723 | " ret_data = []\n",
724 | " for _ in t1_data:\n",
725 | " ret_data.append([])\n",
726 | "# probes = [t1['hypothesis'] for t1 in t1_data]\n",
727 | " probes = [t1['question'] + ' ' + t1['answer'] for t1 in t1_data]\n",
728 | " keep_top_from_hyp = 15\n",
729 | " \n",
730 | " for k_step in range(1, top_k - keep_top_from_hyp + 1):\n",
731 | " temp_ret_data = semantic_search.search(probes, top_k = k_step)\n",
732 | " for ret_it, rets in enumerate(temp_ret_data):\n",
733 | " for ret in rets:\n",
734 | " if ret not in ret_data[ret_it]:\n",
735 | " ret_data[ret_it].append(ret)\n",
736 | " probes[ret_it] += ' ' + ret\n",
737 | " break\n",
738 | " \n",
739 | " # now gather \"keep_top_from_hyp\" by using only hypothesis as probe\n",
740 | " probes = [t1['hypothesis'] for t1 in t1_data]\n",
741 | "# probes = [t1['question'] + ' ' + t1['answer'] for t1 in t1_data]\n",
742 | " temp_ret_data = semantic_search.search(probes, top_k = top_k * 3)\n",
743 | " for ret_it, rets in enumerate(temp_ret_data):\n",
744 | " for ret in rets:\n",
745 | " if ret not in ret_data[ret_it]:\n",
746 | " ret_data[ret_it].append(ret)\n",
747 | " if len(ret_data[ret_it]) == top_k:\n",
748 | " break\n",
749 | " \n",
750 | " assert all([len(x) == top_k for x in ret_data])\n",
751 | " \n",
752 | " compute_retrieval_metrics(ret_data, split = split, verbose = verbose)\n",
753 | "\n",
754 | "# _, _, _, _ = test_sent_transformer_recall(split = 'test', verbose = False)\n",
755 | "test_sent_transformer_recall(split = 'test', verbose = False, top_k = 25)"
756 | ]
757 | },
758 | {
759 | "cell_type": "code",
760 | "execution_count": null,
761 | "metadata": {},
762 | "outputs": [],
763 | "source": []
764 | }
765 | ],
766 | "metadata": {
767 | "kernelspec": {
768 | "display_name": "Environment (conda_csr)",
769 | "language": "python",
770 | "name": "conda_csr"
771 | },
772 | "language_info": {
773 | "codemirror_mode": {
774 | "name": "ipython",
775 | "version": 3
776 | },
777 | "file_extension": ".py",
778 | "mimetype": "text/x-python",
779 | "name": "python",
780 | "nbconvert_exporter": "python",
781 | "pygments_lexer": "ipython3",
782 | "version": "3.8.10"
783 | }
784 | },
785 | "nbformat": 4,
786 | "nbformat_minor": 4
787 | }
788 |
--------------------------------------------------------------------------------
/src/retrieval_utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import copy
3 | import json
4 | import re
5 | import rouge
6 | import string
7 | from collections import Counter
8 |
9 | def sent_text_as_counter(text):
10 | text = text.lower().replace(' \'', '\'')
11 | text = re.sub(r"[.,;/\\?]", "", text)
12 | tokens = text.split()
13 | return Counter(tokens)
14 |
15 | def create_sentence_uuid_map_from_corpus(world_tree_file):
16 | uuid_to_sent = {}
17 | with open(world_tree_file, 'r') as file:
18 | line = file.readline()
19 | uuid_to_sent = json.loads(line)
20 | sent_txt_to_uuid = []
21 | num_uuid = 1
22 | for uuid, wt_text in uuid_to_sent.items():
23 | sent_txt_to_uuid.append({
24 | 'text': wt_text,
25 | 'text_counter': sent_text_as_counter(wt_text),
26 | 'uuid': uuid,
27 | 'num_uuid': str(num_uuid)
28 | })
29 | num_uuid += 1
30 | return sent_txt_to_uuid
31 |
32 | def convert_datapoint_to_sent_to_text(datapoint):
33 | '''
34 | creates a mapping from sentences in context to their text
35 | (e.g. {'sent1': {'text': 'leo is a kind of constellation', 'text_counter': [...]}})
36 | '''
37 | context = datapoint['context']
38 | matches = list(re.finditer("(sent)[0-9]+:", context))
39 | sent_to_text = {}
40 | for match_idx, match in enumerate(matches):
41 | sent_match = match.group()
42 | sent_symb = sent_match[:-1] # remove the ':' in "sentX:'
43 | sent_span = match.span()
44 | start_pos = sent_span[0] + len(sent_match)
45 | end_pos = None
46 | if match_idx + 1 < len(matches):
47 | end_pos = matches[match_idx + 1].span()[0]
48 | sent_text = context[start_pos: end_pos].strip()
49 | sent_to_text[sent_symb] = {
50 | 'text': sent_text,
51 | 'text_counter': sent_text_as_counter(sent_text),
52 | }
53 | return sent_to_text
54 |
55 | def search_for_sent_uuid(probe, sent_txt_to_uuid):
56 | '''
57 | Returns the uuid from worldtree corpus that best match text represented by
58 | probe_counter input
59 | '''
60 | probe_counter = probe['text_counter']
61 | best_uuid = None
62 | best_match = None
63 | best_match_score = 0
64 | for wt_item in sent_txt_to_uuid:
65 | wt_counter = wt_item['text_counter']; wt_uuid = wt_item['num_uuid']
66 | match_counter = wt_counter & probe_counter
67 | match_score = sum(match_counter.values())
68 | if match_score > best_match_score:
69 | best_uuid = wt_uuid
70 | best_match = wt_item
71 | best_match_score = match_score
72 | ratio = float(best_match_score) / float(sum(probe_counter.values()))
73 | if ratio < 0.8:
74 | print('ratio',ratio)
75 | print('probe', probe)
76 | print('best_match_counter', best_match)
77 | print()
78 |
79 | return best_uuid
80 |
81 | def convert_sentences_num_to_uuid(expression, sent_to_text, sent_txt_to_uuid):
82 | '''
83 | Inputs:
84 | - expression: text containing sentence symbol (e.g. 'sent1 & sen2 -> hypothesis')
85 | - sent_to_text: dictionary mapping sentence symbols to sentence text
86 | - sent_txt_to_uuid: dictionary mapping sentence text to worldtree uuid
87 | '''
88 | new_expression = expression
89 |
90 | # print('sent_to_text', sent_to_text)
91 | matches = list(re.finditer("(sent)[0-9]+ ", expression))
92 | for match_idx, match in enumerate(matches):
93 | sent_symb = match.group()
94 | if sent_symb[:-1] in sent_to_text:
95 | sent_text_item = sent_to_text[sent_symb[:-1]]
96 | sent_uuid = search_for_sent_uuid(sent_text_item, sent_txt_to_uuid)
97 | else:
98 | sent_uuid = 1
99 | new_expression = new_expression.replace(sent_symb, f"sent{sent_uuid} " )
100 | return new_expression
101 |
102 | def convert_datapoint_sent_to_uuid(datapoint, world_tree_file='data/arc_entail/supporting_data/worldtree_corpus_sentences_extended.json'):
103 | '''
104 | converts sentence symbols (e.g. 'sent1') in datapoint text to uuid (e.g. 'sent0239-6af2-d042-caf6')
105 | '''
106 | sent_txt_to_uuid = create_sentence_uuid_map_from_corpus(world_tree_file)
107 | sent_to_text = convert_datapoint_to_sent_to_text(datapoint)
108 | new_datapoint = dict(datapoint)
109 | new_datapoint['proof'] = convert_sentences_num_to_uuid(
110 | new_datapoint['proof'], sent_to_text, sent_txt_to_uuid)
111 | return new_datapoint
--------------------------------------------------------------------------------