├── .flake8 ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── aries ├── __init__.py ├── alignment │ ├── biencoder.py │ ├── bm25.py │ ├── cross_encoder.py │ ├── doc_edits.py │ ├── eval.py │ ├── gpt.py │ ├── other.py │ └── precomputed.py └── util │ ├── __init__.py │ ├── color.py │ ├── data.py │ ├── edit.py │ ├── gensim.py │ ├── gpt3.py │ ├── logging.py │ ├── s2orc.py │ └── training.py ├── data └── configs │ ├── bm25.json │ ├── bm25_ao.json │ ├── bm25_high_recall.json │ ├── bm25_high_recall_ao.json │ ├── deberta_biencoder.json │ ├── deberta_biencoder_ao.json │ ├── deberta_cross_encoder.json │ ├── deberta_cross_encoder_ao.json │ ├── edit_generation_paper.json │ ├── gpt_multiedit.json │ ├── gpt_multiedit_ao.json │ ├── gpt_pairwise_0shot_ao.json │ ├── gpt_pairwise_1shot_ao.json │ ├── human.json │ ├── human_ao.json │ ├── linkbert_cross_encoder.json │ ├── linkbert_cross_encoder_ao.json │ ├── specter2_biencoder.json │ ├── specter2_biencoder_ao.json │ ├── specter2_untrained.json │ └── specter2_untrained_ao.json ├── requirements.txt ├── scripts ├── generate_edits.py ├── generate_synthetic_data.py └── train_revision_alignment.py ├── setup.py └── tests └── test_edit.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 150 3 | ignore = E265, E501, F401, E266, W503, E203 4 | exclude = .git, __pycache__, .pytest_cache, docs 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .git/ 3 | data/ 4 | 5 | *.png 6 | *.o 7 | *.so 8 | *.egg-info 9 | 10 | tmp* 11 | _levenshtein.c 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/allenai/cuda:11.3-cudnn8-dev-ubuntu20.04-v0.0.15 2 | 3 | # Set up the main python environment 4 | SHELL ["/bin/sh", "-c"] 5 | 6 | COPY requirements.txt /aries/requirements.txt 7 | WORKDIR /aries 8 | RUN pip install -r requirements.txt 9 | 10 | RUN python -m nltk.downloader -d /opt/miniconda3/share/nltk_data stopwords punkt book popular 11 | 12 | RUN curl https://sh.rustup.rs -sSf | bash -s -- -y 13 | ENV PATH="/root/.cargo/bin:${PATH}" 14 | RUN bash -c "cd /tmp/; git clone https://github.com/openai/tiktoken tiktoken; cd tiktoken; git checkout 0.3.3; pip install ." 15 | 16 | RUN aws s3 sync --no-sign-request s3://ai2-s2-research-public/aries/ data/aries/ 17 | RUN tar -C data/aries -xf data/aries/s2orc.tar.gz 18 | 19 | COPY . /aries 20 | 21 | RUN pip install -e . 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARIES 2 | 3 | Data and code for [ARIES: A Corpus of Scientific Paper Edits Made in Response to Peer Reviews](https://arxiv.org/pdf/2306.12587.pdf) 4 | 5 | ## Dataset 6 | 7 | To download the dataset, run `aws s3 sync --no-sign-request s3://ai2-s2-research-public/aries/ data/aries/`. Below, we provide an overview of the files contained in ARIES. 8 | 9 | `s2orc.tar.gz` contains the S2ORC (full-text) parses of papers with comment-aligned edits, can be extracted with `tar -C data/aries -xf data/aries/s2orc.tar.gz`. 10 | 11 | `paper_edits.jsonl` contains edits for each document. Each document has a source pdf id and a target pdf id corresponding to the original and revised versions, respectively; these are the same as OpenReview pdf ids, so the corresponding PDFs are available at https://openreview.net/references/pdf?id=PDF_ID. Each edit consists of an edit id, a list of source paragraph ids (which are indexes into the "body_text" list of the corresponding paper S2ORC file), and a list of target paragraph ids. 12 | 13 | `review_comments.jsonl` contains the review comments, each uniquely identified by its `(doc_id, comment_id)` tuple. 14 | 15 | `edit_labels_*.jsonl` files contain the comment-edit alignment labels for each data split (train, dev, test). Each has a doc_id and comment id (corresponding to a comment in `review_comments.jsonl`), a list of aligned (positive) edit ids, and a list of negative edit ids. For the synthetic data (train and dev) the negative ids are empty, indicating that there are no edits from the same document that can safely be treated as negative (due to the low recall of the data). Code that loads these files and merges them with comments and paper_edits is in `scripts/train_revision_alignment.py`. 16 | 17 | `alignment_human_eval.jsonl` contains the human evaluation labels (created without PDFs or author responses, only using the same information available to models), formatted in the same way as the `edit_labels_*` files. 18 | 19 | `generated_edits.jsonl` contains edits generated by GPT for each comment in the test split. It is the output of `scripts/generate_edits.py`, although changes in the OpenAI API may cause fluctuations. Some records in the file additionally have an "annotations" field, which contains labels for the edit generation analysis in the ARIES paper. 20 | 21 | `raw_split_ids.json` and `review_replies.jsonl` contains the raw split document ids and author responses/reviews used to construct the synthetic dataset. They are consumed by `scripts/generate_synthetic_data.py`, although the outputs of that script are already available in the train and dev label files. 22 | 23 | `gpt3_cache.sqlite` is a cache of the GPT inputs and responses that should be necessary to reproduce the main paper results. To use it, make sure it is in the path pointed by `"cache_db_path"` in the GPT experiment configs (by default, it isn't, so GPT responses would be re-generated). 24 | 25 | 26 | ## Running experiments 27 | 28 | This codebase is intended to be installed as a module that can be imported by the scripts in `scripts`. After cloning the repo, first install dependencies with `pip install -r requirements.txt`, and then install this repository with `pip install -e .`. Alternatively, the provided Dockerfile can be used to build a suitable environment. Please also note that these instructions assume the dataset has been downloaded and extracted as described above into the `data/aries` directory. 29 | 30 | To train models on the alignment task, run `python scripts/train_revision_alignment.py ` with an appropriate config file. Config files corresponding to the experiments in the paper are in the `data/configs` directory. Results for the experiment are stored in the path specified by the `output_dir` of the config, and the metrics are stored in the `test_metrics.json` file in that directory. The main metrics to consider are the `devthresh_*` metrics, which are based on the decision threshold that maximizes f1 on the dev set. 31 | 32 | 33 | ## Inference 34 | 35 | To predict which edits of a document align to a given comment, use the predict_many method of an aligner. Example using a SPECTER bi-encoder: 36 | 37 | ```python 38 | import transformers 39 | 40 | import aries.util.s2orc 41 | import aries.util.edit 42 | from aries.alignment.doc_edits import DocEdits 43 | from aries.util.data import iter_jsonl_files, index_by 44 | from aries.alignment.biencoder import BiencoderTransformerAligner 45 | 46 | doc_id = "EYCm0AFjaSS" 47 | paper_edits = index_by(iter_jsonl_files(["data/aries/paper_edits.jsonl"]), "doc_id", one_to_one=True)[doc_id] 48 | # Use aries.util.s2orc loader to handle back_matter merging 49 | with aries.util.s2orc.S2orcFetcherFilesystem("data/aries/s2orc/") as fetcher: 50 | s2orc1 = aries.util.s2orc.load_s2orc(paper_edits["source_pdf_id"], fetcher) 51 | s2orc2 = aries.util.s2orc.load_s2orc(paper_edits["target_pdf_id"], fetcher) 52 | doc_edits = DocEdits.from_list(s2orc1, s2orc2, paper_edits["edits"]) 53 | candidate_edits = [edit for edit in doc_edits.paragraph_edits if not edit.is_identical()] 54 | 55 | comment = index_by(iter_jsonl_files(["data/aries/review_comments.jsonl"]), "doc_id")[doc_id][0] 56 | 57 | aligner = BiencoderTransformerAligner( 58 | { 59 | "edit_input_format": "diff", 60 | "query_input_format": "comment_with_context", 61 | "add_diff_tokens": False, 62 | "max_seq_length": 512, 63 | }, 64 | transformers.AutoModel.from_pretrained("allenai/specter"), 65 | transformers.AutoTokenizer.from_pretrained("allenai/specter"), 66 | ) 67 | 68 | predictions = aligner.predict_many( 69 | [ 70 | { 71 | "review_comment": comment["comment"], 72 | "context": comment["comment_context"], 73 | "candidates": candidate_edits, 74 | } 75 | ] 76 | )[0]["predictions"] 77 | 78 | predicted_edits = [(record["score"], record["edit"]) for record in predictions if record["pred"] == 1] 79 | print("Comment:", comment["comment"]) 80 | # Expected result: edits 75, 78, 77 81 | for score, edit in sorted(predicted_edits, key=lambda x: x[0], reverse=True)[:3]: 82 | print("\nEdit {} ({:0.2f}):".format(edit.edit_id, score)) 83 | aries.util.edit.print_word_diff(edit.get_source_text(), edit.get_target_text(), color_format="ansi") 84 | ``` 85 | 86 | ## Edit Generation 87 | 88 | Edits can be generated with GPT models using `scripts/generate_edits.py`. We provide an example config with the prompt used in the paper, which can be run with `python scripts/generate_edits.py configs/edit_generation_paper.json`. However, to get the actual edits used for the paper analysis we recommend using the `generated_edits.jsonl` file in the dataset. 89 | 90 | ## Citation 91 | 92 | ``` 93 | @misc{darcy2023aries, 94 | title={ARIES: A Corpus of Scientific Paper Edits Made in Response to Peer Reviews}, 95 | author={Mike D'Arcy and Alexis Ross and Erin Bransom and Bailey Kuehl and Jonathan Bragg and Tom Hope and Doug Downey}, 96 | year={2023}, 97 | eprint={2306.12587}, 98 | archivePrefix={arXiv}, 99 | primaryClass={cs.CL} 100 | } 101 | ``` 102 | 103 | ## License 104 | 105 | The ARIES dataset is licensed under [ODC-BY 1.0](https://opendatacommons.org/licenses/by/1-0/). The code in this repo is licensed under [Apache 2.0](https://apache.org/licenses/LICENSE-2.0). 106 | -------------------------------------------------------------------------------- /aries/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /aries/alignment/bm25.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import itertools 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | 8 | import gensim 9 | import numpy as np 10 | import tqdm 11 | 12 | import aries.util.data 13 | import aries.util.edit 14 | import aries.util.gensim 15 | from aries.alignment.eval import full_tune_optimal_thresholds 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class BM25Aligner: 21 | def __init__(self, config): 22 | self.config = config 23 | 24 | self.fixed_pred_threshold = self.config.get("fixed_pred_threshold", None) 25 | self.pred_threshold = self.config.get("fixed_pred_threshold", 0.5) 26 | 27 | self.fixed_rel_pred_threshold = self.config.get("fixed_rel_pred_threshold", None) 28 | self.rel_pred_threshold = self.config.get("fixed_rel_pred_threshold", 0.2) 29 | 30 | self.tune_on_dev = self.config.get("tune_on_dev", False) 31 | self.tuning_minimum_recall = self.config.get("tuning_minimum_recall", 0.0) 32 | 33 | self.query_input_format = self.config["query_input_format"] 34 | self.edit_input_format = self.config["edit_input_format"] 35 | 36 | self.output_dir = self.config.get("output_dir", None) 37 | 38 | # Check for conflicts between tune_on_dev and fixed_*_thresholds 39 | if self.tune_on_dev and (self.fixed_pred_threshold is not None or self.fixed_rel_pred_threshold is not None): 40 | logger.warning("tune_on_dev is set to True, but fixed_pred_threshold and/or fixed_rel_pred_threshold are set. Ignoring fixed thresholds.") 41 | 42 | self.bm25_model = None 43 | self.bm25_index = None 44 | self.tfidf_model = None 45 | 46 | self.dictionary = None 47 | if self.config.get("bm25_dictionary", None) is not None: 48 | logger.info("Loading dictionary from {}".format(self.config["bm25_dictionary"])) 49 | self.dictionary = gensim.corpora.Dictionary.load(self.config["bm25_dictionary"]) 50 | 51 | def _candidate_record_to_input_text(self, rec): 52 | if self.query_input_format == "comment_only": 53 | return rec["review_comment"] 54 | elif self.query_input_format == "comment_with_canonical": 55 | return rec["review_comment"] + "\ncanonicalized: " + rec["canonical"]["canonicalized"] 56 | elif self.query_input_format == "reply_comment_or_extracted_comment": 57 | return rec.get("reply_comment_line", rec["review_comment"]) 58 | elif self.query_input_format == "reply_comment_or_extracted_comment_with_canonical": 59 | return rec.get("reply_comment_line", rec["review_comment"]) + "\ncanonicalized: " + rec["canonical"]["canonicalized"] 60 | elif self.query_input_format == "comment_with_context": 61 | comment_str = rec["review_comment"].strip() 62 | if rec.get("context_side", "none") == "left": 63 | comment_str = rec["context"].strip() + " " + comment_str 64 | else: 65 | comment_str = comment_str + " " + rec["context"].strip() 66 | 67 | return "review comment: " + comment_str 68 | raise ValueError("Unknown query_input_format {}".format(self.query_input_format)) 69 | 70 | def _edit_to_input_text(self, edit): 71 | if self.edit_input_format == "added_tokens": 72 | return " ".join(edit.get_added_tokens()) 73 | if self.edit_input_format == "source_text": 74 | return edit.get_source_text() 75 | if self.edit_input_format == "target_text": 76 | return edit.get_target_text() 77 | if self.edit_input_format == "target_text_with_context": 78 | context = "context: none" 79 | if len(edit.target_idxs) != 0 and min(edit.target_idxs) != 0: 80 | context = "context: " + edit.texts2[min(edit.target_idxs) - 1] 81 | return edit.get_target_text() + "\n\n" + context 82 | elif self.edit_input_format == "diff": 83 | return aries.util.edit.make_word_diff( 84 | edit.get_source_text(), 85 | edit.get_target_text(), 86 | color_format="none", 87 | ) 88 | elif self.edit_input_format == "tokens_union": 89 | text1 = edit.get_source_text() 90 | text2 = edit.get_target_text() 91 | textw = text1.split(" ") if len(text1) != 0 else [] 92 | outtextw = text2.split(" ") if len(text2) != 0 else [] 93 | tokens = [] 94 | for idx, x in enumerate(difflib.ndiff(textw, outtextw)): 95 | tokens.append(x[2:]) 96 | return " ".join(tokens) 97 | raise ValueError("Unknown edit_input_format {}".format(self.edit_input_format)) 98 | 99 | def train(self, train_recs, dev_recs): 100 | logger.info("Getting corpus statistics from training documents...") 101 | # Pull the full doc text from the training set 102 | all_doc_edits = dict() 103 | for rec in train_recs: 104 | # We only need one edit to get the DocEdits for the whole doc 105 | if rec["doc_id"] in all_doc_edits: 106 | continue 107 | edits = rec["positives"] + rec["negatives"] 108 | if len(edits) == 0: 109 | continue 110 | all_doc_edits[rec["doc_id"]] = edits[0].doc_edits 111 | 112 | docs = [] 113 | for doc_id, doc_edits in all_doc_edits.items(): 114 | docs.append("\n\n".join([x["text"] for x in doc_edits.s2orc2["pdf_parse"]["body_text"]])) 115 | 116 | corpus = aries.util.gensim.InMemoryTextCorpus(docs, dictionary=self.dictionary) 117 | self.dictionary = corpus.dictionary 118 | 119 | # Save dictionary 120 | self.dictionary.save(os.path.join(self.output_dir, "dictionary.pk")) 121 | 122 | # Tune the thresholds, if needed 123 | if self.tune_on_dev: 124 | logger.info("Tuning thresholds on dev set...") 125 | self.pred_threshold, self.rel_pred_threshold = self._tune_thresholds(dev_recs) 126 | logger.info("Tuned thresholds: pred_threshold={}, rel_pred_threshold={}".format(self.pred_threshold, self.rel_pred_threshold)) 127 | 128 | with open(os.path.join(self.output_dir, "thresholds.json"), "w") as f: 129 | json.dump( 130 | { 131 | "pred_threshold": self.pred_threshold, 132 | "rel_pred_threshold": self.rel_pred_threshold, 133 | }, 134 | f, 135 | ) 136 | 137 | def _tune_thresholds(self, dev_recs): 138 | eval_records = [] 139 | for rec in dev_recs: 140 | eval_records.append( 141 | { 142 | "doc_id": rec["doc_id"], 143 | "review_comment": rec["review_comment"], 144 | "context": rec["context"], 145 | "context_side": rec.get("context_side", "none"), 146 | "candidates": rec["positives"] + rec["negatives"] + rec.get("unknowns", []), 147 | "candidate_labels": [1] * len(rec["positives"]) + [0] * len(rec["negatives"]) + [None] * len(rec.get("unknowns", [])), 148 | } 149 | ) 150 | all_results = self.predict_many(eval_records) 151 | 152 | all_candidates = [] 153 | for rec in all_results: 154 | for idx, ex in enumerate(rec["predictions"]): 155 | ex["label"] = rec["input_record"]["candidate_labels"][idx] 156 | all_candidates.append(ex) 157 | 158 | pred_threshold, rel_pred_threshold, _ = full_tune_optimal_thresholds( 159 | all_candidates, 160 | min_recall=self.tuning_minimum_recall, 161 | num_abs_thresholds=20, 162 | num_rel_thresholds=20, 163 | abs_thresh=self.fixed_pred_threshold, 164 | rel_thresh=self.fixed_rel_pred_threshold, 165 | ) 166 | return pred_threshold, rel_pred_threshold 167 | 168 | def _init_vector_models(self): 169 | self.bm25_model = gensim.models.OkapiBM25Model(dictionary=self.dictionary) 170 | self.tfidf_model = gensim.models.TfidfModel(dictionary=self.dictionary, normalize=True, smartirs="bnn") 171 | 172 | def predict_many(self, *args, **kwargs): 173 | if self.bm25_model is None: 174 | self._init_vector_models() 175 | 176 | results = self._predict_many(*args, **kwargs) 177 | return results 178 | 179 | def _predict_many(self, test_recs, quiet=False): 180 | out_recs = [] 181 | 182 | logger.info("Doing inference with pred_threshold={}, rel_pred_threshold={}".format(self.pred_threshold, self.rel_pred_threshold)) 183 | 184 | for rec in tqdm.tqdm(test_recs, "predicting", disable=quiet): 185 | outrec = { 186 | "input_record": rec, 187 | "predictions": [{"edit": cand, "pred": None, "score": None} for cand in rec["candidates"]], 188 | } 189 | out_recs.append(outrec) 190 | 191 | if len(outrec["predictions"]) == 0: 192 | continue 193 | 194 | candidate_texts = [self._edit_to_input_text(x["edit"]) for x in outrec["predictions"]] 195 | corpus = aries.util.gensim.InMemoryTextCorpus(candidate_texts, dictionary=self.dictionary) 196 | 197 | query_vec = self.tfidf_model[self.dictionary.doc2bow(corpus.preprocess_text(self._candidate_record_to_input_text(rec)))] 198 | 199 | candidate_vectors = self.bm25_model[list(corpus)] 200 | bm25_index = gensim.similarities.SparseMatrixSimilarity(None) 201 | bm25_index.normalize = True 202 | bm25_index.index = gensim.matutils.corpus2csc(candidate_vectors, num_docs=len(corpus), num_terms=len(self.dictionary), dtype=float).T 203 | 204 | cosine_similarities = bm25_index[query_vec].tolist() 205 | 206 | best_candidxs = np.argsort(cosine_similarities).tolist() 207 | best_candidx_score = cosine_similarities[best_candidxs[-1]] 208 | 209 | for candidx, predr in enumerate(outrec["predictions"]): 210 | predr["best_group_score"] = best_candidx_score 211 | predr["cosine_score"] = cosine_similarities[candidx] 212 | 213 | predr["pred"] = ( 214 | 1 215 | if cosine_similarities[candidx] > self.pred_threshold 216 | and cosine_similarities[candidx] >= (best_candidx_score - self.rel_pred_threshold) 217 | else 0 218 | ) 219 | predr["score"] = cosine_similarities[candidx] 220 | 221 | return out_recs 222 | -------------------------------------------------------------------------------- /aries/alignment/cross_encoder.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import os 4 | import sys 5 | 6 | import datasets 7 | import numpy as np 8 | import torch 9 | import tqdm 10 | import transformers 11 | 12 | from aries.alignment.eval import AlignerEvalCallback 13 | from aries.util.edit import make_word_diff 14 | from aries.util.training import TrainLoggerCallback 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class PairwiseTransformerAligner: 20 | def __init__(self, config, model, tokenizer): 21 | self.config = config 22 | self.model = model 23 | self.tokenizer = tokenizer 24 | self.max_length = self.config["max_seq_length"] 25 | 26 | @staticmethod 27 | def preprocess_fn(examples, tokenizer, max_length): 28 | model_inputs = tokenizer( 29 | examples["first_text"], 30 | max_length=max_length, 31 | padding=False, 32 | truncation=True, 33 | ) 34 | model_inputs["labels"] = examples["label"] 35 | return model_inputs 36 | 37 | def _candidate_record_to_input_text(self, rec): 38 | if self.config["query_input_format"] == "comment_only": 39 | return rec["review_comment"] 40 | elif self.config["query_input_format"] == "comment_with_canonical": 41 | return rec["review_comment"] + "\ncanonicalized: " + rec["canonical"]["canonicalized"] 42 | elif self.config["query_input_format"] == "reply_comment_or_extracted_comment": 43 | return rec.get("reply_comment_line", rec["review_comment"]) 44 | elif self.config["query_input_format"] == "reply_comment_or_extracted_comment_with_canonical": 45 | return rec.get("reply_comment_line", rec["review_comment"]) + "\ncanonicalized: " + rec["canonical"]["canonicalized"] 46 | elif self.config["query_input_format"] == "comment_with_context": 47 | comment_str = rec["review_comment"].strip() 48 | if rec.get("context_side", "none") == "left": 49 | comment_str = rec["context"].strip() + " " + comment_str 50 | else: 51 | comment_str = comment_str + " " + rec["context"].strip() 52 | 53 | return "review comment: " + comment_str 54 | raise ValueError("Unknown query_input_format {}".format(self.config["query_input_format"])) 55 | 56 | def _edit_to_input_text(self, edit): 57 | if self.config["edit_input_format"] == "added_tokens": 58 | return " ".join(edit.get_added_tokens()) 59 | if self.config["edit_input_format"] == "source_text": 60 | return edit.get_source_text() 61 | if self.config["edit_input_format"] == "target_text": 62 | return edit.get_target_text() 63 | if self.config["edit_input_format"] == "target_text_with_context": 64 | context = "context: none" 65 | if len(edit.target_idxs) != 0 and min(edit.target_idxs) != 0: 66 | context = "context: " + edit.texts2[min(edit.target_idxs) - 1] 67 | return edit.get_target_text() + "\n\n" + context 68 | elif self.config["edit_input_format"] == "diff": 69 | return make_word_diff( 70 | edit.get_source_text(), 71 | edit.get_target_text(), 72 | color_format="none", 73 | ) 74 | raise ValueError("Unknown edit_input_format {}".format(self.config["edit_input_format"])) 75 | 76 | def _make_example_for_rec_edit(self, rec, edit, label=None): 77 | query_text = self._candidate_record_to_input_text(rec) 78 | edit_text = self._edit_to_input_text(edit) 79 | return { 80 | "doc_id": rec["doc_id"], 81 | "source_pdf_id": rec["source_pdf_id"], 82 | "target_pdf_id": rec["target_pdf_id"], 83 | "review_comment": rec["review_comment"], 84 | "first_text": "review comment: {}\n\nparagraph: {}".format(query_text, edit_text), 85 | "label": label, 86 | } 87 | 88 | def _make_dataset(self, recs, name="dataset", shuffle=False): 89 | if isinstance(recs, dict): 90 | recs = list(recs.values()) 91 | exs = [] 92 | 93 | for rec in recs: 94 | edit_with_labels = [] 95 | edit_with_labels.extend([(x, 1) for x in rec["positives"]]) 96 | edit_with_labels.extend([(x, 0) for x in rec["negatives"]]) 97 | for edit, label in edit_with_labels: 98 | exs.append(self._make_example_for_rec_edit(rec, edit, label=label)) 99 | 100 | tmp = {k: [] for k in exs[0].keys()} 101 | for ex in exs: 102 | for k, v in ex.items(): 103 | tmp[k].append(v) 104 | dset = datasets.Dataset.from_dict(tmp) 105 | 106 | if shuffle: 107 | dset = dset.shuffle() 108 | 109 | dset = dset.map( 110 | functools.partial(PairwiseTransformerAligner.preprocess_fn, tokenizer=self.tokenizer, max_length=self.max_length), 111 | batched=True, 112 | num_proc=4, 113 | load_from_cache_file=False, 114 | desc="Processing {}".format(name), 115 | ) 116 | return dset 117 | 118 | def train(self, train_recs, dev_recs): 119 | if len(train_recs) == 0: 120 | raise ValueError("Got empty train_recs") 121 | if len(dev_recs) == 0: 122 | raise ValueError("Got empty dev_recs") 123 | 124 | training_args_dict = transformers.TrainingArguments(output_dir=self.config["output_dir"], log_level="passive").to_dict() 125 | training_args_dict.update(self.config.get("training_args", dict())) 126 | training_args = transformers.HfArgumentParser(transformers.TrainingArguments).parse_dict(training_args_dict)[0] 127 | 128 | self.rng = np.random.default_rng(self.config["seed"]) 129 | for rec in train_recs: 130 | rec["negatives"] = [x for x in rec["negatives"] if x.is_full_addition()] 131 | train_dset = self._make_dataset(train_recs, shuffle=True) 132 | 133 | self.rng = np.random.default_rng(self.config["seed"]) 134 | dev_dset = self._make_dataset(dev_recs) 135 | 136 | logger.info("{} | {}".format(self.tokenizer.decode(train_dset["input_ids"][0]), self.tokenizer.decode(train_dset["labels"][0]))) 137 | 138 | data_collator = transformers.DataCollatorWithPadding( 139 | self.tokenizer, 140 | pad_to_multiple_of=None, 141 | ) 142 | 143 | model_selector_callback = AlignerEvalCallback( 144 | self.config, 145 | self, 146 | dev_recs, 147 | model_selection_metric_fn=lambda x: x["optimal_f1"], 148 | ) 149 | 150 | # TODO: Make training args configurable from model_config 151 | trainer = transformers.Trainer( 152 | model=self.model, 153 | args=training_args, 154 | train_dataset=train_dset, 155 | eval_dataset=dev_dset, 156 | tokenizer=self.tokenizer, 157 | data_collator=data_collator, 158 | callbacks=[model_selector_callback, TrainLoggerCallback(logger)], 159 | compute_metrics=None, 160 | ) 161 | 162 | _ = trainer.train() 163 | 164 | self.model.load_state_dict(model_selector_callback._best_model_state) 165 | self.model.save_pretrained(os.path.join(self.config["output_dir"], "ptmodel")) 166 | self.tokenizer.save_pretrained(os.path.join(self.config["output_dir"], "ptmodel")) 167 | 168 | def predict_many(self, test_recs): 169 | was_training = self.model.training 170 | self.model.eval() 171 | 172 | out_recs = [] 173 | with tqdm.trange(sum(len(x["candidates"]) for x in test_recs), miniters=1, desc="{}.predict_many".format(self.__class__.__name__)) as pbar: 174 | with torch.no_grad(): 175 | for rec in test_recs: 176 | outrec = { 177 | "input_record": rec, 178 | "predictions": [{"edit": cand, "pred": None, "score": None} for cand in rec["candidates"]], 179 | } 180 | 181 | out_recs.append(outrec) 182 | 183 | for pred_rec in outrec["predictions"]: 184 | tensors = self.tokenizer( 185 | self._make_example_for_rec_edit(rec, pred_rec["edit"])["first_text"], 186 | max_length=self.max_length, 187 | padding=False, 188 | truncation=True, 189 | ) 190 | out = self.model( 191 | input_ids=torch.tensor(tensors["input_ids"], device=self.model.device, dtype=torch.long).unsqueeze(0), 192 | attention_mask=torch.tensor(tensors["attention_mask"], device=self.model.device, dtype=torch.long).unsqueeze(0), 193 | ) 194 | pred_rec["pred"] = torch.argmax(out.logits, dim=-1)[0].item() 195 | pred_rec["score"] = torch.nn.functional.softmax(out.logits, dim=-1)[0].tolist()[1] 196 | pred_rec["logits"] = [out.logits[0][0].item(), out.logits[0][1].item()] 197 | 198 | pbar.update(1) 199 | 200 | self.model.train(was_training) 201 | 202 | return out_recs 203 | -------------------------------------------------------------------------------- /aries/alignment/eval.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | 7 | import numpy as np 8 | import sklearn.exceptions 9 | import sklearn.metrics 10 | import transformers 11 | 12 | from aries.util.data import index_by 13 | from aries.util.logging import pprint_metrics 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class AlignerEvalCallback(transformers.TrainerCallback): 19 | def __init__(self, config, model, eval_records, model_selection_metric_fn=None, model_to_save=None): 20 | self.config = config 21 | self.model = model 22 | self.model_to_save = model_to_save or model.model 23 | self.eval_records = eval_records 24 | # self.eval_precached_dataset = self.model._make_dataset(self.eval_records) 25 | self.eval_precached_dataset = None 26 | 27 | self.model_selection_metric_fn = model_selection_metric_fn 28 | if isinstance(model_selection_metric_fn, str): 29 | self.model_selection_metric_fn = lambda x: x[model_selection_metric_fn] 30 | 31 | self._best_metric_val = float("-inf") 32 | self._best_model_state = None 33 | 34 | @staticmethod 35 | def _clone_cpu_model_state_dict(model): 36 | return collections.OrderedDict((k, v.clone().cpu().detach()) for k, v in model.state_dict().items()) 37 | 38 | def on_evaluate(self, args, state, control, **kwargs): 39 | metrics, all_results, _ = do_model_eval(self.model, self.eval_records, eval_precached_dataset=self.eval_precached_dataset) 40 | 41 | if self.config.get("write_examples_on_eval", False): 42 | with open(os.path.join(self.config["output_dir"], "{}_inferences.jsonl".format("tmp_mid_eval")), "w") as f: 43 | for res in all_results: 44 | f.write(json.dumps(res) + "\n") 45 | 46 | pprint_metrics(metrics, logger, name="dev (mid-train)") 47 | metrics["global_step"] = state.global_step 48 | metrics["epoch"] = state.epoch 49 | metrics["total_flos"] = state.total_flos 50 | with open(os.path.join(self.config["output_dir"], "{}_metrics.jsonl".format("mid_eval")), "a") as f: 51 | f.write(json.dumps(metrics) + "\n") 52 | 53 | if self.model_selection_metric_fn is not None: 54 | metric_val = self.model_selection_metric_fn(metrics) 55 | if metric_val > self._best_metric_val: 56 | logger.info( 57 | "Got new best model at global step {} (epoch {}, {:0.2f} TFLOs)".format(state.global_step, state.epoch, state.total_flos / 1e12) 58 | ) 59 | state.best_metric = metric_val 60 | self._best_metric_val = metric_val 61 | self._best_model_state = AlignerEvalCallback._clone_cpu_model_state_dict(self.model_to_save) 62 | 63 | 64 | def _get_possible_optimal_thresholds(all_candidates): 65 | return _get_possible_optimal_thresholds_smart(all_candidates) 66 | 67 | 68 | def _get_possible_optimal_thresholds_smart(all_candidates): 69 | """Gets the thresholds that have a chance of maximizing f1; that is, 70 | thresholds at positive-negative boundaries (in descending order of score) 71 | and thresholds at the extremes.""" 72 | # Sort by descending score 73 | all_scored_candidates = sorted([x for x in all_candidates if x["score"] is not None], key=lambda x: x["score"], reverse=True) 74 | 75 | if len(all_scored_candidates) == 0: 76 | return [] 77 | 78 | # return list(range(min(x['score'] for x in all_scored_candidates), max(x['score'] for x in all_scored_candidates), 0.05)) 79 | 80 | # The possible thresholds should be the midpoints between each pos-label score and the next-lowest-scoring point, plus the endpoints 81 | possible_thresholds = [] 82 | possible_thresholds.append(all_scored_candidates[0]["score"] + 0.0001) 83 | possible_thresholds.append(all_scored_candidates[-1]["score"] - 0.0001) 84 | # We only need to consider pos-neg boundaries; if there is a run of 85 | # consecutive positive examples, it is never worse to include all of them. 86 | for candidx in range(len(all_scored_candidates)): 87 | cand0 = all_scored_candidates[candidx - 1] 88 | cand1 = all_scored_candidates[candidx] 89 | if cand0["label"] == 1 and cand1["label"] == 0: 90 | thresh = (cand0["score"] + cand1["score"]) / 2 91 | if thresh not in possible_thresholds: 92 | possible_thresholds.append(thresh) 93 | 94 | return possible_thresholds 95 | 96 | 97 | def get_pred_labels_for_threshold(thresh, all_candidates, rel_thresh=0.2): 98 | pred_labels = [] 99 | for x in all_candidates: 100 | if "score" not in x or x.get("base_pred", None) == 0: 101 | pred_labels.append(x["pred"]) 102 | elif "best_group_score" in x: 103 | pred_labels.append(1 if x["score"] > thresh and x["score"] >= (x["best_group_score"] - rel_thresh) else 0) 104 | else: 105 | pred_labels.append(1 if x["score"] > thresh else 0) 106 | return pred_labels 107 | 108 | 109 | def tune_optimal_f1_threshold(all_candidates): 110 | """Find the absolute decision threshold that maximizes F1.""" 111 | if len(all_candidates) == 0: 112 | return None, [] 113 | 114 | possible_thresholds = _get_possible_optimal_thresholds(all_candidates) 115 | 116 | if len(possible_thresholds) == 0: 117 | logger.info("Couldn't get optimal threshold because there were no scores on positive examples") 118 | return None, [x["pred"] for x in all_candidates] 119 | possible_thresholds = sorted(possible_thresholds) 120 | 121 | true_labels = [x["label"] for x in all_candidates] 122 | best = (-float("inf"), None, None) 123 | for thresh in possible_thresholds: 124 | pred_labels = get_pred_labels_for_threshold(thresh, all_candidates) 125 | 126 | f1 = sklearn.metrics.f1_score(true_labels, pred_labels) 127 | if f1 > best[0]: 128 | best = (f1, thresh, pred_labels) 129 | 130 | return best[1], best[2] 131 | 132 | 133 | def full_tune_optimal_thresholds(all_candidates, min_recall=None, num_abs_thresholds=100, num_rel_thresholds=100, abs_thresh=None, rel_thresh=None): 134 | """Find the combination of absolute and relative decision thresholds that 135 | maximize F1. If abs_thresh or rel_thresh are set, only the other one will 136 | be tuned. However, note that this is less efficient and precise than 137 | tune_optimal_f1_threshold if only the absolute threshold needs to be tuned. 138 | To tune the relative threshold, records in all_candidates must have 139 | a "best_group_score" field set.""" 140 | 141 | if len(all_candidates) == 0: 142 | return None, None, [] 143 | 144 | if abs_thresh is not None and rel_thresh is not None: 145 | raise ValueError("Cannot specify both abs_thresh and rel_thresh") 146 | 147 | possible_abs_threshs = [abs_thresh] 148 | possible_rel_threshs = [rel_thresh] 149 | 150 | if abs_thresh is None: 151 | # First, find the maximum pred_threshold that achieves the minimum recall 152 | max_threshold = max(x["score"] for x in all_candidates) 153 | if min_recall > 0: 154 | # We can be efficient by just going down the list in score order 155 | # until we have enough positives (min_recall 156 | # * num positives in all_candidates) 157 | all_candidates.sort(key=lambda x: x["score"], reverse=True) 158 | num_positives = sum(x["label"] == 1 for x in all_candidates) 159 | num_positives_needed = min_recall * num_positives 160 | num_positives_found = 0 161 | for idx, x in enumerate(all_candidates): 162 | if x["label"] == 1: 163 | num_positives_found += 1 164 | if num_positives_found >= num_positives_needed: 165 | max_threshold = x["score"] 166 | break 167 | if num_positives_found < num_positives_needed: 168 | logger.warning("Unable to find enough positives to achieve tuning_minimum_recall of {}".format(min_recall)) 169 | # We're done; thresholds must be low enough to predict positive for everything 170 | min_score = min(x["score"] for x in all_candidates) 171 | max_score = max(x["score"] for x in all_candidates) 172 | return min_score, (max_score - min_score), [1] * len(all_candidates) 173 | possible_abs_threshs = np.linspace(0, max_threshold, num_abs_thresholds) 174 | 175 | if rel_thresh is None: 176 | max_rel_pred_threshold = max(x["score"] for x in all_candidates) - max_threshold 177 | # Iterate rel thresholds from high to low; if we miss the recall target 178 | # we can exit early 179 | possible_rel_threshs = np.linspace(max_rel_pred_threshold, 0, num_rel_thresholds) 180 | 181 | # Now find the combination of pred_threshold and rel_pred_threshold 182 | # that maximizes f1 while achieving the minimum recall 183 | best_f1 = 0 184 | best_thresholds = (0, 0) 185 | best_pred_labels = [] 186 | for pred_threshold in possible_abs_threshs: 187 | for rel_pred_threshold in possible_rel_threshs: 188 | labels = [x["label"] for x in all_candidates] 189 | pred_labels = get_pred_labels_for_threshold(pred_threshold, all_candidates, rel_pred_threshold) 190 | 191 | recall = sklearn.metrics.recall_score(labels, pred_labels) 192 | if recall < min_recall: 193 | break 194 | 195 | f1 = sklearn.metrics.f1_score(labels, pred_labels) 196 | if f1 > best_f1: 197 | best_f1 = f1 198 | best_thresholds = (pred_threshold, rel_pred_threshold) 199 | best_pred_labels = pred_labels 200 | 201 | return best_thresholds[0], best_thresholds[1], best_pred_labels 202 | 203 | 204 | def group_macro_prf1(labels, preds, group_ids, include_empty=False): 205 | grouped_comments = {gid: [] for gid in set(group_ids)} 206 | if not (len(labels) == len(preds)) and (len(labels) == len(group_ids)): 207 | raise ValueError("need len(labels) ({}) == len(preds) ({}) == len(group_ids) ({})".format(len(labels), len(preds), len(group_ids))) 208 | 209 | if len(labels) == 0: 210 | return float("nan"), float("nan"), float("nan"), float("nan") 211 | 212 | for idx in range(len(labels)): 213 | grouped_comments[group_ids[idx]].append((labels[idx], preds[idx])) 214 | group_prf1s = [] 215 | group_ps = [] 216 | group_rs = [] 217 | group_f1s = [] 218 | group_ems = [] 219 | for gid, group in sorted(grouped_comments.items()): 220 | labels, preds = list(zip(*group)) 221 | if any(x == 1 for x in preds): 222 | p = sklearn.metrics.precision_score(labels, preds) 223 | group_ps.append(p) 224 | else: 225 | p = 1 226 | if include_empty: 227 | group_ps.append(p) 228 | 229 | if any(x == 1 for x in labels): 230 | r = sklearn.metrics.recall_score(labels, preds) 231 | group_rs.append(r) 232 | else: 233 | r = 1 234 | if include_empty: 235 | group_rs.append(r) 236 | 237 | if any(x == 1 for x in preds) or any(x == 1 for x in labels): 238 | f1 = sklearn.metrics.f1_score(labels, preds, zero_division="warn") 239 | group_f1s.append(f1) 240 | else: 241 | f1 = 1 242 | if include_empty: 243 | group_f1s.append(f1) 244 | 245 | group_ems.append(1 if all(x == y for x, y in zip(labels, preds)) else 0) 246 | 247 | group_prf1s.append( 248 | ( 249 | p, 250 | r, 251 | sklearn.metrics.f1_score(labels, preds, zero_division=1), 252 | ) 253 | ) 254 | 255 | if include_empty: 256 | pmean, rmean, f1mean = np.mean(np.array(group_prf1s), axis=0).tolist() 257 | else: 258 | pmean = np.mean(group_ps).tolist() 259 | rmean = np.mean(group_rs).tolist() 260 | f1mean = np.mean(group_f1s).tolist() 261 | 262 | return pmean, rmean, f1mean, np.mean(group_ems).tolist() 263 | 264 | 265 | def do_model_eval(model, eval_records, eval_precached_dataset=None, custom_decision_threshold=None, custom_threshold_name="custom_threshold"): 266 | for rec in eval_records: 267 | rec["candidates"] = rec["positives"] + rec["negatives"] + rec.get("unknowns", []) 268 | rec["candidate_labels"] = [1] * len(rec["positives"]) + [0] * len(rec["negatives"]) + [None] * len(rec.get("unknowns", [])) 269 | all_results = model.predict_many(eval_records) 270 | 271 | if len(all_results) != len(eval_records): 272 | raise ValueError("Number of results ({}) does not match number of records ({})".format(len(all_results), len(eval_records))) 273 | 274 | comment2id = dict() 275 | all_candidates = [] 276 | candidate_comment_ids = [] 277 | for rec in all_results: 278 | if rec["input_record"]["review_comment"] not in comment2id: 279 | comment2id[rec["input_record"]["review_comment"]] = len(comment2id) 280 | 281 | for idx, ex in enumerate(rec["predictions"]): 282 | ex["label"] = rec["input_record"]["candidate_labels"][idx] 283 | all_candidates.append(ex) 284 | candidate_comment_ids.append(comment2id[rec["input_record"]["review_comment"]]) 285 | 286 | true_labels = [x["label"] for x in all_candidates] 287 | 288 | def metrics_for_predictions(pred_labels, prefix=""): 289 | nonlocal true_labels 290 | _, _, _, exactmatch = group_macro_prf1(true_labels, pred_labels, candidate_comment_ids, include_empty=False) 291 | ie_macro_p, ie_macro_r, ie_macro_f1, _ = group_macro_prf1(true_labels, pred_labels, candidate_comment_ids, include_empty=True) 292 | metrics = { 293 | "accuracy": sklearn.metrics.accuracy_score(true_labels, pred_labels), 294 | "precision": sklearn.metrics.precision_score(true_labels, pred_labels), 295 | "recall": sklearn.metrics.recall_score(true_labels, pred_labels), 296 | "f1": sklearn.metrics.f1_score(true_labels, pred_labels), 297 | "macro_precision": ie_macro_p, 298 | "macro_recall": ie_macro_r, 299 | "macro_f1": ie_macro_f1, 300 | "exact_match": exactmatch, 301 | "n_pred_positive": sum(1 for x in pred_labels if x == 1), 302 | } 303 | 304 | return {(prefix + k): v for k, v in metrics.items()} 305 | 306 | metrics = dict() 307 | pred_labels = [x["pred"] for x in all_candidates] 308 | metrics.update(metrics_for_predictions(pred_labels, prefix="")) 309 | 310 | optimal_threshold, optimal_pred_labels = tune_optimal_f1_threshold(all_candidates) 311 | if optimal_threshold is not None: 312 | logger.info("Got optimal threshold: {:0.3f}".format(optimal_threshold)) 313 | metrics.update(metrics_for_predictions(optimal_pred_labels, prefix="optimal_")) 314 | metrics["optimal_decision_threshold"] = optimal_threshold 315 | 316 | if custom_decision_threshold is not None: 317 | custom_pred_labels = get_pred_labels_for_threshold(custom_decision_threshold, all_candidates) 318 | metrics.update(metrics_for_predictions(custom_pred_labels, prefix=(custom_threshold_name + "_"))) 319 | metrics[(custom_threshold_name + "_decision_threshold")] = custom_decision_threshold 320 | 321 | metrics.update( 322 | { 323 | "n_true_positive": sum(1 for x in all_candidates if x["label"] == 1), 324 | "n_candidates": len(all_candidates), 325 | "n_comments": len(eval_records), 326 | } 327 | ) 328 | 329 | serializable_results = [] 330 | for res in all_results: 331 | sres = dict() 332 | for k, v in res["input_record"].items(): 333 | try: 334 | json.dumps(v) 335 | sres[k] = v 336 | except TypeError: 337 | pass 338 | 339 | cands = [] 340 | for pred_rec in res["predictions"]: 341 | edit = pred_rec["edit"] 342 | scand = {k: v for k, v in pred_rec.items() if k not in ["edit"]} 343 | scand["edit_source_idxs"] = edit.source_idxs 344 | scand["edit_target_idxs"] = edit.target_idxs 345 | if any(x >= len(edit.doc_edits.s2orc2["pdf_parse"]["body_text"]) and x != 9999 for x in edit.target_idxs): 346 | raise KeyError( 347 | "Out of bounds! {} {} {} {}".format( 348 | edit.doc_edits.s2orc2["paper_id"], 349 | len(edit.doc_edits.s2orc2["pdf_parse"]["body_text"]), 350 | str(edit.target_idxs), 351 | edit.get_target_text(), 352 | ) 353 | ) 354 | scand["edit_source_pdf_id"] = edit.doc_edits.s2orc1["paper_id"] 355 | scand["edit_target_pdf_id"] = edit.doc_edits.s2orc2["paper_id"] 356 | cands.append(scand) 357 | sres["candidates"] = cands 358 | 359 | serializable_results.append(sres) 360 | 361 | return ( 362 | metrics, 363 | serializable_results, 364 | # pair_results, 365 | index_by( 366 | serializable_results, 367 | lambda x: (x["doc_id"], x["source_pdf_id"], x["target_pdf_id"], x["review_comment"]), 368 | ), 369 | ) 370 | -------------------------------------------------------------------------------- /aries/alignment/gpt.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | 7 | import tqdm 8 | 9 | from aries.util.data import index_by 10 | from aries.util.edit import make_word_diff 11 | from aries.util.gpt3 import Gpt3CacheClient 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class GptChatAligner: 17 | def __init__(self, config): 18 | self.config = config 19 | self.system_prompt = self.config["gpt_system_prompt"] 20 | self.prompt_template = self.config["gpt_prompt_template"] 21 | self.model_name = self.config["gpt_model"] 22 | self.max_length = self.config["gpt_max_length"] 23 | self.cache_db_path = self.config["cache_db_path"] 24 | 25 | def train(self, train_recs, dev_recs): 26 | logger.warning("GPT doesn't train; ignoring call to train()") 27 | 28 | def _predict_one(self, comment, edit, gptcli): 29 | tags = { 30 | "{{review_comment}}": comment, 31 | "{{target_paragraph}}": edit.get_target_text(), 32 | "{{source_paragraph}}": edit.get_source_text(), 33 | "{{diff_paragraph}}": make_word_diff( 34 | edit.get_source_text(), 35 | edit.get_target_text(), 36 | color_format="none", 37 | ), 38 | } 39 | msg = self.prompt_template 40 | for k, v in sorted(tags.items()): 41 | msg = msg.replace(k, v) 42 | messages = [ 43 | {"role": "system", "content": self.system_prompt}, 44 | {"role": "user", "content": msg}, 45 | ] 46 | resp = gptcli.chat_completion( 47 | model=self.model_name, 48 | messages=messages, 49 | temperature=0.0, 50 | max_tokens=self.max_length, 51 | top_p=1, 52 | frequency_penalty=0, 53 | presence_penalty=0, 54 | ) 55 | result_text = resp["choices"][0]["message"]["content"] 56 | result_words = set(result_text.lower().replace(".", " ").replace(",", " ").replace("\n", " ").replace('"', " ").replace("'", " ").split(" ")) 57 | # Extract yes/no answer from response text 58 | has_yes = "yes" in result_words or "answer=yes" in result_words 59 | has_no = "no" in result_words or "answer=no" in result_words 60 | pred = None 61 | if has_yes and has_no: 62 | pred = None 63 | raise ValueError("Got both yes and no in response") 64 | elif has_yes: 65 | pred = 1 66 | elif has_no: 67 | pred = 0 68 | else: 69 | logger.error("Bad response: {}".format(result_text)) 70 | raise ValueError("Got neither yes nor no in response") 71 | return pred, resp 72 | 73 | def predict_many(self, test_recs): 74 | out_recs = [] 75 | 76 | total_tokens, uncached_total_tokens = 0, 0 77 | loopname = "{}.predict_many".format(self.__class__.__name__) 78 | with tqdm.trange(sum(len(x["candidates"]) for x in test_recs), miniters=1, desc=loopname) as pbar: 79 | with Gpt3CacheClient(self.cache_db_path) as gptcli: 80 | for rec in test_recs: 81 | outrec = { 82 | "input_record": rec, 83 | "predictions": [{"edit": cand, "pred": None, "score": None} for cand in rec["candidates"]], 84 | } 85 | out_recs.append(outrec) 86 | 87 | for pred_rec in outrec["predictions"]: 88 | pred_label, resp = self._predict_one(rec["review_comment"], pred_rec["edit"], gptcli) 89 | total_tokens += resp["usage"]["total_tokens"] 90 | uncached_total_tokens += resp["usage"]["uncached_total_tokens"] 91 | 92 | pred_rec["pred"] = pred_label 93 | pred_rec["score"] = pred_label 94 | 95 | pbar.set_description(f"{loopname} tt={total_tokens} utt={uncached_total_tokens}", refresh=False) 96 | pbar.update(1) 97 | 98 | return out_recs 99 | 100 | 101 | class GptChatFullPaperAligner: 102 | def __init__(self, config): 103 | self.config = config 104 | self.system_prompt = self.config["gpt_system_prompt"] 105 | self.prompt_template = self.config["gpt_prompt_template"] 106 | self.model_name = self.config["gpt_model"] 107 | self.max_length = self.config["gpt_max_length"] 108 | self.cache_db_path = self.config["cache_db_path"] 109 | self.output_dir = self.config.get("output_dir", None) 110 | self.max_response_length = 500 111 | 112 | self.raw_responses = [] 113 | 114 | def train(self, train_recs, dev_recs): 115 | logger.warning("GPT doesn't train; ignoring call to train()") 116 | 117 | def _make_chunked_paper_diff(self, doc_edits, chunk_size, gptcli): 118 | full_diff_string, edits_by_id = doc_edits.make_paper_diff_string( 119 | color_format="none", 120 | print_ids_only=True, 121 | return_edit_ids=True, 122 | ) 123 | 124 | para_chunks = full_diff_string.split("\n\n") 125 | 126 | diff_chunks = [] 127 | cur_chunk = [] 128 | cur_chunk_len = 0 129 | # Note: we don't account for individual paras being bigger than 130 | # chunk_size; that probably never happens anyway 131 | for para_chunk in para_chunks: 132 | # Add 2 for the stripped \n\n 133 | new_chunk_len = gptcli.estimate_num_tokens(para_chunk, self.model_name) + 2 134 | if cur_chunk_len + new_chunk_len > chunk_size: 135 | diff_chunks.append("\n\n".join(cur_chunk)) 136 | cur_chunk = [] 137 | cur_chunk_len = 0 138 | cur_chunk.append(para_chunk) 139 | cur_chunk_len += new_chunk_len 140 | 141 | if len(cur_chunk) != 0: 142 | diff_chunks.append("\n\n".join(cur_chunk)) 143 | 144 | return diff_chunks, edits_by_id 145 | 146 | def _make_comments_text_blob(self, recs): 147 | comments_text_blob = "" 148 | for idx, comment in enumerate(recs): 149 | comments_text_blob += comment.replace("\n", " ") + "\ncomment id: {}\n\n".format(idx) 150 | return comments_text_blob 151 | 152 | def _predict_one_doc(self, doc_edits, comments, gptcli): 153 | comments_text = self._make_comments_text_blob(comments) 154 | 155 | base_length = gptcli.estimate_num_tokens(self.prompt_template, self.model_name) + gptcli.estimate_num_tokens( 156 | self.system_prompt, self.model_name 157 | ) 158 | if "{{review_comments}}" in self.prompt_template: 159 | base_length += gptcli.estimate_num_tokens(comments_text, self.model_name) 160 | chunk_size = self.max_length - base_length - self.max_response_length 161 | 162 | diff_chunks, edits_by_id = self._make_chunked_paper_diff(doc_edits, chunk_size=chunk_size, gptcli=gptcli) 163 | 164 | all_response_lines_by_comment = {idx: [] for idx in range(len(comments))} 165 | total_tokens, uncached_total_tokens = 0, 0 166 | for chunk in diff_chunks: 167 | tags = { 168 | "{{review_comments}}": comments_text, 169 | "{{paper_diff_chunk}}": chunk, 170 | } 171 | msg = self.prompt_template 172 | for k, v in sorted(tags.items()): 173 | msg = msg.replace(k, v) 174 | messages = [ 175 | {"role": "system", "content": self.system_prompt}, 176 | {"role": "user", "content": msg}, 177 | ] 178 | if base_length + gptcli.estimate_num_tokens(chunk, self.model_name) + self.max_response_length > 8150: 179 | print(base_length, gptcli.estimate_num_tokens(chunk, self.model_name), self.max_response_length) 180 | print() 181 | try: 182 | resp = gptcli.chat_completion( 183 | model=self.model_name, 184 | messages=messages, 185 | temperature=0.0, 186 | # max_tokens=self.max_length, 187 | max_tokens=self.max_response_length, 188 | top_p=1, 189 | frequency_penalty=0, 190 | presence_penalty=0, 191 | ) 192 | except Exception as e: 193 | breakpoint() 194 | print(e) 195 | total_tokens += resp["usage"]["total_tokens"] 196 | uncached_total_tokens += resp["usage"]["uncached_total_tokens"] 197 | result_text = resp["choices"][0]["message"]["content"] 198 | 199 | self.raw_responses.append( 200 | { 201 | # "doc_id": doc_id, 202 | "source_pdf_id": doc_edits.s2orc1["paper_id"], 203 | "target_pdf_id": doc_edits.s2orc2["paper_id"], 204 | "comments": comments, 205 | "comments_text": comments_text, 206 | "response_text": result_text, 207 | } 208 | ) 209 | 210 | for line in result_text.split("\n"): 211 | # Imperfect but good-enough detection of JSON lines 212 | if not line.startswith("{"): 213 | continue 214 | 215 | # Hacky; fix some specific known failures 216 | line = line.replace(" \\phi", " \\\\phi") 217 | try: 218 | obj = json.loads(line) 219 | except json.JSONDecodeError as e: 220 | logger.error("Failed to parse JSON line: {}".format(line)) 221 | # raise e 222 | continue 223 | all_response_lines_by_comment[obj["comment_id"]].append(obj) 224 | 225 | results = [] 226 | for comment_id, resps in all_response_lines_by_comment.items(): 227 | # Ignore the abstract (9999) since it isn't diffed in DocEdits 228 | all_edit_ids = sorted(set(itertools.chain(*[x["edit_ids"] for x in resps])) - {9999}) 229 | results.append( 230 | { 231 | "review_comment": comments[comment_id], 232 | "predicted_positive_edits": [ 233 | { 234 | "source_idxs": edits_by_id[x].source_idxs, 235 | "target_idxs": edits_by_id[x].target_idxs, 236 | } 237 | for x in all_edit_ids 238 | ], 239 | } 240 | ) 241 | 242 | usage_info = { 243 | "total_tokens": total_tokens, 244 | "uncached_total_tokens": uncached_total_tokens, 245 | } 246 | return results, usage_info 247 | 248 | def predict_many(self, test_recs): 249 | out_recs = [] 250 | 251 | # We need to run the model for the pdf pair of each *candidate*, since 252 | # it is possible to have candidates sampled from other documents than 253 | # the one the comment was for. 254 | comment_pdf_pairs = [] 255 | for rec in test_recs: 256 | for edit in rec["candidates"]: 257 | comment_pdf_pairs.append( 258 | { 259 | "comment": rec["review_comment"], 260 | "pdf_pair": (edit.doc_edits.s2orc1["paper_id"], edit.doc_edits.s2orc2["paper_id"]), 261 | "doc_edits": edit.doc_edits, 262 | } 263 | ) 264 | # For consistency, include comments in the many-to-many alignment 265 | # even when no candidates are given 266 | if "source_pdf_id" in rec and "target_pdf_id" in rec: 267 | comment_pdf_pairs.append( 268 | { 269 | "comment": rec["review_comment"], 270 | "pdf_pair": (rec["source_pdf_id"], rec["target_pdf_id"]), 271 | "doc_edits": None, 272 | } 273 | ) 274 | 275 | comment_pdf_pairs_by_pdf = index_by(comment_pdf_pairs, "pdf_pair") 276 | 277 | total_tokens, uncached_total_tokens = 0, 0 278 | with Gpt3CacheClient(self.cache_db_path) as gptcli: 279 | loopname = "{}.predict_many".format(self.__class__.__name__) 280 | predictions_by_pdf = dict() 281 | pbar = tqdm.tqdm(comment_pdf_pairs_by_pdf.items(), miniters=1, desc=loopname) 282 | for pdf_pair, comment_recs in pbar: 283 | if all(x["doc_edits"] is None for x in comment_recs): 284 | # No candidates for this pdf pair, so skip it 285 | continue 286 | predictions_by_pdf[pdf_pair], token_usage = self._predict_one_doc( 287 | [x for x in comment_recs if x["doc_edits"] is not None][0]["doc_edits"], 288 | sorted(set([x["comment"] for x in comment_recs])), 289 | gptcli, 290 | ) 291 | predictions_by_pdf[pdf_pair] = index_by(predictions_by_pdf[pdf_pair], "review_comment", one_to_one=True) 292 | 293 | total_tokens += token_usage["total_tokens"] 294 | uncached_total_tokens += token_usage["uncached_total_tokens"] 295 | pbar.set_description(f"{loopname} tt={total_tokens} utt={uncached_total_tokens}", refresh=False) 296 | 297 | for rec in test_recs: 298 | outrec = { 299 | "input_record": rec, 300 | "predictions": [{"edit": cand, "pred": None, "score": None} for cand in rec["candidates"]], 301 | } 302 | 303 | out_recs.append(outrec) 304 | 305 | for pred in outrec["predictions"]: 306 | pred_rec = predictions_by_pdf[(pred["edit"].doc_edits.s2orc1["paper_id"], pred["edit"].doc_edits.s2orc2["paper_id"])][ 307 | rec["review_comment"] 308 | ] 309 | pos_edits = [] if pred_rec is None else pred_rec["predicted_positive_edits"] 310 | pred_label = 0 311 | for edit in pos_edits: 312 | if (sorted(edit["source_idxs"]) == sorted(pred["edit"].source_idxs)) and ( 313 | sorted(edit["target_idxs"]) == sorted(pred["edit"].target_idxs) 314 | ): 315 | pred_label = 1 316 | break 317 | pred["pred"] = pred_label 318 | pred["score"] = pred_label 319 | 320 | if self.output_dir is not None: 321 | with open(os.path.join(self.output_dir, "raw_gpt_outputs.json"), "w") as f: 322 | json.dump(self.raw_responses, f) 323 | 324 | return out_recs 325 | -------------------------------------------------------------------------------- /aries/alignment/other.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | 5 | 6 | class MultiStageAligner: 7 | def __init__(self, config, aligners): 8 | self.config = config 9 | self.aligners = aligners 10 | 11 | self.prune_candidates = config.get("prune_candidates", False) 12 | 13 | def train(self, train_recs, dev_recs): 14 | logger.info("Multi-stage aligner doesn't train; skipping...") 15 | 16 | def _update_candidate_scores(self, candidate): 17 | # Fill base_pred, pred, and score based on the stack of aligner predictions 18 | candidate["base_pred"] = None 19 | candidate["pred"] = None 20 | candidate["score"] = None 21 | if len(candidate["predictions"]) == 0: 22 | return 23 | 24 | # If any aligner predicts 0, then the candidate's pred is 0. The 25 | # base_pred is 0 if any aligner other than the last one predicts 0 (1 otherwise). 26 | # The score is the final aligner's score. 27 | for pred_idx, pred_rec in enumerate(candidate["predictions"]): 28 | if pred_rec is None: 29 | continue 30 | if pred_rec["pred"] == 0: 31 | if pred_idx < len(candidate["predictions"]) - 1: 32 | candidate["base_pred"] = 0 33 | candidate["pred"] = 0 34 | elif pred_rec["pred"] == 1 and candidate["base_pred"] is None: 35 | if pred_idx < len(candidate["predictions"]) - 1: 36 | candidate["base_pred"] = 1 37 | candidate["pred"] = 1 38 | 39 | if candidate["predictions"][-1] is not None: 40 | candidate["score"] = candidate["predictions"][-1]["score"] 41 | 42 | def predict_many(self, *args, **kwargs): 43 | results = self._predict_many(*args, **kwargs) 44 | return results 45 | 46 | def _predict_many(self, test_recs): 47 | out_recs = [] 48 | 49 | for rec in test_recs: 50 | out_recs.append( 51 | { 52 | "input_record": rec, 53 | "predictions": [{"edit": x, "predictions": [], "base_pred": None, "pred": None, "score": None} for x in rec["candidates"]], 54 | } 55 | ) 56 | 57 | backmaps = [list(range(len(x["candidates"]))) for x in test_recs] 58 | 59 | # Don't modify the input test_recs if we need to prune 60 | cur_recs = test_recs 61 | if self.prune_candidates: 62 | cur_recs = [x.copy() for x in test_recs] 63 | for rec in cur_recs: 64 | rec["candidates"] = rec["candidates"].copy() 65 | 66 | pruned_idxs = [set() for x in test_recs] 67 | for aligner_idx, aligner in enumerate(self.aligners): 68 | logger.info(f"Running aligner {aligner_idx + 1} of {len(self.aligners)} ({aligner.__class__.__name__})") 69 | predictions = aligner.predict_many(cur_recs) 70 | 71 | # Update the corresponding prediction lists, keeping track of the 72 | # back-mappings from pruned candidates 73 | for recidx, rec in enumerate(predictions): 74 | for candidx, cand in enumerate(rec["predictions"]): 75 | out_cand = out_recs[recidx]["predictions"][backmaps[recidx][candidx]] 76 | 77 | # Hack: need to remove 'edit' to make the cands 78 | # JSON-serializable 79 | assert out_cand["edit"] == cand["edit"] 80 | del cand["edit"] 81 | 82 | out_cand["predictions"].append(cand) 83 | self._update_candidate_scores(out_cand) 84 | if out_cand["pred"] is None: 85 | breakpoint() 86 | print(out_cand["pred"]) 87 | 88 | if self.prune_candidates: 89 | # Append None to predictions for any candidates that were pruned 90 | # by previous aligners 91 | for recidx, rec in enumerate(out_recs): 92 | for candidx in pruned_idxs[recidx]: 93 | rec["predictions"][candidx]["predictions"].append(None) 94 | self._update_candidate_scores(rec["predictions"][candidx]) 95 | 96 | if aligner_idx < len(self.aligners) - 1: 97 | # Prune anything that was predicted to be 0 98 | candidates_to_prune = [] 99 | for recidx, rec in enumerate(predictions): 100 | for candidx, cand in enumerate(rec["predictions"]): 101 | if cand["pred"] == 0: 102 | candidates_to_prune.append((recidx, candidx)) 103 | 104 | # Reverse sort is important to ensure indices don't shift as we prune them 105 | for recidx, candidx in sorted(candidates_to_prune, key=lambda x: x[1], reverse=True): 106 | backmaps[recidx].pop(candidx) 107 | cur_recs[recidx]["candidates"].pop(candidx) 108 | pruned_idxs[recidx].add(candidx) 109 | 110 | return out_recs 111 | -------------------------------------------------------------------------------- /aries/alignment/precomputed.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | 6 | from aries.util.data import index_by, openc 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class PrecomputedEditsAligner: 12 | def __init__(self, config): 13 | self.config = config 14 | 15 | def train(self, train_recs, dev_recs): 16 | logger.warning("{} doesn't train; ignoring call to train()".format(self.__class__.__name__)) 17 | 18 | def predict_many(self, test_recs): 19 | out_recs = [] 20 | 21 | predictions_by_docid = dict() 22 | with openc(self.config["precomputed_predictions_jsonl_path"], "rt") as f: 23 | predictions_by_docid = index_by(map(json.loads, f), "doc_id") 24 | 25 | warned_docs = set() 26 | for rec in test_recs: 27 | outrec = { 28 | "input_record": rec, 29 | "predictions": [{"edit": cand, "pred": None, "score": None} for cand in rec["candidates"]], 30 | } 31 | out_recs.append(outrec) 32 | 33 | if rec["doc_id"] not in predictions_by_docid: 34 | if rec["doc_id"] not in warned_docs: 35 | logger.warning("missing prediction for doc: {}".format(rec["doc_id"])) 36 | warned_docs.add(rec["doc_id"]) 37 | for cand_rec in outrec["predictions"]: 38 | cand_rec["pred"] = 0 39 | cand_rec["score"] = 0 40 | continue 41 | 42 | pred_recs = predictions_by_docid[rec["doc_id"]] 43 | pred_rec = None 44 | for rec2 in pred_recs: 45 | # if rec["review_comment"] == dset_rec["review_comment"]: 46 | # if rec["review_comment"].strip(".\n ") == rec2["review_comment"].strip(".\n "): 47 | if rec["review_comment"].strip() == rec2["comment"].strip(): 48 | pred_rec = rec2 49 | break 50 | if pred_rec is None: 51 | logger.warning("Missing prediction match for comment: {}".format(rec["review_comment"])) 52 | 53 | for cand_rec in outrec["predictions"]: 54 | pred_label = 0 55 | for edit_id in pred_rec["positive_edits"]: 56 | if edit_id == cand_rec["edit"].edit_id: 57 | pred_label = 1 58 | break 59 | if cand_rec["edit"].is_identical(): 60 | pred_label = 0 61 | cand_rec["pred"] = pred_label 62 | cand_rec["score"] = pred_label 63 | 64 | return out_recs 65 | -------------------------------------------------------------------------------- /aries/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/aries/5691bae71a101225ed345d0ffc42e47609f03bbb/aries/util/__init__.py -------------------------------------------------------------------------------- /aries/util/color.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | # Adapted from jupyterlab css 4 | COLOR_TABLE = { 5 | "black": {"hex": "3e424d", "ansi": "30"}, 6 | "red": {"hex": "e75c58", "ansi": "31"}, 7 | "green": {"hex": "00a050", "ansi": "32"}, 8 | "yellow": {"hex": "ddbb33", "ansi": "33"}, 9 | "blue": {"hex": "2090ff", "ansi": "34"}, 10 | "magenta": {"hex": "d060c0", "ansi": "35"}, 11 | "cyan": {"hex": "60c7c7", "ansi": "36"}, 12 | "white": {"hex": "c0c0b0", "ansi": "37"}, 13 | "strong-black": {"hex": "303030", "ansi": "90"}, 14 | "strong-red": {"hex": "b03030", "ansi": "91"}, 15 | "strong-green": {"hex": "007030", "ansi": "92"}, 16 | "strong-yellow": {"hex": "b08010", "ansi": "93"}, 17 | "strong-blue": {"hex": "0070dd", "ansi": "94"}, 18 | "strong-magenta": {"hex": "a03090", "ansi": "95"}, 19 | "strong-cyan": {"hex": "209090", "ansi": "96"}, 20 | "strong-white": {"hex": "a0a0b0", "ansi": "97"}, 21 | } 22 | 23 | 24 | def colorify(s: str, color: str, bold: bool = False, form="html", tag_side="both"): 25 | """if tag_side is 'left', only the left tag is added. If tag_side irght 26 | 'right', only the right tag is added. This is useful if, for example, 27 | a list of tokens needs to be colored without joining the tokens. Raises an 28 | error if this is not possible for the given form.""" 29 | if color is None or form == "none": 30 | return s 31 | 32 | m = re.match(r"#(?P[0-9a-fA-F]{6})", color) 33 | valid_ansi = False 34 | if not m: 35 | if color in COLOR_TABLE: 36 | valid_ansi = True 37 | hex_color = COLOR_TABLE[color]["hex"] 38 | else: 39 | raise ValueError("Invalid color {}".format(color)) 40 | else: 41 | hex_color = m.group("hexcode") 42 | 43 | left_tag, right_tag = "", "" 44 | if form == "html": 45 | bold_code = "font-weight: bold;" if bold else "" 46 | left_tag = ''.format(code=hex_color, boldness=bold_code) 47 | right_tag = "" 48 | elif form == "ansi" and valid_ansi: 49 | bold_code = "1" if bold else "0" 50 | left_tag = "\033[{boldness};{code}m".format(code=COLOR_TABLE[color]["ansi"], boldness=bold_code) 51 | right_tag = "\033[0m" 52 | else: 53 | raise ValueError("Invalid format {}".format(form)) 54 | 55 | if tag_side == "left": 56 | return left_tag + s 57 | elif tag_side == "right": 58 | return s + right_tag 59 | elif tag_side == "both": 60 | return left_tag + s + right_tag 61 | raise ValueError("Invalid tag_side {}".format(tag_side)) 62 | 63 | 64 | def colorprint(s, color=None, bold=False, form="ansi", *print_args, **print_kwargs): 65 | return print(colorify(s, color, bold=bold, form=form), *print_args, **print_kwargs) 66 | -------------------------------------------------------------------------------- /aries/util/data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import gzip 3 | import itertools 4 | import json 5 | import lzma 6 | import os 7 | import sqlite3 8 | from typing import Any, Callable, Dict, Iterable, Iterator, List, Union 9 | 10 | import numpy as np 11 | 12 | try: 13 | import zstandard 14 | except ImportError: 15 | zstandard = None 16 | 17 | try: 18 | import orjson 19 | except ImportError: 20 | orjson = json 21 | 22 | 23 | class ReservoirSampler: 24 | def __init__(self, size, rng=None): 25 | self.size = size 26 | self.rng = rng or np.random.default_rng() 27 | 28 | self.reservoir = [] 29 | self.n_seen = 0 30 | 31 | def add(self, x): 32 | self.n_seen += 1 33 | if len(self.reservoir) < self.size: 34 | self.reservoir.append(x) 35 | else: 36 | idx = self.rng.integers(0, self.n_seen) 37 | if idx < self.size: 38 | self.reservoir[idx] = x 39 | 40 | def add_many(self, xs): 41 | for x in xs: 42 | self.add(x) 43 | 44 | def get_reservoir(self): 45 | return self.reservoir 46 | 47 | 48 | def openc(fname, mode="rt", *, compression="auto", **kwargs): 49 | """Opens a file, transparently handling a variety of possible compression schemes.""" 50 | if mode == "w": 51 | mode = "wt" 52 | 53 | if mode == "x": 54 | mode = "xt" 55 | 56 | kwargs["mode"] = mode 57 | if compression == "auto": 58 | # TODO: Maybe use magic number instead of extension 59 | if fname.lower().endswith(".gz"): 60 | compression = "gzip" 61 | elif fname.lower().endswith(".xz"): 62 | compression = "lzma" 63 | elif fname.lower().endswith(".zst"): 64 | compression = "zstd" 65 | else: 66 | compression = "none" 67 | 68 | open_fn = open 69 | if compression == "gzip": 70 | open_fn = gzip.open 71 | elif compression == "lzma": 72 | open_fn = lzma.open 73 | elif compression == "zstd": 74 | if zstandard is None: 75 | raise ValueError("zstandard module is not available") 76 | open_fn = zstandard.open 77 | 78 | return open_fn(fname, **kwargs) 79 | 80 | 81 | def iter_jsonl_files(infiles): 82 | if isinstance(infiles, str): 83 | infiles = [infiles] 84 | for infile in infiles: 85 | with openc(infile) as f: 86 | for obj in map(orjson.loads, f): 87 | yield obj 88 | 89 | 90 | def zip_strict(*iterables): 91 | # Until python 3.10, seems like there's no builtin way to do this, but 92 | # there's a fairly simple workaround implementation: 93 | # https://stackoverflow.com/a/32954700 94 | canary = object() 95 | for tup in itertools.zip_longest(*iterables, fillvalue=canary): 96 | if canary in tup: 97 | raise ValueError("Iterables have different lengths") 98 | yield tup 99 | 100 | 101 | def downsample_recs(recs: List[Any], downsample_config: Dict[str, Any]): 102 | if downsample_config is None: 103 | # Return recs, for consistency with old configs before downsampling was added 104 | return recs.copy() 105 | 106 | if downsample_config.get("keep_n", -1) != -1 and downsample_config.get("keep_ratio", -1) != -1: 107 | raise ValueError("Need only one of keep_n and keep_ratio (not both)") 108 | 109 | keep_n = len(recs) 110 | if "keep_n" in downsample_config: 111 | keep_n = downsample_config["keep_n"] 112 | elif "keep_ratio" in downsample_config: 113 | keep_n = max(1, int(downsample_config["keep_ratio"] * len(recs))) 114 | 115 | assert isinstance(keep_n, int) and keep_n > 0 116 | 117 | if keep_n > len(recs): 118 | raise ValueError("Can't sample more data points than the dataset has") 119 | 120 | rng = np.random.default_rng(downsample_config.get("seed", None)) 121 | return [recs[idx] for idx in rng.choice(len(recs), size=keep_n, replace=False)] 122 | 123 | 124 | def batch_iter(iterable, batch_size): 125 | batch = [] 126 | for rec in iterable: 127 | if len(batch) >= batch_size: 128 | yield batch 129 | batch = [] 130 | batch.append(rec) 131 | 132 | if len(batch) != 0: 133 | yield batch 134 | 135 | 136 | def index_by( 137 | lst: Union[Iterable, Iterator], 138 | key: Union[str, Callable], 139 | one_to_one=False, 140 | ) -> Dict: 141 | key_fn = key 142 | if isinstance(key_fn, str): 143 | key_fn = lambda x: x[key] 144 | 145 | index = dict() 146 | if one_to_one: 147 | for rec in lst: 148 | k = key_fn(rec) 149 | if k in index: 150 | raise ValueError("Duplicate key: {}".format(k)) 151 | index[k] = rec 152 | else: 153 | for rec in lst: 154 | k = key_fn(rec) 155 | if k not in index: 156 | index[k] = [] 157 | index[k].append(rec) 158 | return index 159 | 160 | 161 | def deduplicate_by( 162 | lst: Union[Iterable, Iterator], 163 | key: Union[str, Callable], 164 | ) -> List: 165 | key_fn = key 166 | if isinstance(key_fn, str): 167 | key_fn = lambda x: x[key] 168 | 169 | new_lst = [] 170 | used_keys = set() 171 | for rec in lst: 172 | k = key_fn(rec) 173 | if k not in used_keys: 174 | used_keys.add(k) 175 | new_lst.append(rec) 176 | return new_lst 177 | 178 | 179 | def counter_jaccard(counter1: Dict, counter2: Dict) -> float: 180 | """Computes the jaccard overlap of two dict objects.""" 181 | if len(counter1) == 0 and len(counter2) == 0: 182 | return float("nan") 183 | if len(counter1) == 0 or len(counter2) == 0: 184 | return 0.0 185 | 186 | intersection = sum((counter1 & counter2).values()) 187 | if intersection == 0: 188 | return 0.0 189 | return intersection / (sum(counter1.values()) + sum(counter2.values()) - intersection) 190 | -------------------------------------------------------------------------------- /aries/util/edit.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import difflib 3 | import itertools 4 | from typing import Iterable, List, Tuple, Union 5 | 6 | import numpy as np 7 | import tqdm 8 | from cffi import FFI 9 | 10 | from .color import colorify, colorprint 11 | 12 | 13 | def init_levenshtein_c(): 14 | ffibuilder = FFI() 15 | ffibuilder.set_source( 16 | "_levenshtein", 17 | r""" 18 | int levenshtein(int *seq1, int seq1_len, int *seq2, int seq2_len, int *v0) 19 | { 20 | // Adapted from https://en.wikipedia.org/wiki/Levenshtein_distance (CC-BY-SA) 21 | 22 | // v0 is just a buffer for temporary calculations; easier to 23 | // ask the caller to allocate it than to deal with C mem 24 | // management 25 | 26 | int substitutionCost, insertionCost, deletionCost; 27 | int tmpval; 28 | 29 | for (int i = 0; i < seq2_len+1; i++) { 30 | v0[i] = i; 31 | } 32 | 33 | for (int i = 0; i < seq1_len; i++) { 34 | // calculate v1 (current row distances) from the previous row v0 35 | 36 | // first element of v1 is A[i+1][0] 37 | // edit distance is delete (i+1) chars from s to match empty t 38 | tmpval = i + 1; 39 | 40 | // use formula to fill in the rest of the row 41 | for(int j = 0; j < seq2_len; j++) { 42 | // calculating costs for A[i+1][j+1] 43 | deletionCost = v0[j + 1] + 1; 44 | insertionCost = tmpval + 1; 45 | substitutionCost = v0[j]; 46 | if (seq1[i] != seq2[j]) { 47 | substitutionCost++; 48 | } 49 | 50 | v0[j] = tmpval; 51 | 52 | tmpval = deletionCost; 53 | if (insertionCost < tmpval) { 54 | tmpval = insertionCost; 55 | } 56 | if (substitutionCost < tmpval) { 57 | tmpval = substitutionCost; 58 | } 59 | } 60 | v0[seq2_len] = tmpval; 61 | } 62 | // after the last swap, the results of v1 are now in v0 63 | return v0[seq2_len]; 64 | } 65 | """, 66 | ) 67 | 68 | ffibuilder.cdef("int levenshtein(int*, int, int*, int, int*);") 69 | 70 | # Compile the C module and import it 71 | ffibuilder.compile(verbose=True) 72 | from _levenshtein import ffi, lib 73 | 74 | return ffi, lib 75 | 76 | 77 | levenshtein_ffi, levenshtein_lib = None, None 78 | 79 | 80 | def levenshtein_distance(seq1, seq2): 81 | # We call a C function for levenshtein via CFFI because it is about 1000x 82 | # faster than the python version (the difference between running in an hour 83 | # vs running in a month) 84 | 85 | global levenshtein_ffi, levenshtein_lib 86 | 87 | if levenshtein_ffi is None: 88 | levenshtein_ffi, levenshtein_lib = init_levenshtein_c() 89 | 90 | if isinstance(seq1, str): 91 | seq1 = [ord(c) for c in seq1] 92 | 93 | if isinstance(seq2, str): 94 | seq2 = [ord(c) for c in seq2] 95 | 96 | if len(seq1) > len(seq2): 97 | seq1, seq2 = seq2, seq1 98 | 99 | # Important: these arrs need to be in their own variables, NOT inlined with 100 | # the levenshtein_ffi.from_buffer, or else the GC will free the memory and 101 | # memory will get corrupted (often manifests as seq2 overwriting seq1, but 102 | # also can segfault) 103 | seq1_arr = np.array(seq1, dtype=np.int32) 104 | seq2_arr = np.array(seq2, dtype=np.int32) 105 | v0_arr = np.zeros(len(seq2) + 1, dtype=np.int32) 106 | 107 | seq1_buf = levenshtein_ffi.cast("int*", levenshtein_ffi.from_buffer(seq1_arr)) 108 | seq2_buf = levenshtein_ffi.cast("int*", levenshtein_ffi.from_buffer(seq2_arr)) 109 | v0 = levenshtein_ffi.cast("int*", levenshtein_ffi.from_buffer(v0_arr)) 110 | 111 | result = levenshtein_lib.levenshtein(seq1_buf, len(seq1), seq2_buf, len(seq2), v0) 112 | return result 113 | 114 | 115 | def basic_token_align(seq1, seq2, seq2_ignored_ids: Iterable = None): 116 | """Aligns the tokens of seq1 and seq2 assuming that seq2 contains all the 117 | characters of seq1, but possibly with some extra tokens (e.g., special 118 | whitespace markers from a huggingface transformers tokenizer) and possibly 119 | partitioned differently. 120 | 121 | In cases where the boundaries are mismatched, this maps to the token with 122 | largest overlap, and breaks ties in favor of earlier tokens. 123 | 124 | if seq2_ignored_ids is given, the specified token indexes in seq2 are 125 | ignored and will not be aligned to anything in seq1. 126 | 127 | Returns a tuple (dist, alignment) where dist is the total of mismatches 128 | (number of characters that seq2 token boundaries had to be moved to 129 | complete alignment) and `alignment` is a list of the same length as seq2 130 | containing the indexes of the aligned tokens from seq1 (or None if the 131 | token did not overlap seq1 at all).""" 132 | 133 | if seq2_ignored_ids is None: 134 | seq2_ignored_ids = set() 135 | 136 | # if seq1[0] == 'numerous': 137 | # breakpoint() 138 | 139 | seq1idxs = list(itertools.chain(*[[(idx, c) for c in tok] for idx, tok in enumerate(seq1)])) 140 | seq2idxs = list(itertools.chain(*[[(idx, c) for c in tok] for idx, tok in enumerate(seq2)])) 141 | 142 | seq2_seq1_char_align = [None] * len(seq2idxs) 143 | idx1 = 0 144 | last_valid = None 145 | for chridx2, (idx2, c2) in enumerate(seq2idxs): 146 | if idx1 >= len(seq1idxs): 147 | break 148 | if c2 == seq1idxs[idx1][1] and idx2 not in seq2_ignored_ids: 149 | seq2_seq1_char_align[chridx2] = idx1 150 | last_valid = idx1 151 | idx1 += 1 152 | 153 | # Ensure that all chars of seq1 were mapped to a char in seq2 154 | # if ''.join(seq1) != ''.join(seq2): 155 | if last_valid != (len(seq1idxs) - 1): 156 | raise ValueError("Cannot align: Sequences didn't contain the same characters") 157 | 158 | # Align the sequences 159 | alignment_counts = {idx: collections.Counter() for idx in range(len(seq2))} 160 | # for idx1, idx2 in zip(seq1idxs, seq2idxs): 161 | for chridx1, (idx2, c2) in zip(seq2_seq1_char_align, seq2idxs): 162 | idx1 = seq1idxs[chridx1][0] if chridx1 is not None else None 163 | alignment_counts[idx2][idx1] += 1 164 | 165 | alignments = [] 166 | n_mismatch_total = 0 167 | for idx2 in range(len(seq2)): 168 | best_idxs = sorted( 169 | alignment_counts[idx2].keys(), reverse=True, key=lambda x: (alignment_counts[idx2][x], -x if x is not None else float("-inf")) 170 | ) 171 | best_idx1 = best_idxs[0] 172 | if best_idx1 is None and len(best_idxs) > 1: 173 | best_idx1 = best_idxs[1] 174 | n_mismatch_total += sum(alignment_counts[idx2].values()) - alignment_counts[idx2][best_idx1] 175 | alignments.append(best_idx1) 176 | 177 | return (n_mismatch_total, alignments) 178 | 179 | 180 | def print_word_diff(text1, text2, color_format="ansi", **print_kwargs): 181 | print(make_word_diff(text1, text2, color_format=color_format), **print_kwargs) 182 | 183 | 184 | def make_word_diff(text1, text2, color_format="ansi"): 185 | if not isinstance(text1, list): 186 | text1 = text1.split(" ") if len(text1) != 0 else [] 187 | 188 | if not isinstance(text2, list): 189 | text2 = text2.split(" ") if len(text2) != 0 else [] 190 | 191 | prevtok = " " 192 | parity = 0 193 | 194 | def color_for_tok(tok): 195 | if color_format == "none": 196 | return None 197 | 198 | if tok == "+": 199 | return "green" 200 | elif tok == "-": 201 | return "red" 202 | elif tok == "?": 203 | return "blue" 204 | return None 205 | 206 | s = "" 207 | for idx, x in enumerate(difflib.ndiff(text1, text2)): 208 | if prevtok != x[0] and prevtok in ("+", "-"): 209 | s += colorify(prevtok + "]", color=color_for_tok(prevtok), form=color_format) 210 | if prevtok != x[0] and x[0] in ("+", "-"): 211 | if parity == 0 and idx > 0: 212 | s += " " 213 | s += colorify("[" + x[0], color=color_for_tok(x[0]), form=color_format) 214 | 215 | if x[0] == " ": 216 | if idx != 0: 217 | s += " " 218 | s += x[2:] 219 | parity = 0 220 | elif x[0] == "?": 221 | pass 222 | else: 223 | # s = '['+x[0]+x[1:]+x[0]+']' 224 | if prevtok != x[0]: 225 | parity = parity ^ 1 226 | else: 227 | s += " " 228 | s += colorify(x[2:], color=color_for_tok(x[0]), form=color_format) 229 | prevtok = x[0] 230 | 231 | if prevtok in ("+", "-"): 232 | s += colorify(prevtok + "]", color=color_for_tok(prevtok), form=color_format) 233 | 234 | return s 235 | 236 | 237 | def build_offsets( 238 | toks: Union[str, List[str]], 239 | chunk_length: int, 240 | ) -> dict: 241 | offsets = dict() 242 | for idx in range(len(toks) - chunk_length + 1): 243 | chunk = tuple(toks[idx : idx + chunk_length]) 244 | if chunk not in offsets: 245 | offsets[chunk] = [] 246 | offsets[chunk].append(idx) 247 | return offsets 248 | 249 | 250 | def update_overlaps( 251 | cur_overlaps: List[Tuple[int, int]], 252 | toks1: Union[str, List[str]], 253 | toks2: Union[str, List[str]], 254 | idx2: int, 255 | min_length: int, 256 | ) -> Tuple[List[Tuple[int, int]], List[Tuple[Tuple[int, int], Tuple[int, int]]]]: 257 | overlaps = [] 258 | new_overlaps = [] 259 | for overlap in cur_overlaps: 260 | overlap_length = idx2 - overlap[1] 261 | end1 = overlap[0] + overlap_length 262 | if end1 < len(toks1) and idx2 < len(toks2) and toks1[end1] == toks2[idx2]: 263 | new_overlaps.append(overlap) 264 | elif overlap_length >= min_length: 265 | overlaps.append(((overlap[0], overlap[0] + overlap_length), (overlap[1], overlap[1] + overlap_length))) 266 | return new_overlaps, overlaps 267 | 268 | 269 | def find_overlapping_substrings( 270 | toks1: Union[str, List[str]], 271 | toks2: Union[str, List[str]], 272 | min_length: int = 32, 273 | ): 274 | """ 275 | Finds overlapping substrings of toks1 and toks2, where toks1 and toks2 are 276 | lists of tokens. 277 | 278 | min_length is the minimum number of tokens that a match must span in order 279 | to be returned 280 | 281 | Returns a list of pairs of spans, e.g. [((10, 20), (14, 24))]. Each span 282 | pair is a (start_idx, end_idx) tuple representing a half-open interval. 283 | 284 | Any long match technically contains many shorter matches. This function 285 | returns only the longest match for each set; for each returned pair of 286 | spans (span1, span2), there will be no other returned pair (span3, span4) 287 | such that span3 contains span1 AND span4 contains span2. 288 | """ 289 | if len(toks1) == 0 or len(toks2) == 0: 290 | return [] 291 | 292 | # Use chunks to reduce number of hits per token, but don't go too high 293 | # since mem usage is len(toks1)*chunk_length. If character tokenization and 294 | # long chunk_length (e.g., 1000), then we would use 1000x the memory needed 295 | # to store toks1. 296 | chunk_length = min(min_length, 10) 297 | offsets1 = build_offsets(toks1, chunk_length) 298 | overlaps = [] 299 | cur_overlaps = [] 300 | 301 | for idx2, tk2 in enumerate(toks2): 302 | cur_overlaps, new_overlaps = update_overlaps(cur_overlaps, toks1, toks2, idx2, min_length) 303 | overlaps.extend(new_overlaps) 304 | 305 | if idx2 <= (len(toks2) - min_length): 306 | chunk = tuple(toks2[idx2 : idx2 + chunk_length]) 307 | for idx1 in offsets1.get(chunk, []): 308 | has_overlap = False 309 | for overlap in cur_overlaps: 310 | overlap_length = idx2 - overlap[1] 311 | if idx1 - overlap_length == overlap[0]: 312 | has_overlap = True 313 | break 314 | if not has_overlap: 315 | cur_overlaps.append((idx1, idx2)) 316 | 317 | idx2 = len(toks2) 318 | _, new_overlaps = update_overlaps(cur_overlaps, toks1, toks2, idx2, min_length) 319 | overlaps.extend(new_overlaps) 320 | 321 | final_overlaps = [] 322 | for o1 in overlaps: 323 | is_subset = False 324 | for o2 in overlaps: 325 | if o1 != o2 and o1[0][0] >= o2[0][0] and o1[0][1] <= o2[0][1] and o1[1][0] >= o2[1][0] and o1[1][1] <= o2[1][1]: 326 | is_subset = True 327 | break 328 | if not is_subset: 329 | final_overlaps.append(o1) 330 | return final_overlaps 331 | -------------------------------------------------------------------------------- /aries/util/gensim.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | 4 | import gensim 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def stem_tokens(tokens): 10 | return list(map(gensim.parsing.preprocessing.stem, tokens)) 11 | 12 | 13 | class InMemoryTextCorpus(gensim.corpora.textcorpus.TextCorpus): 14 | def __init__(self, texts, dictionary=None, **kwargs): 15 | self.texts = texts 16 | if "token_filters" not in kwargs: 17 | kwargs["token_filters"] = [stem_tokens] 18 | if "character_filters" not in kwargs: 19 | kwargs["character_filters"] = [ 20 | gensim.parsing.preprocessing.lower_to_unicode, 21 | gensim.utils.deaccent, 22 | gensim.parsing.preprocessing.strip_multiple_whitespaces, 23 | gensim.parsing.preprocessing.strip_punctuation, 24 | ] 25 | super().__init__(dictionary=dictionary, **kwargs) 26 | # self.token_filters = [gensim.parsing.preprocessing.remove_short_tokens, gensim.parsing.preprocessing.remove_stopword_tokens] 27 | 28 | def __getitem__(self, item): 29 | return self.dictionary.doc2bow(self.preprocess_text(self.texts[item])) 30 | 31 | def init_dictionary(self, dictionary): 32 | self.dictionary = dictionary if dictionary is not None else gensim.corpora.Dictionary() 33 | if dictionary is None: 34 | logger.debug("Initializing dictionary") 35 | metadata_setting = self.metadata 36 | self.metadata = False 37 | self.dictionary.add_documents(self.get_texts()) 38 | self.metadata = metadata_setting 39 | else: 40 | logger.debug("Dictionary already initialized") 41 | 42 | def get_texts(self): 43 | return list(map(self.preprocess_text, self.texts)) 44 | 45 | def __len__(self): 46 | return len(self.texts) 47 | -------------------------------------------------------------------------------- /aries/util/gpt3.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | import os 5 | import sqlite3 6 | import time 7 | 8 | import openai 9 | import tiktoken 10 | import tqdm 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class Gpt3CacheClient: 16 | def __init__(self, cache_db_path): 17 | self.cache_db = self._init_cache_db(cache_db_path) 18 | 19 | if openai.api_key is None: 20 | if "OPENAI_API_KEY" not in os.environ: 21 | logger.error("Need OpenAI key in OPENAI_API_KEY") 22 | openai.api_key = os.environ["OPENAI_API_KEY"] 23 | 24 | self.tokenizer = None 25 | self.tokenizers_by_model = dict() 26 | 27 | def estimate_num_tokens(self, text, model="text-davinci-003"): 28 | return len(self._get_tokenizer(model).encode(text)) 29 | 30 | def _get_tokenizer(self, model): 31 | if model not in self.tokenizers_by_model: 32 | self.tokenizers_by_model[model] = tiktoken.encoding_for_model(model) 33 | return self.tokenizers_by_model[model] 34 | 35 | def __enter__(self): 36 | self.cache_db.__enter__() 37 | return self 38 | 39 | def __exit__(self, *args, **kwargs): 40 | self.cache_db.__exit__(*args, **kwargs) 41 | 42 | def close(self): 43 | self.cache_db.close() 44 | 45 | def _init_cache_db(self, cache_db_path): 46 | db = sqlite3.connect(cache_db_path) 47 | try: 48 | cur = db.cursor() 49 | cur.execute( 50 | """create table if not exists gpt3_cache ( 51 | model text not null, 52 | prompt text not null, 53 | temperature real not null, 54 | top_p real not null, 55 | max_tokens integer not null, 56 | total_tokens integer not null, 57 | frequency_penalty real not null, 58 | presence_penalty real not null, 59 | logprobs integer not null, 60 | response_json text not null, 61 | response_timestamp real 62 | )""" 63 | ) 64 | cur.execute("create index if not exists prompt_index on gpt3_cache (prompt)") 65 | cur.execute( 66 | """create table if not exists chat_gpt3_cache ( 67 | model text not null, 68 | messages_json text not null, 69 | temperature real not null, 70 | top_p real not null, 71 | max_tokens integer not null, 72 | total_tokens integer not null, 73 | frequency_penalty real not null, 74 | presence_penalty real not null, 75 | response_json text not null, 76 | response_timestamp real 77 | )""" 78 | ) 79 | cur.execute("create index if not exists messages_json_index on chat_gpt3_cache (messages_json)") 80 | db.commit() 81 | return db 82 | except Exception as e: 83 | db.close() 84 | raise e 85 | 86 | def get_gpt3_result(self, *args, **kwargs): 87 | """Deprecated. Use prompt_completion() instead.""" 88 | return self.prompt_completion(*args, **kwargs) 89 | 90 | def prompt_completion( 91 | self, 92 | model, 93 | prompt, 94 | temperature, 95 | max_tokens, 96 | top_p, 97 | frequency_penalty, 98 | presence_penalty, 99 | prompt_token_count=-1, 100 | logprobs=0, 101 | ): 102 | """Works like openai.Completion.create, but adds a caching layer.""" 103 | if prompt_token_count < 0: 104 | prompt_token_count = self.estimate_num_tokens(prompt, model) 105 | 106 | db_keyvals = { 107 | "model": model, 108 | "prompt": prompt, 109 | "temperature": temperature, 110 | "max_tokens": max_tokens, 111 | "top_p": top_p, 112 | "frequency_penalty": frequency_penalty, 113 | "presence_penalty": presence_penalty, 114 | "logprobs": logprobs, 115 | } 116 | cur = self.cache_db.cursor() 117 | 118 | cache_json = None 119 | from_cache = False 120 | # Cache only makes sense if temperature==0 (deterministic result) 121 | if temperature == 0.0: 122 | select_keyvals = db_keyvals.copy() 123 | select_keyvals["prompt_token_count"] = prompt_token_count 124 | dbrecs = cur.execute( 125 | """select response_json from gpt3_cache 126 | where 127 | model = :model and 128 | prompt = :prompt and 129 | temperature = :temperature and 130 | ((:prompt_token_count+max_tokens) > total_tokens or max_tokens = :max_tokens) and 131 | total_tokens <= (:prompt_token_count+:max_tokens) and 132 | top_p = :top_p and 133 | frequency_penalty = :frequency_penalty and 134 | presence_penalty = :presence_penalty and 135 | logprobs >= :logprobs""", 136 | select_keyvals, 137 | ).fetchall() 138 | if len(dbrecs) == 1: 139 | cache_json = dbrecs[0][0] 140 | elif len(dbrecs) >= 2: 141 | logger.warning("Got {} recs for gpt3 query when only one was expected.".format(len(dbrecs))) 142 | cache_json = dbrecs[0][0] 143 | if cache_json is None: 144 | logger.debug("UNCACHED prompt completion") 145 | resp = openai.Completion.create(**db_keyvals) 146 | insert_keyvals = db_keyvals.copy() 147 | cache_json = json.dumps(resp) 148 | insert_keyvals["response_json"] = cache_json 149 | insert_keyvals["response_timestamp"] = datetime.datetime.timestamp(datetime.datetime.utcnow()) 150 | insert_keyvals["total_tokens"] = resp["usage"]["total_tokens"] 151 | cur.execute( 152 | """INSERT INTO gpt3_cache ( model, prompt, temperature, top_p, max_tokens, frequency_penalty, presence_penalty, logprobs, response_json, response_timestamp, total_tokens) 153 | VALUES (:model, :prompt, :temperature, :top_p, :max_tokens, :frequency_penalty, :presence_penalty, :logprobs, :response_json, :response_timestamp, :total_tokens)""", 154 | insert_keyvals, 155 | ) 156 | self.cache_db.commit() 157 | else: 158 | from_cache = True 159 | 160 | resp = json.loads(cache_json) 161 | if from_cache: 162 | resp["usage"]["uncached_total_tokens"] = 0 163 | else: 164 | resp["usage"]["uncached_total_tokens"] = resp["usage"]["total_tokens"] 165 | return resp 166 | 167 | def chat_completion( 168 | self, 169 | model, 170 | messages, 171 | temperature, 172 | max_tokens, 173 | top_p, 174 | frequency_penalty, 175 | presence_penalty, 176 | messages_token_count=-1, 177 | max_retries=3, 178 | ): 179 | """Works like openai.ChatCompletion.create, but adds a caching layer.""" 180 | 181 | # Sort keys when serializing to maximize cache hits 182 | messages_json = json.dumps(messages, sort_keys=True) 183 | 184 | if messages_token_count < 0: 185 | messages_token_count = sum(self.estimate_num_tokens(x["content"], model) for x in messages) 186 | 187 | db_keyvals = { 188 | "model": model, 189 | "messages_json": messages_json, 190 | "temperature": temperature, 191 | "max_tokens": max_tokens, 192 | "top_p": top_p, 193 | "frequency_penalty": frequency_penalty, 194 | "presence_penalty": presence_penalty, 195 | } 196 | cur = self.cache_db.cursor() 197 | 198 | cache_json = None 199 | from_cache = False 200 | # Cache only makes sense if temperature==0 (deterministic result) 201 | if temperature == 0.0: 202 | select_keyvals = db_keyvals.copy() 203 | select_keyvals["messages_token_count"] = messages_token_count 204 | dbrecs = cur.execute( 205 | """select response_json from chat_gpt3_cache 206 | where 207 | model = :model and 208 | messages_json = :messages_json and 209 | temperature = :temperature and 210 | ((:messages_token_count+max_tokens) > total_tokens or max_tokens = :max_tokens) and 211 | total_tokens <= (:messages_token_count+:max_tokens) and 212 | top_p = :top_p and 213 | frequency_penalty = :frequency_penalty and 214 | presence_penalty = :presence_penalty 215 | """, 216 | select_keyvals, 217 | ).fetchall() 218 | if len(dbrecs) == 1: 219 | cache_json = dbrecs[0][0] 220 | elif len(dbrecs) >= 2: 221 | logger.warning("Got {} recs for gpt3 query when only one was expected.".format(len(dbrecs))) 222 | cache_json = dbrecs[0][0] 223 | if cache_json is None: 224 | logger.debug("UNCACHED chat completion") 225 | 226 | model_keyvals = db_keyvals.copy() 227 | del model_keyvals["messages_json"] 228 | model_keyvals["messages"] = messages 229 | 230 | resp = None 231 | while resp is None and max_retries >= 0: 232 | try: 233 | resp = openai.ChatCompletion.create(**model_keyvals) 234 | except openai.error.RateLimitError: 235 | logger.warning("Rate limit error on openai request, waiting 60 seconds and trying again") 236 | time.sleep(60) 237 | max_retries -= 1 238 | 239 | insert_keyvals = db_keyvals.copy() 240 | cache_json = json.dumps(resp) 241 | insert_keyvals["response_json"] = cache_json 242 | insert_keyvals["response_timestamp"] = datetime.datetime.timestamp(datetime.datetime.utcnow()) 243 | insert_keyvals["total_tokens"] = resp["usage"]["total_tokens"] 244 | cur.execute( 245 | """INSERT INTO chat_gpt3_cache ( model, messages_json, temperature, top_p, max_tokens, frequency_penalty, presence_penalty, response_json, response_timestamp, total_tokens) 246 | VALUES (:model, :messages_json, :temperature, :top_p, :max_tokens, :frequency_penalty, :presence_penalty, :response_json, :response_timestamp, :total_tokens)""", 247 | insert_keyvals, 248 | ) 249 | self.cache_db.commit() 250 | else: 251 | from_cache = True 252 | 253 | resp = json.loads(cache_json) 254 | if from_cache: 255 | resp["usage"]["uncached_total_tokens"] = 0 256 | else: 257 | resp["usage"]["uncached_total_tokens"] = resp["usage"]["total_tokens"] 258 | return resp 259 | -------------------------------------------------------------------------------- /aries/util/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pprint 4 | from typing import Callable, List, Union 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def init_logging(logfile=None, level=logging.INFO): 10 | handlers = [ 11 | logging.StreamHandler(), 12 | ] 13 | if logfile: 14 | handlers.append(logging.FileHandler(logfile)) 15 | 16 | logging.basicConfig( 17 | format="%(asctime)s [%(levelname)s] (%(name)s): %(message)s", 18 | datefmt="%Y-%m-%d %H:%M:%S", 19 | level=level, 20 | handlers=handlers, 21 | ) 22 | 23 | 24 | def pprint_metrics(metrics, print_fn: Union[Callable[[str], None], logging.Logger] = print, val_format="{:0.4f}", int_format="{:d}", name="eval"): 25 | if isinstance(print_fn, logging.Logger): 26 | print_fn = print_fn.info 27 | 28 | if name != "": 29 | name += " " 30 | 31 | for k, v in metrics.items(): 32 | vstr = str(v) 33 | if isinstance(v, float) or isinstance(v, int): 34 | vstr = val_format.format(v) 35 | if isinstance(v, int) and int_format is not None: 36 | vstr = int_format.format(v) 37 | 38 | print_fn("{name}{metric_name}: {val}".format(name=name, metric_name=k, val=vstr)) 39 | 40 | 41 | class PrettyFloatPrinter(pprint.PrettyPrinter): 42 | def __init__(self, *args, **kwargs): 43 | if "sort_dicts" not in kwargs: 44 | kwargs["sort_dicts"] = False 45 | super().__init__(*args, **kwargs) 46 | 47 | def format(self, obj, ctx, maxlvl, lvl): 48 | if isinstance(obj, float): 49 | return "{:.4f}".format(obj), True, False 50 | # elif isinstance(obj, dict): 51 | # print('gd', obj) 52 | # v = '{' + ',\n'.join(["'{}': {}".format(k, self.format(v, ctx, maxlvl, lvl+1)[0]) for k, v in obj.items()]) + '}', True, False 53 | # print(v[0]) 54 | # return v 55 | return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl + 1) 56 | 57 | 58 | def table2str(grid, format_fn=str, col_names=None, row_names=None, colsep=" | ", rowend="", header_row_sep="-"): 59 | if col_names is None: 60 | col_names = ["" for _ in range(len(grid[0]))] 61 | col_names = list(map(str, col_names)) 62 | if row_names is None: 63 | row_names = ["" for _ in range(len(grid))] 64 | row_names = list(map(str, row_names)) 65 | 66 | new_grid = [[""] + col_names] 67 | for rowidx, row in enumerate(grid): 68 | new_grid.append([row_names[rowidx]] + [format_fn(cell) for cell in row]) 69 | return raw_table2str(new_grid, colsep=colsep, rowend=rowend, header_row_sep=header_row_sep) 70 | 71 | 72 | def raw_table2str(grid, colsep=" | ", rowend="", header_row_sep="-"): 73 | s = "" 74 | 75 | col_widths = [max(len(grid[y][x]) for y in range(len(grid))) for x in range(len(grid[0]))] 76 | for y, row in enumerate(grid): 77 | if all(cell == "" for cell in row[1:]): 78 | continue 79 | # s += ' ' 80 | s += colsep.join(["{text:>{width}s}".format(width=col_widths[x], text=cell) if col_widths[x] != 0 else "" for x, cell in enumerate(row)]) 81 | s += "{}\n".format(rowend) 82 | if y == 0: 83 | if len(header_row_sep) == 1: 84 | s += header_row_sep * (sum(col_widths) + len(colsep) * (len(col_widths) - 1) + 1) + "\n" 85 | elif len(header_row_sep) == 0: 86 | continue 87 | else: 88 | s += header_row_sep + ("\n" if not header_row_sep.endswith("\n") else "") 89 | return s 90 | -------------------------------------------------------------------------------- /aries/util/s2orc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sqlite3 5 | import sys 6 | 7 | import tqdm 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def fuse_back_matter(s2json): 13 | """Fuse back matter into body text (mutating the input object) and return 14 | the mutated s2orc object. Often the parser puts whatever is on the last 15 | pdf page into back matter even if it is clearly part of the appendix, so 16 | this function tries to fix that.""" 17 | 18 | s2json["pdf_parse"]["body_text"] = s2json["pdf_parse"]["body_text"] + s2json["pdf_parse"]["back_matter"] 19 | return s2json 20 | 21 | 22 | def load_s2orc(pdf_id, fetcher): 23 | s = fetcher.get(pdf_id) 24 | if s is None: 25 | return None 26 | return fuse_back_matter(s) 27 | 28 | 29 | def iter_s2orc_pairs(base_path, paper_records, error_on_missing=True): 30 | with S2orcFetcherFilesystem(base_path) as fetcher: 31 | for record in tqdm.tqdm(paper_records, desc="loading papers"): 32 | doc_id = record["doc_id"] 33 | 34 | if not all(fetcher.has(pdf_id) for pdf_id in [record["source_pdf_id"], record["target_pdf_id"]]): 35 | if error_on_missing: 36 | raise RuntimeError("missing pdf ids for doc {} ({}, {})".format(doc_id, record["source_pdf_id"], record["target_pdf_id"])) 37 | else: 38 | logger.warning("missing pdf ids for doc {} ({}, {})".format(doc_id, record["source_pdf_id"], record["target_pdf_id"])) 39 | continue 40 | 41 | s2orc1 = load_s2orc(record["source_pdf_id"], fetcher) 42 | 43 | s2orc2 = load_s2orc(record["target_pdf_id"], fetcher) 44 | 45 | yield doc_id, s2orc1, s2orc2 46 | 47 | 48 | def iter_s2orc_docs(config, pdf_ids): 49 | with S2orcFetcherSqlite( 50 | config.get("s2orc_db_path", ":memory:"), 51 | fallback_fetcher=S2orcFetcherFilesystem(config["s2orc_base_path"]) if config.get("s2orc_base_path", None) else None, 52 | update_db=False, 53 | ) as fetcher: 54 | for pdf_id in tqdm.tqdm(pdf_ids, desc="loading papers"): 55 | if not fetcher.has(pdf_id): 56 | logger.warning("missing pdf ids for doc {}".format(pdf_id)) 57 | continue 58 | 59 | s2orc2 = load_s2orc(pdf_id, fetcher) 60 | 61 | yield pdf_id, s2orc2 62 | 63 | 64 | class S2orcFetcher: 65 | def get(self, pdf_id): 66 | raise NotImplementedError() 67 | 68 | def has(self, pdf_id): 69 | raise NotImplementedError() 70 | 71 | 72 | class S2orcFetcherDummy(S2orcFetcher): 73 | def get(self, pdf_id): 74 | return None 75 | 76 | def has(self, pdf_id): 77 | return False 78 | 79 | 80 | class S2orcFetcherSqlite(S2orcFetcher): 81 | def __init__(self, s2orc_db_path, fallback_fetcher=None, update_db=False): 82 | self.s2orc_db_path = s2orc_db_path 83 | self.fallback_fetcher = fallback_fetcher or S2orcFetcherDummy() 84 | self.update_db = update_db 85 | self.db = None 86 | self.cur = None 87 | 88 | def __enter__(self): 89 | self.db = sqlite3.connect(self.s2orc_db_path) 90 | self.cur = self.db.cursor() 91 | self.cur.execute("BEGIN") 92 | # We create the table/index regardless of update_db, since otherwise we hit errors later 93 | self.cur.execute("CREATE TABLE IF NOT EXISTS pdf_records (pdf_id TEXT PRIMARY KEY NOT NULL, title TEXT, json TEXT)") 94 | self.cur.execute("CREATE INDEX IF NOT EXISTS pdf_records_by_id ON pdf_records (pdf_id)") 95 | return self 96 | 97 | def __exit__(self, exc_type, exc_val, exc_tb): 98 | self.db.commit() 99 | self.db.close() 100 | 101 | def get(self, pdf_id): 102 | rec = self.cur.execute("SELECT json FROM pdf_records WHERE pdf_id=?", (pdf_id,)).fetchone() 103 | if rec is not None: 104 | return json.loads(rec[0]) 105 | s2orc_json = self.fallback_fetcher.get(pdf_id) 106 | 107 | if self.update_db and s2orc_json is not None: 108 | self.cur.execute("INSERT INTO pdf_records (pdf_id, title, json) VALUES (?, ?, ?)", (pdf_id, s2orc_json["title"], json.dumps(s2orc_json))) 109 | return s2orc_json 110 | 111 | def has(self, pdf_id): 112 | rec = self.cur.execute("SELECT 1 FROM pdf_records WHERE pdf_id=?", (pdf_id,)).fetchone() 113 | if rec is not None: 114 | return True 115 | return self.fallback_fetcher.has(pdf_id) 116 | 117 | 118 | class S2orcFetcherFilesystem(S2orcFetcher): 119 | def __init__(self, s2orc_base_path): 120 | self.s2orc_base_path = s2orc_base_path 121 | 122 | def __enter__(self): 123 | return self 124 | 125 | def __exit__(self, exc_type, exc_val, exc_tb): 126 | return 127 | 128 | def get(self, pdf_id): 129 | if not self.s2orc_base_path: 130 | return None 131 | 132 | path = os.path.join(self.s2orc_base_path, "{}.json".format(pdf_id)) 133 | 134 | try: 135 | with open(path) as f: 136 | return json.load(f) 137 | except FileNotFoundError: 138 | return None 139 | 140 | def has(self, pdf_id): 141 | if not self.s2orc_base_path: 142 | return False 143 | path = os.path.join(self.s2orc_base_path, "{}.json".format(pdf_id)) 144 | return os.path.exists(path) 145 | -------------------------------------------------------------------------------- /aries/util/training.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import logging 4 | import os 5 | 6 | import torch 7 | import transformers 8 | 9 | from .logging import pprint_metrics 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class TrainLoggerCallback(transformers.TrainerCallback): 15 | def __init__(self, logger): 16 | self.logger = logger 17 | 18 | def on_log(self, args, state, control, logs=None, **kwargs): 19 | if not logs: 20 | return 21 | self.logger.info( 22 | "Logs at global step {} (epoch {}, {:0.2f} TFLOs): {}".format(state.global_step, state.epoch, state.total_flos / 1e12, json.dumps(logs)) 23 | ) 24 | 25 | 26 | class Seq2SeqEvalCallback(transformers.TrainerCallback): 27 | def __init__(self, config, model, eval_records, model_eval_fn, model_selection_metric_fn=None): 28 | self.config = config 29 | self.model = model 30 | self.eval_records = eval_records 31 | self.model_eval_fn = model_eval_fn 32 | self.eval_precached_dataset = self.model._make_dataset(self.eval_records) 33 | 34 | self.model_selection_metric_fn = model_selection_metric_fn 35 | if isinstance(model_selection_metric_fn, str): 36 | self.model_selection_metric_fn = lambda x: x[model_selection_metric_fn] 37 | 38 | self._best_metric_val = float("-inf") 39 | self._best_model_state = None 40 | 41 | @staticmethod 42 | def _clone_cpu_model_state_dict(model): 43 | return collections.OrderedDict((k, v.clone().cpu().detach()) for k, v in model.state_dict().items()) 44 | 45 | def on_evaluate(self, args, state, control, **kwargs): 46 | metrics, all_results, _ = self.model_eval_fn(self.model, self.eval_records, eval_precached_dataset=self.eval_precached_dataset) 47 | 48 | if self.config.get("write_examples_on_eval", False): 49 | with open(os.path.join(self.config["output_dir"], "{}_inferences.jsonl".format("tmp_mid_eval")), "w") as f: 50 | for res in all_results: 51 | f.write(json.dumps(res) + "\n") 52 | 53 | pprint_metrics(metrics, logger, name="dev (mid-train)") 54 | if self.model_selection_metric_fn is not None: 55 | metric_val = self.model_selection_metric_fn(metrics) 56 | if metric_val > self._best_metric_val: 57 | logger.info( 58 | "Got new best model at global step {} (epoch {}, {:0.2f} TFLOs)".format(state.global_step, state.epoch, state.total_flos / 1e12) 59 | ) 60 | state.best_metric = metric_val 61 | self._best_metric_val = metric_val 62 | self._best_model_state = Seq2SeqEvalCallback._clone_cpu_model_state_dict(self.model.model) 63 | 64 | 65 | class SequentialTrainer(transformers.Trainer): 66 | def _get_train_sampler(self): 67 | if self.train_dataset is None: 68 | return None 69 | return torch.utils.data.SequentialSampler(self.train_dataset) 70 | -------------------------------------------------------------------------------- /data/configs/bm25.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "diffs", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 10000, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "dev_max_negatives": 100, 17 | "dev_seed": 1, 18 | "model_pipeline": [ 19 | { 20 | "model_type": "bm25", 21 | "query_input_format": "comment_with_context", 22 | "edit_input_format": "tokens_union" 23 | } 24 | ], 25 | "output_dir": "data/experiments/bm25/", 26 | "seed": 1 27 | } 28 | 29 | -------------------------------------------------------------------------------- /data/configs/bm25_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "full_additions", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 10000, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "dev_max_negatives": 100, 17 | "dev_seed": 1, 18 | "model_pipeline": [ 19 | { 20 | "model_type": "bm25", 21 | "query_input_format": "comment_with_context", 22 | "edit_input_format": "tokens_union" 23 | } 24 | ], 25 | "output_dir": "data/experiments/bm25_ao/", 26 | "seed": 1 27 | } 28 | 29 | -------------------------------------------------------------------------------- /data/configs/bm25_high_recall.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "diffs", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 10000, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "dev_max_negatives": 100, 17 | "dev_seed": 1, 18 | "model_pipeline": [ 19 | { 20 | "model_type": "bm25", 21 | "tune_on_dev": true, 22 | "tuning_minimum_recall": 0.9, 23 | "query_input_format": "comment_with_context", 24 | "edit_input_format": "tokens_union" 25 | } 26 | ], 27 | "output_dir": "data/experiments/bm25_high_recall/", 28 | "seed": 1 29 | } 30 | 31 | -------------------------------------------------------------------------------- /data/configs/bm25_high_recall_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "full_additions", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 10000, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "dev_max_negatives": 100, 17 | "dev_seed": 1, 18 | "model_pipeline": [ 19 | { 20 | "model_type": "bm25", 21 | "tune_on_dev": true, 22 | "tuning_minimum_recall": 0.9, 23 | "query_input_format": "comment_with_context", 24 | "edit_input_format": "tokens_union" 25 | } 26 | ], 27 | "output_dir": "data/experiments/bm25_high_recall_ao/", 28 | "seed": 1 29 | } 30 | 31 | -------------------------------------------------------------------------------- /data/configs/deberta_biencoder.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "diffs", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "microsoft/deberta-v3-large", 22 | "model_type": "biencoder", 23 | "max_seq_length": 1024, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 2, 30 | "gradient_accumulation_steps": 8, 31 | "num_train_epochs": -1, 32 | "max_steps": 8192, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/deberta_biencoder/" 51 | } -------------------------------------------------------------------------------- /data/configs/deberta_biencoder_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "full_additions", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "microsoft/deberta-v3-large", 22 | "model_type": "biencoder", 23 | "max_seq_length": 1024, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 2, 30 | "gradient_accumulation_steps": 8, 31 | "num_train_epochs": -1, 32 | "max_steps": 8192, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/deberta_biencoder_ao/" 51 | } -------------------------------------------------------------------------------- /data/configs/deberta_cross_encoder.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "diffs", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "microsoft/deberta-v3-large", 22 | "model_type": "cross_encoder", 23 | "max_seq_length": 1024, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 2, 30 | "gradient_accumulation_steps": 8, 31 | "num_train_epochs": -1, 32 | "max_steps": 8192, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/deberta_cross_encoder/" 51 | } -------------------------------------------------------------------------------- /data/configs/deberta_cross_encoder_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "full_additions", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "microsoft/deberta-v3-large", 22 | "model_type": "cross_encoder", 23 | "max_seq_length": 1024, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 2, 30 | "gradient_accumulation_steps": 8, 31 | "num_train_epochs": -1, 32 | "max_steps": 8192, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/deberta_cross_encoder_ao/" 51 | } -------------------------------------------------------------------------------- /data/configs/edit_generation_paper.json: -------------------------------------------------------------------------------- 1 | { 2 | "cache_db_path": "data/gpt3_cache.sqlite", 3 | "s2orc_base_path": "data/aries/s2orc/", 4 | "seed": 42, 5 | "output_dir": "data/experiments/edit_generation/", 6 | "split_ids_file": "data/aries/split_ids.json", 7 | "split_name": "test", 8 | "review_comments_file": "data/aries/review_comments.jsonl", 9 | "prompt_template": "Consider the following excerpt of a scientific paper which is under review for a conference:\n\n--- START ---\nAbstract: {__abstract}\n\nBody: {__body_chunk}\n--- END ---\n\n---\nA reviewer made the following comment about the paper: {__comment_with_context}\n\nWrite a response to the reviewer and an edit (or edits) that could be added somewhere in the paper (or Appendix) to resolve the reviewer's comment. Above an edit, write the location in the paper where it should be added. The edit should not explicitly say that it is written in response to a reviewer comment; it just needs to improve the paper such that a future reviewer would be unlikely to make the same comment. If addressing the comment requires additional experiments or information that you do not have access to, you can use placeholders or fill in reasonable guesses for that information. An edit may be a new sentence, paragraph, or section, depending on the comment.\n\nFor ease of parsing, write \"Response:\" before the reviewer response, \"Location:\" before the edit location(s), and \"Edit:\" before the edit(s)." 10 | } 11 | 12 | -------------------------------------------------------------------------------- /data/configs/gpt_multiedit.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "diffs", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 1, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "model_pipeline": [ 17 | { 18 | "model_type": "bm25", 19 | "bm25_dictionary": "data/experiments/bm25_high_recall/dictionary.pk", 20 | "fixed_pred_threshold": 3.3, 21 | "fixed_rel_pred_threshold": 4.4, 22 | "query_input_format": "comment_with_context", 23 | "edit_input_format": "tokens_union" 24 | }, 25 | { 26 | "model_type": "gpt_full_paper", 27 | "cache_db_path": "data/gpt3_cache.sqlite", 28 | "gpt_model": "gpt-4-0314", 29 | "gpt_max_length": 8000, 30 | "gpt_system_prompt": "You are a helpful research assistant. You must determine whether a given review comment is relevant to a given paper revision.", 31 | "gpt_prompt_template": "Consider the following comments that a reviewer made about a scientific paper (each followed by a unique comment id):\n\n{{review_comments}}\n\nBelow is a partial diff of the original paper text and the paper text after the authors made revisions in response to various reviews. Changes are indicated with brackets \"[]\" with a \"+\" for additions and a \"-\" for deletions. Below each paragraph is a unique \"edit id\". Determine which edits were meant to address the given reviewer comments above.\n\n---BEGIN PAPER DIFF---\n{{paper_diff_chunk}}\n---END PAPER DIFF---\n\nWhich edit ids correspond to each of the reviewer's comments? The relationship is many-to-many; one comment could correspond to several edits, and several comment could correspond to the same edit. There could also be comments that the authors didn't address at all or edits that were not made in response to any particular comment.\n\nWrite the answer as JSON lines with the format {\"comment_id\": , \"edit_ids\": [], \"notes\": \"\"} where each record has a comment id and the list of edit ids that correspond to it. The \"notes\" field is optional and can contain any notes about edits you weren't sure about or reasons for including/omitting certain edits." 32 | } 33 | ], 34 | "output_dir": "data/experiments/gpt_multiedit/", 35 | "seed": 1 36 | } 37 | 38 | -------------------------------------------------------------------------------- /data/configs/gpt_multiedit_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "full_additions", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 1, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "model_pipeline": [ 17 | { 18 | "model_type": "bm25", 19 | "bm25_dictionary": "data/experiments/bm25_high_recall_ao/dictionary.pk", 20 | "fixed_pred_threshold": 3.1, 21 | "fixed_rel_pred_threshold": 4.1, 22 | "query_input_format": "comment_with_context", 23 | "edit_input_format": "tokens_union" 24 | }, 25 | { 26 | "model_type": "gpt_full_paper", 27 | "cache_db_path": "data/gpt3_cache.sqlite", 28 | "gpt_model": "gpt-4-0314", 29 | "gpt_max_length": 8000, 30 | "gpt_system_prompt": "You are a helpful research assistant. You must determine whether a given review comment is relevant to a given paper revision.", 31 | "gpt_prompt_template": "Consider the following comments that a reviewer made about a scientific paper (each followed by a unique comment id):\n\n{{review_comments}}\n\nBelow is a partial diff of the original paper text and the paper text after the authors made revisions in response to various reviews. Changes are indicated with brackets \"[]\" with a \"+\" for additions and a \"-\" for deletions. Below each paragraph is a unique \"edit id\". Determine which edits were meant to address the given reviewer comments above.\n\n---BEGIN PAPER DIFF---\n{{paper_diff_chunk}}\n---END PAPER DIFF---\n\nWhich edit ids correspond to each of the reviewer's comments? The relationship is many-to-many; one comment could correspond to several edits, and several comment could correspond to the same edit. There could also be comments that the authors didn't address at all or edits that were not made in response to any particular comment.\n\nWrite the answer as JSON lines with the format {\"comment_id\": , \"edit_ids\": [], \"notes\": \"\"} where each record has a comment id and the list of edit ids that correspond to it. The \"notes\" field is optional and can contain any notes about edits you weren't sure about or reasons for including/omitting certain edits." 32 | } 33 | ], 34 | "output_dir": "data/experiments/gpt_multiedit_ao/", 35 | "seed": 1 36 | } 37 | 38 | -------------------------------------------------------------------------------- /data/configs/gpt_pairwise_0shot_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "full_additions", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 1, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "model_pipeline": [ 17 | { 18 | "model_type": "bm25", 19 | "bm25_dictionary": "data/experiments/bm25_high_recall_ao/dictionary.pk", 20 | "fixed_pred_threshold": 3.1, 21 | "fixed_rel_pred_threshold": 4.1, 22 | "query_input_format": "comment_with_context", 23 | "edit_input_format": "tokens_union" 24 | }, 25 | { 26 | "model_type": "gpt", 27 | "cache_db_path": "data/gpt3_cache.sqlite", 28 | "gpt_model": "gpt-4-0314", 29 | "gpt_max_length": 512, 30 | "gpt_system_prompt": "You are a helpful research assistant. You must determine whether a given review comment is relevant to a given paper revision.", 31 | "gpt_prompt_template": "Consider the following review comment for a scientific paper: {{review_comment}}\n\nConsider the following paragraph, which was added to the paper after the review: {{target_paragraph}}\n Is the new paragraph likely to have been added for the purpose of addressing this review comment? Answer with \"yes\" or \"no\"." 32 | } 33 | ], 34 | "output_dir": "data/experiments/gpt_pairwise_0shot_ao/", 35 | "seed": 1 36 | } 37 | 38 | -------------------------------------------------------------------------------- /data/configs/gpt_pairwise_1shot_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "full_additions", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 1, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "model_pipeline": [ 17 | { 18 | "model_type": "bm25", 19 | "bm25_dictionary": "data/experiments/bm25_high_recall_ao/dictionary.pk", 20 | "fixed_pred_threshold": 3.1, 21 | "fixed_rel_pred_threshold": 4.1, 22 | "query_input_format": "comment_with_context", 23 | "edit_input_format": "tokens_union" 24 | }, 25 | { 26 | "model_type": "gpt", 27 | "cache_db_path": "data/gpt3_cache.sqlite", 28 | "gpt_model": "gpt-4-0314", 29 | "gpt_max_length": 512, 30 | "gpt_system_prompt": "You are a helpful research assistant. You must determine whether a given review comment is relevant to a given paper revision.", 31 | "gpt_prompt_template": "You need to determine which edits correspond to a given reviewer comment for a scientific paper. Given a comment and a paper edit (where changes are enclosed by brackets with +/- to indicate additions/deletions), you must determine whether the edit was likely added to the paper to address the comment. Here are some examples:\n\ncomment: Relatedly, the conclusion mentions \"... random freeze can reduce the computational cost and memory footprint of the power method in GEP \" but this is not explored in much detail in the results.\n\nedit: [+We also observe the potential benefit of random freeze from the perspective of the gradient distribution. Compared to no freeze, random freeze leads to less clipping bias and gradient distortion, as shown in Figure 2 . We adopt the same clipping bound for random freeze and no freeze. As fewer dimensions contribute to the norm computation, random freeze reduces the clipping probability and therefore alleviates clipping bias (Zhang et al., 2020) . We also note that the norm of sparse gradients are not equally scaled down, weak gradients can spontaneously become larger during training, which mitigates the distortion of gradients due to perturbation, while the perturbation of random freeze is already moderated compared to no freeze. With the random freeze strategy, in later epochs the variance of the norm magnitude decreases. A lower number of high-magnitude gradient norms implies less clipping bias, while the decrease in low magnitude gradient norms implies a higher signal-to-noise ratio of the perturbed gradients. The two plots overlap in the subfigure corresponding to the first epoch as the freeze rate is 0 and the networks are initialized equally. The freeze rate at the 20th epoch is 0.45 and reaches 0.9 at the 40th epoch. Note that both axes are in log scale.+]\n\nDoes the edit address the comment (yes/no)?\nA: No\n\ncomment: Relatedly, the conclusion mentions \"... random freeze can reduce the computational cost and memory footprint of the power method in GEP \" but this is not explored in much detail in the results.\n\nedit: [+Advantages of sparsity Projected DP-SGD induces a significant additional computation cost by running the power method and projecting gradients into and out of subspace. For the power method, the basic operation is W W V , W \u2208 R d\u00d7s denotes sample gradients, V \u2208 R d\u00d7b denotes eigenvectors, the computational cost is O(dbs). Similarly, for projection V X; X \u2208 R d\u00d7n denotes original gradients, the computational cost is O(dbn). Applying random freeze, a random selection of rows of X are deleted, while corresponding rows of V, W can be removed as no information of gradient exits in that subspace. We note that b, s might also be able to be reduced. Overall, the computational cost is between O(1 \u2212 r) and O((1 \u2212 r) 3 ). Another issue of projected DP-SGD is the memory footprint of V . Saving sparse V by random freeze can be achieved by storing non-zero values and indices of zeros. The cost of indexing is logarithmic of the number of parameters, consider that log 2 10 9 < 32, we can decrease the memory footprint by removing a single 32 bit float gradient. Communication overhead can similarly be reduced. We note that random freeze uses the same mask during one training epoch, which could contain multiple groups of eigenvectors and communication rounds. Therefore, the cost of indexing is negligible: communication overhead and memory footprint are \u00d5(1 \u2212 r). Further, we define the total density as the total amount of non-zero gradients by random freeze over the total amount of gradients by the original dense representation to reflect these advantages of sparsity.+]\n\nDoes the edit address the comment (yes/no)?\nA: Yes\n\nNow give the answer for the following example:\n\ncomment: {{review_comment}}\n\nedit: {{diff_paragraph}}\n\nDoes the edit address the comment (yes/no)?" 32 | } 33 | ], 34 | "output_dir": "data/experiments/gpt_pairwise_1shot_ao/", 35 | "seed": 1 36 | } 37 | 38 | -------------------------------------------------------------------------------- /data/configs/human.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "diffs", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 10000, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "model_pipeline": [ 17 | { 18 | "model_type": "precomputed", 19 | "precomputed_predictions_jsonl_path": "data/aries/alignment_human_eval.jsonl" 20 | } 21 | ], 22 | "output_dir": "data/experiments/human/", 23 | "seed": 1 24 | } 25 | 26 | -------------------------------------------------------------------------------- /data/configs/human_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dummy.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "train_negative_sample_method": "other_docs", 9 | "dev_negative_sample_method": "other_docs", 10 | "test_negative_sample_method": "same_doc", 11 | "candidate_edit_type": "full_additions", 12 | "candidate_min_chars": 100, 13 | "max_negatives": 10000, 14 | "prune_candidates": true, 15 | "write_examples_on_eval": true, 16 | "model_pipeline": [ 17 | { 18 | "model_type": "precomputed", 19 | "precomputed_predictions_jsonl_path": "data/aries/alignment_human_eval.jsonl" 20 | } 21 | ], 22 | "output_dir": "data/experiments/human_ao/", 23 | "seed": 1 24 | } 25 | 26 | -------------------------------------------------------------------------------- /data/configs/linkbert_cross_encoder.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "diffs", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "michiyasunaga/LinkBERT-large", 22 | "model_type": "cross_encoder", 23 | "max_seq_length": 512, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 2, 30 | "gradient_accumulation_steps": 8, 31 | "num_train_epochs": -1, 32 | "max_steps": 8192, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/linkbert_cross_encoder/" 51 | } -------------------------------------------------------------------------------- /data/configs/linkbert_cross_encoder_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "full_additions", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "michiyasunaga/LinkBERT-large", 22 | "model_type": "cross_encoder", 23 | "max_seq_length": 512, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 2, 30 | "gradient_accumulation_steps": 8, 31 | "num_train_epochs": -1, 32 | "max_steps": 8192, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/linkbert_cross_encoder_ao/" 51 | } -------------------------------------------------------------------------------- /data/configs/specter2_biencoder.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "diffs", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "allenai/specter2", 22 | "model_type": "biencoder", 23 | "max_seq_length": 512, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 16, 30 | "gradient_accumulation_steps": 1, 31 | "num_train_epochs": -1, 32 | "max_steps": 8192, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/specter2_biencoder/" 51 | } -------------------------------------------------------------------------------- /data/configs/specter2_biencoder_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "full_additions", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "allenai/specter2", 22 | "model_type": "biencoder", 23 | "max_seq_length": 512, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 16, 30 | "gradient_accumulation_steps": 1, 31 | "num_train_epochs": -1, 32 | "max_steps": 8192, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/specter2_biencoder_ao/" 51 | } -------------------------------------------------------------------------------- /data/configs/specter2_untrained.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "diffs", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "allenai/specter2", 22 | "model_type": "biencoder", 23 | "max_seq_length": 512, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 16, 30 | "gradient_accumulation_steps": 1, 31 | "num_train_epochs": -1, 32 | "max_steps": 0, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/specter2_untrained/" 51 | } -------------------------------------------------------------------------------- /data/configs/specter2_untrained_ao.json: -------------------------------------------------------------------------------- 1 | { 2 | "s2orc_base_path": "data/aries/s2orc/", 3 | "paper_edits_file": "data/aries/paper_edits.jsonl", 4 | "review_comments_file": "data/aries/review_comments.jsonl", 5 | "train_edit_labels_file": "data/aries/edit_labels_train.jsonl", 6 | "dev_edit_labels_file": "data/aries/edit_labels_dev.jsonl", 7 | "test_edit_labels_file": "data/aries/edit_labels_test.jsonl", 8 | "candidate_edit_type": "full_additions", 9 | "max_negatives": 20, 10 | "prune_candidates": true, 11 | "write_examples_on_eval": true, 12 | "candidate_min_chars": 100, 13 | "train_negative_sample_method": "other_docs", 14 | "dev_negative_sample_method": "other_docs", 15 | "test_negative_sample_method": "same_doc", 16 | "train_hard_negative_ratio": 0.0, 17 | "dev_max_negatives": 100, 18 | "dev_seed": 1, 19 | "model_pipeline": [ 20 | { 21 | "model_name_or_path": "allenai/specter2", 22 | "model_type": "biencoder", 23 | "max_seq_length": 512, 24 | "query_input_format": "comment_with_context", 25 | "edit_input_format": "diff", 26 | "add_diff_tokens": false, 27 | "training_args": { 28 | "learning_rate": 2e-05, 29 | "per_device_train_batch_size": 16, 30 | "gradient_accumulation_steps": 1, 31 | "num_train_epochs": -1, 32 | "max_steps": 0, 33 | "adam_beta1": 0.9, 34 | "adam_beta2": 0.999, 35 | "lr_scheduler_type": "linear", 36 | "warmup_ratio": 0.0, 37 | "warmup_steps": 256, 38 | "logging_steps": 256, 39 | "save_steps": 10000, 40 | "evaluation_strategy": "steps", 41 | "max_grad_norm": 1.0, 42 | "log_level": "passive", 43 | "log_level_replica": "passive", 44 | "full_determinism": false, 45 | "fp16": true 46 | } 47 | } 48 | ], 49 | "seed": 1, 50 | "output_dir": "data/experiments/specter2_untrained_ao/" 51 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.2 --extra-index-url https://download.pytorch.org/whl/cu113 2 | numpy==1.21 3 | gensim==4.3 4 | scipy==1.7 5 | scikit-learn==1.0 6 | matplotlib==3.4 7 | seaborn==0.11 8 | tokenizers==0.12.1 9 | transformers==4.21 10 | tqdm 11 | pytest 12 | pylint 13 | 14 | nltk 15 | 16 | openai>=0.27 17 | 18 | requests 19 | openreview-py==1.0.23 20 | jupyterlab 21 | jupyterlab-vim 22 | datasets 23 | dill<0.3.5 24 | sacrebleu 25 | protobuf==3.20 26 | openpyxl 27 | -------------------------------------------------------------------------------- /scripts/generate_edits.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | 6 | import tqdm 7 | 8 | import aries.util.data 9 | import aries.util.s2orc 10 | from aries.util.data import index_by, iter_jsonl_files 11 | from aries.util.gpt3 import Gpt3CacheClient 12 | from aries.util.logging import init_logging 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def generate_edits_for_doc_comment( 18 | doc_s2orc, 19 | comment_record, 20 | prompt_template, 21 | gptcli, 22 | ): 23 | prompt = make_gpt_prompt(doc_s2orc, comment_record, prompt_template, gptcli) 24 | messages = [ 25 | { 26 | "role": "system", 27 | "content": "You are SciGPT, a research assistant that specializes in helping authors to improve their scientific papers. Follow the user's instructions carefully.", 28 | }, 29 | { 30 | "role": "user", 31 | "content": prompt, 32 | }, 33 | ] 34 | result = { 35 | "source_pdf_id": doc_s2orc["paper_id"], 36 | "comment_record": comment_record, 37 | "openreview_base_pdf": "https://openreview.net/references/pdf?id={}".format(doc_s2orc["paper_id"]), 38 | "gpt_edit": None, 39 | } 40 | try: 41 | response = gptcli.chat_completion( 42 | model="gpt-4-0314", 43 | messages=messages, 44 | temperature=0, 45 | max_tokens=1024, 46 | top_p=1, 47 | frequency_penalty=0, 48 | presence_penalty=0, 49 | ) 50 | result_text = response["choices"][0]["message"]["content"] 51 | except Exception: 52 | logging.exception("Error generating edit for doc_id={}".format(doc_s2orc["paper_id"])) 53 | return result, None 54 | parsed_result = parse_result_text(result_text) 55 | 56 | result["gpt_edit"] = result_text[result_text.find("Location:") :] 57 | 58 | return result, response 59 | 60 | 61 | def make_gpt_prompt(doc_s2orc, comment_record, template, gptcli): 62 | abstract = doc_s2orc["pdf_parse"]["abstract"][0]["text"] 63 | body_text_blob = "" 64 | prev_section = "unknown" 65 | for idx, x in enumerate(doc_s2orc["pdf_parse"]["body_text"]): 66 | secheader = "" 67 | secname = x["section"] if x["section"] else "unknown" 68 | if secname != prev_section: 69 | secheader = "section: {}\n".format(secname) 70 | prev_section = secname 71 | newtext = "{}{}\nparagraph id: {}\n\n".format(secheader, x["text"], idx) 72 | if gptcli.estimate_num_tokens(body_text_blob + newtext, model="gpt-4") < 6 * (2**10): 73 | body_text_blob += newtext 74 | 75 | comment_with_context = comment_record["comment"].strip() 76 | if comment_record["comment_context"] != "": 77 | comment_with_context += "\ncontext: {}".format(comment_record["comment_context"]) 78 | 79 | variables = { 80 | "__abstract": abstract, 81 | #'__comment': comment_record['comment'], 82 | "__comment_with_context": comment_with_context, 83 | "__body_chunk": body_text_blob, 84 | #'__full_review': None, 85 | } 86 | s = template.format(**variables) 87 | return s 88 | 89 | 90 | def parse_result_text(result_text): 91 | result = {"response": "", "edits": []} 92 | lines = result_text.split("\n") 93 | i = 0 94 | while i < len(lines): 95 | line = lines[i].strip() 96 | if line.startswith("Response:"): 97 | if result["response"] != "": 98 | raise ValueError("Multiple 'Response' tags") 99 | result["response"] = line[9:].strip() 100 | i += 1 101 | elif line.startswith("Location:"): 102 | location = line[9:].strip() 103 | i += 1 104 | while i < len(lines) and not lines[i].strip().startswith("Edit:"): 105 | location += "\n" + lines[i].strip() 106 | i += 1 107 | if i < len(lines) and lines[i].strip().startswith("Edit:"): 108 | edit = lines[i][5:].strip() 109 | i += 1 110 | while i < len(lines) and not lines[i].strip().startswith("Location:"): 111 | edit += "\n" + lines[i].strip() 112 | i += 1 113 | result["edits"].append({"location": location.strip(), "edit": edit.strip()}) 114 | else: 115 | i += 1 116 | return result 117 | 118 | 119 | def augment_config(config): 120 | config_defaults = { 121 | "seed": 42, 122 | } 123 | for k, v in config_defaults.items(): 124 | config[k] = config.get(k, v) 125 | 126 | NEEDED_KEYS = ["s2orc_base_path", "cache_db_path", "output_dir", "split_ids_file", "split_name", "review_comments_file"] 127 | missing_keys = [x for x in NEEDED_KEYS if x not in config] 128 | if len(missing_keys) > 0: 129 | raise ValueError("Missing config keys: %s" % missing_keys) 130 | 131 | 132 | def main(): 133 | with open(sys.argv[1]) as f: 134 | config = json.load(f) 135 | 136 | augment_config(config) 137 | 138 | os.makedirs(config["output_dir"], exist_ok=True) 139 | 140 | init_logging( 141 | logfile=os.path.join(config["output_dir"], "logging_output.log"), 142 | level=logging.INFO, 143 | ) 144 | 145 | with open(config["split_ids_file"]) as f: 146 | pdf_pair_ids = json.load(f)[config["split_name"]] 147 | 148 | pair_ids_by_doc = index_by(pdf_pair_ids, "doc_id", one_to_one=True) 149 | 150 | review_comments = [x for x in iter_jsonl_files([config["review_comments_file"]]) if x["doc_id"] in pair_ids_by_doc] 151 | review_comments_by_docid = index_by(review_comments, "doc_id") 152 | 153 | all_outputs = [] 154 | tt = 0 155 | utt = 0 156 | 157 | with aries.util.s2orc.S2orcFetcherSqlite( 158 | config.get("s2orc_db_path", ":memory:"), 159 | fallback_fetcher=aries.util.s2orc.S2orcFetcherFilesystem(config["s2orc_base_path"]), 160 | update_db=False, 161 | ) as fetcher: 162 | with Gpt3CacheClient(config["cache_db_path"]) as gptcli: 163 | with tqdm.trange(sum(map(len, review_comments_by_docid.values()))) as pbar: 164 | for doc_id, comment_records in review_comments_by_docid.items(): 165 | doc_s2orc = fetcher.get(pair_ids_by_doc[doc_id]["source_pdf_id"]) 166 | for idx, comment_record in enumerate(comment_records): 167 | record, response = generate_edits_for_doc_comment(doc_s2orc, comment_record, config["prompt_template"], gptcli) 168 | all_outputs.append(record) 169 | if response is None: 170 | raise ValueError("GPT returned no response") 171 | tt += response["usage"]["total_tokens"] 172 | utt += response["usage"]["uncached_total_tokens"] 173 | pbar.set_description("tt={} utt={}, doc={}".format(tt, utt, doc_s2orc["paper_id"])) 174 | pbar.update(1) 175 | 176 | with open(os.path.join(config["output_dir"], "edits.jsonl"), "w") as f: 177 | for record in all_outputs: 178 | f.write(json.dumps(record) + "\n") 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /scripts/generate_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import itertools 3 | import json 4 | import logging 5 | import os 6 | import re 7 | import sys 8 | import time 9 | 10 | import nltk 11 | import nltk.util 12 | import numpy as np 13 | import tqdm 14 | from nltk.util import ngrams 15 | 16 | import aries.util.data 17 | import aries.util.s2orc 18 | from aries.alignment.doc_edits import make_full_aligns 19 | from aries.util.data import deduplicate_by, index_by, iter_jsonl_files 20 | from aries.util.edit import find_overlapping_substrings 21 | from aries.util.logging import init_logging 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | STOPWORDS = set(nltk.corpus.stopwords.words("english")) | set(",./<>?!@#$%^&*()_+-={}|[]\\,") 26 | 27 | 28 | def get_author_replies(review_note, forum_notes): 29 | replies = [x for x in forum_notes if x["replyto"] == review_note["id"] and any("author" in xx.lower() for xx in x["writers"])] 30 | # Sometimes authors break up their response by replying to their own message 31 | nested_replies = [] 32 | for reply in replies: 33 | nested_replies.extend(get_author_replies(reply, forum_notes)) 34 | return replies + nested_replies 35 | 36 | 37 | def _get_combined_text2_overlap_spans_overlapping_span(all_overlaps, span, is_sorted=False): 38 | """Given a set of overlap span pairs, find all of the ones for which the 39 | text2 span overlaps the given span, and return a list of just the text2 40 | spans of those overlaps, merged such that any partially-overlapping spans 41 | are collapsed into a single span in the list.""" 42 | if not is_sorted: 43 | # Sort by text2 idxs; this allows fast lookups for overlaps contained within a span 44 | all_overlaps = sorted(all_overlaps, key=lambda x: x[1]) 45 | 46 | overlaps = [] 47 | for ov in all_overlaps: 48 | if ov[1][0] >= span[0] and ov[1][0] < span[1]: 49 | overlaps.append(ov) 50 | elif ov[1][1] > span[0] and ov[1][1] <= span[1]: 51 | overlaps.append(ov) 52 | 53 | if len(overlaps) <= 1: 54 | return [x[1] for x in overlaps] 55 | 56 | combined = [] 57 | last_span = overlaps[0][1] 58 | for ov in overlaps[1:]: 59 | if ov[1][0] < last_span[1]: 60 | last_span = (last_span[0], ov[1][1]) 61 | else: 62 | combined.append(last_span) 63 | last_span = ov[1] 64 | combined.append(last_span) 65 | return combined 66 | 67 | 68 | TextReplyMatch = collections.namedtuple("TextReplyMatch", "doc_id review_id reply_id match_text spans line_span next_line_span reply_text") 69 | 70 | 71 | def get_tight_span(toks, line_overlaps, prevnewline, nextnewline): 72 | # There are two main ways authors mess with the text at this 73 | # point: (1) they correct typos, causing non-exact matches, and 74 | # (2) they add prefixes or quotation marks to the text (e.g., 75 | # "Comment 1:"). To deal with the first case, we want to 76 | # include the whole line even if there are some non-matched 77 | # spans in the middle. But, to deal with the second case we 78 | # want to omit the start or end of the line if those aren't 79 | # matched and they don't occur in the middle of a word. 80 | tight_span = [max(min(x[0] for x in line_overlaps), prevnewline), min(max(x[1] for x in line_overlaps), nextnewline)] 81 | while toks[tight_span[0]].isspace() or toks[tight_span[0]] in ".:)*": 82 | tight_span[0] += 1 83 | 84 | # Expand back if the span started mid-word (usually a capitalization difference) 85 | while toks[tight_span[0] - 1].isalpha(): 86 | tight_span[0] -= 1 87 | 88 | # Citations are weird; never prune when we have one 89 | if re.search(r" et |[0-9]{4}", toks[prevnewline : tight_span[0]]): 90 | tight_span[0] = prevnewline 91 | 92 | while toks[tight_span[1] - 1].isspace(): 93 | tight_span[1] -= 1 94 | if tight_span[0] > prevnewline + 20: 95 | tight_span[0] = prevnewline 96 | if tight_span[1] < nextnewline - 10: 97 | tight_span[1] = nextnewline 98 | 99 | return tight_span 100 | 101 | 102 | def _get_match_for_overlap(toks1, toks2, overlap, all_overlaps, min_line_overlap_ratio, doc_id, review_id, reply_id): 103 | # Check if it takes up most of a line (quotes usually do) 104 | prevnewline = max(0, toks2.rfind("\n", 0, overlap[1][0])) 105 | 106 | nextnewline = toks2.find("\n", overlap[1][1] - 1) 107 | nextnewline = nextnewline if nextnewline >= 0 else len(toks2) 108 | while nextnewline > prevnewline and toks2[nextnewline - 1] == "\n": 109 | nextnewline -= 1 110 | 111 | if nextnewline == prevnewline: 112 | print(nextnewline, prevnewline, overlap, len(toks2)) 113 | line_overlaps = _get_combined_text2_overlap_spans_overlapping_span(all_overlaps, (prevnewline, nextnewline)) 114 | total_line_overlap = sum(max(x[1], prevnewline) - min(x[0], nextnewline) for x in line_overlaps) 115 | lineratio = total_line_overlap / (nextnewline - prevnewline) 116 | 117 | if lineratio < min_line_overlap_ratio: 118 | return None 119 | 120 | tight_span = get_tight_span(toks2, line_overlaps, prevnewline, nextnewline) 121 | # if abs(tight_span[0] - prevnewline) > 2 or abs(tight_span[0] - nextnewline) > 2: 122 | # print('difference! oldline={},\nnewline={}'.format(toks2[prevnewline:nextnewline], toks2[tight_span[0]:tight_span[1]],)) 123 | 124 | nextnextnewline = nextnewline 125 | while nextnextnewline < len(toks2) and toks2[nextnextnewline] == "\n": 126 | nextnextnewline += 1 127 | nnlend = nextnextnewline 128 | while nextnextnewline < len(toks2) and toks2[nextnextnewline] != "\n": 129 | nextnextnewline += 1 130 | # print(toks1[overlap[0][0]:overlap[0][1]]) 131 | # all_matches.append((toks1[overlap[0][0]:overlap[0][1]], docrec['docid'], revrec['id'], reply['id'], overlap)) 132 | # all_matches.append((toks2[prevnewline:nextnewline], docrec["docid"], revrec["id"], reply["id"], overlap, (prevnewline, nextnewline), toks2)) 133 | return TextReplyMatch( 134 | doc_id, 135 | review_id, 136 | reply_id, 137 | # None, 138 | # None, 139 | # None, 140 | # docrec["docid"], 141 | # revrec["id"], 142 | # reply["id"], 143 | # toks2[prevnewline:nextnewline], 144 | toks2[tight_span[0] : tight_span[1]], 145 | overlap, 146 | # (prevnewline, nextnewline), 147 | tuple(tight_span), 148 | (nnlend, nextnextnewline), 149 | toks2, 150 | ) 151 | 152 | 153 | def get_author_comment_replies_for_doc(forum_id, review_replies, min_length=80, min_line_overlap_ratio=0.9): 154 | all_matches = [] 155 | for review_rec in review_replies: 156 | replies = review_rec["author_replies"] 157 | used_spans = set() 158 | for reply in replies: 159 | toks1 = "\n".join([str(x) for x in review_rec["content"].values()]) 160 | toks2 = reply["content"]["comment"] 161 | 162 | overlaps = find_overlapping_substrings(toks1, toks2, min_length=min_length) 163 | overlaps.sort(key=lambda x: x[1]) 164 | 165 | for overlap in overlaps: 166 | m = _get_match_for_overlap( 167 | toks1, toks2, overlap, overlaps, min_line_overlap_ratio, review_rec["forum"], review_rec["id"], reply["id"] 168 | ) 169 | if m is not None: 170 | sp = (m.doc_id, m.review_id, m.reply_id, m.line_span) 171 | if sp not in used_spans: 172 | all_matches.append(m) 173 | used_spans.add(sp) 174 | else: 175 | logger.debug("Skipping duplicate match: %s", sp) 176 | return all_matches 177 | 178 | 179 | def make_bow(txt): 180 | return collections.Counter(ngrams([x for x in txt.split() if x.lower() not in STOPWORDS], 1)) 181 | 182 | 183 | def _similarity_many_many_minl(txts1, txts2, match_denom=False): 184 | ngs1 = [make_bow(txt) for txt in txts1] 185 | ngs2 = [make_bow(txt) for txt in txts2] 186 | sim_mat = np.zeros((len(txts1), len(txts2))) 187 | 188 | if match_denom: 189 | denom = max(sum(x.values()) for x in itertools.chain(ngs1, ngs2)) 190 | 191 | def sim_fn(counter1, counter2): 192 | return sum((counter1 & counter2).values()) / denom 193 | 194 | else: 195 | 196 | def sim_fn(counter1, counter2): 197 | return sum((counter1 & counter2).values()) / max(40, min(sum(counter1.values()), sum(counter2.values()))) 198 | 199 | for idx1, idx2 in itertools.product(range(len(txts1)), range(len(txts2))): 200 | ng1 = ngs1[idx1] 201 | ng2 = ngs2[idx2] 202 | if len(ng1) == 0 and len(ng2) == 0: 203 | sim_mat[idx1, idx2] = sim_fn(collections.Counter(txts1[idx1]), collections.Counter(txts2[idx2])) 204 | sim_mat[idx1, idx2] = sim_fn(ng1, ng2) 205 | return sim_mat 206 | 207 | 208 | def _get_high_similarity_comment_edit_texts(comments, edits, sim_threshold): 209 | output_matches = [] 210 | 211 | t2s = [] 212 | for edit_idx, edit in enumerate(edits): 213 | try: 214 | output_text = " ".join(edit.get_added_tokens()) 215 | except RecursionError: 216 | logger.error("Recursion error for edit %s", edit_idx) 217 | output_text = "" 218 | t2s.append(output_text) 219 | sim_mat = _similarity_many_many_minl(comments, t2s, match_denom=False) 220 | for cidx, rec in enumerate(comments): 221 | # If there are multiple matches, take only the best; others are sometimes spurious 222 | best = None 223 | for eidx in range(len(t2s)): 224 | if sim_mat[cidx, eidx] <= sim_threshold: 225 | continue 226 | 227 | edit = edits[eidx] 228 | # We allow a little wiggle room for off-by-1's (could come from bad splits/parses), 229 | # but it's unlikely to be correct if there were many non-consecutive matches 230 | if (sorted(edit.target_idxs)[-1] - sorted(edit.target_idxs)[0]) >= (len(edit.target_idxs) * 2 - 1): 231 | continue 232 | 233 | if edits[eidx].is_full_deletion() or edits[eidx].is_identical() or len(edits[eidx].get_added_tokens()) < 5: 234 | continue 235 | 236 | if best is None or best[2] < sim_mat[cidx, eidx]: 237 | best = (cidx, eidx, sim_mat[cidx, eidx]) 238 | 239 | if best is not None: 240 | output_matches.append(best) 241 | return output_matches 242 | 243 | 244 | def get_high_precision_reply_based_alignments_for_doc( 245 | pdf_pair_record, 246 | review_replies, 247 | sim_threshold, 248 | min_reply_overlap_chars, 249 | min_line_overlap_ratio, 250 | s2orc_fetcher, 251 | ): 252 | doc_id = pdf_pair_record["doc_id"] 253 | s2orc1 = aries.util.s2orc.load_s2orc(pdf_pair_record["source_pdf_id"], s2orc_fetcher) 254 | s2orc2 = aries.util.s2orc.load_s2orc(pdf_pair_record["target_pdf_id"], s2orc_fetcher) 255 | 256 | if s2orc1 is None or s2orc2 is None or s2orc1["paper_id"] == s2orc2["paper_id"]: 257 | return [] 258 | 259 | forum_replies = get_author_comment_replies_for_doc( 260 | doc_id, 261 | review_replies, 262 | min_length=min_reply_overlap_chars, 263 | min_line_overlap_ratio=min_line_overlap_ratio, 264 | ) 265 | 266 | review_recs = [ 267 | { 268 | "review_id": x.review_id, 269 | "review_comment": x.match_text, 270 | "reply": x.reply_text[x.next_line_span[0] : x.next_line_span[1]], 271 | "full_match": x, 272 | } 273 | for x in forum_replies 274 | ] 275 | 276 | # If we don't even have enough tokens to form a sentence, it's probably invalid 277 | review_recs = [x for x in review_recs if len(x["review_comment"].split()) >= 4] 278 | 279 | review_recs = deduplicate_by(review_recs, "review_comment") 280 | 281 | aligns = make_full_aligns(s2orc1, s2orc2) 282 | 283 | aug_review_comments = [x["reply"] for x in review_recs] 284 | 285 | output_matches = [] 286 | 287 | for cidx, eidx, sim in _get_high_similarity_comment_edit_texts(aug_review_comments, aligns.paragraph_edits, sim_threshold): 288 | edit = aligns.paragraph_edits[eidx] 289 | 290 | output_matches.append( 291 | { 292 | "doc_id": doc_id, 293 | "source_pdf_id": pdf_pair_record["source_pdf_id"], 294 | "target_pdf_id": pdf_pair_record["target_pdf_id"], 295 | "review_id": review_recs[cidx]["review_id"], 296 | "edit": edit, 297 | "review_comment": review_recs[cidx]["review_comment"], 298 | "reply": review_recs[cidx]["reply"], 299 | "forum_reply": review_recs[cidx]["full_match"], 300 | "similarity": sim, 301 | } 302 | ) 303 | 304 | paper_edit_record = aligns.to_json() 305 | paper_edit_record["doc_id"] = doc_id 306 | 307 | review_comment_records = [ 308 | { 309 | "comment_id": cidx, 310 | "doc_id": doc_id, 311 | "annotation": "synthetic", 312 | "comment": x["review_comment"], 313 | "comment_context": "", 314 | "review_id": x["review_id"], 315 | } 316 | for cidx, x in enumerate(output_matches) 317 | ] 318 | 319 | edit_label_records = [ 320 | {"doc_id": doc_id, "comment_id": cidx, "positive_edits": [x["edit"].edit_id], "negative_edits": [], "annotation": "synthetic"} 321 | for cidx, x in enumerate(output_matches) 322 | ] 323 | return paper_edit_record, review_comment_records, edit_label_records, output_matches 324 | 325 | 326 | def augment_config(config): 327 | config_defaults = { 328 | "similarity_threshold": 0.26, 329 | "min_reply_overlap_chars": 40, 330 | "min_line_overlap_ratio": 0.9, 331 | "seed": 42, 332 | } 333 | for k, v in config_defaults.items(): 334 | config[k] = config.get(k, v) 335 | 336 | NEEDED_KEYS = ["s2orc_base_path", "output_dir", "split_ids_file", "split_name", "review_replies_file"] 337 | missing_keys = [x for x in NEEDED_KEYS if x not in config] 338 | if len(missing_keys) > 0: 339 | raise ValueError("Missing config keys: %s" % missing_keys) 340 | 341 | 342 | def main(): 343 | with open(sys.argv[1]) as f: 344 | config = json.load(f) 345 | 346 | augment_config(config) 347 | 348 | init_logging( 349 | logfile=os.path.join(config["output_dir"], "logging_output.log"), 350 | level=logging.INFO, 351 | ) 352 | 353 | with open(config["split_ids_file"]) as f: 354 | pdf_pair_ids = json.load(f)[config["split_name"]] 355 | 356 | review_replies = list(iter_jsonl_files([config["review_replies_file"]])) 357 | review_replies_by_docid = index_by(review_replies, "forum") 358 | 359 | paper_edit_records = [] 360 | review_comment_records = [] 361 | edit_label_records = [] 362 | full_match_records = [] 363 | 364 | pbar = tqdm.tqdm(pdf_pair_ids) 365 | with aries.util.s2orc.S2orcFetcherSqlite( 366 | config.get("s2orc_db_path", ":memory:"), 367 | fallback_fetcher=aries.util.s2orc.S2orcFetcherFilesystem(config["s2orc_base_path"]), 368 | update_db=False, 369 | ) as fetcher: 370 | for pdf_pair in pbar: 371 | if pdf_pair["doc_id"] not in review_replies_by_docid: 372 | continue 373 | 374 | per, rcr, elr, fmr = get_high_precision_reply_based_alignments_for_doc( 375 | pdf_pair, 376 | review_replies_by_docid[pdf_pair["doc_id"]], 377 | sim_threshold=config["similarity_threshold"], 378 | min_reply_overlap_chars=config["min_reply_overlap_chars"], 379 | min_line_overlap_ratio=config["min_line_overlap_ratio"], 380 | s2orc_fetcher=fetcher, 381 | ) 382 | 383 | paper_edit_records.append(per) 384 | review_comment_records.extend(rcr) 385 | 386 | for elr_i in elr: 387 | elr_i["split"] = config["split_name"] 388 | edit_label_records.extend(elr) 389 | 390 | full_match_records.extend(fmr) 391 | 392 | pbar.set_description("n_results={}".format(len(edit_label_records)), refresh=False) 393 | 394 | with open(os.path.join(config["output_dir"], "paper_edits.jsonl"), "w") as f: 395 | for rec in paper_edit_records: 396 | f.write(json.dumps(rec) + "\n") 397 | 398 | with open(os.path.join(config["output_dir"], "review_comments.jsonl"), "w") as f: 399 | for rec in review_comment_records: 400 | f.write(json.dumps(rec) + "\n") 401 | 402 | with open(os.path.join(config["output_dir"], "edit_labels.jsonl"), "w") as f: 403 | for rec in edit_label_records: 404 | f.write(json.dumps(rec) + "\n") 405 | 406 | 407 | if __name__ == "__main__": 408 | main() 409 | -------------------------------------------------------------------------------- /scripts/train_revision_alignment.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import itertools 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | import sacrebleu 10 | import torch 11 | import tqdm 12 | import transformers 13 | 14 | import aries.util.data 15 | import aries.util.s2orc 16 | from aries.alignment.biencoder import BiencoderTransformerAligner 17 | from aries.alignment.bm25 import BM25Aligner 18 | from aries.alignment.cross_encoder import PairwiseTransformerAligner 19 | from aries.alignment.doc_edits import DocEdits 20 | from aries.alignment.eval import do_model_eval 21 | from aries.alignment.gpt import GptChatAligner, GptChatFullPaperAligner 22 | from aries.alignment.other import MultiStageAligner 23 | from aries.alignment.precomputed import PrecomputedEditsAligner 24 | from aries.util.data import index_by, iter_jsonl_files 25 | from aries.util.logging import init_logging, pprint_metrics 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | try: 31 | # Needed for SLED models 32 | import sled 33 | except ImportError: 34 | sled = None 35 | 36 | 37 | def _load_transformer(config, cls): 38 | transformer_model = cls.from_pretrained(config["model_name_or_path"]) 39 | if torch.cuda.device_count() > 0: 40 | transformer_model = transformer_model.to(torch.device("cuda")) 41 | if config.get("model_adapter", None) is not None: 42 | logger.info("initializing adapter: {}".format(config["model_adapter"])) 43 | transformer_model.load_adapter(config["model_adapter"], source="hf", load_as="adapter", set_active=True) 44 | logger.info(transformer_model.adapter_summary()) 45 | tokenizer = transformers.AutoTokenizer.from_pretrained(config["model_name_or_path"]) 46 | 47 | return transformer_model, tokenizer 48 | 49 | 50 | def init_model_from_config(config): 51 | model = None 52 | if config["model_type"] == "cross_encoder": 53 | transformer_model, tokenizer = _load_transformer(config, transformers.AutoModelForSequenceClassification) 54 | model = PairwiseTransformerAligner(config, transformer_model, tokenizer) 55 | elif config["model_type"] == "biencoder": 56 | transformer_model, tokenizer = _load_transformer(config, transformers.AutoModel) 57 | model = BiencoderTransformerAligner(config, transformer_model, tokenizer) 58 | elif config["model_type"] == "gpt": 59 | model = GptChatAligner(config) 60 | elif config["model_type"] == "gpt_full_paper": 61 | model = GptChatFullPaperAligner(config) 62 | elif config["model_type"] == "bm25": 63 | model = BM25Aligner(config) 64 | elif config["model_type"] == "precomputed": 65 | model = PrecomputedEditsAligner(config) 66 | else: 67 | raise ValueError("Unknown model type: {}".format(config["model_type"])) 68 | 69 | return model 70 | 71 | 72 | def train_eval_main(config, split_data): 73 | models = [] 74 | # Initial models are treated as pre-processing filters and do not train 75 | for model_conf in config["model_pipeline"][:-1]: 76 | model_conf["output_dir"] = config["output_dir"] 77 | model_conf["seed"] = config["seed"] 78 | model = init_model_from_config(model_conf) 79 | logger.info("model_cls: {}".format(str(model.__class__.__name__))) 80 | models.append(model) 81 | 82 | model_conf = config["model_pipeline"][-1] 83 | model_conf["output_dir"] = config["output_dir"] 84 | model_conf["seed"] = config["seed"] 85 | model = init_model_from_config(model_conf) 86 | logger.info("main model_cls: {}".format(str(model.__class__.__name__))) 87 | model.train(split_data["train"], split_data["dev"]) 88 | models.append(model) 89 | 90 | model = MultiStageAligner(config, models) 91 | 92 | eval_splits = ["dev", "test"] 93 | if config.get("do_final_evaluation_on_train", False): 94 | eval_splits.append("train") 95 | dev_threshold = None 96 | for name in eval_splits: 97 | recs = split_data[name] 98 | 99 | metrics, all_results, by_review = do_model_eval(model, recs, custom_decision_threshold=dev_threshold, custom_threshold_name="devthresh") 100 | if name == "dev": 101 | dev_threshold = metrics.get("optimal_decision_threshold", None) 102 | 103 | logger.info("Done. Writing output...") 104 | with open(os.path.join(config["output_dir"], "{}_inferences.jsonl".format(name)), "w") as f: 105 | for res in all_results: 106 | f.write(json.dumps(res) + "\n") 107 | 108 | logger.info("Final {} metrics:".format(name)) 109 | pprint_metrics(metrics, logger, name=name) 110 | 111 | with open(os.path.join(config["output_dir"], "{}_metrics.json".format(name)), "w") as f: 112 | if "bleu" in metrics and isinstance(metrics["bleu"], sacrebleu.metrics.bleu.BLEUScore): 113 | metrics["bleu"] = metrics["bleu"].score 114 | json.dump(metrics, f) 115 | 116 | with open(os.path.join(config["output_dir"], "{}_inferences_by_review.jsonl".format(name)), "w") as f: 117 | for rec in by_review.values(): 118 | f.write(json.dumps(rec) + "\n") 119 | 120 | 121 | def make_revision_alignment_prediction_data( 122 | config, 123 | review_comments_by_doc, 124 | paper_edits_by_doc, 125 | edit_labels_file, 126 | max_negatives, 127 | negative_sample_method="same_doc", 128 | hard_negative_ratio=0.0, 129 | seed=None, 130 | ): 131 | if seed is None: 132 | seed = config["seed"] 133 | 134 | edit_labels = list(iter_jsonl_files([edit_labels_file])) 135 | edit_labels_by_doc = index_by(edit_labels, "doc_id") 136 | 137 | all_split_edits = list( 138 | itertools.chain(*[[(doc_id, x) for x in paper_edits_by_doc[doc_id].paragraph_edits] for doc_id in edit_labels_by_doc.keys()]) 139 | ) 140 | 141 | examples = [] 142 | 143 | rng = np.random.default_rng(seed) 144 | for doc_id in edit_labels_by_doc.keys(): 145 | distractor_idxs = rng.choice( 146 | len(all_split_edits), 147 | size=min(len(all_split_edits), config["distractor_reservoir_size"]), 148 | replace=False, 149 | ) 150 | distractor_pool = [all_split_edits[i][1] for i in distractor_idxs if all_split_edits[i][0] != doc_id] 151 | doc_examples = get_alignments_for_revision( 152 | config, 153 | rng, 154 | doc_id, 155 | review_comments_by_doc[doc_id], 156 | paper_edits_by_doc[doc_id], 157 | edit_labels=edit_labels_by_doc[doc_id], 158 | max_negatives=max_negatives, 159 | distractor_pool=distractor_pool, 160 | negative_sample_method=negative_sample_method, 161 | hard_negative_ratio=hard_negative_ratio, 162 | ) 163 | examples.extend(doc_examples) 164 | 165 | return examples 166 | 167 | 168 | def _filter_edits_by_type(edit_list, keep_type, min_length=0): 169 | newlist = None 170 | if keep_type == "full_additions": 171 | newlist = [edit for edit in edit_list if edit.is_full_addition()] 172 | elif keep_type == "diffs": 173 | newlist = [edit for edit in edit_list if not edit.is_identical()] 174 | elif keep_type == "source_diffs": 175 | newlist = [edit for edit in edit_list if (len(edit.source_idxs) != 0 and not edit.is_identical())] 176 | else: 177 | raise ValueError("Invalid candidate edit type {}".format(keep_type)) 178 | 179 | if min_length > 0: 180 | newlist = [edit for edit in newlist if len(edit.get_source_text() + edit.get_target_text()) >= min_length] 181 | 182 | return newlist 183 | 184 | 185 | def get_alignments_for_revision( 186 | config, 187 | rng, 188 | doc_id, 189 | review_comments, 190 | edits, 191 | edit_labels, 192 | max_negatives=999999, 193 | distractor_pool=None, 194 | negative_sample_method="same_doc", 195 | hard_negative_ratio=0.0, 196 | ): 197 | review_comments_by_id = index_by(review_comments, "comment_id", one_to_one=True) 198 | 199 | examples = [] 200 | 201 | for record in edit_labels: 202 | positives = [edits.by_id(x) for x in record["positive_edits"]] 203 | positives = _filter_edits_by_type(positives, config["candidate_edit_type"], min_length=config["candidate_min_chars"]) 204 | pos_ids = set([x.edit_id for x in positives]) 205 | 206 | if negative_sample_method == "same_doc": 207 | # Assume all non-positive the paras from the same doc are negatives (appropriate when positives are high-recall) 208 | negatives = [x for idx, x in enumerate(edits.paragraph_edits) if x.edit_id not in pos_ids] 209 | elif negative_sample_method == "other_docs": 210 | # Only sample negatives from edits to other docs (appropriate when positives are low-recall) 211 | if distractor_pool is None: 212 | raise ValueError("Need distractor edits from other docs to use other_doc_edits negative sampling") 213 | negatives = distractor_pool.copy() 214 | 215 | negatives = _filter_edits_by_type(negatives, config["candidate_edit_type"], min_length=config["candidate_min_chars"]) 216 | 217 | rng.shuffle(negatives) 218 | 219 | if len(negatives) <= max_negatives: 220 | final_negatives = negatives 221 | elif config["hard_negative_strategy"] == "none" or hard_negative_ratio == 0: 222 | final_negatives = negatives 223 | if hard_negative_ratio != 0: 224 | logger.warning( 225 | "hard_negative_ratio was {} but hard_negative_strategy is {}; no hard negatives will be used".format( 226 | hard_negative_ratio, config["hard_negative_strategy"] 227 | ) 228 | ) 229 | else: 230 | hard_negatives = _get_hard_negatives(negatives, positives, strategy=config["hard_negative_strategy"])[:max_negatives] 231 | n_hard = min(len(hard_negatives), int(max_negatives * hard_negative_ratio)) 232 | n_easy = max_negatives - n_hard 233 | 234 | # note: Could potentially duplicate an example between easy and 235 | # hard negatives since hard are just sorted; maybe try to dedup 236 | final_negatives = negatives[:n_easy] + hard_negatives[:n_hard] 237 | 238 | final_negatives = final_negatives[:max_negatives] 239 | 240 | comment = review_comments_by_id[record["comment_id"]] 241 | 242 | example = { 243 | "source_pdf_id": edits.s2orc1["paper_id"], 244 | "target_pdf_id": edits.s2orc2["paper_id"], 245 | "doc_id": doc_id, 246 | "comment_id": comment["comment_id"], 247 | "review_comment": comment["comment"], 248 | "context": comment["comment_context"], 249 | "context_side": comment.get("context_side", None), 250 | "positives": positives, 251 | "negatives": final_negatives, 252 | } 253 | if example["context_side"] is None: 254 | if example["context"] != "" and example["context"].strip().startswith("["): 255 | example["context_side"] = "right" 256 | else: 257 | example["context_side"] = "left" 258 | 259 | examples.append(example) 260 | 261 | return examples 262 | 263 | 264 | def _get_hard_negatives(negatives, positives, strategy="none"): 265 | """Returns the negatives sorted by hardness, and possibly also filtered by hardness""" 266 | if len(positives) == 0 or strategy == "none": 267 | return [] 268 | elif strategy == "length": 269 | pos_lengths = [len(x.get_target_text()) for x in positives] 270 | return sorted(negatives, key=lambda x: min(abs(len(x.get_target_text()) - pl) for pl in pos_lengths)) 271 | elif strategy == "aggregate_unigram_overlap": 272 | all_pos_tokens = collections.Counter(itertools.chain(*[x.get_target_text().lower().split() for x in positives])) 273 | return sorted( 274 | negatives, key=lambda x: -aries.util.data.counter_jaccard(all_pos_tokens, collections.Counter(x.get_target_text().lower().split())) 275 | ) 276 | 277 | raise ValueError("Unknown strategy {}".format(strategy)) 278 | 279 | 280 | def init_data(config): 281 | review_comments = list(iter_jsonl_files([config["review_comments_file"]])) 282 | review_comments_by_doc = index_by(review_comments, "doc_id") 283 | 284 | paper_edits = iter_jsonl_files([config["paper_edits_file"]]) 285 | 286 | paper_edits_by_doc = index_by(paper_edits, "doc_id", one_to_one=True) 287 | 288 | for doc_id, s2orc1, s2orc2 in aries.util.s2orc.iter_s2orc_pairs(config["s2orc_base_path"], [x[1] for x in sorted(paper_edits_by_doc.items())]): 289 | paper_edits_by_doc[doc_id] = DocEdits.from_list(s2orc1, s2orc2, paper_edits_by_doc[doc_id]["edits"]) 290 | 291 | all_data = dict() 292 | all_data["dev"] = make_revision_alignment_prediction_data( 293 | config, 294 | review_comments_by_doc, 295 | paper_edits_by_doc, 296 | config["dev_edit_labels_file"], 297 | max_negatives=config.get("dev_max_negatives", config["max_negatives"]), 298 | seed=config.get("dev_seed", config["seed"]), 299 | negative_sample_method=config.get("dev_negative_sample_method", config["default_negative_sample_method"]), 300 | ) 301 | logger.info("dev data size: {}".format(len(all_data["dev"]))) 302 | 303 | all_data["test"] = make_revision_alignment_prediction_data( 304 | config, 305 | review_comments_by_doc, 306 | paper_edits_by_doc, 307 | config["test_edit_labels_file"], 308 | max_negatives=9999, 309 | seed=config["seed"], 310 | negative_sample_method=config.get("test_negative_sample_method", config["default_negative_sample_method"]), 311 | ) 312 | logger.info("test data size: {}".format(len(all_data["test"]))) 313 | 314 | all_data["train"] = make_revision_alignment_prediction_data( 315 | config, 316 | review_comments_by_doc, 317 | paper_edits_by_doc, 318 | config["train_edit_labels_file"], 319 | max_negatives=config["max_negatives"], 320 | seed=config["seed"], 321 | negative_sample_method=config.get("train_negative_sample_method", config["default_negative_sample_method"]), 322 | ) 323 | logger.info("train data size: {}".format(len(all_data["train"]))) 324 | 325 | return all_data 326 | 327 | 328 | def augment_config(config): 329 | config_defaults = { 330 | "max_negatives": 999999, 331 | "candidate_edit_type": "diffs", 332 | "candidate_min_chars": 100, 333 | "prune_candidates": False, 334 | "write_examples_on_eval": True, 335 | "distractor_reservoir_size": 1000, 336 | "default_negative_sample_method": "same_doc", 337 | "train_hard_negative_ratio": 0.0, 338 | "hard_negative_strategy": ("length" if config.get("train_hard_negative_ratio", 0.0) != 0 else "none"), 339 | } 340 | for k, v in config_defaults.items(): 341 | config[k] = config.get(k, v) 342 | 343 | NEEDED_KEYS = [ 344 | "dev_edit_labels_file", 345 | "test_edit_labels_file", 346 | "train_edit_labels_file", 347 | "model_pipeline", 348 | "output_dir", 349 | "paper_edits_file", 350 | "review_comments_file", 351 | "s2orc_base_path", 352 | "seed", 353 | ] 354 | missing_keys = [x for x in NEEDED_KEYS if x not in config] 355 | if len(missing_keys) > 0: 356 | raise ValueError("Missing config keys: %s" % missing_keys) 357 | 358 | 359 | def main(): 360 | with open(sys.argv[1]) as f: 361 | config = json.load(f) 362 | 363 | augment_config(config) 364 | 365 | os.makedirs(config["output_dir"], exist_ok=True) 366 | 367 | init_logging( 368 | logfile=os.path.join(config["output_dir"], "logging_output.log"), 369 | level=logging.INFO, 370 | ) 371 | 372 | transformers.set_seed(config["seed"]) 373 | 374 | all_data = init_data(config) 375 | train_eval_main(config, all_data) 376 | 377 | 378 | if __name__ == "__main__": 379 | main() 380 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="aries", 8 | version="0.1.0", 9 | author="Mike D'Arcy", 10 | author_email="miked@collaborator.allenai.org", 11 | description="Code for the ARIES project", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | license="Apache 2.0", 15 | url="https://github.com/allenai/aries", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Topic :: Scientific/Engineering", 19 | "Development Status :: 2 - Pre-Alpha", 20 | "Programming Language :: Python :: 3", 21 | "Operating System :: POSIX :: Linux", 22 | ], 23 | ) 24 | 25 | -------------------------------------------------------------------------------- /tests/test_edit.py: -------------------------------------------------------------------------------- 1 | from aries.util.edit import levenshtein_distance, basic_token_align, find_overlapping_substrings 2 | 3 | def test_basic_token_align(): 4 | seq1 = ['this', 'is', 'my', 'sentence'] 5 | seq2 = ['this', 'is', 'my', 'sentence'] 6 | d, align = basic_token_align(seq1, seq2) 7 | assert d == 0 8 | assert align == [0, 1, 2, 3] 9 | 10 | 11 | seq2 = ['t', 'h', 'i', 's', 'i', 's', 'm', 'y', 's', 'e', 'n', 't', 'e', 'n', 'c', 'e'] 12 | d, align = basic_token_align(seq1, seq2) 13 | assert d == 0 14 | assert align == [0]*4 + [1]*2 + [2]*2 + [3]*8 15 | 16 | seq2 = ['thisi', 's', 'mys', 'entence'] 17 | d, align = basic_token_align(seq1, seq2) 18 | assert d == 2 19 | assert align == [0, 1, 2, 3] 20 | 21 | seq2 = ['this', '_is', '_my', '_sentence'] 22 | d, align = basic_token_align(seq1, seq2) 23 | assert d == 3 24 | assert align == [0, 1, 2, 3] 25 | 26 | seq2 = ['this', 'is', 'my'] 27 | try: 28 | d, align = basic_token_align(seq1, seq2) 29 | assert False, "Expected error since characters didn't match" 30 | except ValueError: 31 | pass 32 | 33 | seq2 = ['[this]', 'this', 'is', '[smy]', 'my', 'sentence', '[e]'] 34 | d, align = basic_token_align(seq1, seq2, seq2_ignored_ids=[0,3,6]) 35 | assert d == 0 36 | assert align == [None, 0, 1, None, 2, 3, None] 37 | 38 | def test_levenshtein(): 39 | assert levenshtein_distance('', '') == 0 40 | assert levenshtein_distance('', 'text') == 4 41 | assert levenshtein_distance('text', '') == 4 42 | assert levenshtein_distance('text', 'text') == 0 43 | assert levenshtein_distance('text', 'textb') == 1 44 | assert levenshtein_distance('textb', 'text') == 1 45 | assert levenshtein_distance('texta', 'textb') == 1 46 | assert levenshtein_distance('abba', 'acca') == 2 47 | 48 | def test_find_overlapping_substrings(): 49 | assert find_overlapping_substrings('', '', min_length=1) == [] 50 | assert find_overlapping_substrings('', 'text', min_length=1) == [] 51 | assert find_overlapping_substrings('text', '', min_length=1) == [] 52 | assert find_overlapping_substrings('text', 'text', min_length=1) == [((0, 4), (0, 4))] 53 | assert find_overlapping_substrings('text', 'text', min_length=4) == [((0, 4), (0, 4))] 54 | assert find_overlapping_substrings('text', 'text', min_length=5) == [] 55 | 56 | assert find_overlapping_substrings('atext', 'text', min_length=2) == [((1, 5), (0, 4))] 57 | assert find_overlapping_substrings('texta', 'text', min_length=2) == [((0, 4), (0, 4))] 58 | assert find_overlapping_substrings('text', 'atext', min_length=2) == [((0, 4), (1, 5))] 59 | assert find_overlapping_substrings('text', 'texta', min_length=2) == [((0, 4), (0, 4))] 60 | assert find_overlapping_substrings('btext', 'atext', min_length=2) == [((1, 5), (1, 5))] 61 | 62 | assert sorted(find_overlapping_substrings('the man and the cat', 'the cat and the man', min_length=4)) == [((0, 4), (0, 4)), ((0, 7), (12, 19)), ((7, 16), (7, 16)), ((12, 19), (0, 7))] 63 | 64 | --------------------------------------------------------------------------------