├── .gitignore ├── README.md ├── license.txt ├── ms2 ├── __init__.py ├── data │ ├── munge.py │ └── review_datasets.py ├── models │ ├── __init__.py │ ├── abstract_classifier.py │ ├── evidence_inference_models.py │ ├── pubmed_tagger.py │ ├── transformer_summarizer.py │ └── utils.py └── utils.py ├── requirements.txt ├── sample.json └── scripts ├── modeling ├── consistency_scorer.py ├── decode.py ├── f1_scorer.py ├── select_pubmed_types_of_interest.py ├── select_reviews.py ├── splits.py ├── summarizer_input_prep.py ├── table_to_text_summarizer_input.py ├── tabular_summarizer_input_prep.py └── text_to_table_input_prep.py └── run_ms2.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | .vscode 4 | /data 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MS^2: Multi-Document Summarization of Medical Studies 2 | 3 | **Note: This dataset is now part of the [MSLR2022 Shared Task](https://github.com/allenai/mslr-shared-task). We encourage you to use the data as modified for the task: available [here](https://github.com/allenai/mslr-shared-task#dataset-access). There is also a leaderboard for this task, available [here](https://leaderboard.allenai.org/mslr-ms2/submissions/public).** 4 | 5 | ### Description 6 | 7 | MS^2 is a dataset containing medical systematic reviews, their constituent studies, and a large amount of related markup. This repository contains code for attempting to produce summaries from this data. To find out more about how we created this dataset, please read our [preprint](https://arxiv.org/abs/2104.06486). 8 | 9 | This dataset is created as an annotated subset of the Semantic Scholar research corpus. MS^2 is licensed under the following license agreement: [Semantic Scholar API and Dataset License Agreement](http://s2-public-api-prod.us-west-2.elasticbeanstalk.com/corpus/legal/) 10 | 11 | All following commands are assumed to be run in the same terminal session, so variables such as `PYTHONPATH` are assumed to be carried between components. 12 | 13 | ### Set Up 14 | 15 | You might wish to create a conda env: 16 | ``` 17 | conda create -n ms2 python=3.8 18 | # or conda create -p ms2 python=3.8 19 | conda activate ms2 20 | ``` 21 | 22 | You will need to install these packages: 23 | ``` 24 | conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch 25 | pip install -r requirements.txt 26 | wget https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-encdec-base-16384.tar.gz 27 | wget https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-encdec-large-16384.tar.gz 28 | ``` 29 | 30 | ### Data & Checkpoints 31 | 32 | **We encourage you to use the cleaned up data files provided [here](https://github.com/allenai/mslr-shared-task#dataset-access)** 33 | 34 | The original data and model files associated with the paper are linked below. 35 | 36 | | File | Description | sha1 | md5 | 37 | | ----------- | ----------- | ----------- | ----------- | 38 | | [ms_data_2021-04-12.zip](https://ai2-s2-ms2.s3-us-west-2.amazonaws.com/ms_data_2021-04-12.zip) | MS^2 Dataset Files | 6090fbea | 7cf243af | 39 | | [bart_base_ckpt_7.ckpt](https://ai2-s2-ms2.s3-us-west-2.amazonaws.com/bart_base_ckpt_7.ckpt) | BART checkpoint | 9698478c | 4a0d5630 | 40 | | [longformer_base_ckpt_7.ckpt](https://ai2-s2-ms2.s3-us-west-2.amazonaws.com/longformer_base_ckpt_7.ckpt) | Longformer (LED) checkpoint | 327f9f41 | 4558b0d4 | 41 | | [evidence_inference_models.zip](https://ai2-s2-ms2.s3-us-west-2.amazonaws.com/evidence_inference_models.zip) | EI models | bc7fecdc | 2bc1bdaf | 42 | | [decoded.zip](https://ai2-s2-ms2.s3-us-west-2.amazonaws.com/decoded.zip) | | a9e023e2 | 0725f2a4 | 43 | | [decoded_with_scores.zip](https://ai2-s2-ms2.s3-us-west-2.amazonaws.com/decoded_with_scores.zip) | | 38715772 | 5808924e | 44 | 45 | All files are on AWS S3, so you can also acquire them using the AWS cli, e.g. `aws s3 cp s3://ai2-s2-ms2/ms_data_2021-04-12.zip $LOCALDIR/ms2_data/` 46 | 47 | [comment]: <> (```) 48 | [comment]: <> (sha1sum ms_data_2021-04-12.zip) 49 | [comment]: <> (6090fbea367c7c52a4c3a9418591792d8dea90e7 ms_data_2021-04-12.zip) 50 | [comment]: <> (md5sum ms_data_2021-04-12.zip) 51 | [comment]: <> (7cf243af2373ad496d948fc73d7dcf31 ms_data_2021-04-12.zip) 52 | [comment]: <> (```) 53 | 54 | ### Input Prep 55 | 56 | The first step is to convert model inputs for the summarizer. This converts the review structure into tensorized versions of inputs and outputs; either text or table inputs or outputs. The primary versions of interest are the text-to-text version and the table-to-table versions. See [sample.json](sample.json) for an example of the raw inputs. 57 | 58 | This will need to be repeated for each subset: 59 | ``` 60 | input_subset=... 61 | output_reviews_file=... 62 | MAX_LENGTH="--max_length 500" 63 | # Run from either the ms2 root or specify the path of the ms2 repository. 64 | export PYTHONPATH=./ 65 | # text-text version 66 | python scripts/modeling/summarizer_input_prep.py --input $input_subset --output $output_reviews_file --tokenizer facebook/bart-base $MAX_LENGTH 67 | # table-table version 68 | python scripts/modeling/tabular_summarizer_input_prep.py --input $input_subset --output $output_reviews_file --tokenizer facebook/bart-base $MAX_LENGTH 69 | # text-table version 70 | python scripts/modeling/text_to_table_input_prep.py --input $input_subset --output $output_reviews_file --tokenizer facebook/bart-base $MAX_LENGTH 71 | # table-text version 72 | python scripts/modeling/table_to_text_summarizer_input.py --input $input_subset --output $output_reviews_file --tokenizer facebook/bart-base $MAX_LENGTH 73 | ``` 74 | 75 | ### Modeling 76 | 77 | All model training uses the same script. Run with `--help` for all options. This requires at least one RTX8000 (users of just one will need to adjust GRAD_ACCUM appropriately). 78 | ``` 79 | training_reviews_file="result of input prep" 80 | validation_reviews_file="result of input prep" 81 | training_root="place to store model artifacts" 82 | EPOCHS=8 # more doesn't seem to do much 83 | GRAD_ACCUM=16 # if using 2x RTX8000, otherwise set for batch sizes of 32 84 | MODEL_NAME= # options are facebook/bart-base, facebook/bart-large, /path/to/longformer/base, /path/to/longformer/large 85 | python ms2/models/transformer_summarizer.py \ 86 | --train $training_reviews_file \ 87 | --val $validation_reviews_file \ 88 | --training_root $training_dir \ 89 | --epochs=$EPOCHS \ 90 | --grad_accum=$GRAD_ACCUM \ 91 | --fp16 \ 92 | --model_name $MODEL_NAME 93 | ``` 94 | 95 | ### Decoding 96 | 97 | Make predictions via: 98 | ``` 99 | INPUT=$validation_reviews_file 100 | OUTPUT="well, you want this to go somewhere?" 101 | CHECKPOINT="trained model" 102 | NUM_BEAMS=6 103 | MODEL_NAME="same as in modeling" 104 | python scripts/modeling/decode.py --input $INPUT --output $OUTPUT --checkpoint $CHECKPOINT --num_beams=$NUM_BEAMS --model_name $MODEL_NAME 105 | ``` 106 | 107 | The tabular target settings should have the extra arguments: `--min_length 2 --max_length 10` 108 | 109 | ### Scoring 110 | 111 | For tabular scoring: 112 | ``` 113 | f="$OUTPUT from above" 114 | python scripts/modeling/f1_scorer.py --input $f --output $f.scores 115 | ``` 116 | 117 | For textual scoring (requires a GPU): 118 | ``` 119 | f="$OUTPUT from above" 120 | evidence_inference_dir=... 121 | evidence_inference_classifier_params=... 122 | python scripts/modeling/consistency_scorer.py --model_outputs $f --output $f.scores --evidence_inference_dir $evidence_inference_dir --evidence_inference_classifier_params $evidence_inference_params & 123 | ``` 124 | 125 | ### Evidence Inference 126 | 127 | This section uses a modified version of the evidence inference dataset that discards the comparator. Clone evidence inference fom the [ms2 tag](https://github.com/jayded/evidence-inference/releases/tag/ms2). Once installing the requirements.txt file, the models may be trained via: 128 | ``` 129 | python evidence_inference/models/pipeline.py --params params/sampling_abstracts/bert_pipeline_8samples.json --output_dir $evidence_inference_dir 130 | ``` 131 | 132 | ### Citation 133 | 134 | If using this dataset, please cite: 135 | 136 | ``` 137 | @inproceedings{deyoung-etal-2021-ms, 138 | title = "{MS}{\^{}}2: Multi-Document Summarization of Medical Studies", 139 | author = "DeYoung, Jay and 140 | Beltagy, Iz and 141 | van Zuylen, Madeleine and 142 | Kuehl, Bailey and 143 | Wang, Lucy Lu", 144 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 145 | month = nov, 146 | year = "2021", 147 | address = "Online and Punta Cana, Dominican Republic", 148 | publisher = "Association for Computational Linguistics", 149 | url = "https://aclanthology.org/2021.emnlp-main.594", 150 | pages = "7494--7513" 151 | } 152 | ``` 153 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /ms2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/ms2/a03ab009e00c5e412b4c55f6ec4f9b49c2d8a7f6/ms2/__init__.py -------------------------------------------------------------------------------- /ms2/data/munge.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | FIELDS = r"^(ABSTRACT|AIMS|AIM|ANALYSIS|AND|APPRAISAL|AREAS|ASSESSMENT|AUTHORS ' CONCLUSIONS|AUTHORS'|AUTHORS|" + \ 4 | "BACKGROUND|BENEFITS|BMC|BMI|" + \ 5 | "CAUTION|CENTRAL|CINAHL|CINHAL|CLINICAL|COLLECTION|Conclusions and Relevance|CONCLUSIONS|CONCLUSION|CONTEXT|COPD|CRITERIA|" + \ 6 | "DATA|DATE|DESIGN|DISCUSSION|DURATION|" + \ 7 | "ELIGIBILITY|EMBASE|EVIDENCE|EXTRACTION|" + \ 8 | "FDA|FINDINGS|FUNDING|" + \ 9 | "GOALS|GOAL|GRADE|GUIDELINES|GUIDELINE|" + \ 10 | "IDENTIFICATION|IMPLICATION|IMPLICATIONS|IMPORTANCE|INCLUSION|INTERPRETATION|INTERVENTIONS|INTRODUCTION|" + \ 11 | "LANGUAGE|LILACS|LIMITATIONS|LITERATURE|" + \ 12 | "MATERIAL|MATERIALS|MEASUREMENTS|MEASURES|MEDLINE|METHODS|METHODOLOGY|METHOD|" + \ 13 | "OBJECTIVES|OBJECTIVE|OUTCOMES|OUTCOME|" + \ 14 | "PARTICIPANTS|PATIENTS|POPULATION|PRACTICE|PRIMARY|PRISMA|PROCEDURES|PROSPECT|PROSPERO|PROTOCOL|PUBMED|PURPOSE|PsycINFO|" + \ 15 | "QUALITY|QUESTIONS|QUESTION|" + \ 16 | "RATIONALE|RCTS|RCTs|RCT|REASONS|RECENT FINDINGS|RECOMMENDATIONS|RECOMMENDATION|REGISTRATION|RELEVANCE|RESEARCH|RESULTS|RESULT|REVIEW ER 'S CONCLUSIONS|REVIEWER'S|REVIEWERS'|REVIEWERS|REVIEW|" + \ 17 | r"SAMPLE|SCOPUS|SEARCH|SECONDARY|SELECTING|SELECTION|SETTING|SIGNIFICANCE|SOURCES|SOURCE|SPORTD|STRATEGY|STUDIES|STUDY|SUMMARY|SYNTHESIS|SYSTEMATIC)" 18 | 19 | fields_re = re.compile(FIELDS, flags=re.IGNORECASE) 20 | spaces_re = re.compile(r'\s(\s)+', flags=re.DOTALL) 21 | 22 | def insert_spaces(s: str) -> str: 23 | # surround any likely offending section headings or paragraph starts with spaces 24 | with_spaces = re.sub(fields_re, r' \1 ', s) 25 | # replace all double (or n) space characters with a single matching one 26 | with_single_spaces = re.sub(spaces_re, r'\1', with_spaces) 27 | return with_single_spaces 28 | -------------------------------------------------------------------------------- /ms2/data/review_datasets.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | 4 | from collections import namedtuple 5 | from dataclasses import replace 6 | from typing import List, Tuple 7 | 8 | import torch 9 | 10 | from torch.utils.data import Dataset 11 | 12 | from ms2.utils import TargetReference, TargetSummary 13 | from ms2.models.utils import pad_tensors 14 | 15 | random.seed(12345) 16 | # TODO allow reading from disk 17 | # TODO allow memory pinning 18 | class ReviewDataset(Dataset): 19 | """A Dataset of Partially generated Reviews""" 20 | #An element of the dataset is a three-tuple consisting of: 21 | #* a representation of the references 22 | #* a representation of the review question + a partial summary/conclusion state 23 | #* the target next word to generate 24 | 25 | Instance = namedtuple('Instance', ['refs', 'preface', 'target']) 26 | 27 | def __init__(self, data: List[TargetSummary], format_function): 28 | super(ReviewDataset).__init__() 29 | self.data = data 30 | random.shuffle(self.data) 31 | self.instances = list(itertools.chain.from_iterable(map(format_function, self.data))) 32 | 33 | def __len__(self): 34 | return len(self.instances) 35 | 36 | def __getitem__(self, idx): 37 | return self.instances[idx] 38 | 39 | @staticmethod 40 | def from_file(f: str, format_function) -> 'ReviewDataset': 41 | def tensorize_reference(reference: TargetReference) -> TargetReference: 42 | title_abstract = torch.LongTensor(reference.title_abstract) 43 | full_text = torch.LongTensor(reference.full_text) if reference.full_text is not None else None 44 | return replace(reference, 45 | title_abstract=title_abstract, 46 | full_text=full_text, 47 | ) 48 | def tensorize(summary: TargetSummary) -> TargetSummary: 49 | # TODO what about summaries with no preface 50 | preface = torch.LongTensor(summary.preface) if summary.preface is not None and len(summary.preface) > 0 else torch.LongTensor([0]) 51 | references = list(map(tensorize_reference, summary.references)) 52 | target_texts = list(map(torch.LongTensor, summary.target_texts)) 53 | return replace( 54 | summary, 55 | preface=preface, 56 | target_texts=target_texts, 57 | references=references, 58 | ) 59 | 60 | summaries = TargetSummary.read_summaries(f) 61 | summaries = list(map(tensorize, summaries)) 62 | return ReviewDataset(summaries, format_function) 63 | 64 | @staticmethod 65 | def to_flattened_model_inputs(instance: TargetSummary) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: 66 | # TODO this is a dumb concatenation. be smarter. add separators or something. 67 | ref_texts = torch.cat([ref.title_abstract for ref in instance.references], dim=0) 68 | preface = instance.preface 69 | ret = [] 70 | for txt in instance.target_texts: 71 | ret.append(ReviewDataset.Instance(ref_texts, preface, txt)) 72 | return ret 73 | 74 | class ToUnflattenedModelInputsFunction(object): 75 | def __init__(self, padding_value): 76 | self.padding_value = padding_value 77 | 78 | def __call__(self, instance: TargetSummary) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: 79 | ref_texts = [ref.title_abstract for ref in instance.references] 80 | ref_texts = pad_tensors(ref_texts, padding_value=self.padding_value) 81 | preface = instance.preface 82 | ret = [] 83 | for txt in instance.target_texts: 84 | ret.append((ReviewDataset.Instance(ref_texts, preface, txt))) 85 | return ret -------------------------------------------------------------------------------- /ms2/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/ms2/a03ab009e00c5e412b4c55f6ec4f9b49c2d8a7f6/ms2/models/__init__.py -------------------------------------------------------------------------------- /ms2/models/abstract_classifier.py: -------------------------------------------------------------------------------- 1 | """Models and simple code for working with a CSV to generate classifications. 2 | 3 | This code exists to train/evaluate two types of classifiers: 4 | - one on entire abstracts, as whether or not they are suitable for inclusions 5 | - one on abstract *sentences* to specify their types (e.g. background/goal/ 6 | results/varying types of conclusion statements) 7 | 8 | The classes are inferred from the input CSVs 9 | """ 10 | import argparse 11 | import logging 12 | import os 13 | import random 14 | import numpy as np 15 | 16 | from collections import Counter 17 | from typing import List 18 | 19 | import matplotlib.pyplot as plt 20 | import pandas as pd 21 | import pytorch_lightning as pl 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | from pytorch_lightning.loggers import TensorBoardLogger 27 | from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, confusion_matrix, precision_recall_curve 28 | from torch.utils.data import Dataset, DataLoader 29 | from transformers import AutoTokenizer, AutoModel 30 | from transformers.optimization import get_linear_schedule_with_warmup 31 | 32 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG) 33 | 34 | 35 | # this separate model exists because PyTorch lightning wants more things in and 36 | # out of its serialization than we need 37 | class AbstractTagger(nn.Module): 38 | def __init__(self, model_name, classes, tokenizer, model_dir): 39 | super(AbstractTagger, self).__init__() 40 | self.model = AutoModel.from_pretrained(model_name) 41 | self.dropout = nn.Dropout(self.model.config.hidden_dropout_prob) 42 | self.classifier = nn.Linear(self.model.config.hidden_size, len(classes)) 43 | self.classes = classes 44 | self.tokenizer = tokenizer 45 | self.model_dir = model_dir 46 | 47 | def forward(self, input_ids, labels=None): 48 | attention_mask = (input_ids != self.tokenizer.pad_token_id) 49 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 50 | sent_sep_embeddings = outputs[0][input_ids == self.tokenizer.sep_token_id] 51 | sent_sep_embeddings = self.dropout(sent_sep_embeddings) 52 | logits = self.classifier(sent_sep_embeddings) 53 | loss = None 54 | flat_labels = labels 55 | if labels is not None: 56 | flat_labels = labels[labels != -100] 57 | assert flat_labels.size(0) == sent_sep_embeddings.size(0) 58 | loss_fct = nn.CrossEntropyLoss() 59 | loss = loss_fct(logits, flat_labels) 60 | return loss, logits, flat_labels 61 | 62 | def decode(self, texts: List[str], max_length: int, group_in_one_abstract: bool): 63 | input_ids_list = [] 64 | input_ids = [] 65 | for text in texts: 66 | tokens = self.tokenizer.encode(text, truncation=True, max_length=max_length) 67 | if len(input_ids) + len(tokens) > max_length or not group_in_one_abstract: 68 | input_ids_list.append(input_ids) 69 | input_ids = [] 70 | if len(input_ids) > 0: 71 | tokens = tokens[1:] # drop the leading 72 | input_ids.extend(tokens) 73 | if len(input_ids) > 0: 74 | input_ids_list.append(input_ids) 75 | input_ids = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in input_ids_list], batch_first=True, padding_value=self.tokenizer.pad_token_id).long() 76 | loss, logits, flat_labels = self.forward(input_ids.cuda()) 77 | assert loss is None 78 | assert flat_labels is None 79 | pred_labels = torch.argmax(logits, dim=1) 80 | dist = torch.softmax(logits, dim=1) 81 | assert len(pred_labels) == len(texts) 82 | return pred_labels, dist 83 | 84 | 85 | class LightningAbstractTagger(pl.LightningModule): 86 | 87 | def __init__(self, args, model_name, classes, tokenizer, model_dir): 88 | super(LightningAbstractTagger, self).__init__() 89 | self.save_hyperparameters() 90 | self.args = args 91 | self.model = AbstractTagger(model_name, classes, tokenizer, model_dir) 92 | 93 | def training_step(self, batch, batch_idx): 94 | names, input_ids, labels = batch 95 | loss, logits, flat_labels = self.forward(input_ids=input_ids, labels=labels) 96 | preds = torch.argmax(logits, dim=1) 97 | report = classification_report( 98 | flat_labels.cpu().numpy(), 99 | preds.detach().cpu().numpy(), 100 | labels=list(range(len(self.model.classes))), 101 | target_names=self.model.classes, 102 | output_dict=True, 103 | zero_division=0) 104 | accuracy = accuracy_score(flat_labels.cpu().numpy(), preds.detach().cpu().numpy()) 105 | 106 | return { 107 | 'scores': torch.softmax(logits, dim=1).detach().cpu(), 108 | 'accuracy': accuracy, 109 | 'loss': loss, 110 | 'preds': preds, 111 | 'labels': flat_labels, 112 | 'log': { 113 | **report['macro avg'], 114 | 'loss': loss, 115 | 'accuracy': accuracy, 116 | } 117 | } 118 | 119 | def forward(self, input_ids, labels=None): 120 | return self.model.forward(input_ids=input_ids, labels=labels) 121 | 122 | def decode(self, texts: List[str], max_length: int): 123 | return self.model.decode(texts, max_length, group_in_one_abstract=self.args.seq) 124 | 125 | def configure_optimizers(self): 126 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-5) 127 | dataset_size = self.train_dataloader.dataloader.dataset.__len__() 128 | num_steps = dataset_size * self.args.epochs / self.args.grad_accum / self.args.batch_size 129 | scheduler = get_linear_schedule_with_warmup( 130 | optimizer, num_warmup_steps=num_steps * 0.1, num_training_steps=num_steps 131 | ) 132 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 133 | 134 | def validation_step(self, batch, batch_idx): 135 | return self.test_step(batch, batch_idx) 136 | 137 | def validation_epoch_end(self, outputs): 138 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean().cpu() 139 | preds = torch.cat([x['preds'] for x in outputs]).cpu().numpy() 140 | labels = torch.cat([x['labels'] for x in outputs]).cpu().numpy() 141 | report = classification_report( 142 | labels, 143 | preds, 144 | labels=list(range(len(self.model.classes))), 145 | target_names=self.model.classes, 146 | output_dict=True, 147 | zero_division=0) 148 | accuracy = accuracy_score(labels, preds) 149 | logging.info(f'loss: {avg_loss}, accuracy: {accuracy}, macro avg: {report["macro avg"]}') 150 | return { 151 | 'accuracy': accuracy, 152 | 'val_loss': avg_loss, 153 | 'log': report['macro avg'], 154 | } 155 | 156 | def test_step(self, batch, batch_idx): 157 | return self.training_step(batch, batch_idx) 158 | 159 | def test_epoch_end(self, outputs): 160 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean().cpu() 161 | preds = torch.cat([x['preds'] for x in outputs]).cpu().numpy() 162 | labels = torch.cat([x['labels'] for x in outputs]).cpu().numpy() 163 | if len(self.model.classes) == 2: 164 | scores = torch.cat([x['scores'] for x in outputs]).cpu().numpy() 165 | scores = scores[:, 1] 166 | auc = roc_auc_score( 167 | labels, 168 | scores, 169 | average='macro') 170 | logging.info('auc: {}'.format(auc)) 171 | precisions, recalls, thresholds = precision_recall_curve( 172 | labels, 173 | scores, 174 | pos_label=1 175 | ) 176 | logging.info('precisions: {}'.format(precisions)) 177 | logging.info('recalls: {}'.format(recalls)) 178 | logging.info('thresholds: {}'.format(thresholds)) 179 | plt.ioff() 180 | fig, ax = plt.subplots() 181 | line_kwargs = {"drawstyle": "steps-post", 'label': 'prf'} 182 | line_ = ax.plot(recalls, precisions, **line_kwargs) 183 | ax.set(xlabel="Recall", ylabel="Precision") 184 | plt.savefig(os.path.join(self.model.model_dir, 'test_prf.png')) 185 | report = classification_report( 186 | labels, 187 | preds, 188 | labels=list(range(len(self.model.classes))), 189 | target_names=self.model.classes, 190 | output_dict=True, 191 | zero_division=0) 192 | for key, value in report.items(): 193 | if type(value) != dict: 194 | logging.info(f'{key}: {value:.2f}') 195 | continue 196 | p = value['precision'] 197 | r = value['recall'] 198 | f = value['f1-score'] 199 | s = value['support'] 200 | logging.info(f'p: {p:.2f}, r: {r:.2f}, f: {f:.2f}, s: {s} - {key}') 201 | 202 | conf = confusion_matrix( 203 | labels, 204 | preds, 205 | normalize='true') 206 | logging.info('confusion matrix\n{}'.format(conf)) 207 | accuracy = accuracy_score(labels, preds) 208 | logging.info(f'loss: {avg_loss}, accuracy: {accuracy}, macro avg: {report["macro avg"]}') 209 | return { 210 | 'accuracy': accuracy, 211 | 'test_loss': avg_loss, 212 | 'log': report['macro avg'], 213 | } 214 | 215 | 216 | class AbstractsDataset(Dataset): 217 | """Reads strings and values from a CSV""" 218 | def __init__( 219 | self, 220 | csv_path: str, 221 | instance_name_field: str, 222 | instance_text_field: str, 223 | instance_cls_field: str, 224 | tokenizer, 225 | classes: List[str], 226 | max_length: int, 227 | seq: bool, 228 | limit_classes: bool): 229 | super(AbstractsDataset, self).__init__() 230 | df = pd.read_csv(csv_path).fillna(value="MISSING!") 231 | df = df[df[instance_cls_field] != "MISSING!"] 232 | if classes is None: 233 | classes = set(filter(lambda kls: ',' not in kls, df[instance_cls_field])) 234 | self.classes = list(classes) 235 | self.intern_class = dict(((x, i) for (i, x) in enumerate(classes))) 236 | AbstractsDataset.tokenizer = tokenizer 237 | self.max_length = max_length 238 | self.instances = [] 239 | for row in df[[instance_name_field, instance_text_field, instance_cls_field]].itertuples(index=False): 240 | self.instances.extend(self._elem_to_training_instance(row)) 241 | if seq: 242 | merged_instances = [] 243 | prev_s2id = None 244 | found_ids = set() 245 | merged_tokens_one_instance = [] 246 | merged_labels_one_instance = None 247 | for instance in self.instances: 248 | s2id, tokens, labels = instance 249 | if s2id != prev_s2id or len(merged_tokens_one_instance) + len(tokens) > self.max_length: 250 | if prev_s2id is not None: 251 | assert len(merged_labels_one_instance) == merged_tokens_one_instance.count(self.tokenizer.sep_token_id) 252 | merged_instances.append((s2id, merged_tokens_one_instance, merged_labels_one_instance)) 253 | 254 | merged_tokens_one_instance = [] 255 | merged_labels_one_instance = [] 256 | prev_s2id = s2id 257 | if s2id in found_ids: 258 | logging.error(f'repeated s2id: {s2id}') 259 | found_ids.add(s2id) 260 | if len(merged_tokens_one_instance) > 0: 261 | tokens = tokens[1:] # drop the leading 262 | merged_tokens_one_instance.extend(tokens) 263 | merged_labels_one_instance.extend(labels) 264 | merged_instances.append((s2id, merged_tokens_one_instance, merged_labels_one_instance)) 265 | self.instances = merged_instances 266 | 267 | if limit_classes: 268 | for instance in self.instances: 269 | labels = None 270 | new_labels = [] 271 | for label in instance[2]: 272 | if label == self.intern_class['BACKGROUND'] or label == self.intern_class['GOAL']: 273 | new_labels.append(1) 274 | elif label == self.intern_class['EFFECT']: 275 | new_labels.append(2) 276 | else: 277 | new_labels.append(0) 278 | instance[2].clear() 279 | instance[2].extend(new_labels) 280 | self.classes = ['ETC', 'BACKGROUND', 'EFFECT'] 281 | self.intern_class = {'ETC': 0, 'BACKGROUND': 1, 'EFFECT': 2} 282 | 283 | def __len__(self): 284 | return len(self.instances) 285 | 286 | def __getitem__(self, idx): 287 | return self.instances[idx] 288 | 289 | def _elem_to_training_instance(self, elem): 290 | name, text, kls = elem 291 | # in initial annotation of the abstract sentence classes, some contain a 292 | # mix of information so instead of making a decision we punted and gave 293 | # it two classes. 294 | # as that provides unclear signal, we omit these instances 295 | if ',' in kls: 296 | classes = kls.split(',') 297 | return [] 298 | else: 299 | classes = [kls] 300 | ret = [] 301 | # auto trim instances 302 | text = self.tokenizer(text, truncation=True, max_length=self.max_length)['input_ids'] 303 | for kls in classes: 304 | kls = self.intern_class[kls] 305 | ret.append((name, text, [kls])) 306 | return ret 307 | 308 | @staticmethod 309 | def collate_fn(instances): 310 | pad_token_id = AbstractsDataset.tokenizer.pad_token_id 311 | (names, texts, kls) = zip(*instances) 312 | input_ids = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in texts], batch_first=True, padding_value=pad_token_id) 313 | labels = torch.nn.utils.rnn.pad_sequence([torch.tensor(x) for x in kls], batch_first=True, padding_value=-100) 314 | 315 | return names, input_ids, labels 316 | 317 | 318 | def main(): 319 | parser = argparse.ArgumentParser(description='train and evaluate an abstract (or text) classifier from a CSV input') 320 | parser.add_argument('--model_dir', required=True, help='Training dir') 321 | parser.add_argument('--train', required=True, help='Training dataset') 322 | parser.add_argument('--test', required=True, help='Testing dataset') 323 | parser.add_argument('--name_field', required=True, help='Some field to grab an id') 324 | parser.add_argument('--text_field', required=True, help='Some field to grab text for classification') 325 | parser.add_argument('--label_field', required=True, help='Some field to grab the label') 326 | parser.add_argument('--model', default='roberta-large', help='BERT model?') 327 | parser.add_argument("--seed", type=int, default=1234, help="Seed") 328 | parser.add_argument('--batch_size', default=16, type=int, help='batch size') 329 | parser.add_argument('--grad_accum', default=1, type=int, help='gradient accumulation') 330 | parser.add_argument('--epochs', default=5, type=int, help='epochs') 331 | parser.add_argument('--save_file', required=False, help='Where to save an output file') 332 | parser.add_argument('--seq', action='store_true', help='Sequence labeling') 333 | parser.add_argument('--limit_classes', action='store_true', help='Background, effect, etc') 334 | args = parser.parse_args() 335 | 336 | random.seed(args.seed) 337 | np.random.seed(args.seed) 338 | torch.manual_seed(args.seed) 339 | if torch.cuda.is_available(): 340 | torch.cuda.manual_seed_all(args.seed) 341 | 342 | classes = None 343 | max_length = 512 344 | 345 | logger = TensorBoardLogger( 346 | save_dir=os.path.join(args.model_dir, 'logs') 347 | ) 348 | logging.info('Loading data') 349 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) 350 | 351 | def flatten_list_of_list(nested_list): 352 | return [item for sublist in nested_list for item in sublist] 353 | 354 | train_data = AbstractsDataset(args.train, args.name_field, args.text_field, args.label_field, tokenizer, classes, max_length, args.seq, args.limit_classes) 355 | classes = train_data.classes 356 | training_distribution = Counter(flatten_list_of_list([(train_data.classes[cls_id] for cls_id in inst[-1]) for inst in train_data.instances])) 357 | logging.info('Training distribution {}'.format(training_distribution)) 358 | 359 | if args.limit_classes: 360 | classes = None 361 | test_data = AbstractsDataset(args.test, args.name_field, args.text_field, args.label_field, tokenizer, classes, max_length, args.seq, args.limit_classes) 362 | if args.limit_classes: 363 | classes = train_data.classes 364 | assert train_data.classes == test_data.classes 365 | 366 | testing_distribution = Counter(flatten_list_of_list([(train_data.classes[cls_id] for cls_id in inst[-1]) for inst in test_data.instances])) 367 | logging.info('Testing distribution {}'.format(testing_distribution)) 368 | 369 | logging.info('Loading model') 370 | model = LightningAbstractTagger(args, args.model, classes, tokenizer, args.model_dir) 371 | 372 | logging.info('Loaded {} training examples, {} test examples'.format(len(train_data), len(test_data))) 373 | train_dataloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=AbstractsDataset.collate_fn) 374 | test_dataloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=AbstractsDataset.collate_fn) 375 | logging.info('Creating trainer') 376 | trainer = pl.Trainer( 377 | distributed_backend=None, 378 | replace_sampler_ddp=False, 379 | default_root_dir=args.model_dir, 380 | max_epochs=args.epochs, 381 | gpus=1, 382 | logger=logger, 383 | show_progress_bar=True, 384 | log_save_interval=1, 385 | row_log_interval=1, 386 | precision=16, amp_level='O2', 387 | accumulate_grad_batches=args.grad_accum, 388 | checkpoint_callback=None, 389 | ) 390 | # TODO resume from checkpoint! 391 | logging.info('Training!') 392 | trainer.fit(model=model, train_dataloader=train_dataloader, val_dataloaders=test_dataloader) # super cheating 393 | trainer.test(model=model, test_dataloaders=test_dataloader) 394 | if args.save_file: 395 | torch.save(model.model, args.save_file) 396 | 397 | sample_abstract = 'From three trials in more severe OAG, there is some evidence that medication was associated with more progressive visual field loss and 3 to 8 mmHg less IOP lowering than surgery. In the longer-term (two trials) the risk of failure of the randomised treatment was greater with medication than trabeculectomy (OR 3.90, 95% CI 1.60 to 9.53; hazard ratio (HR) 7.27, 95% CI 2.23 to 25.71). Medications and surgery have evolved since these trials were undertaken. Evidence from one trial suggests that, beyond five years, the risk of needing cataract surgery did not differ according to initial treatment policy (OR 0.63, 95% CI 0.15 to 2.62). Methodological weaknesses were identified in all the trials. AUTHORS CONCLUSIONS\nPrimary surgery lowers IOP more than primary medication but is associated with more eye discomfort. One trial suggests that visual field restriction at five years is not significantly different whether initial treatment is medication or trabeculectomy. There is some evidence from two small trials in more severe OAG, that initial medication (pilocarpine, now rarely used as first line medication) is associated with more glaucoma progression than surgery. Beyond five years, there is no evidence of a difference in the need for cataract surgery according to initial treatment. Further RCTs of current medical treatments compared with surgery are required, particularly for people with severe glaucoma and in black ethnic groups. Economic evaluations are required to inform treatment policy.' 398 | sentences = sample_abstract.split('. ') 399 | sentences = [s + '.' for s in sentences] # pyt the period back 400 | model = model.eval() 401 | for p in model.parameters(): 402 | p.requires_grad = False 403 | labels, _ = model.decode(sentences, max_length) 404 | for label_idx, sentence in zip(labels, sentences): 405 | label = classes[label_idx] 406 | logging.info(f'{label} - {sentence}') 407 | 408 | 409 | if __name__ == '__main__': 410 | main() 411 | -------------------------------------------------------------------------------- /ms2/models/evidence_inference_models.py: -------------------------------------------------------------------------------- 1 | """Copied from Evidence Inference 2 | 3 | """ 4 | from typing import List, Optional 5 | 6 | import json 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from transformers import RobertaForSequenceClassification, RobertaTokenizer, PretrainedConfig 12 | 13 | from ms2.models.utils import PaddedSequence 14 | 15 | def initialize_models(params: dict, unk_token=''): 16 | max_length = params['max_length'] 17 | tokenizer = RobertaTokenizer.from_pretrained(params['bert_vocab']) 18 | #tokenizer = BertTokenizer.from_pretrained(params['bert_vocab']) 19 | pad_token_id = tokenizer.pad_token_id 20 | cls_token_id = tokenizer.cls_token_id 21 | sep_token_id = tokenizer.sep_token_id 22 | evidence_classes = dict((y,x) for (x,y) in enumerate(params['evidence_classifier']['classes'])) 23 | if bool(params.get('random_init', 0)): 24 | with open(params['bert_config'], 'r') as inf: 25 | cfg = inf.read() 26 | id_config = PretrainedConfig.from_dict(json.loads(cfg), num_labels=2) 27 | cls_config = PretrainedConfig.from_dict(json.loads(cfg), num_labels=len(evidence_classes)) 28 | use_half_precision = bool(params['evidence_identifier'].get('use_half_precision', 0)) 29 | evidence_identifier = BertClassifier(bert_dir=None, 30 | pad_token_id=pad_token_id, 31 | cls_token_id=cls_token_id, 32 | sep_token_id=sep_token_id, 33 | num_labels=2, 34 | max_length=max_length, 35 | use_half_precision=use_half_precision, 36 | config=id_config) 37 | use_half_precision = bool(params['evidence_classifier'].get('use_half_precision', 0)) 38 | evidence_classifier = BertClassifier(bert_dir=None, 39 | pad_token_id=pad_token_id, 40 | cls_token_id=cls_token_id, 41 | sep_token_id=sep_token_id, 42 | num_labels=len(evidence_classes), 43 | max_length=max_length, 44 | use_half_precision=use_half_precision, 45 | config=cls_config) 46 | else: 47 | bert_dir = params['bert_dir'] 48 | use_half_precision = bool(params['evidence_identifier'].get('use_half_precision', 0)) 49 | evidence_identifier = BertClassifier(bert_dir=bert_dir, 50 | pad_token_id=pad_token_id, 51 | cls_token_id=cls_token_id, 52 | sep_token_id=sep_token_id, 53 | num_labels=2, 54 | max_length=max_length, 55 | use_half_precision=use_half_precision) 56 | use_half_precision = bool(params['evidence_classifier'].get('use_half_precision', 0)) 57 | evidence_classifier = BertClassifier(bert_dir=bert_dir, 58 | pad_token_id=pad_token_id, 59 | cls_token_id=cls_token_id, 60 | sep_token_id=sep_token_id, 61 | num_labels=len(evidence_classes), 62 | max_length=max_length, 63 | use_half_precision=use_half_precision) 64 | word_interner = tokenizer.get_vocab() 65 | de_interner = dict((x,y) for (y,x) in word_interner.items()) 66 | #de_interner = tokenizer.ids_to_tokens 67 | return evidence_identifier, evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer 68 | 69 | class BertClassifier(nn.Module): 70 | """Thin wrapper around BertForSequenceClassification""" 71 | def __init__(self, 72 | bert_dir: Optional[str], 73 | pad_token_id: int, 74 | cls_token_id: int, 75 | sep_token_id: int, 76 | num_labels: int, 77 | max_length: int=512, 78 | use_half_precision=False, 79 | config: Optional[PretrainedConfig]=None): 80 | super(BertClassifier, self).__init__() 81 | if bert_dir is None: 82 | assert config is not None 83 | assert config.num_labels == num_labels 84 | bert = RobertaForSequenceClassification(config) 85 | #bert = BertForSequenceClassification(config) 86 | else: 87 | bert = RobertaForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels) 88 | #bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels) 89 | if use_half_precision: 90 | import apex 91 | bert = bert.half() 92 | self.bert = bert 93 | self.pad_token_id = pad_token_id 94 | self.cls_token_id = cls_token_id 95 | self.sep_token_id = sep_token_id 96 | self.max_length = max_length 97 | 98 | def forward(self, 99 | query: List[torch.tensor], 100 | document_batch: List[torch.tensor]): 101 | assert len(query) == len(document_batch) 102 | # note about device management: 103 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) 104 | # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access 105 | target_device = next(self.parameters()).device 106 | cls_token = torch.tensor([self.cls_token_id])#.to(device=document_batch[0].device) 107 | sep_token = torch.tensor([self.sep_token_id])#.to(device=document_batch[0].device) 108 | input_tensors = [] 109 | position_ids = [] 110 | for q, d in zip(query, document_batch): 111 | if len(q) + len(d) + 2 > self.max_length: 112 | d = d[:(self.max_length - len(q) - 2)] 113 | input_tensors.append(torch.cat([cls_token, q, sep_token, d.to(dtype=q.dtype)])) 114 | position_ids.append(torch.arange(0, input_tensors[-1].size().numel())) 115 | #position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))) 116 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device) 117 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device) 118 | (classes,) = self.bert(bert_input.data, attention_mask=bert_input.mask(on=1.0, off=0.0, dtype=torch.float, device=target_device), position_ids=positions.data) 119 | assert torch.all(classes == classes) # for nans 120 | return classes -------------------------------------------------------------------------------- /ms2/models/pubmed_tagger.py: -------------------------------------------------------------------------------- 1 | """Multiclass tagger for Pubmed Publication Types 2 | 3 | The pubmed publication types category seems to be under-populated in pubmed 4 | articles, so we developed this tagger based on abstracts. 5 | 6 | The abstracts classifier is an unsuitable replacement for this because this is 7 | (1) multiclass and (2) trained via negative sampling. 8 | 9 | There's no real way of evaluating performance at this time, and for the moment 10 | it is not used in the main review processing pipeline. 11 | """ 12 | import argparse 13 | import glob 14 | import json 15 | import logging 16 | import operator 17 | import os 18 | import random 19 | 20 | from collections import defaultdict 21 | from typing import Dict, List, Union 22 | 23 | import _jsonnet 24 | import pytorch_lightning as pl 25 | 26 | import torch 27 | import torch.nn as nn 28 | 29 | from pytorch_lightning.loggers import TensorBoardLogger 30 | from sklearn.metrics import classification_report 31 | from torch.utils.data import Dataset, DataLoader 32 | from transformers import BertModel, BertTokenizerFast 33 | 34 | from ms2.models.utils import PaddedSequence 35 | 36 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG) 37 | 38 | random.seed(2468) 39 | 40 | class PubmedTagger(pl.LightningModule): 41 | 42 | def __init__(self, bert_name, classes): 43 | super(PubmedTagger, self).__init__() 44 | self.bert = BertModel.from_pretrained(bert_name) 45 | self.lin = nn.Linear(self.bert.config.hidden_size, len(classes)) 46 | self.classes = classes 47 | 48 | def forward(self, tokens, token_mask, labels=None, labels_mask=None): 49 | """ 50 | 51 | Args: 52 | tokens (torch.LongTensor): bs * len token ids 53 | token_mask (torch.FloatTensor): bs * len mask; elements are 1.0 for on, 0.0 for off 54 | labels (torch.LongTensor, optional): bs * num_classes. The true classes associated with each instance in the batch. Defaults to None. 55 | labels_mask (torch.LongTensor, optional): bs * num_classes. A mask for classes to ignore for each instance in the batch. Elements are 1.0 for on, 0.0 for off. Defaults to None. 56 | 57 | Returns: 58 | Tuple matching HuggingFace BERTs: (?loss, logits, hidden states, attentions) 59 | loss (torch.FloatTensor of shape (1,), optional): returned when `labels` are present 60 | logits (torch.FloatTensor of shape (bs * num_classes)): pre-sigmoid output from the multiclass layer 61 | hidden states: as in HuggingFace BERT 62 | attentions: as in HuggingFace BERT 63 | """ 64 | bert_outputs = self.bert(tokens, token_mask) 65 | last_hidden_states = bert_outputs[0] 66 | cls_tokens = bert_outputs[0][:,0] 67 | logits = self.lin(cls_tokens) 68 | outputs = (logits, bert_outputs[-2], bert_outputs[-1]) 69 | if labels is not None: 70 | if labels_mask is not None: 71 | logits *= labels_mask 72 | #logits = logits.to(dtype=torch.LongTensor) 73 | loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) 74 | outputs = (loss,) + outputs 75 | else: 76 | loss = None 77 | return outputs 78 | 79 | def training_step(self, batch, batch_idx): 80 | text: PaddedSequence 81 | labels: torch.LongTensor 82 | labels_mask: torch.FloatTensor 83 | text, labels, labels_mask = batch 84 | device = next(self.parameters()).device 85 | text = text.to(device=device) 86 | mask = text.mask(on=1, off=0, dtype=torch.float, device=device) 87 | labels = labels.to(device=device) 88 | loss, logits, _, _ = self.forward(text.data, mask, labels=labels, labels_mask=labels_mask) 89 | preds = torch.round(nn.functional.sigmoid(logits)) 90 | report = classification_report(labels.cpu().numpy(), preds.detach().cpu().numpy(), target_names=self.classes, output_dict=True, zero_division=0) 91 | acc = sum(preds.masked_select(labels_mask.to(torch.bool)) == labels.masked_select(labels_mask.to(torch.bool))) / labels_mask.sum() 92 | return { 93 | 'batch_log_metrics': loss.item(), 94 | 'loss': loss, 95 | 'acc': acc, 96 | 'labels': labels_mask.sum().item(), 97 | 'log': { 98 | 'f1': report['macro avg']['f1-score'], 99 | 'loss': loss.item(), 100 | 'train_loss': loss.item(), 101 | } 102 | } 103 | 104 | def validation_step(self, batch, batch_idx): 105 | text: PaddedSequence 106 | labels: torch.LongTensor 107 | labels_mask: torch.FloatTensor 108 | text, labels, labels_mask = batch 109 | device = next(self.parameters()).device 110 | text = text.to(device=device) 111 | mask = text.mask(on=1, off=0, dtype=torch.float, device=device) 112 | labels = labels.to(device) 113 | loss, logits, _, _ = self.forward(text.data, mask, labels=labels, labels_mask=None) 114 | preds = torch.round(nn.functional.sigmoid(logits)) 115 | report = classification_report(labels.cpu().numpy(), preds.detach().cpu().numpy(), target_names=self.classes, output_dict=True, zero_division=0) 116 | acc = sum(preds.masked_select(labels_mask.to(torch.bool)) == labels.masked_select(labels_mask.to(torch.bool))) / labels_mask.sum() 117 | return { 118 | 'val_loss': loss, 119 | 'val_acc': acc, 120 | 'labels': labels_mask.sum().item(), 121 | } 122 | 123 | def validation_epoch_end(self, outputs): 124 | return { 125 | 'val_loss': sum(output['val_loss'] for output in outputs), 126 | 'val_acc': torch.mean(torch.tensor([output['val_acc'] for output in outputs])), 127 | 'labels': torch.sum(torch.tensor([output['labels'] for output in outputs])), 128 | } 129 | 130 | def configure_optimizers(self): 131 | return torch.optim.Adam(self.parameters(), lr=1e-3) 132 | 133 | 134 | class PubmedDataset(Dataset): 135 | 136 | def __init__( 137 | self, 138 | inputs_dir: str, 139 | sample_negatives: bool, 140 | tokenizer, 141 | classes: List[str], 142 | max_length: int, 143 | truncate_extra: bool=True): 144 | self.classes = classes 145 | self.intern_class = dict(((x,i) for (i,x) in enumerate(classes))) 146 | self.tokenizer = tokenizer 147 | self.max_length = max_length 148 | positives = [] 149 | negatives = [] 150 | for f in glob.glob(os.path.join(inputs_dir, 'targets', '*')): 151 | positives.extend(read_jsonl(f)) 152 | if len(positives) > 1000000: 153 | logging.info('not loading all data due to memory constraints') 154 | break 155 | for f in glob.glob(os.path.join(inputs_dir, 'etc', '*')): 156 | negatives.extend(read_jsonl(f)) 157 | if len(negatives) > 5 * len(positives): 158 | logging.info('not loading all data due to memory constraints') 159 | break 160 | for p in positives: 161 | p['source'] = 'positive' 162 | for n in negatives: 163 | n['source'] = 'negative' 164 | if sample_negatives: 165 | negatives = random.sample(negatives, len(positives)) 166 | self.positives = positives 167 | self.negatives = negatives 168 | all_data = positives + negatives 169 | random.shuffle(all_data) 170 | # turn the data into a list of [text, labels, label_mask] 171 | self.instances = list(map(self._elem_to_training_instance, all_data)) 172 | 173 | def __len__(self): 174 | return len(self.instances) 175 | 176 | def __getitem__(self, idx): 177 | return self.instances[idx] 178 | 179 | def _elem_to_training_instance(self, elem): 180 | # auto trim instances 181 | text = torch.LongTensor(self.tokenizer.convert_tokens_to_ids(elem['text'])[:self.max_length]) 182 | publication_types = elem['publication_types'] 183 | publication_types = list(map(lambda t: self.intern_class[t], publication_types)) 184 | types = torch.zeros((len(self.classes),)) 185 | types[publication_types] = 1 186 | if elem['source'] == 'positive': 187 | mask = types 188 | elif elem['source'] == 'negative': 189 | mask = torch.ones((len(self.classes,))) 190 | else: 191 | raise ValueError('impossible state with unknown elem {}'.format(elem)) 192 | return text, types, mask 193 | 194 | @staticmethod 195 | def collate_fn(instances): 196 | texts, labels, label_masks = zip(*instances) 197 | texts = PaddedSequence.autopad(texts, batch_first=True, padding_value=0) 198 | labels = torch.stack(labels) 199 | label_masks = torch.stack(label_masks) 200 | return texts, labels, label_masks 201 | 202 | def read_jsonl(f: str) -> List[Union[Dict, List]]: 203 | with open(f, 'r') as inf: 204 | return list(map(json.loads, inf)) 205 | 206 | def main(): 207 | parser = argparse.ArgumentParser(description='train and evaluate pubmed multiclass tagger') 208 | parser.add_argument('--model_dir', required=True, help='Training dir') 209 | parser.add_argument('--config', required=True) 210 | args = parser.parse_args() 211 | config = json.loads(_jsonnet.evaluate_file(args.config)) 212 | logging.info(config) 213 | 214 | model_name = config['model_name'] 215 | train_dir = config['train'] 216 | val_dir = config['val'] 217 | batch_size = config['batch_size'] 218 | max_epochs = config['epochs'] 219 | classes = config['classes'] 220 | max_length = config['max_length'] 221 | 222 | logger = TensorBoardLogger( 223 | save_dir=os.getcwd(), 224 | version=1, 225 | name='lightning_logs' 226 | ) 227 | 228 | logging.info('Loading model') 229 | model = PubmedTagger(model_name, classes) 230 | model = model.cuda() 231 | 232 | logging.info('Loading data') 233 | tokenizer = BertTokenizerFast.from_pretrained(model_name) 234 | train_data = PubmedDataset(train_dir, True, tokenizer, classes, max_length) 235 | val_data = PubmedDataset(val_dir, False, tokenizer, classes, max_length) 236 | logging.info('Loaded {} training examples, {} validation_examples'.format(len(train_data), len(val_data))) 237 | train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=PubmedDataset.collate_fn) 238 | val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=PubmedDataset.collate_fn) 239 | logging.info('Creating trainer') 240 | trainer = pl.Trainer(default_root_dir=args.model_dir, max_epochs=max_epochs, gpus=1, logger=logger) 241 | # TODO resume from checkpoint! 242 | logging.info('Training!') 243 | trainer.fit(model, train_dataloader, val_dataloader) 244 | 245 | if __name__ == '__main__': 246 | main() 247 | -------------------------------------------------------------------------------- /ms2/models/transformer_summarizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import itertools 4 | import logging 5 | import os 6 | import re 7 | import json 8 | 9 | from typing import Iterable, List, Optional, Tuple 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | 15 | from sklearn.metrics import classification_report 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader, DistributedSampler 18 | 19 | from transformers import BartForConditionalGeneration, AutoConfig 20 | from transformers.file_utils import ModelOutput 21 | from transformers.modeling_bart import _prepare_bart_decoder_inputs 22 | from transformers.modeling_outputs import BaseModelOutput 23 | from transformers.optimization import get_linear_schedule_with_warmup 24 | 25 | from pytorch_lightning import Trainer 26 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateLogger 27 | from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel 28 | from pytorch_lightning.core.lightning import LightningModule 29 | 30 | from longformer import LongformerEncoderDecoderForConditionalGeneration, LongformerEncoderDecoderConfig 31 | from longformer.sliding_chunks import pad_to_window_size 32 | 33 | from ms2.data.review_datasets import ReviewDataset, ToUnflattenedModelInputsFunction 34 | from ms2.models.utils import rouge_scores 35 | from ms2.utils import get_tokenizer 36 | 37 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 38 | 39 | 40 | class ReferenceInteractingBartSummarizer(nn.Module): 41 | # tokenizer passed separately to allow adding for special characters 42 | def __init__(self, model_type, tokenizer, args): 43 | super().__init__() 44 | config = AutoConfig.from_pretrained(model_type) 45 | config.attention_dropout = args.attention_dropout 46 | config.gradient_checkpointing = args.grad_ckpt 47 | if model_type == 'facebook/bart-base': # bug in HF configuration 48 | config.encoder_attention_heads = 12 49 | config.decoder_attention_heads = 12 50 | self.model = BartForConditionalGeneration.from_pretrained(model_type, config=config) 51 | self.model.resize_token_embeddings(len(tokenizer.get_vocab())) 52 | self.tokenizer = tokenizer 53 | self.config = self.model.config 54 | self.args = args 55 | 56 | def _encode_multiple( 57 | self, 58 | inputs: torch.LongTensor, 59 | preamble: torch.LongTensor, 60 | output_attentions=False, 61 | output_hidden_states=False, 62 | return_dict=False 63 | ): 64 | """ 65 | inputs: Padded reference texts 66 | preamble: single beginning/prompt 67 | """ 68 | inputs = inputs[:self.args.max_num_refs] 69 | preamble = preamble.repeat(inputs.size()[0], 1) 70 | encoder_input = torch.cat([preamble, inputs], dim=1)[:, :self.config.max_position_embeddings] 71 | encoder_outputs = self.model.model.encoder( 72 | encoder_input, 73 | output_attentions=output_attentions, 74 | output_hidden_states=output_hidden_states, 75 | attention_mask=None, 76 | return_dict=False 77 | ) 78 | selection_mask = encoder_input != self.config.pad_token_id 79 | input_ids = torch.masked_select(encoder_input, selection_mask).unsqueeze(0) 80 | if len(encoder_outputs) == 1: 81 | encoded = encoder_outputs[0] 82 | encoder_states, all_attentions = None, None 83 | else: 84 | encoded, encoder_states, all_attentions = encoder_outputs 85 | encoder_states = tuple(torch.masked_select(hs, selection_mask) for hs in encoder_states) 86 | all_attentions = tuple(torch.masked_select(attn, selection_mask) for attn in all_attentions) 87 | encoded_sequences = torch.masked_select(encoded, selection_mask.unsqueeze(-1)).reshape(1, -1, encoded.size()[-1]) 88 | if torch.any(torch.isnan(encoded)): 89 | raise RuntimeError('Found nans while encoding inputs!') 90 | if return_dict: 91 | return input_ids, BaseModelOutput( 92 | last_hidden_state=encoded_sequences, 93 | hidden_states=encoder_states, 94 | attentions=all_attentions, 95 | ) 96 | else: 97 | return input_ids, (encoded_sequences, encoder_states, all_attentions) 98 | 99 | def forward(self, inputs: torch.Tensor, preambles: torch.Tensor, targets: torch.Tensor): 100 | # prep the decoder inputs and the loss labels and masks 101 | # Note that `lm_labels` is similar to `decoder_input_ids` but shifted one step to the left. 102 | # There's also a small difference in the use of and as shown in the following example 103 | # For example, 104 | # decoder_input_ids = ' some text .' 105 | # lm_labels = ' some text .' 106 | targets = targets[:, :self.args.max_length] # limit target length for memory 107 | decoder_input_ids = targets[:, :-1].contiguous() 108 | lm_labels = targets[:, 1:].clone() 109 | 110 | decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( 111 | self.model.config, 112 | None, # this would be the input ids but we very much do not want them here 113 | decoder_input_ids=decoder_input_ids, 114 | decoder_padding_mask=None, 115 | causal_mask_dtype=self.model.model.shared.weight.dtype, 116 | ) 117 | _, (encoded_sequences, _, _) = self._encode_multiple(inputs, preambles, return_dict=False) 118 | if torch.any(torch.isnan(encoded_sequences.data)): 119 | raise RuntimeError('Found nans while encoding inputs!') 120 | 121 | # decoder output! 122 | decoder_outputs = self.model.model.decoder( 123 | input_ids=decoder_input_ids, 124 | encoder_hidden_states=encoded_sequences, 125 | encoder_padding_mask=None, 126 | decoder_padding_mask=decoder_padding_mask, 127 | decoder_causal_mask=causal_mask, 128 | decoder_cached_states=None, 129 | output_attentions=False, 130 | output_hidden_states=False, 131 | use_cache=False, 132 | ) 133 | if torch.any(torch.isnan(decoder_outputs[0])): 134 | raise RuntimeError('Found nans while decoding!') 135 | 136 | lm_logits = F.linear(decoder_outputs[0], self.model.model.shared.weight, bias=self.model.final_logits_bias) 137 | if torch.any(torch.isnan(lm_logits)): 138 | raise RuntimeError('Found nans while predicting lm weights!') 139 | outputs = (lm_logits,) + decoder_outputs[1:] # Add cache, hidden states and attention if they are here 140 | loss_fct = nn.CrossEntropyLoss(reduction='none') 141 | # Note: masking will need to be re-added if bs > 1 (currently not possible!) 142 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.model.config.vocab_size), lm_labels.view(-1)) 143 | if torch.any(torch.isnan(masked_lm_loss)): 144 | raise RuntimeError('Invalid loss!') 145 | masked_lm_loss = masked_lm_loss.mean() 146 | outputs = (masked_lm_loss,) + outputs 147 | 148 | return outputs 149 | 150 | @torch.no_grad() 151 | def generate_summary( 152 | self, 153 | inputs: torch.Tensor, 154 | preambles: torch.Tensor, 155 | max_length: Optional[int] = None, 156 | min_length: Optional[int] = None, 157 | do_sample: Optional[bool] = None, 158 | early_stopping: Optional[bool] = None, 159 | num_beams: Optional[int] = None, 160 | temperature: Optional[float] = None, 161 | top_k: Optional[int] = None, 162 | top_p: Optional[float] = None, 163 | repetition_penalty: Optional[float] = None, 164 | bad_words_ids: Optional[Iterable[int]] = None, 165 | bos_token_id: Optional[int] = None, 166 | pad_token_id: Optional[int] = None, 167 | eos_token_id: Optional[int] = None, 168 | length_penalty: Optional[float] = None, 169 | no_repeat_ngram_size: Optional[int] = None, 170 | num_return_sequences: Optional[int] = None, 171 | attention_mask: Optional[torch.LongTensor] = None, 172 | decoder_start_token_id: Optional[int] = None, 173 | use_cache: Optional[bool] = None, 174 | **model_kwargs 175 | ) -> torch.LongTensor: 176 | 177 | # We cannot generate if the model does not have a LM head 178 | if self.model.get_output_embeddings() is None: 179 | raise AttributeError( 180 | "You tried to generate sequences with a model that does not have a LM Head." 181 | "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" 182 | ) 183 | 184 | max_length = max_length if max_length is not None else self.config.max_length 185 | min_length = min_length if min_length is not None else self.config.min_length 186 | do_sample = do_sample if do_sample is not None else self.config.do_sample 187 | early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping 188 | use_cache = use_cache if use_cache is not None else self.config.use_cache 189 | num_beams = num_beams if num_beams is not None else self.config.num_beams 190 | temperature = temperature if temperature is not None else self.config.temperature 191 | top_k = top_k if top_k is not None else self.config.top_k 192 | top_p = top_p if top_p is not None else self.config.top_p 193 | repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty 194 | bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id 195 | pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id 196 | eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id 197 | length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty 198 | no_repeat_ngram_size = ( 199 | no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size 200 | ) 201 | bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids 202 | num_return_sequences = ( 203 | num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences 204 | ) 205 | decoder_start_token_id = ( 206 | decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id 207 | ) 208 | 209 | # just decode one at a time for sanity's sake! 210 | batch_size = 1 211 | #if input_ids is not None: 212 | # batch_size = input_ids.shape[0] # overriden by the input batch_size 213 | #else: 214 | # batch_size = 1 215 | 216 | assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." 217 | assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." 218 | assert isinstance(do_sample, bool), "`do_sample` should be a boolean." 219 | assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." 220 | assert isinstance(use_cache, bool), "`use_cache` should be a boolean." 221 | assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." 222 | assert temperature > 0, "`temperature` should be strictly positive." 223 | assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." 224 | assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." 225 | assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." 226 | assert inputs is not None or preambles is not None 227 | assert inputs is not None or ( 228 | isinstance(bos_token_id, int) and bos_token_id >= 0 229 | ), "If inputs is not defined, `bos_token_id` should be a positive integer." 230 | assert pad_token_id is None or ( 231 | isinstance(pad_token_id, int) and (pad_token_id >= 0) 232 | ), "`pad_token_id` should be a positive integer." 233 | assert (eos_token_id is None) or ( 234 | isinstance(eos_token_id, int) and (eos_token_id >= 0) 235 | ), "`eos_token_id` should be a positive integer." 236 | assert length_penalty > 0, "`length_penalty` should be strictly positive." 237 | assert ( 238 | isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 239 | ), "`no_repeat_ngram_size` should be a positive integer." 240 | assert ( 241 | isinstance(num_return_sequences, int) and num_return_sequences > 0 242 | ), "`num_return_sequences` should be a strictly positive integer." 243 | assert ( 244 | bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) 245 | ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" 246 | 247 | # not allow to duplicate outputs when greedy decoding 248 | if do_sample is False: 249 | if num_beams == 1: 250 | # no_beam_search greedy generation conditions 251 | assert ( 252 | num_return_sequences == 1 253 | ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" 254 | 255 | else: 256 | # beam_search greedy generation conditions 257 | assert ( 258 | num_beams >= num_return_sequences 259 | ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" 260 | 261 | # create attention mask if necessary 262 | # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 263 | # if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): 264 | # attention_mask = input_ids.ne(pad_token_id).long() 265 | # elif attention_mask is None: 266 | # attention_mask = input_ids.new_ones(input_ids.shape) 267 | 268 | # set pad_token_id to eos_token_id if not set. Important that this is done after 269 | # attention_mask is created 270 | if pad_token_id is None and eos_token_id is not None: 271 | logging.warning( 272 | "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) 273 | ) 274 | pad_token_id = eos_token_id 275 | 276 | # current position and vocab size 277 | if hasattr(self.config, "vocab_size"): 278 | vocab_size = self.config.vocab_size 279 | elif ( 280 | self.config.is_encoder_decoder 281 | and hasattr(self.config, "decoder") 282 | and hasattr(self.config.decoder, "vocab_size") 283 | ): 284 | vocab_size = self.config.decoder.vocab_size 285 | 286 | # set effective batch size and effective batch multiplier according to do_sample 287 | if do_sample: 288 | effective_batch_size = batch_size * num_return_sequences 289 | effective_batch_mult = num_return_sequences 290 | else: 291 | effective_batch_size = batch_size 292 | effective_batch_mult = 1 293 | 294 | assert self.config.is_encoder_decoder 295 | if decoder_start_token_id is None: 296 | # see if BOS token can be used for decoder_start_token_id 297 | if bos_token_id is not None: 298 | decoder_start_token_id = bos_token_id 299 | elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"): 300 | decoder_start_token_id = self.config.decoder.bos_token_id 301 | else: 302 | raise ValueError( 303 | "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" 304 | ) 305 | 306 | assert hasattr(self.model, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self.model) 307 | assert callable(self.model.get_encoder), "{} should be a method".format(self.model.get_encoder) 308 | 309 | # get encoder and store encoder outputs 310 | encoder_outputs: ModelOutput 311 | input_ids, encoder_outputs = self._encode_multiple(inputs, preambles, return_dict=True) 312 | 313 | # Expand input ids if num_beams > 1 or num_return_sequences > 1 314 | if num_return_sequences > 1 or num_beams > 1: 315 | input_ids_len = input_ids.shape[-1] 316 | input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) 317 | #attention_mask = attention_mask.unsqueeze(1).expand( 318 | # batch_size, effective_batch_mult * num_beams, input_ids_len 319 | #) 320 | 321 | input_ids = input_ids.contiguous().view( 322 | effective_batch_size * num_beams, input_ids_len 323 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) 324 | #attention_mask = attention_mask.contiguous().view( 325 | # effective_batch_size * num_beams, input_ids_len 326 | #) # shape: (batch_size * num_return_sequences * num_beams, cur_len) 327 | 328 | attention_mask = None 329 | if self.config.is_encoder_decoder: 330 | # create empty decoder_input_ids 331 | input_ids = torch.full( 332 | (effective_batch_size * num_beams, 1), 333 | decoder_start_token_id, 334 | dtype=torch.long, 335 | device=next(self.parameters()).device, 336 | ) 337 | cur_len = 1 338 | 339 | assert ( 340 | batch_size == encoder_outputs.last_hidden_state.shape[0] 341 | ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " 342 | 343 | # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) 344 | expanded_batch_idxs = ( 345 | torch.arange(batch_size) 346 | .view(-1, 1) 347 | .repeat(1, num_beams * effective_batch_mult) 348 | .view(-1) 349 | .to(input_ids.device) 350 | ) 351 | 352 | # expand encoder_outputs 353 | encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 354 | 0, expanded_batch_idxs 355 | ) 356 | 357 | # save encoder_outputs in `model_kwargs` 358 | model_kwargs["encoder_outputs"] = encoder_outputs 359 | 360 | else: 361 | cur_len = input_ids.shape[-1] 362 | 363 | assert ( 364 | cur_len < max_length 365 | ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" 366 | 367 | if num_beams > 1: 368 | output = self.model._generate_beam_search( 369 | input_ids, 370 | cur_len=cur_len, 371 | max_length=max_length, 372 | min_length=min_length, 373 | do_sample=do_sample, 374 | early_stopping=early_stopping, 375 | temperature=temperature, 376 | top_k=top_k, 377 | top_p=top_p, 378 | repetition_penalty=repetition_penalty, 379 | no_repeat_ngram_size=no_repeat_ngram_size, 380 | bad_words_ids=bad_words_ids, 381 | pad_token_id=pad_token_id, 382 | eos_token_id=eos_token_id, 383 | batch_size=effective_batch_size, 384 | num_return_sequences=num_return_sequences, 385 | length_penalty=length_penalty, 386 | num_beams=num_beams, 387 | vocab_size=vocab_size, 388 | attention_mask=attention_mask, 389 | use_cache=use_cache, 390 | model_kwargs=model_kwargs, 391 | ) 392 | else: 393 | output = self.model._generate_no_beam_search( 394 | input_ids, 395 | cur_len=cur_len, 396 | max_length=max_length, 397 | min_length=min_length, 398 | do_sample=do_sample, 399 | temperature=temperature, 400 | top_k=top_k, 401 | top_p=top_p, 402 | repetition_penalty=repetition_penalty, 403 | no_repeat_ngram_size=no_repeat_ngram_size, 404 | bad_words_ids=bad_words_ids, 405 | pad_token_id=pad_token_id, 406 | eos_token_id=eos_token_id, 407 | batch_size=effective_batch_size, 408 | attention_mask=attention_mask, 409 | use_cache=use_cache, 410 | model_kwargs=model_kwargs, 411 | ) 412 | 413 | return output 414 | 415 | def collate_fn(self, batch: List[ReviewDataset.Instance]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 416 | (inst,) = batch 417 | return inst.refs, inst.preface.unsqueeze(0), inst.target.unsqueeze(0) 418 | 419 | # TODO add a simple transformer 420 | # TODO add a SciBERT or BiomedRoBERTa transformer 421 | # TODO uh oh, size problems! 422 | class LightningBartSummarizer(LightningModule): 423 | 424 | def __init__(self, args): 425 | super().__init__() 426 | self.save_hyperparameters() 427 | self.args = args 428 | self.tokenizer = get_tokenizer('facebook/bart-base') 429 | if 'long' in args.model_name: 430 | self.summarizer = SingleStreamBartSummarizer(args.model_name, self.tokenizer, args) 431 | else: 432 | self.summarizer = ReferenceInteractingBartSummarizer(args.model_name, self.tokenizer, args) 433 | 434 | self.config = self.summarizer.config 435 | self.generation_params = { 436 | 'num_beams': args.num_beams, 437 | 'length_penalty': args.length_penalty, 438 | 'no_repeat_ngram_size': args.no_repeat_ngram_size, 439 | 'early_stopping': True, 440 | 'decoder_start_token_id': self.config.bos_token_id, 441 | 'min_length': args.min_length, 442 | 'max_length': args.max_length, 443 | 'temperature': args.temperature, 444 | 'repetition_penalty': args.repetition_penalty, 445 | } 446 | self.predictions_file = None 447 | 448 | def forward(self, inputs, preambles, targets): 449 | return self.summarizer.forward(inputs=inputs, preambles=preambles, targets=targets) 450 | 451 | @torch.no_grad() 452 | def generate_summary(self, *args, **kwargs) -> torch.LongTensor: 453 | for p in self.summarizer.parameters(): 454 | p.requires_grad = False 455 | ret = self.summarizer.generate_summary(*args, **kwargs) 456 | for p in self.summarizer.parameters(): 457 | p.requires_grad = True 458 | return ret 459 | 460 | def training_step(self, batch, batch_idx): 461 | outputs = self.forward(*batch) 462 | loss = outputs[0] 463 | output = { 464 | 'loss': loss, 465 | 'train_loss': loss, 466 | 'log': { 467 | 'train_loss': loss, 468 | 'lr': loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr'], 469 | }, 470 | } 471 | return output 472 | 473 | @torch.no_grad() 474 | def validation_step(self, batch, batch_idx): 475 | for p in self.summarizer.parameters(): 476 | p.requires_grad = False 477 | outputs = self.forward(*batch) 478 | for p in self.summarizer.parameters(): 479 | p.requires_grad = True 480 | loss = outputs[0] 481 | lm_logits = outputs[1] 482 | generations = self.generate_summary( 483 | batch[0], 484 | batch[1], 485 | **self.generation_params, 486 | ) 487 | targets = batch[2] 488 | output = { 489 | 'val_loss': loss.cpu(), 490 | 'progress_bar': { 491 | 'val_loss': loss.cpu(), 492 | }, 493 | 'preambles': [x.cpu() for x in batch[1]], 494 | 'generations': [[x.cpu()] for x in generations], 495 | 'teacher_forced_generations': [torch.argmax(lm_logits, dim=1).detach().cpu()], 496 | 'targets': [[x.cpu()] for x in targets], 497 | } 498 | assert len(output['generations']) == len(output['targets']) 499 | assert list(map(len, output['generations'])) == list(map(len, output['targets'])) 500 | return output 501 | 502 | def validation_epoch_end(self, outputs): 503 | losses = np.mean([output['val_loss'] for output in outputs]) 504 | generated, teacher_forced_generations, targets = self._accumulate_generations(outputs) 505 | assert len(generated) > 0 506 | scores = rouge_scores(generated, targets, self.summarizer.tokenizer, use_aggregator=True) 507 | tf_scores = rouge_scores(teacher_forced_generations, targets, self.summarizer.tokenizer, use_aggregator=True) 508 | # TODO: if self.use_ddp: sync val_loss and rouge scores across GPUs 509 | output = { 510 | 'val_loss': losses, 511 | 'log': { 512 | 'val_loss': losses, 513 | }, 514 | } 515 | for rouge_type, prf_scores in scores.items(): 516 | output['val_' + rouge_type + '_p'] = prf_scores.mid.precision 517 | output['val_' + rouge_type + '_r'] = prf_scores.mid.recall 518 | output['val_' + rouge_type + '_f'] = prf_scores.mid.fmeasure 519 | output['log']['val_' + rouge_type + '_p'] = prf_scores.mid.precision 520 | output['log']['val_' + rouge_type + '_r'] = prf_scores.mid.recall 521 | output['log']['val_' + rouge_type + '_f'] = prf_scores.mid.fmeasure 522 | for rouge_type, prf_scores in tf_scores.items(): 523 | output['val_tf' + rouge_type + '_p'] = prf_scores.mid.precision 524 | output['val_tf' + rouge_type + '_r'] = prf_scores.mid.recall 525 | output['val_tf' + rouge_type + '_f'] = prf_scores.mid.fmeasure 526 | output['progress_bar'] = { 527 | 'val_loss': losses, 528 | 'val_rougeL_f': output['val_rougeL_f'], 529 | 'val_rougeLsum_f': output['val_rougeLsum_f'], 530 | 'val_rouge1_f': output['val_rouge1_f'], 531 | 'val_rouge2_f': output['val_rouge2_f'], 532 | #'val_rougeL_r': output['val_rougeL_r'], 533 | #'val_rouge1_r': output['val_rouge1_r'], 534 | #'val_rouge2_r': output['val_rouge2_r'], 535 | #'val_rougeL_p': output['val_rougeL_p'], 536 | #'val_rouge1_p': output['val_rouge1_p'], 537 | #'val_rouge2_p': output['val_rouge2_p'], 538 | } 539 | for k, v in output.items(): 540 | if 'rouge' in k: 541 | output['log'][k] = v 542 | if self.args.evidence_inference_eval: 543 | scores = self._evidence_inference_score(generated, targets) 544 | output['progress_bar']['macro_f1'] = scores['macro avg']['f1-score'] 545 | output['log']['macro_f1'] = scores['macro avg']['f1-score'] 546 | output['log']['macro_r'] = scores['macro avg']['recall'] 547 | output['log']['macro_p'] = scores['macro avg']['precision'] 548 | for i, j in scores.items(): 549 | if not isinstance(j, dict): 550 | output['log'][i] = j 551 | else: 552 | for k, l in j.items(): 553 | ik = i + '_' + k 554 | output['log'][ik] = l 555 | 556 | return output 557 | 558 | @torch.no_grad() 559 | def test_step(self, batch, batch_idx): 560 | if self.predictions_file is None: 561 | self.predictions_file = open(os.path.join(self.args.training_root, 'predictions.json'), 'w') 562 | output = self.validation_step(batch, batch_idx) 563 | data = { 564 | 'batch_idx': batch_idx, 565 | 'preamble': self.tokenizer.decode(batch[1].squeeze(), skip_special_tokens=False), 566 | 'generated': self.tokenizer.decode(output['generations'][0][0], skip_special_tokens=False), 567 | 'target': self.tokenizer.decode(output['targets'][0][0], skip_special_tokens=False) 568 | } 569 | json_record = json.dumps(data) 570 | self.predictions_file.write(json_record + '\n') 571 | self.predictions_file.flush() 572 | return {'test_loss': output['val_loss'], 'progress_bar': {'val_loss': output['val_loss']}, } 573 | 574 | def test_epoch_end(self, outputs): 575 | self.predictions_file.close() 576 | 577 | def _accumulate_generations(self, outputs) -> Tuple[List[List[torch.IntTensor]], List[List[torch.IntTensor]], List[List[torch.IntTensor]]]: 578 | generated = [] 579 | teacher_forced_generations = [] 580 | targets = [] 581 | for output in outputs: 582 | # both the generated and targets should be lists of lists of inttensors 583 | gen = output.get('generations', []) 584 | tf = output.get('teacher_forced_generations', []) 585 | if len(gen) > 0: 586 | generated.extend(gen) 587 | teacher_forced_generations.extend(tf) 588 | tgt = output.get('targets', []) 589 | targets.extend(tgt) 590 | assert len(tgt) == len(gen) 591 | return generated, teacher_forced_generations, targets 592 | 593 | def _evidence_inference_score(self, generations, truths): 594 | generations_labels = ['significantly decreased', 'no significant difference', 'significantly increased', 'broken generation'] 595 | generations_mapping = { 596 | 'significantly decreased': 0, 597 | 'no significant difference': 1, 598 | 'significantly increased': 2, 599 | 'broken generation': 3 600 | } 601 | generations = list(map(lambda s: s.replace('', '').replace('', ''), map(str.lower, map(self.tokenizer.decode, itertools.chain.from_iterable(generations))))) 602 | truths = list(map(lambda s: s.replace('', '').replace('', ''), map(str.lower, map(self.tokenizer.decode, itertools.chain.from_iterable(truths))))) 603 | pretty_generations = [] 604 | pretty_truths = [] 605 | for gen in generations: 606 | pretty_generations.append(generations_mapping.get(gen, 3)) 607 | for t in truths: 608 | pretty_truths.append(generations_mapping.get(t, 3)) 609 | all_labels = set(generations_labels[x] for x in (set(pretty_generations) | set(pretty_truths))) 610 | assert len(generations) == len(truths) 611 | return classification_report( 612 | pretty_truths, 613 | pretty_generations, 614 | target_names=all_labels, 615 | output_dict=True, 616 | digits=3, 617 | zero_division=0, 618 | ) 619 | 620 | def configure_optimizers(self): 621 | if self.args.debug: 622 | return torch.optim.Adam(self.summarizer.parameters(), lr=self.args.lr) # const LR 623 | optimizer = torch.optim.Adam(self.summarizer.parameters(), lr=self.args.lr, eps=self.args.adam_epsilon) 624 | num_gpus = torch.cuda.device_count() 625 | num_steps = self.args.dataset_size * self.args.epochs / num_gpus / self.args.grad_accum 626 | scheduler = get_linear_schedule_with_warmup( 627 | optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_steps 628 | ) 629 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 630 | 631 | def configure_ddp(self, model, device_ids): 632 | # Needs to override the default ddp to set `find_unused_parameters=False` for gradient checkpointing to work 633 | model = LightningDistributedDataParallel( 634 | model, 635 | device_ids=device_ids, 636 | find_unused_parameters=False 637 | ) 638 | return model 639 | 640 | def _get_loader(self, dataset, is_train): 641 | if self.trainer.use_ddp: 642 | sampler = DistributedSampler(dataset, shuffle=is_train) 643 | shuffle = False 644 | else: 645 | sampler = None 646 | shuffle = is_train 647 | loader = DataLoader( 648 | dataset, 649 | batch_size=1, 650 | shuffle=shuffle, 651 | sampler=sampler, 652 | num_workers=0, 653 | collate_fn=self.summarizer.collate_fn, 654 | drop_last=False, 655 | ) 656 | return loader 657 | 658 | def train_dataloader(self): 659 | return self._get_loader(self.train_dataset, True) 660 | 661 | def val_dataloader(self): 662 | return self._get_loader(self.val_dataset, False) 663 | 664 | def test_dataloader(self): 665 | return self._get_loader(self.val_dataset, False) 666 | 667 | def grad_norm(self, norm_type): 668 | # Override PTL `grad_norm` function to only return `total_grad_norm` instead norms of individual params 669 | # TODO: grad_norm reporting needs to take fp16 loss scale into account 670 | parameters = [p for p in self.parameters() if p.grad is not None] 671 | device = parameters[0].device 672 | total_norm = torch.zeros([], device=device if parameters else None) 673 | norm_type = float(norm_type) 674 | for p in parameters: 675 | param_norm = p.grad.data.pow(norm_type).sum() 676 | total_norm.add_(param_norm) 677 | total_norm = (total_norm ** (1.0 / norm_type)) 678 | return {'total_grad_norm': total_norm} 679 | 680 | @classmethod 681 | def add_args(cls, parser): 682 | # generation args 683 | parser.add_argument('--num_beams', default=4, type=int, help='How many beams to use when decoding during validation') 684 | parser.add_argument('--min_length', type=int, default=20, help='Minimum summary lengths') 685 | parser.add_argument('--max_length', type=int, default=512, help='Maximum target lengths') 686 | parser.add_argument('--max_num_refs', type=int, default=25, help='Maximum number of reference text') 687 | parser.add_argument('--temperature', type=float, help="Sampling temperature") 688 | parser.add_argument('--repetition_penalty', type=float, help="") 689 | parser.add_argument('--length_penalty', default=2.0, type=float, help='Length penalty when decoding during validation') 690 | parser.add_argument('--no_repeat_ngram_size', default=3, type=int, help='Size of ngram not to repeat when decoding during validation') 691 | # training args 692 | parser.add_argument('--train_rouge_eval_batches', default=100, type=int, help='How often (in batches) to generate in the training data for rouge scoring?') 693 | parser.add_argument('--grad_ckpt', action='store_true', help='Enable gradient checkpointing to save memory') 694 | # model args 695 | parser.add_argument('--attention_dropout', default=0.1, type=float, help='Length penalty when decoding during validation') 696 | parser.add_argument('--lr', default=1e-5, type=float, help='Learning rate') 697 | parser.add_argument('--adam_epsilon', default=1e-8, type=float, help='Adam epsilon') 698 | parser.add_argument('--warmup_steps', default=1000, type=int, help='Batches for warmup') 699 | parser.add_argument('--fp16', action='store_true', help='Use fp16') 700 | parser.add_argument('--model_name', default='facebook/bart-base', help='name of path of a model') 701 | parser.add_argument('--evidence_inference_eval', default=False, action='store_true', help='When producing a significance classification, ') 702 | 703 | parser.add_argument('--debug', action='store_true', help='Debugging') 704 | 705 | 706 | class SingleStreamBartSummarizer(nn.Module): 707 | 708 | def __init__(self, model_path, tokenizer, args): 709 | super().__init__() 710 | # TODO(jayd) look into DistilBART https://github.com/huggingface/transformers/blob/5543b30aa6b52da3c8f7d9e525b0edc26226d717/examples/seq2seq/ 711 | config = LongformerEncoderDecoderConfig.from_pretrained( 712 | model_path, 713 | attention_mode='sliding_chunks_no_overlap', 714 | attention_dropout=args.attention_dropout, 715 | gradient_checkpointing=args.grad_ckpt, 716 | ) 717 | # with `sliding_chunks_no_overlap`, attention size is 3 * attention_window. Use 340 if total amount of attention is 1024 (as in BART) or use 170 if you feel 170*3=510 is the average length of ref. I used 340 in other experiments and it works well and haven't tried 170 718 | attention_size = 340 719 | config.attention_window = [attention_size] * config.encoder_layers 720 | logging.info('config:' + str(config)) 721 | self.model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained( 722 | model_path, 723 | config=config, 724 | ) 725 | self.max_input_length = (config.max_encoder_position_embeddings // (2 * attention_size)) * 2 * attention_size 726 | self.model.resize_token_embeddings(len(tokenizer.get_vocab())) 727 | self.tokenizer = tokenizer 728 | self.config = self.model.config 729 | self.args = args 730 | 731 | def _prepare_input_ids(self, inputs, preambles): 732 | # TODO fix the global attention mask 733 | assert inputs.size(0) == preambles.size(0) == 1 734 | input_ids = torch.cat([preambles, inputs], dim=1) # combine preamble and refs in one long sequence 735 | input_ids = input_ids[:, :self.max_input_length] # limit to max input size 736 | attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long) 737 | attention_mask[input_ids == self.tokenizer.pad_token_id] = 0 738 | attention_mask[0, :preambles.size()[1]] = 2 # global attention on preamble 739 | input_ids, attention_mask = pad_to_window_size( # ideally, should be moved inside the LongformerModel 740 | input_ids, attention_mask, self.config.attention_window[0], self.tokenizer.pad_token_id) 741 | assert all(list(map(lambda x: x <= self.max_input_length, input_ids.size()))) 742 | return input_ids, attention_mask 743 | 744 | def forward(self, inputs, preambles, targets): 745 | input_ids, attention_mask = self._prepare_input_ids(inputs, preambles) 746 | return self.model( 747 | input_ids=input_ids, attention_mask=attention_mask, 748 | decoder_input_ids=targets[:, :-1], labels=targets[:, 1:]) 749 | 750 | def collate_fn(self, batch: List[ReviewDataset.Instance]) -> Tuple[torch.Tensor, torch.Tensor]: 751 | assert len(batch) == 1 752 | instance = batch[0] 753 | refs = instance.refs.data 754 | refs = refs.masked_select(refs != 0) # remove padding and combine in one long sequence 755 | preface = instance.preface 756 | target = instance.target 757 | return refs.unsqueeze(dim=0), preface.unsqueeze(dim=0), target.unsqueeze(dim=0) # batch of size 1 758 | 759 | @torch.no_grad() 760 | def generate_summary(self, inputs, preambles, **kwargs): 761 | input_ids, attention_mask = self._prepare_input_ids(inputs, preambles) 762 | return self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs) 763 | 764 | 765 | def get_args(): 766 | parser = argparse.ArgumentParser(description='Train a BART based summarization model!') 767 | parser.add_argument('--epochs', default=10, type=int, help='Train for how many epochs?') 768 | parser.add_argument('--train', required=True, help='jsonl serialized training files') 769 | parser.add_argument('--val', required=True) 770 | parser.add_argument('--training_root', required=True, help='Where to save checkpoints, etc.') 771 | parser.add_argument('--grad_accum', default=4, type=int, help='Number of gradient accumulation steps') 772 | parser.add_argument('--test', action='store_true', help='Skip training. Run prediction on the validation set') 773 | parser.add_argument('--test_all_ckpts', action='store_true', help='Skip training. Run prediction on the validation set over all ckpts') 774 | parser.add_argument('--dataset_size', default=14607, type=int, help='Number of instances in the training set') # TODO: read from the data 775 | 776 | LightningBartSummarizer.add_args(parser) 777 | return parser 778 | 779 | 780 | def main(): 781 | parser = get_args() 782 | args = parser.parse_args() 783 | model = LightningBartSummarizer(args) 784 | 785 | if args.test or args.test_all_ckpts: 786 | # loading the training dataset is expensive and unnecessary if we're only evaluating 787 | model.train_dataset = [] 788 | else: 789 | model.train_dataset = ReviewDataset.from_file(args.train, format_function=ToUnflattenedModelInputsFunction(model.config.pad_token_id)) 790 | model.val_dataset = ReviewDataset.from_file(args.val, format_function=ToUnflattenedModelInputsFunction(model.config.pad_token_id)) 791 | logging.info(f'Loaded training dataset of length {len(model.train_dataset)}, val: {len(model.val_dataset)}') 792 | 793 | resume_from_checkpoint = None 794 | ckpts = glob.glob(os.path.join(args.training_root, '*.ckpt')) 795 | logging.info('Found {} pre-existing checkpoints: {}'.format(len(ckpts), ckpts)) 796 | if len(ckpts) > 0: 797 | epochs = map(lambda ckpt: re.match('.*_([0-9]+)\.ckpt', ckpt).group(1), ckpts) 798 | ckpts = {int(e): c for (e, c) in zip(epochs, ckpts)} 799 | best = max(ckpts.keys()) 800 | resume_from_checkpoint = ckpts[best] 801 | logging.info('Resuming from existing checkpoint {}'.format(resume_from_checkpoint)) 802 | 803 | # single machine for the moment 804 | checkpoint_callback = ModelCheckpoint( 805 | filepath=os.path.join(args.training_root, 'model.ckpt'), 806 | verbose=True, 807 | # save_best_only=False, 808 | save_top_k=-1, 809 | monitor='val_loss', 810 | mode='min', 811 | ) 812 | trainer = Trainer( 813 | gpus=-1, 814 | num_sanity_val_steps=2, 815 | val_check_interval=0.5 if not args.debug else 1.0, 816 | check_val_every_n_epoch=1 if not args.debug else 10, 817 | distributed_backend='ddp', 818 | replace_sampler_ddp=False, 819 | num_nodes=1, 820 | default_root_dir=args.training_root, 821 | max_epochs=args.epochs, 822 | log_gpu_memory=True, 823 | show_progress_bar=True, 824 | log_save_interval=10, 825 | accumulate_grad_batches=args.grad_accum, 826 | precision=16 if args.fp16 else 32, amp_level='O2', 827 | checkpoint_callback=checkpoint_callback, 828 | callbacks=[LearningRateLogger()], 829 | resume_from_checkpoint=resume_from_checkpoint, 830 | track_grad_norm=2, 831 | ) 832 | 833 | if not (args.test or args.test_all_ckpts): 834 | trainer.fit(model) 835 | # Possibly a CUDA/pytorch bug: it seems after a recent update of the S2 servers 836 | # this code block would reliably trigger a crash of the nvidia drivers and require 837 | # a reboot to restore the server. scripts/modeling/decode.py still works. 838 | 839 | #trainer.test(model) 840 | #if args.test_all_ckpts: 841 | # ckpts = glob.glob(os.path.join(args.training_root, '*.ckpt')) 842 | # epochs = map(lambda ckpt: re.match('.*_([0-9]+)\.ckpt', ckpt).group(1), ckpts) 843 | # ckpts = {int(e): c for (e, c) in zip(epochs, ckpts)} 844 | # if len(ckpts) == 0: 845 | # raise ValueError('Cannot restore from 0 checkpoints!') 846 | # logging.info('Testing over {} pre-existing checkpoints: {}'.format(len(ckpts), ckpts)) 847 | # for epoch, ckpt in ckpts.items(): 848 | # resume_from_checkpoint = ckpt 849 | # trainer = Trainer( 850 | # gpus=-1, 851 | # distributed_backend=None, 852 | # replace_sampler_ddp=False, 853 | # num_nodes=1, 854 | # default_root_dir=args.training_root, 855 | # log_gpu_memory=True, 856 | # show_progress_bar=True, 857 | # log_save_interval=10, 858 | # precision=16 if args.fp16 else 32, amp_level='O2', 859 | # resume_from_checkpoint=resume_from_checkpoint, 860 | # track_grad_norm=2, 861 | # ) 862 | # to_evaluate = LightningBartSummarizer.load_from_checkpoint(checkpoint_path=ckpt) 863 | # to_evaluate.eval() 864 | # to_evaluate.predictions_file = open(os.path.join(args.training_root, f'epoch_{epoch}_predictions.json'), 'w') 865 | # trainer.test(to_evaluate, test_dataloaders=model.val_dataloader()) 866 | 867 | 868 | if __name__ == '__main__': 869 | main() 870 | -------------------------------------------------------------------------------- /ms2/models/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import torch 5 | 6 | from rouge_score import rouge_scorer 7 | from rouge_score import scoring 8 | 9 | """Citation 10 | @inproceedings{lin-2004-rouge, 11 | title = "{ROUGE}: A Package for Automatic Evaluation of Summaries", 12 | author = "Lin, Chin-Yew", 13 | booktitle = "Text Summarization Branches Out", 14 | month = jul, 15 | year = "2004", 16 | address = "Barcelona, Spain", 17 | publisher = "Association for Computational Linguistics", 18 | url = "https://www.aclweb.org/anthology/W04-1013", 19 | pages = "74--81", 20 | } 21 | """ 22 | 23 | from torch import nn 24 | from torch.nn.utils.rnn import pad_sequence, PackedSequence, pack_padded_sequence, pad_packed_sequence 25 | from torch.nn.utils.rnn import pad_sequence 26 | 27 | 28 | def rouge_scores(preds: List[List[torch.Tensor]], targets: List[List[torch.Tensor]], tokenizer, use_stemmer=False, use_aggregator=False): 29 | # largely copied from https://github.com/huggingface/nlp/blob/master/metrics/rouge/rouge.py#L84 30 | rouge_types = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum'] 31 | scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer) 32 | refs, hyps = [], [] 33 | for p, t in zip(preds, targets): 34 | assert len(p) == len(t) 35 | refs.extend(p) 36 | hyps.extend(t) 37 | 38 | if use_aggregator: 39 | aggregator = scoring.BootstrapAggregator() 40 | scores = None 41 | else: 42 | aggregator = None 43 | scores = [] 44 | 45 | for ref, pred in zip(refs, hyps): 46 | if isinstance(ref, torch.Tensor): 47 | ref = tokenizer.decode(ref).lower() 48 | if isinstance(pred, torch.Tensor): 49 | pred = tokenizer.decode(pred).lower() 50 | score = scorer.score(ref, pred) 51 | if use_aggregator: 52 | aggregator.add_scores(score) 53 | else: 54 | scores.append(score) 55 | 56 | if use_aggregator: 57 | result = aggregator.aggregate() 58 | else: 59 | result = {} 60 | for key in scores[0]: 61 | result[key] = list(score[key] for score in scores) 62 | 63 | return result 64 | 65 | def pad_tensors(data: List[torch.Tensor], padding_value) -> torch.Tensor: 66 | data_ = [] 67 | for d in data: 68 | if len(d.size()) == 0: 69 | d = d.unsqueeze(0) 70 | data_.append(d) 71 | padded = pad_sequence(data_, batch_first=True, padding_value=padding_value) 72 | return padded 73 | 74 | # TODO(jayd) memory pinning? 75 | # borrowed from the EraserBenchmark 76 | @dataclass(eq=True, frozen=True) 77 | class PaddedSequence: 78 | """A utility class for padding variable length sequences mean for RNN input 79 | This class is in the style of PackedSequence from the PyTorch RNN Utils, 80 | but is somewhat more manual in approach. It provides the ability to generate masks 81 | for outputs of the same input dimensions. 82 | The constructor should never be called directly and should only be called via 83 | the autopad classmethod. 84 | We'd love to delete this, but we pad_sequence, pack_padded_sequence, and 85 | pad_packed_sequence all require shuffling around tuples of information, and some 86 | convenience methods using these are nice to have. 87 | """ 88 | 89 | data: torch.Tensor 90 | batch_sizes: torch.Tensor 91 | batch_first: bool = False 92 | 93 | @classmethod 94 | def autopad(cls, data, batch_first: bool = False, padding_value=0, device=None) -> 'PaddedSequence': 95 | # handle tensors of size 0 (single item) 96 | data_ = [] 97 | for d in data: 98 | if len(d.size()) == 0: 99 | d = d.unsqueeze(0) 100 | data_.append(d) 101 | padded = pad_sequence(data_, batch_first=batch_first, padding_value=padding_value) 102 | if batch_first: 103 | batch_lengths = torch.LongTensor([x.size()[0] for x in data_]) 104 | if any([x == 0 for x in batch_lengths]): 105 | raise ValueError( 106 | "Found a 0 length batch element, this can't possibly be right: {}".format(batch_lengths)) 107 | else: 108 | # TODO actually test this codepath 109 | batch_lengths = torch.LongTensor([len(x) for x in data]) 110 | return PaddedSequence(padded, batch_lengths, batch_first).to(device=device) 111 | 112 | def pack_other(self, data: torch.Tensor): 113 | return pack_padded_sequence(data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False) 114 | 115 | @classmethod 116 | def from_packed_sequence(cls, ps: PackedSequence, batch_first: bool, padding_value=0) -> 'PaddedSequence': 117 | padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value) 118 | return PaddedSequence(padded, batch_sizes, batch_first) 119 | 120 | def cuda(self) -> 'PaddedSequence': 121 | return PaddedSequence(self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first) 122 | 123 | def to(self, device=None, dtype=None, copy=False, non_blocking=False) -> 'PaddedSequence': 124 | # TODO make to() support all of the torch.Tensor to() variants 125 | return PaddedSequence( 126 | self.data.to(device=device, dtype=dtype, copy=copy, non_blocking=non_blocking), 127 | self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking), 128 | batch_first=self.batch_first) 129 | 130 | def mask(self, on=int(0), off=int(0), device='cpu', size=None, dtype=None) -> torch.Tensor: 131 | if size is None: 132 | size = self.data.size() 133 | out_tensor = torch.zeros(*size, dtype=dtype) 134 | # TODO this can be done more efficiently 135 | out_tensor.fill_(off) 136 | # note to self: these are probably less efficient than explicilty populating the off values instead of the on values. 137 | if self.batch_first: 138 | for i, bl in enumerate(self.batch_sizes): 139 | out_tensor[i, :bl] = on 140 | else: 141 | for i, bl in enumerate(self.batch_sizes): 142 | out_tensor[:bl, i] = on 143 | return out_tensor.to(device) 144 | 145 | def unpad(self, other: torch.Tensor=None) -> List[torch.Tensor]: 146 | if other is None: 147 | other = self 148 | if isinstance(other, PaddedSequence): 149 | other = other.data 150 | out = [] 151 | for o, bl in zip(other, self.batch_sizes): 152 | out.append(o[:bl]) 153 | return out 154 | 155 | def flip(self) -> 'PaddedSequence': 156 | return PaddedSequence(self.data.transpose(0, 1), not self.batch_first, self.padding_value) 157 | 158 | def mangle_bart_with_longformer(bart_model, extend_encoder=True, extend_decoder=True): 159 | # TODO fix with https://github.com/allenai/longformer/blob/encoderdecoder/scripts/convert_bart_to_longformerencoderdecoder.py 160 | def replace_layers(model, config): 161 | for i, layer in enumerate(model.layers): 162 | self_attn = LongformerSelfAttention(config, layer_id=i) 163 | self_attn.query = layer.self_attn.q_proj 164 | self_attn.key = layer.self_attn.k_proj 165 | self_attn.value = layer.self_attn.v_proj 166 | # TODO should these parameters be tied? they aren't in the longformer source 167 | self_attn.query_global = layer.self_attn.q_proj 168 | self_attn.key_global = layer.self_attn.k_proj 169 | self_attn.value_global = layer.self_attn.v_proj 170 | # TODO longformer has no out_proj which seems odd 171 | layer.self_attn.self = self_attn 172 | bart_model.config.max_position_embeddings = 16 * bart_model.config.max_position_embeddings 173 | # this is a hack. 174 | # it might even get fixed eventually. 175 | if extend_decoder: 176 | new_decoder_embeds = torch.cat( 177 | [bart_model.model.decoder.embed_positions.weight] + 178 | 15 * [bart_model.model.decoder.embed_positions.weight[2:]], 179 | dim=0).clone() 180 | # TODO experiment with adding additional tokens, e.g. separating documents (or maybe an embedding per document?) + separating the prompt + ?? 181 | bart_model.model.decoder.embed_positions = LearnedPositionalEmbedding( 182 | num_embeddings=new_decoder_embeds.size()[0] - 2, 183 | embedding_dim=new_decoder_embeds.size()[1], 184 | padding_idx=bart_model.model.decoder.embed_positions.padding_idx, 185 | offset=0, 186 | ) 187 | bart_model.model.decoder.embed_positions.weight = nn.Parameter(new_decoder_embeds, requires_grad=True) 188 | 189 | if extend_encoder: 190 | new_encoder_embeds = torch.cat( 191 | [bart_model.model.encoder.embed_positions.weight] + 192 | 15 * [bart_model.model.encoder.embed_positions.weight[2:]], 193 | dim=0).clone() 194 | bart_model.model.encoder.embed_positions = LearnedPositionalEmbedding( 195 | num_embeddings=new_encoder_embeds.size()[0] - 2, 196 | embedding_dim=new_encoder_embeds.size()[1], 197 | padding_idx=bart_model.model.encoder.embed_positions.padding_idx, 198 | offset=0, 199 | ) 200 | bart_model.model.encoder.embed_positions.weight = nn.Parameter(new_encoder_embeds, requires_grad=True) 201 | # TODO(jayd,iz) are any of these parameters even in the right ballpark? 202 | bart_model.config.attention_probs_dropout_prob = bart_model.config.dropout 203 | bart_model.config.attention_window = [64] * len(bart_model.model.encoder.layers) 204 | bart_model.config.attention_dilation = [1] * len(bart_model.model.encoder.layers) 205 | bart_model.config.attention_mode = 'tvm' 206 | # bart_model.config.attention_mode = 'sliding_chunks' # TODO is this the right choice? 207 | bart_model.config.autoregressive = False 208 | if extend_encoder: 209 | replace_layers(bart_model.model.encoder, bart_model.config) 210 | if extend_decoder: 211 | replace_layers(bart_model.model.decoder, bart_model.config) 212 | -------------------------------------------------------------------------------- /ms2/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | 4 | from collections import defaultdict 5 | from dataclasses import dataclass, asdict, is_dataclass 6 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 7 | 8 | import torch 9 | 10 | from dataclasses_json import dataclass_json 11 | from transformers import AutoTokenizer 12 | 13 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG) 14 | NUM_PROCS = 8 15 | 16 | START_POPULATION='' 17 | END_POPULATION='' 18 | START_INTERVENTION='' 19 | END_INTERVENTION='' 20 | START_OUTCOME='' 21 | END_OUTCOME='' 22 | START_BACKGROUND = '' 23 | END_BACKGROUND = '' 24 | START_REFERENCE = '' 25 | END_REFERENCE = '' 26 | START_EVIDENCE = '' 27 | END_EVIDENCE = '' 28 | SEP_TOKEN = '' 29 | EXTRA_TOKENS = [ 30 | START_BACKGROUND, 31 | END_BACKGROUND, 32 | START_REFERENCE, 33 | END_REFERENCE, 34 | SEP_TOKEN, 35 | START_POPULATION, 36 | END_POPULATION, 37 | START_INTERVENTION, 38 | END_INTERVENTION, 39 | START_OUTCOME, 40 | END_OUTCOME, 41 | START_EVIDENCE, 42 | END_EVIDENCE, 43 | ] 44 | 45 | def get_tokenizer(tokenizer_type: str): 46 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_type, additional_special_tokens=EXTRA_TOKENS) 47 | return tokenizer 48 | 49 | @dataclass_json 50 | @dataclass 51 | class Significance: 52 | intervention: str 53 | outcome: str 54 | classification: Dict[str, float] 55 | evidence_sentence: Optional[str]=None 56 | evidence_sentence_score: Optional[float]=None 57 | 58 | @dataclass_json 59 | @dataclass 60 | class TargetReference: 61 | title_abstract: Union[torch.Tensor, str] 62 | full_text: Optional[Union[torch.Tensor, List[int], str]] 63 | s2id: Optional[str] 64 | s2hash: Optional[str] 65 | 66 | @dataclass_json 67 | @dataclass 68 | class TargetSummary: 69 | """Target input/output for a summarization model 70 | 71 | Preface: 72 | """ 73 | preface: Optional[Union[str, List[int], torch.Tensor]] 74 | target_texts: Union[List[str], List[int], List[torch.Tensor]] 75 | review_id: str 76 | references: List[TargetReference] 77 | s2id: Optional[str] 78 | s2hash: Optional[str] 79 | 80 | @staticmethod 81 | def read_summaries(f: str) -> List['TargetSummary']: 82 | with open(f, 'r') as inf: 83 | summaries = map(TargetSummary.from_json, inf) 84 | summaries = list(summaries) 85 | return summaries 86 | 87 | @dataclass_json 88 | @dataclass 89 | class Reference: 90 | """Any kind of scientific paper 91 | 92 | Unfortunately, none of the fields are always present in the data, so we 93 | will be left guessing what the best way to find the actual text of any 94 | given reference. 95 | """ 96 | identifiers: List[Dict[str, str]] 97 | metadata: Dict[str, str] 98 | title: Optional[str]=None 99 | doi: Optional[str]=None 100 | pmid: Optional[str]=None 101 | # these must be populated later 102 | s2id: Optional[str]=None 103 | s2hash: Optional[str]=None 104 | abstract: Optional[str]=None 105 | content: Optional[str]=None 106 | publication_types: Optional[List[str]]=None 107 | significances: Optional[List[Significance]]=None 108 | interventions: Optional[List[str]]=None 109 | outcomes: Optional[List[str]]=None 110 | populations: Optional[List[str]]=None 111 | in_doc_significances: Optional[List[Significance]]=None 112 | 113 | @dataclass_json 114 | @dataclass 115 | class Study: 116 | """Any scientific study 117 | 118 | Typically there is a one-to-one mapping between references and studies, 119 | although not always. Some studies appear to be published multiple ways, at 120 | multiple times (e.g. a full paper and a later conference abstract). 121 | 122 | In this dataset, a study should always contain exactly one reference element. 123 | """ 124 | references: List[Reference] 125 | identifiers: List[Dict[str, str]] 126 | metadata: Dict[str, str] 127 | pmid: Optional[str]=None 128 | doi: Optional[str] = None 129 | 130 | @dataclass_json 131 | @dataclass 132 | class Review: 133 | """Systematic Review representation 134 | 135 | All reviews should have a structured abstract, a document name, and title. 136 | One would expect to be able to always have access to included studies, and 137 | one would be wrong. At least one review exists in the data where no studies 138 | could be found, and thus no review was performed 139 | """ 140 | docid: str 141 | title: str 142 | authors: str 143 | abstract: str 144 | # the final field is an optional distribution over model labels 145 | structured_abstract: List[Tuple[str, str, Optional[Dict[str, float]]]] 146 | summary: Optional[str] 147 | structured_summary: List[Tuple[str, str, Optional[Dict[str, float]]]] 148 | included_studies: List[Study] 149 | ongoing_studies: List[Study] 150 | awaiting_studies: List[Study] 151 | excluded_studies: List[Study] 152 | general_references: List[Reference] 153 | unattributed_references: Optional[List[Reference]]=None 154 | content: Optional[str]=None 155 | doi: Optional[str]=None 156 | content: Optional[str]=None # separate from the abstract 157 | s2id: Optional[str]=None 158 | s2hash: Optional[str]=None 159 | pmid: Optional[str]=None 160 | interventions: Optional[List[str]]=None 161 | outcomes: Optional[List[str]]=None 162 | populations: Optional[List[str]]=None 163 | significances: Optional[List[Significance]]=None 164 | 165 | def __post_init__(self): 166 | assert self.docid is not None 167 | assert self.title is not None 168 | assert self.abstract is not None and len(self.abstract.strip()) > 0 169 | #assert len(self.structured_abstract) > 0 170 | 171 | def extract_references(self) -> List[Reference]: 172 | studies = itertools.chain.from_iterable([ 173 | self.included_studies, 174 | self.ongoing_studies, 175 | self.excluded_studies, 176 | self.awaiting_studies, 177 | ]) 178 | study_refs = itertools.chain.from_iterable(map(lambda x: x.references, studies)) 179 | all_refs = list(study_refs) + self.general_references + self.unattributed_references 180 | return all_refs 181 | 182 | @staticmethod 183 | def read_reviews(f) -> List['Review']: 184 | with open(f, 'r') as inf: 185 | inf = filter(lambda line: len(line.strip()) > 0, inf) 186 | reviews = map(Review.from_json, inf) 187 | reviews = list(reviews) 188 | return reviews -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | biopython 2 | boto3 3 | dataclasses_json 4 | jsonnet 5 | pandas 6 | git+https://github.com/allenai/longformer.git@55c1f5f75d6f33b7ee23af0caf0392a52f8d4f4c 7 | rouge_score 8 | scikit-learn 9 | spacy 10 | torch 11 | tqdm 12 | git+http://github.com/ibeltagy/transformers.git@longformer_encoder_decoder#egg=transformers 13 | https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.5/en_core_sci_sm-0.2.5.tar.gz 14 | #git+https://github.com/ibeltagy/pytorch-lightning.git@v0.8.5_fixes 15 | -------------------------------------------------------------------------------- /scripts/modeling/consistency_scorer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import math 5 | import os 6 | import re 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from sklearn.metrics import classification_report, confusion_matrix 12 | from scipy.spatial.distance import jensenshannon 13 | from ms2.models.evidence_inference_models import initialize_models 14 | from ms2.models.utils import rouge_scores 15 | from ms2.utils import ( 16 | get_tokenizer, 17 | EXTRA_TOKENS, 18 | SEP_TOKEN, 19 | START_INTERVENTION, 20 | END_INTERVENTION, 21 | START_OUTCOME, 22 | END_OUTCOME, 23 | ) 24 | 25 | TOKENS_SPLIT = '|'.join(EXTRA_TOKENS) 26 | 27 | INTERVENTION_RE = START_INTERVENTION + '(.*?)' + END_INTERVENTION 28 | OUTCOME_RE = START_OUTCOME + '(.*?)' + END_OUTCOME 29 | 30 | def ios(preamble): 31 | # we know that the reference abstract is already space tokenized 32 | start_stop_words = EXTRA_TOKENS + ['', ''] 33 | def clean_str(s): 34 | for w in start_stop_words: 35 | s = s.replace(w, '') 36 | return s 37 | outcomes = list(map(clean_str, re.findall(OUTCOME_RE, preamble))) 38 | interventions = list(map(clean_str, re.findall(INTERVENTION_RE, preamble))) 39 | return interventions, outcomes 40 | 41 | def evidence_inference_score(model, evidence_inference_tokenizer, summary, preamble, use_ios): 42 | ret = [] 43 | if use_ios: 44 | interventions, outcomes = ios(preamble) 45 | summary = evidence_inference_tokenizer(summary, return_tensors='pt')['input_ids'] 46 | for i, o in itertools.product(interventions, outcomes): 47 | preamble = i + ' ' + evidence_inference_tokenizer.sep_token + ' ' + o 48 | ico = evidence_inference_tokenizer(preamble, return_tensors='pt')['input_ids'] 49 | classes = model(ico, summary) 50 | classes = torch.softmax(classes, dim=-1).detach().cpu().squeeze().tolist() 51 | significance_distribution = dict(zip(["significantly decreased", "no significant difference", "significantly increased"], classes)) 52 | ret.append(significance_distribution) 53 | else: 54 | preamble = "" 55 | ico = evidence_inference_tokenizer(preamble, return_tensors='pt')['input_ids'] 56 | summary = evidence_inference_tokenizer(summary, return_tensors='pt')['input_ids'] 57 | classes = model(ico, summary) 58 | classes = torch.softmax(classes, dim=-1).detach().cpu().squeeze().tolist() 59 | significance_distribution = dict(zip(["significantly decreased", "no significant difference", "significantly increased"], classes)) 60 | ret.append(significance_distribution) 61 | 62 | return ret 63 | 64 | 65 | def jsd(m1, m2): 66 | keys = list(set(m1.keys()) | set(m2.keys())) 67 | m1 = [m1.get(k, 0) for k in keys] 68 | m2 = [m2.get(k, 0) for k in keys] 69 | return jensenshannon(m1, m2, base=2) 70 | 71 | def entailment_score(model, evidence_inference_tokenizer, generated, target, preamble, use_ios): 72 | generated_distributions = evidence_inference_score(model, evidence_inference_tokenizer, generated, preamble, use_ios) 73 | summary_distributions = evidence_inference_score(model, evidence_inference_tokenizer, target, preamble, use_ios) 74 | jsds = [] 75 | for generated_distribution, summary_distribution in zip(generated_distributions, summary_distributions): 76 | jsds.append(jsd(generated_distribution, summary_distribution)) 77 | if len(jsds) == 0: 78 | return None 79 | return np.mean(jsds) 80 | 81 | def f1_score(model, evidence_inference_tokenizer, generateds, targets, preambles, use_ios): 82 | summary_preds = [] 83 | generated_preds = [] 84 | in_doc_classifications = [] 85 | labels = ["significantly decreased", "no significant difference", "significantly increased"] 86 | mapping = {x:i for (i,x) in enumerate(labels)} 87 | for generated, target, preamble in zip(generateds, targets, preambles): 88 | generated_distributions = evidence_inference_score(model, evidence_inference_tokenizer, generated, preamble, use_ios) 89 | summary_distributions = evidence_inference_score(model, evidence_inference_tokenizer, target, preamble, use_ios) 90 | in_doc_generated = [] 91 | in_doc_target = [] 92 | for generated_distribution, summary_distribution in zip(generated_distributions, summary_distributions): 93 | generated_targets = sorted(generated_distribution.items(), key=lambda x: x[1]) 94 | summary_targets = sorted(summary_distribution.items(), key=lambda x: x[1]) 95 | best_summary_target = summary_targets[-1][0] 96 | in_doc_target.append(best_summary_target) 97 | summary_preds.append(best_summary_target) 98 | generated_target = generated_targets[-1][0] 99 | generated_preds.append(generated_target) 100 | in_doc_generated.append(generated_target) 101 | if len(in_doc_generated) == 0: 102 | continue 103 | in_doc_classifications.append( 104 | classification_report( 105 | np.array([mapping[x] for x in in_doc_target]), 106 | np.array([mapping[x] for x in in_doc_generated]), 107 | target_names=labels, 108 | labels=list(range(len(labels))), 109 | output_dict=True, 110 | digits=4 111 | ) 112 | ) 113 | res = classification_report(np.array([mapping[x] for x in summary_preds]), np.array([mapping[x] for x in generated_preds]), target_names=labels, output_dict=True, digits=4) 114 | return res 115 | 116 | 117 | def jsd_uniform(model, evidence_inference_tokenizer, target, preamble, use_ios): 118 | summary_distributions = evidence_inference_score(model, evidence_inference_tokenizer, target, preamble, use_ios) 119 | jsds = [] 120 | # baseline distributions 121 | generated_distribution = { 122 | 'significantly decreased': .134, 123 | 'no significant difference': .570, 124 | 'significantly increased': .296, 125 | } 126 | for summary_distribution in summary_distributions: 127 | jsds.append(jsd(generated_distribution, summary_distribution)) 128 | if len(jsds) == 0: 129 | return None 130 | return np.mean(jsds) 131 | 132 | def entailment_scores(model, evidence_inference_tokenizer, generateds, targets, preambles, use_ios): 133 | f1_scores = f1_score(model, evidence_inference_tokenizer, generateds, targets, preambles, use_ios) 134 | scores = list(map(lambda x: entailment_score(model, evidence_inference_tokenizer, *x, use_ios), zip(generateds, targets, preambles))) 135 | scores = list(filter(lambda x: x is not None, scores)) 136 | uniform_scores = list(map(lambda x: jsd_uniform(model, evidence_inference_tokenizer, *x, use_ios), zip(targets, preambles))) 137 | uniform_scores = list(filter(lambda x: x is not None, uniform_scores)) 138 | assert len(scores) > 0 139 | avg = np.mean(scores) 140 | s = np.std(scores) 141 | uniform_score = np.mean(uniform_scores) 142 | return { 143 | 'average': avg, 144 | 'std': s, 145 | 'uniform_preds': uniform_score, 146 | 'f1_score': f1_scores, 147 | } 148 | 149 | def clean(s): 150 | for t in EXTRA_TOKENS + ['', '']: 151 | s = s.replace(t, '') 152 | s = s.replace(' ', ' ') 153 | return s 154 | 155 | def main(): 156 | parser = argparse.ArgumentParser(description='Score model outputs based on a consistency score between target and prediction') 157 | parser.add_argument('--model_outputs', required=True, help='json file of model outputs with "target", "generated", and "preamble" fields') 158 | parser.add_argument('--evidence_inference_dir', required=True, help='Directory containing trained evidence inference models') 159 | parser.add_argument('--evidence_inference_classifier_params', required=True, help='Params to initialize evidence inference models') 160 | parser.add_argument('--unconditioned_classifier', action='store_true', help='Use an unconditioned evidence inference classifier') 161 | parser.add_argument('--output', required=True, help='Output file for scores') 162 | args = parser.parse_args() 163 | 164 | with open(args.model_outputs, 'r') as inf: 165 | outputs = [json.loads(line) for line in inf] 166 | generated = [x['generated'] for x in outputs] 167 | targets = [x['target'] for x in outputs] 168 | preambles = [x['preamble'] for x in outputs] 169 | generated = list(map(clean, generated)) 170 | targets = list(map(clean, targets)) 171 | 172 | # rouge scoring 173 | tokenizer = get_tokenizer('facebook/bart-base') 174 | rouge_results = rouge_scores([[x] for x in generated], [[x] for x in targets], tokenizer, use_aggregator=True) 175 | print('Rouge') 176 | print(rouge_results) 177 | 178 | # evidence inference scoring 179 | with open(args.evidence_inference_classifier_params, 'r') as inf: 180 | params = json.loads(inf.read()) 181 | _, evidence_inference_classifier, _, _, _, evidence_inference_tokenizer = initialize_models(params) 182 | if args.unconditioned_classifier: 183 | classifier_file = os.path.join(args.evidence_inference_dir, 'unconditioned_evidence_classifier', 'unconditioned_evidence_classifier.pt') 184 | else: 185 | classifier_file = os.path.join(args.evidence_inference_dir, 'evidence_classifier', 'evidence_classifier.pt') 186 | #evidence_inference_classifier.load_state_dict(torch.load(classifier_file)) 187 | # pooler parameters are added by default in an older transformers, so we have to ignore that those are uninitialized. 188 | evidence_inference_classifier.load_state_dict(torch.load(classifier_file), strict=False) 189 | evidence_inference_classifier.cuda() 190 | 191 | entailment_results = entailment_scores(evidence_inference_classifier, evidence_inference_tokenizer, generated, targets, preambles, use_ios=not args.unconditioned_classifier) 192 | print('entailment') 193 | print(entailment_results) 194 | 195 | assert args.output != args.model_outputs 196 | with open(args.output, 'w') as of: 197 | of.write('rouge\n') 198 | of.write(json.dumps(rouge_results)) 199 | of.write('\n\n') 200 | of.write('entailment\n') 201 | of.write(json.dumps(entailment_results)) 202 | of.write('\n') 203 | 204 | 205 | if __name__ == '__main__': 206 | main() -------------------------------------------------------------------------------- /scripts/modeling/decode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import torch 5 | 6 | #from ms2.data.review_datasets import ReviewDataset 7 | from ms2.data.review_datasets import ReviewDataset, ToUnflattenedModelInputsFunction 8 | from ms2.models.transformer_summarizer import LightningBartSummarizer 9 | 10 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description='Summarize a document using a saved checkpoint') 14 | parser.add_argument('--input', required=True, help='Dataset for decoding') 15 | parser.add_argument('--output', required=True, help='Output file') 16 | parser.add_argument('--checkpoint', required=True, help='Saved checkpoint') 17 | LightningBartSummarizer.add_args(parser) 18 | args = parser.parse_args() 19 | model = LightningBartSummarizer(args) 20 | model.load_state_dict(torch.load(args.checkpoint)['state_dict']) 21 | model.cuda() 22 | collate_fn = model.summarizer.collate_fn 23 | dataset = ReviewDataset.from_file(args.input, format_function=ToUnflattenedModelInputsFunction(model.config.pad_token_id)) 24 | logging.info(f'Output file: {args.output}') 25 | with open(args.output, 'w') as output: 26 | assert output is not None 27 | for instance in dataset: 28 | inputs, preambles, targets = collate_fn([instance]) 29 | # defaults to try: https://github.com/huggingface/transformers/blob/v2.10.0/examples/summarization/bart/evaluate_cnn.py#L26-L40 30 | outputs = model.summarizer.generate_summary( 31 | inputs=inputs.cuda(), 32 | preambles=preambles.cuda(), 33 | num_beams=args.num_beams, 34 | length_penalty=args.length_penalty, 35 | max_length=args.max_length, 36 | min_length=args.min_length, 37 | no_repeat_ngram_size=args.no_repeat_ngram_size, 38 | early_stopping=True, 39 | decoder_start_token_id=model.config.bos_token_id, 40 | repetition_penalty=args.repetition_penalty, 41 | temperature=args.temperature, 42 | ) 43 | generated = model.tokenizer.decode(outputs[0].cpu(), skip_special_tokens=False) 44 | target = model.tokenizer.decode(targets.data[0], skip_special_tokens=False) 45 | logging.info('Generated: {}'.format(generated)) 46 | logging.info('Target: {}'.format(target)) 47 | output.write(json.dumps({ 48 | 'preamble': model.tokenizer.decode(preambles.squeeze()), 49 | 'generated': generated, 50 | 'target': target, 51 | })) 52 | output.write('\n') 53 | output.flush() 54 | 55 | if __name__ == '__main__': 56 | main() -------------------------------------------------------------------------------- /scripts/modeling/f1_scorer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from sklearn.metrics import classification_report, confusion_matrix 5 | 6 | from ms2.utils import ( 7 | EXTRA_TOKENS, 8 | SEP_TOKEN, 9 | ) 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description='Perform F1 scoring over textual versions of the output') 13 | parser.add_argument('--input', required=True, help='jsonl file of {generated: ..., target:...}') 14 | parser.add_argument('--output', required=False, help='output file') 15 | args = parser.parse_args() 16 | 17 | def fix_str(s): 18 | for t in EXTRA_TOKENS + [SEP_TOKEN, '', '']: 19 | s = s.replace(t, '') 20 | s = s.replace(' ', ' ') 21 | return s 22 | 23 | generations = [] 24 | targets = [] 25 | 26 | with open(args.input, 'r') as inf: 27 | for line in inf: 28 | content = json.loads(line) 29 | generations.append(fix_str(content['generated'])) 30 | targets.append(fix_str(content['target'])) 31 | 32 | str_map = { 33 | 'significantly decreased': 0, 34 | 'no significant difference': 1, 35 | 'significantly increased': 2, 36 | 'broken generation': 3, 37 | } 38 | 39 | for x in generations: 40 | if x not in str_map: 41 | print(x) 42 | generations = [str_map.get(x, str_map['broken generation']) for x in generations] 43 | targets = [str_map[x] for x in targets] 44 | 45 | scores = classification_report(targets, generations, digits=3, output_dict=True, target_names=list(str_map.keys())[:3]) 46 | print(scores) 47 | confusions = confusion_matrix(targets, generations, normalize='true') 48 | print('confusion matrix') 49 | print(confusions) 50 | if args.output is not None: 51 | assert args.output != args.input 52 | with open(args.output, 'w') as of: 53 | of.write(json.dumps(scores)) 54 | of.write('\n') 55 | of.write(json.dumps(confusions.tolist())) 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/modeling/select_pubmed_types_of_interest.py: -------------------------------------------------------------------------------- 1 | 2 | MESH_TERMS_TO_TARGET = set([ 3 | 'Systematic Review', 4 | 'Review', 5 | 'Randomized Controlled Trial', 6 | 'Clincal Trial', 7 | 'Controlled Clinical Trial', 8 | 'Clinical Trial, Phase I', 9 | 'Clinical Trial, Phase II', 10 | 'Clinical Trial, Phase III', 11 | 'Clinical Trial, Phase IV', 12 | ]) 13 | 14 | BAD_MESH_TERMS = set([ 15 | 'Randomized Controlled Trial, Veterinary', 16 | 'Clinical Trial, Veterinary', 17 | 'Observational Study, Veterinary', 18 | 'Case Report', 19 | ]) -------------------------------------------------------------------------------- /scripts/modeling/select_reviews.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser(description='Select a subset of a reviews file by s2id') 6 | parser.add_argument('--input_reviews', required=True, help='Input reviews file, jsonl formatted') 7 | parser.add_argument('--output_reviews', required=True, help='Output reviews file') 8 | parser.add_argument('--ids', required=True, help='s2ids file') 9 | args = parser.parse_args() 10 | 11 | with open(args.ids, 'r') as ids_file: 12 | ids = set(map(str.strip, ids_file)) 13 | 14 | already_written = set() 15 | skipped = 0 16 | with open(args.output_reviews, 'w') as of: 17 | for input_review in args.input_reviews.split(','): 18 | with open(input_review, 'r') as inf: 19 | for line in inf: 20 | content = json.loads(line) 21 | title = content['title'].lower() 22 | if 'redact' in title or 'withdraw' in title: 23 | skipped += 1 24 | continue 25 | eyeD = str(content['s2id']) 26 | if eyeD in ids: 27 | if eyeD not in already_written: 28 | of.write(line) 29 | already_written.add(eyeD) 30 | print(f'skipped {skipped} reviews as redacted or withdrawn, wrote {len(already_written)} reviews') 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /scripts/modeling/splits.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | def main(): 6 | parser = argparse.argumentparser(description='select a subset of a reviews file by s2id') 7 | parser.add_argument('--input_reviews', required=True, help='input reviews file, jsonl formatted') 8 | parser.add_argument('--output_dir', required=True, help='output reviews file') 9 | parser.add_argument('--test_ids', required=True, help='s2ids file') 10 | parser.add_argument('--train_ids', required=True, help='s2ids file') 11 | parser.add_argument('--val_ids', required=True, help='s2ids file') 12 | args = parser.parse_args() 13 | 14 | def read_ids(ids_file): 15 | with open(ids_file, 'r') as idf: 16 | ids = set(map(str.strip, idf)) 17 | return ids 18 | train_ids = read_ids(args.train_ids) 19 | val_ids = read_ids(args.val_ids) 20 | test_ids = read_ids(args.test_ids) 21 | 22 | train_file = os.path.join(args.output_dir, 'train.jsonl') 23 | val_file = os.path.join(args.output_dir, 'val.jsonl') 24 | test_file = os.path.join(args.output_dir, 'test.jsonl') 25 | 26 | shared_ids = train_ids & val_ids & test_ids 27 | assert len(shared_ids) == 0 28 | 29 | with open(train_file, 'w') as train_f, \ 30 | open(val_file, 'w') as val_f, \ 31 | open(test_file, 'w') as test_f: 32 | for input_review in args.input_reviews.split(','): 33 | with open(input_review, 'r') as inf: 34 | for line in inf: 35 | content = json.loads(line) 36 | eyeD = str(content['s2id']) 37 | if eyeD in train_ids: 38 | train_f.write(line) 39 | elif eyeD in val_ids: 40 | val_f.write(line) 41 | elif eyeD in test_ids: 42 | test_f.write(line) 43 | else: 44 | print(f'Unknown id {eyeD}') 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /scripts/modeling/summarizer_input_prep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import itertools 4 | import json 5 | import logging 6 | import multiprocessing 7 | import os 8 | import re 9 | import shutil 10 | 11 | from dataclasses import asdict, replace 12 | from typing import List, Optional, Set, Tuple 13 | 14 | import tqdm 15 | 16 | from ms2.data.munge import fields_re, spaces_re 17 | from ms2.utils import ( 18 | get_tokenizer, 19 | Review, 20 | Study, 21 | Reference, 22 | TargetReference, 23 | TargetSummary, 24 | EXTRA_TOKENS, 25 | START_BACKGROUND, 26 | END_BACKGROUND, 27 | START_REFERENCE, 28 | END_REFERENCE, 29 | SEP_TOKEN, 30 | ) 31 | 32 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 33 | 34 | NUM_PROCS = 8 35 | 36 | # fields that begin a review but we aren't attempting to generate 37 | PREAMBLE_FIELDS = set([ 38 | 'BACKGROUND', 39 | 'GOAL', 40 | ]) 41 | # fields that we exclude both generation and targeting - 42 | # either because they are out of scope, or too detailed, or useless 43 | EXCLUDED_FIELDS = set([ 44 | 'RESULT', 45 | 'METHODS', 46 | 'ETC', 47 | 'DETAILED_FINDINGS', 'RECOMMENDATION', 'FURTHER_STUDY', 'EVIDENCE_QUALITY' 48 | ]) 49 | # fields we want to generate 50 | # eventually will include 'EVIDENCE_QUALITY', potentially 'FURTHER_STUDY' 51 | TARGET_FIELDS = set([ 52 | 'CONCLUSION', 53 | 'EFFECT', 54 | ]) 55 | 56 | def extract_summary_parts(summary_parts: List[Tuple[str, str]]) -> Tuple[str, str, List[str]]: 57 | # TODO(jayd) filter for results-like, preamble-like, conclusions-line 58 | preambles = [] 59 | targets = [] 60 | unknown_fields = [] 61 | 62 | for field, value, _ in summary_parts: 63 | if field in PREAMBLE_FIELDS: 64 | preambles.append(value) 65 | elif field in TARGET_FIELDS: 66 | targets.append(value) 67 | elif field in EXCLUDED_FIELDS: 68 | pass 69 | else: 70 | unknown_fields.append(field) 71 | 72 | return '\n'.join(preambles), '\n'.join(targets), unknown_fields 73 | 74 | def select_reference_from_study(study: Study) -> Optional[Reference]: 75 | if len(study.references) == 1: 76 | return study.references[0] 77 | else: 78 | have_abstract = list(filter(lambda x: x.abstract is not None, study.references)) 79 | have_abstract_and_body = list(filter(lambda x: x.content is not None, have_abstract)) 80 | if len(have_abstract_and_body) > 0: 81 | # TODO is there a better choice? 82 | return have_abstract_and_body[0] 83 | elif len(have_abstract) > 0: 84 | return have_abstract[0] 85 | else: 86 | return None 87 | 88 | def process_reference(reference: Reference) -> Optional[TargetReference]: 89 | if reference.abstract is None: 90 | return None 91 | title = reference.title.strip() 92 | abstract = reference.abstract.strip() 93 | title_abstract = START_REFERENCE + ' ' + title + ' ' + SEP_TOKEN + ' ' + abstract + ' ' + END_REFERENCE 94 | return TargetReference( 95 | title_abstract=title_abstract, 96 | full_text=None, 97 | s2id=reference.s2id, 98 | s2hash=reference.s2hash, 99 | ) 100 | 101 | def process_review(review: Review, max_refs: int, tokenizer) -> Tuple[Optional[TargetSummary], Set[str]]: 102 | input_text = None 103 | target_texts = [] 104 | summaries = [] 105 | if review.structured_abstract is not None and len(review.structured_abstract) > 1: 106 | abs_preamble, abs_target, abs_unknown_fields = extract_summary_parts(review.structured_abstract) 107 | if len(abs_preamble.strip()) == 0: 108 | return None 109 | abs_preamble = START_BACKGROUND + ' ' + abs_preamble + ' ' + END_BACKGROUND 110 | else: 111 | abs_preamble, abs_target, abs_unknown_fields = None, None, [] 112 | if review.structured_summary is not None and len(review.structured_summary) > 1: 113 | sum_preamble, sum_target, sum_unknown_fields = extract_summary_parts(review.structured_summary) 114 | sum_preamble = START_BACKGROUND + ' ' + sum_preamble + ' ' + END_BACKGROUND 115 | else: 116 | sum_preamble, sum_target, sum_unknown_fields = None, None, [] 117 | 118 | unknown_fields = set(abs_unknown_fields) or set(sum_unknown_fields) 119 | if len(unknown_fields) > 0: 120 | logging.info(f'For review {review.docid}, found unknown fields {unknown_fields} in abstract and summary!') 121 | 122 | if abs_target is None and sum_target is None: 123 | return None, unknown_fields 124 | 125 | if abs_target is not None and len(abs_target) > 0: 126 | target_texts.append(abs_target) 127 | if sum_target is not None and len(sum_target) > 0: 128 | target_texts.append(sum_target) 129 | 130 | if abs_preamble is not None: 131 | input_text = abs_preamble 132 | elif sum_preamble is not None: 133 | input_text = sum_preamble 134 | else: 135 | input_text = '' 136 | 137 | if len(input_text.strip()) == 0: 138 | return None, set() 139 | 140 | selected_study_references = map(select_reference_from_study, review.included_studies) 141 | processed_references = map(process_reference, selected_study_references) 142 | references = list(filter(lambda x: x is not None, processed_references)) 143 | 144 | if max_refs is not None and len(references) > max_refs: 145 | logging.info('Truncating review {} references from {} to {}'.format(review.docid, len(references), max_refs)) 146 | references = references[:max_refs] 147 | target_texts = clean_targets(target_texts) 148 | 149 | return TargetSummary( 150 | preface=input_text.strip(), 151 | target_texts=target_texts, 152 | review_id=review.docid, 153 | references=references, 154 | s2id=review.s2id, 155 | s2hash=review.s2hash, 156 | ), unknown_fields 157 | 158 | def clean_targets(target_texts: List[str]) -> List[str]: 159 | # remove fancy markers from the target 160 | for elem in EXTRA_TOKENS: 161 | cleaned_targets = [] 162 | for t in target_texts: 163 | cleaned_targets.append(t.replace(elem, '')) 164 | target_texts = cleaned_targets 165 | # remove standard section markers from the start of the 166 | cleaned_targets = [] 167 | for t in target_texts: 168 | beginning, end = t[:50], t[50:] 169 | beginning = re.sub(fields_re, '', beginning) 170 | beginning = re.sub(spaces_re, ' ', beginning) 171 | cleaned_targets.append(beginning + end) 172 | return cleaned_targets 173 | 174 | def tokenize_target_summary(summary: TargetSummary, tokenizer, max_length: Optional[int]) -> TargetSummary: 175 | end_reference = tokenizer.encode(END_REFERENCE, add_special_tokens=False)[0] 176 | def tokenize_target_reference(target_reference: TargetReference) -> TargetReference: 177 | title_abstract = tokenizer.encode( 178 | target_reference.title_abstract, 179 | add_special_tokens=False, 180 | truncation=max_length is not None, 181 | max_length=max_length 182 | ) 183 | if title_abstract[-1] != end_reference: 184 | title_abstract = title_abstract + [end_reference] 185 | return replace(target_reference, 186 | title_abstract=title_abstract, 187 | full_text=tokenizer.encode( 188 | target_reference.full_text, 189 | add_special_tokens=False, 190 | truncation=max_length is not None, max_length=max_length 191 | ) if target_reference.full_text is not None else None, 192 | ) 193 | 194 | end_background = tokenizer.encode(END_BACKGROUND, add_special_tokens=False)[0] 195 | if len(summary.preface) > 0: 196 | preface = tokenizer.encode(summary.preface, 197 | truncation=max_length is not None, 198 | max_length=max_length, 199 | add_special_tokens=False) 200 | if preface[-1] != end_background: 201 | preface = preface + [end_background] 202 | else: 203 | preface = None 204 | 205 | target_texts = [] 206 | eos_token = tokenizer.encode(tokenizer._eos_token.content, add_special_tokens=False)[0] 207 | for target_text in summary.target_texts: 208 | if len(target_text) == 0: 209 | continue 210 | target_text = tokenizer.encode( 211 | target_text, 212 | add_special_tokens=True, 213 | truncation=max_length is not None, max_length=max_length 214 | ) 215 | if target_text[-1] != eos_token: 216 | target_text = target_text + [eos_token] 217 | target_texts.append(target_text) 218 | 219 | return replace(summary, 220 | preface=preface, 221 | target_texts=target_texts, 222 | references=list(map(tokenize_target_reference, summary.references)), 223 | ) 224 | 225 | def valid_review(summary: TargetSummary) -> bool: 226 | if summary is None: 227 | return False 228 | if len(summary.references) == 0: 229 | return False 230 | if len(summary.target_texts) == 0: 231 | return False 232 | return True 233 | 234 | def total_reference_lengths(summary: TargetSummary) -> int: 235 | return sum((len(ref.title_abstract) for ref in summary.references)) 236 | 237 | def total_decoding_lengths(summary: TargetSummary) -> int: 238 | preface_length = len(summary.preface) if summary.preface is not None else 0 239 | target_lengths = max(map(len, summary.target_texts)) 240 | return preface_length + target_lengths 241 | 242 | def main(): 243 | parser = argparse.ArgumentParser(description='Convert Reviews into TargetSummary objects') 244 | parser.add_argument('--input', required=True, help='jsonl serialized input reviews') 245 | parser.add_argument('--output', required=True, help='file for jsonl serialized output targets') 246 | parser.add_argument('--tokenizer_save', required=False, default=None, help='Where should we save the tokenizer to?') 247 | parser.add_argument('--tokenizer', required=True, help='tokenizer type, e.g. BART') 248 | parser.add_argument('--max_length', type=int, default=None, required=False, help='truncate sequence lengths?') 249 | parser.add_argument('--max_refs', type=int, default=None, required=False, help='truncate number of included refs?') 250 | args = parser.parse_args() 251 | 252 | # TODO(jayd) assign ids to these elements! 253 | tokenizer = get_tokenizer(args.tokenizer) 254 | if args.tokenizer_save is not None: 255 | logging.info(f'Saving tokenizer with extended vocab to {args.tokenizer_save}') 256 | tokenizer.save_pretrained(args.tokenizer_save) 257 | # workaround for what seems to be a huggingface bug 258 | likely_misnamed_tokenizer_config = os.path.join(args.tokenizer_save, 'tokenizer_config.json') 259 | if os.path.exists(likely_misnamed_tokenizer_config): 260 | shutil.copyfile(likely_misnamed_tokenizer_config, os.path.join(args.tokenizer_save, 'config.json')) 261 | reviews = Review.read_reviews(args.input) 262 | logging.info(f'Loaded {len(reviews)}') 263 | with multiprocessing.Pool(processes=NUM_PROCS) as p: 264 | logging.info('Processing reviews') 265 | processed = p.imap(functools.partial(process_review, max_refs=args.max_refs, tokenizer=tokenizer), reviews, chunksize=1) 266 | processed = filter(lambda x: x is not None, processed) 267 | processed = list(processed) 268 | target_reviews, unknown_fields = zip(*processed) 269 | all_unknown_fields = set(itertools.chain.from_iterable(unknown_fields)) 270 | if len(all_unknown_fields) > 0: 271 | logging.info(f'Unable to process fields {all_unknown_fields}') 272 | non_empty_reviews = list(filter(valid_review, target_reviews)) 273 | logging.info('Tensorizing reviews') 274 | tensorized_reviews = p.imap(functools.partial(tokenize_target_summary, tokenizer=tokenizer, max_length=args.max_length), non_empty_reviews, chunksize=1) 275 | tensorized_reviews = list(tensorized_reviews) 276 | logging.info(f'After processing, a total of {len(non_empty_reviews)} reviews are left, with {len(tensorized_reviews)} reviews for input') 277 | review_target_lengths = list(p.map(total_decoding_lengths, tensorized_reviews)) 278 | review_reference_lengths = list(p.map(total_reference_lengths, tensorized_reviews)) 279 | total_lengths = [sum(x) for x in zip(review_reference_lengths, review_target_lengths)] 280 | min_length = min(total_lengths) 281 | max_length = max(total_lengths) 282 | avg_length = sum(total_lengths) / len(total_lengths) 283 | logging.info(f'Input/Output lengths are a minimum of {min_length} wordpieces long, maximum of {max_length} wordpieces, and average of {avg_length} wordpieces.') 284 | with open(args.output, 'w') as of: 285 | for review in tqdm.tqdm(tensorized_reviews): 286 | of.write(json.dumps(asdict(review))) 287 | of.write('\n') 288 | 289 | if __name__ == '__main__': 290 | main() -------------------------------------------------------------------------------- /scripts/modeling/table_to_text_summarizer_input.py: -------------------------------------------------------------------------------- 1 | """Produces instances for the transformer summarizer 2 | 3 | From each review, this extracts all possible I/O tuplets that are *not* in the Review Effect statements and: 4 | * turns every reference into a format like | intervention | outcome | evidence sentence | significance class | 5 | * keeping only tuplets like ^^ that have an evidence sentence score above some threshold 6 | * uses the preamble 7 | * targets the EFFECT statement 8 | 9 | """ 10 | import argparse 11 | import functools 12 | import itertools 13 | import json 14 | import logging 15 | import multiprocessing 16 | 17 | from dataclasses import asdict 18 | from typing import List 19 | 20 | from ms2.utils import ( 21 | get_tokenizer, 22 | Review, 23 | TargetReference, 24 | TargetSummary, 25 | START_BACKGROUND, 26 | END_BACKGROUND, 27 | START_INTERVENTION, 28 | END_INTERVENTION, 29 | START_OUTCOME, 30 | END_OUTCOME, 31 | START_REFERENCE, 32 | END_REFERENCE, 33 | SEP_TOKEN, 34 | START_EVIDENCE, 35 | END_EVIDENCE 36 | ) 37 | 38 | from summarizer_input_prep import ( 39 | clean_targets, 40 | extract_summary_parts, 41 | tokenize_target_summary, 42 | valid_review, 43 | ) 44 | 45 | from tabular_summarizer_input_prep import ( 46 | review_ios, 47 | sig_class, 48 | ) 49 | 50 | NUM_PROCS = 4 51 | INTERVENTION_RE = START_INTERVENTION + '(.*?)' + END_INTERVENTION 52 | OUTCOME_RE = START_OUTCOME + '(.*?)' + END_OUTCOME 53 | 54 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 55 | 56 | def process_review(review: Review, evidence_threshold) -> List[TargetSummary]: 57 | ret = [] 58 | refs = [] 59 | review_io_parts = set(itertools.product(*review_ios(review))) 60 | preface, target, _ = extract_summary_parts(review.structured_abstract) 61 | for ref in review.extract_references(): 62 | if ref.significances is None: 63 | continue 64 | for sig in ref.significances: 65 | if sig.evidence_sentence_score < evidence_threshold: 66 | continue 67 | i, o = sig.intervention, sig.outcome 68 | if (i, o) not in review_io_parts: 69 | continue 70 | #raise ValueError('Impossible!') 71 | clazz = sig_class(sig.classification) 72 | fake_text = '\n'.join([ 73 | START_REFERENCE, 74 | START_INTERVENTION + ' ' + i + ' ' + END_INTERVENTION, 75 | START_OUTCOME + ' ' + o + ' ' + END_OUTCOME, 76 | START_EVIDENCE + ' ' + sig.evidence_sentence + ' ' + END_EVIDENCE, 77 | SEP_TOKEN + ' ' + clazz, 78 | END_REFERENCE, 79 | ]) 80 | refs.append(TargetReference( 81 | title_abstract=fake_text, 82 | full_text=None, 83 | s2id=ref.s2id, 84 | s2hash=ref.s2hash, 85 | )) 86 | 87 | ret.append(TargetSummary( 88 | preface=START_BACKGROUND + ' ' + preface + ' ' + END_BACKGROUND, 89 | target_texts=clean_targets([target]), 90 | review_id=review.docid, 91 | references=refs, 92 | s2id=review.s2id, 93 | s2hash=review.s2hash 94 | )) 95 | return ret 96 | 97 | def main(): 98 | parser = argparse.ArgumentParser(description='Convert Reviews into TargetSummary objects') 99 | parser.add_argument('--input', required=True, help='jsonl serialized input reviews') 100 | parser.add_argument('--output', required=True, help='file for jsonl serialized output targets') 101 | parser.add_argument('--evidence_sentence_threshold', default=0.0, type=float, help='Evidence sentence score minimum thresholds') 102 | parser.add_argument('--tokenizer', required=True, help='tokenizer type, e.g. BART') 103 | parser.add_argument('--max_length', type=int, default=None, required=False, help='truncate sequence lengths?') 104 | args = parser.parse_args() 105 | 106 | tokenizer = get_tokenizer(args.tokenizer) 107 | review_count = 0 108 | written = 0 109 | with open(args.input, 'r') as inf, \ 110 | open(args.output, 'w') as of: 111 | for line in inf: 112 | if len(line.strip()) == 0: 113 | continue 114 | review = Review.from_json(line) 115 | review_count += 1 116 | instances = process_review(review, evidence_threshold=args.evidence_sentence_threshold) 117 | instances = list(filter(valid_review, instances)) 118 | tensorized_reviews = map(functools.partial(tokenize_target_summary, tokenizer=tokenizer, max_length=args.max_length), instances) 119 | tensorized_reviews = list(tensorized_reviews) 120 | for review in tensorized_reviews: 121 | of.write(json.dumps(asdict(review))) 122 | of.write('\n') 123 | written += 1 124 | assert written > 0 125 | 126 | 127 | if __name__ == '__main__': 128 | main() -------------------------------------------------------------------------------- /scripts/modeling/tabular_summarizer_input_prep.py: -------------------------------------------------------------------------------- 1 | """Produces instances for the transformer summarizer 2 | 3 | From each review, this extracts all possible I/O tuplets that are *not* in the Review Effect statements and: 4 | * turns every reference into a format like | intervention | outcome | evidence sentence | significance class | 5 | * keeping only tuplets like ^^ that have an evidence sentence score above some threshold 6 | * turns the preamble into a format like | intervention | outcome | 7 | * turns the target into the literal text of the significance class 8 | 9 | """ 10 | import argparse 11 | import functools 12 | import itertools 13 | import json 14 | import logging 15 | import multiprocessing 16 | import re 17 | 18 | from collections import defaultdict 19 | from dataclasses import asdict 20 | from typing import Dict, List, Set, Tuple 21 | 22 | from ms2.utils import ( 23 | get_tokenizer, 24 | Review, 25 | TargetReference, 26 | TargetSummary, 27 | START_BACKGROUND, 28 | END_BACKGROUND, 29 | START_INTERVENTION, 30 | END_INTERVENTION, 31 | START_OUTCOME, 32 | END_OUTCOME, 33 | START_REFERENCE, 34 | END_REFERENCE, 35 | SEP_TOKEN, 36 | START_EVIDENCE, 37 | END_EVIDENCE, 38 | EXTRA_TOKENS 39 | ) 40 | 41 | from summarizer_input_prep import ( 42 | TARGET_FIELDS, 43 | tokenize_target_summary, 44 | valid_review, 45 | ) 46 | 47 | NUM_PROCS = 4 48 | INTERVENTION_RE = START_INTERVENTION + '(.*?)' + END_INTERVENTION 49 | OUTCOME_RE = START_OUTCOME + '(.*?)' + END_OUTCOME 50 | 51 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 52 | 53 | def clean_str(s): 54 | for elem in EXTRA_TOKENS: 55 | s = s.replace(elem, '') 56 | s = s.replace(' ', ' ') 57 | return s 58 | 59 | def extract_non_target_parts(summary_parts: List[Tuple[str, str]]) -> Set[str]: 60 | ret = set() 61 | for field, value, _ in summary_parts: 62 | if field not in TARGET_FIELDS and field not in {'FURTHER_STUDY', 'RECOMMENDATION', 'EVIDENCE_QUALITY', 'DETAILED_FINDINGS', 'RESULT'}: 63 | ret.add(value) 64 | return ret 65 | 66 | def review_ios(review: Review) -> Tuple[Set[str], Set[str]]: 67 | assert review.structured_abstract is not None and len(review.structured_abstract) > 0 68 | non_summary_inputs = extract_non_target_parts(review.structured_abstract) 69 | non_summary_inputs = '\n'.join(non_summary_inputs) 70 | intervention_groups = list(map(str.strip, map(clean_str, re.findall(INTERVENTION_RE, non_summary_inputs)))) 71 | outcome_groups = list(map(str.strip, map(clean_str, re.findall(OUTCOME_RE, non_summary_inputs)))) 72 | return intervention_groups, outcome_groups 73 | 74 | def sig_class(dist: Dict[str, float]) -> str: 75 | best_score = float('-inf') 76 | best_class = None 77 | for clazz, score in dist.items(): 78 | if score > best_score: 79 | best_score = score 80 | best_class = clazz 81 | return best_class 82 | 83 | def process_review(review: Review, evidence_threshold) -> List[TargetSummary]: 84 | ret = [] 85 | ref_opinions = defaultdict(list) 86 | review_io_parts = set(itertools.product(*review_ios(review))) 87 | for ref in review.extract_references(): 88 | if ref.significances is None: 89 | continue 90 | for sig in ref.significances: 91 | if sig.evidence_sentence_score < evidence_threshold: 92 | continue 93 | i, o = sig.intervention, sig.outcome 94 | if (i, o) not in review_io_parts: 95 | continue 96 | clazz = sig_class(sig.classification) 97 | fake_text = '\n'.join([ 98 | START_REFERENCE, 99 | START_INTERVENTION + ' ' + i + ' ' + END_INTERVENTION, 100 | START_OUTCOME + ' ' + o + ' ' + END_OUTCOME, 101 | START_EVIDENCE + ' ' + sig.evidence_sentence + ' ' + END_EVIDENCE, 102 | SEP_TOKEN + ' ' + clazz, 103 | END_REFERENCE, 104 | ]) 105 | ref_opinions[(i,o)].append(TargetReference( 106 | title_abstract=fake_text, 107 | s2id=ref.s2id, 108 | s2hash=ref.s2hash, 109 | full_text=None, 110 | )) 111 | 112 | for sig in review.significances: 113 | i, o = sig.intervention, sig.outcome 114 | if (i, o) not in review_io_parts: 115 | continue 116 | clazz = sig_class(sig.classification) 117 | refs = ref_opinions[(i, o)] 118 | if len(refs) > 0: 119 | ret.append(TargetSummary( 120 | preface='\n'.join([ 121 | START_BACKGROUND, 122 | START_INTERVENTION + ' ' + i + ' ' + END_INTERVENTION, 123 | START_OUTCOME + ' ' + o + ' ' + END_OUTCOME, 124 | END_BACKGROUND, 125 | ]), 126 | target_texts=[clazz], 127 | review_id=review.docid + '_int_' + i + '_out_' + o, 128 | references=refs, 129 | s2id=review.s2id, 130 | s2hash=review.s2hash 131 | )) 132 | return ret 133 | 134 | def main(): 135 | parser = argparse.ArgumentParser(description='Convert Reviews into TargetSummary objects') 136 | parser.add_argument('--input', required=True, help='jsonl serialized input reviews') 137 | parser.add_argument('--output', required=True, help='file for jsonl serialized output targets') 138 | parser.add_argument('--evidence_sentence_threshold', default=0.0, type=float, help='Evidence sentence score minimum thresholds') 139 | parser.add_argument('--tokenizer', required=True, help='tokenizer type, e.g. BART') 140 | parser.add_argument('--max_length', type=int, default=None, required=False, help='truncate sequence lengths?') 141 | args = parser.parse_args() 142 | 143 | tokenizer = get_tokenizer(args.tokenizer) 144 | review_count = 0 145 | with open(args.input, 'r') as inf, \ 146 | open(args.output, 'w') as of: 147 | for line in inf: 148 | if len(line.strip()) == 0: 149 | continue 150 | review = Review.from_json(line) 151 | instances = process_review(review, evidence_threshold=args.evidence_sentence_threshold) 152 | instances = list(filter(valid_review, instances)) 153 | tensorized_reviews = map(functools.partial(tokenize_target_summary, tokenizer=tokenizer, max_length=args.max_length), instances) 154 | tensorized_reviews = list(tensorized_reviews) 155 | for review in tensorized_reviews: 156 | of.write(json.dumps(asdict(review))) 157 | of.write('\n') 158 | review_count += 1 159 | logging.info(f'Wrote {review_count} instances') 160 | assert review_count > 0 161 | 162 | if __name__ == '__main__': 163 | main() 164 | -------------------------------------------------------------------------------- /scripts/modeling/text_to_table_input_prep.py: -------------------------------------------------------------------------------- 1 | """Produces instances for the transformer summarizer 2 | 3 | Uses textual inputs from the references, 4 | From each review, this extracts all possible I/O tuplets that are *not* in the Review Effect statements and: 5 | * turns every reference into a format like | intervention | outcome | evidence sentence | significance class | 6 | * keeping only tuplets like ^^ that have an evidence sentence score above some threshold 7 | * turns the preamble into a format like | intervention | outcome | 8 | * turns the target into the literal text of the significance class 9 | 10 | """ 11 | import argparse 12 | import functools 13 | import itertools 14 | import json 15 | import logging 16 | import multiprocessing 17 | import re 18 | 19 | from collections import defaultdict 20 | from dataclasses import asdict 21 | from typing import Dict, List, Set, Tuple 22 | 23 | from ms2.utils import ( 24 | get_tokenizer, 25 | Review, 26 | TargetSummary, 27 | START_BACKGROUND, 28 | END_BACKGROUND, 29 | START_INTERVENTION, 30 | END_INTERVENTION, 31 | START_OUTCOME, 32 | END_OUTCOME, 33 | ) 34 | 35 | from summarizer_input_prep import ( 36 | select_reference_from_study, 37 | process_reference, 38 | tokenize_target_summary, 39 | valid_review, 40 | ) 41 | 42 | from tabular_summarizer_input_prep import ( 43 | review_ios, 44 | sig_class, 45 | ) 46 | 47 | NUM_PROCS = 4 48 | INTERVENTION_RE = START_INTERVENTION + '(.*?)' + END_INTERVENTION 49 | OUTCOME_RE = START_OUTCOME + '(.*?)' + END_OUTCOME 50 | 51 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 52 | 53 | def process_review(review: Review) -> List[TargetSummary]: 54 | ret = [] 55 | review_io_parts = set(itertools.product(*review_ios(review))) 56 | 57 | selected_study_references = map(select_reference_from_study, review.included_studies) 58 | processed_references = map(process_reference, selected_study_references) 59 | refs = list(filter(lambda x: x is not None, processed_references)) 60 | 61 | for sig in review.significances: 62 | i, o = sig.intervention, sig.outcome 63 | if (i, o) not in review_io_parts: 64 | continue 65 | clazz = sig_class(sig.classification) 66 | if len(refs) > 0: 67 | ret.append(TargetSummary( 68 | preface='\n'.join([ 69 | START_BACKGROUND, 70 | START_INTERVENTION + ' ' + i + ' ' + END_INTERVENTION, 71 | START_OUTCOME + ' ' + o + ' ' + END_OUTCOME, 72 | END_BACKGROUND, 73 | ]), 74 | target_texts=[clazz], 75 | review_id=review.docid + '_int_' + i + '_out_' + o, 76 | references=refs, 77 | s2id=review.s2id, 78 | s2hash=review.s2hash 79 | )) 80 | return ret 81 | 82 | def main(): 83 | parser = argparse.ArgumentParser(description='Convert Reviews into TargetSummary objects') 84 | parser.add_argument('--input', required=True, help='jsonl serialized input reviews') 85 | parser.add_argument('--output', required=True, help='file for jsonl serialized output targets') 86 | parser.add_argument('--tokenizer', required=True, help='tokenizer type, e.g. BART') 87 | parser.add_argument('--max_length', type=int, default=None, required=False, help='truncate sequence lengths?') 88 | args = parser.parse_args() 89 | 90 | tokenizer = get_tokenizer(args.tokenizer) 91 | review_count = 0 92 | written = 0 93 | with open(args.input, 'r') as inf, \ 94 | open(args.output, 'w') as of: 95 | for line in inf: 96 | if len(line.strip()) == 0: 97 | continue 98 | review = Review.from_json(line) 99 | review_count += 1 100 | instances = process_review(review) 101 | instances = list(filter(valid_review, instances)) 102 | tensorized_reviews = map(functools.partial(tokenize_target_summary, tokenizer=tokenizer, max_length=args.max_length), instances) 103 | tensorized_reviews = list(tensorized_reviews) 104 | for review in tensorized_reviews: 105 | of.write(json.dumps(asdict(review))) 106 | of.write('\n') 107 | written += 1 108 | assert written > 0 109 | 110 | if __name__ == '__main__': 111 | main() 112 | 113 | -------------------------------------------------------------------------------- /scripts/run_ms2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | PYTHONPATH=$(readlink -e .):$PYTHONPATH 6 | export PYTHONPATH 7 | PYTHON=$HOME/local/ms2_repro/bin/python 8 | set -o nounset 9 | set -o pipefail 10 | 11 | function ckpt { 12 | local cmd="$1" 13 | local name="$2" 14 | local ckpt_file="$ARTIFACTS/logs/$name.ckpt" 15 | local partial_ckpt_file="$ARTIFACTS/logs/$name.partial" 16 | local log_file_base="$ARTIFACTS/logs/$name" 17 | mkdir -p "$(dirname $ckpt_file)" "$(dirname $log_file_base)" 18 | if [ -e "$partial_ckpt_file" ] ; then 19 | cat "$partial_ckpt_file" >> "$partial_ckpt_file".old 20 | fi 21 | if [ ! -e "$ckpt_file" ] ; then 22 | echo "running $name; $cmd" 23 | echo "$cmd" > "$partial_ckpt_file" 24 | if [ -e "${log_file_base}.e" ]; then 25 | mv "${log_file_base}.e" "${log_file_base}.e.old" 26 | fi 27 | if [ -e "${log_file_base}.o" ]; then 28 | mv "${log_file_base}.o" "${log_file_base}.o.old" 29 | fi 30 | # shellcheck disable=SC2086 31 | eval $cmd > >(tee "${log_file_base}.o") 2> >(tee "${log_file_base}.e" >&2) && touch $ckpt_file || (echo "failed $name ; $cmd" ; exit 1) 32 | #else 33 | #echo "already ran '$name'; clear '$ckpt_file' to rerun" 34 | fi 35 | } 36 | 37 | mkdir $HOME/scratch 38 | ARTIFACTS="$HOME/scratch/ms2_repro/" 39 | OUTPUTS="$ARTIFACTS/outputs" 40 | 41 | MAX_LENGTH="--max_length 500" 42 | for subset in training validation testing ; do 43 | cmd="srun -p short -t 16:00:00 --mem 24G $PYTHON scripts/modeling/summarizer_input_prep.py --input $HOME/scratch/ms2_repro/ms2_data/${subset}_reviews.jsonl --output $OUTPUTS/text_to_text/${subset}.jsonl --tokenizer facebook/bart-base $MAX_LENGTH" 44 | ckpt "$cmd" "text_to_text/$subset" 45 | done 46 | 47 | training_reviews_file=$OUTPUTS/text_to_text/training.jsonl 48 | validation_reviews_file=$OUTPUTS/text_to_text/validation.jsonl 49 | testing_reviews_file=$OUTPUTS/text_to_text/testing.jsonl 50 | 51 | training_dir=$OUTPUTS/text_to_text/training/bart-base/ 52 | EPOCHS=8 53 | GRAD_ACCUM=16 54 | MODEL_NAME='facebook/bart-base' 55 | cmd="srun -p frink --gres gpu:1 --mem=64G --cpus-per-task=16 \ 56 | $PYTHON ms2/models/transformer_summarizer.py \ 57 | --train $training_reviews_file \ 58 | --val $validation_reviews_file \ 59 | --training_root $training_dir \ 60 | --epochs=$EPOCHS \ 61 | --grad_accum=$GRAD_ACCUM \ 62 | --fp16 \ 63 | --model_name $MODEL_NAME \ 64 | --max_num_refs 25" 65 | 66 | ckpt "$cmd" "text_to_text/training/bart-base" 67 | 68 | NUM_BEAMS=6 69 | CHECKPOINT=$training_dir/_ckpt_epoch_7.ckpt 70 | INPUT=$validation_reviews_file 71 | OUTPUT=$training_dir/decode/validation 72 | mkdir -p $(dirname $OUTPUT) 73 | cmd="srun -p frink --gres gpu:1 --mem 32G --cpus-per-task=8 \ 74 | $PYTHON \ 75 | scripts/modeling/decode.py \ 76 | --input $INPUT --output $OUTPUT \ 77 | --checkpoint $CHECKPOINT \ 78 | --num_beams=$NUM_BEAMS \ 79 | --model_name $MODEL_NAME" 80 | ckpt "$cmd" "text_to_text/training/bart_base/decode/validation" & 81 | 82 | 83 | INPUT=$training_reviews_file 84 | OUTPUT=$training_dir/decode/training 85 | cmd="srun -p frink --gres gpu:1 --mem 32G --cpus-per-task=8 \ 86 | $PYTHON \ 87 | scripts/modeling/decode.py \ 88 | --input $INPUT --output $OUTPUT \ 89 | --checkpoint $CHECKPOINT \ 90 | --num_beams=$NUM_BEAMS \ 91 | --model_name $MODEL_NAME" 92 | ckpt "$cmd" "text_to_text/training/bart_base/decode/training" & 93 | 94 | 95 | INPUT=$testing_reviews_file 96 | OUTPUT=$training_dir/decode/testing 97 | cmd="srun -p frink --gres gpu:1 --mem 32G --cpus-per-task=8 \ 98 | $PYTHON \ 99 | scripts/modeling/decode.py \ 100 | --input $INPUT --output $OUTPUT \ 101 | --checkpoint $CHECKPOINT \ 102 | --num_beams=$NUM_BEAMS \ 103 | --model_name $MODEL_NAME" 104 | ckpt "$cmd" "text_to_text/training/bart_base/decode/testing" & 105 | 106 | wait 107 | 108 | training_dir=$OUTPUTS/text_to_text/training/longformer_base/ 109 | # longformer, bart-large 110 | MODEL_NAME=$HOME/scratch/ms2_repro/source_models/longformer-encdec-base-16384 111 | cmd="srun -p frink --gres gpu:1 --mem=64G --cpus-per-task=16 \ 112 | $PYTHON ms2/models/transformer_summarizer.py \ 113 | --train $training_reviews_file \ 114 | --val $validation_reviews_file \ 115 | --training_root $training_dir \ 116 | --epochs=$EPOCHS \ 117 | --grad_accum=$GRAD_ACCUM \ 118 | --fp16 \ 119 | --model_name $MODEL_NAME \ 120 | --max_num_refs 25" 121 | 122 | ckpt "$cmd" "text_to_text/training/longformer_base" & 123 | 124 | NUM_BEAMS=6 125 | CHECKPOINT=$training_dir/_ckpt_epoch_7.ckpt 126 | INPUT=$validation_reviews_file 127 | OUTPUT=$training_dir/decode/validation 128 | mkdir -p $(dirname $OUTPUT) 129 | cmd="srun -p frink --gres gpu:1 --mem 32G --cpus-per-task=8 \ 130 | $PYTHON \ 131 | scripts/modeling/decode.py \ 132 | --input $INPUT --output $OUTPUT \ 133 | --checkpoint $CHECKPOINT \ 134 | --num_beams=$NUM_BEAMS \ 135 | --model_name $MODEL_NAME" 136 | ckpt "$cmd" "text_to_text/training/longformer_base/decode/validation" & 137 | 138 | 139 | INPUT=$training_reviews_file 140 | OUTPUT=$training_dir/decode/training 141 | cmd="srun -p frink --gres gpu:1 --mem 32G --cpus-per-task=8 \ 142 | $PYTHON \ 143 | scripts/modeling/decode.py \ 144 | --input $INPUT --output $OUTPUT \ 145 | --checkpoint $CHECKPOINT \ 146 | --num_beams=$NUM_BEAMS \ 147 | --model_name $MODEL_NAME" 148 | ckpt "$cmd" "text_to_text/training/longformer_base/decode/training" & 149 | 150 | 151 | INPUT=$testing_reviews_file 152 | OUTPUT=$training_dir/decode/testing 153 | cmd="srun -p frink --gres gpu:1 --mem 32G --cpus-per-task=8 \ 154 | $PYTHON \ 155 | scripts/modeling/decode.py \ 156 | --input $INPUT --output $OUTPUT \ 157 | --checkpoint $CHECKPOINT \ 158 | --num_beams=$NUM_BEAMS \ 159 | --model_name $MODEL_NAME" 160 | ckpt "$cmd" "text_to_text/training/longformer_base/decode/testing" & 161 | 162 | wait 163 | --------------------------------------------------------------------------------