├── Aligner.py ├── LICENSE ├── README.md ├── annotation2MRPC_Aligner.py ├── createSubDatasets.py ├── docSum2MRPC_Aligner.py ├── environment_superPAL.yml ├── filterContained.py ├── finalAlignmentPred.py ├── main_predict.py ├── manual_datasets ├── dev_DUC_index_only.csv ├── dev_MN_only.csv ├── restore_alignments.py ├── test_DUC_index_only.csv ├── test_MN_only.csv └── train_full_details_no_oies_fixed_no_duplications_only_index.csv ├── run_glue.py ├── supervised_oie_wrapper ├── __init__.py ├── format_oie.py └── run_oie.py └── utils.py /Aligner.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | class Aligner(object): 5 | """ 6 | finds document-summary alignment pair candidates 7 | 8 | """ 9 | 10 | def __init__(self, data_path='.', mode='dev', 11 | log_file='results/dev_log.txt', metric_precompute=True, output_file = './prediction_dev.csv', 12 | database='duc2004,duc2007,MultiNews'): 13 | 14 | self.data_path = data_path 15 | self.mode = mode 16 | self.log_file = log_file 17 | self.metric_precompute = metric_precompute 18 | self.output_file = output_file 19 | if ',' in database: 20 | self.database = database.split(',') 21 | else: 22 | self.database = [database] 23 | 24 | 25 | self.summ_sents = [] 26 | self.doc_sents = [] 27 | self.final_alignments = [] 28 | 29 | 30 | # set up logger 31 | logging.basicConfig( 32 | level=logging.INFO, 33 | format="%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s", 34 | handlers=[ 35 | logging.FileHandler(f"{self.log_file}"), 36 | logging.StreamHandler() 37 | ]) 38 | 39 | if self.metric_precompute: 40 | self.metrics_data = {} 41 | else: 42 | metric_precompute_path = f'data/final_data/metric_compute_{self.mode}_oie.json' 43 | if os.path.isfile(metric_precompute_path): 44 | with open(metric_precompute_path, 'r') as f: 45 | self.metrics_data = json.load(f) 46 | else: 47 | logging.warning(f"WARNING: No metric_precompute data was found") 48 | self.metrics_data = None 49 | 50 | 51 | # self.scus_list = [] 52 | # self.scu_sent_pairs = [] 53 | self.docSentsOIE = True #whether need to generate OIE for doc sents. Depends on aligner. 54 | 55 | 56 | 57 | 58 | def read_and_split(self, dataset, sfile): 59 | ### process the summary 60 | logging.info(f"Evaluating the Alignment for Summary: {sfile}") 61 | sfile_basename = os.path.basename(sfile) 62 | if dataset in ['duc2007', 'duc2004']: 63 | doc_name = sfile_basename.split('.')[0] 64 | elif dataset == 'MultiNews': 65 | doc_name = "MultiNews_" + sfile_basename.split('_')[0] 66 | else: 67 | doc_name = sfile_basename.split('.')[0] 68 | 69 | summary = read_generic_file(sfile) 70 | s_sents = [] 71 | # for line in summary: 72 | # s_sents.extend(tokenize.sent_tokenize(line)) 73 | s_sents = tokenize.sent_tokenize(" ".join(summary)) 74 | self.summ_sents = [] 75 | idx_start = 0 76 | for sent in s_sents: 77 | self.summ_sents.append({'summaryFile': sfile_basename, 'scuSentCharIdx': idx_start, 78 | 'scuSentence': sent, 'database': dataset, 'topic': doc_name}) 79 | idx_start = idx_start + len(sent) + 1 # 1 for the space character 80 | 81 | ## process all the documents files 82 | doc_files = glob.glob(f"{self.data_path}/{doc_name}/*") 83 | 84 | logging.info(f"Following documents have been found for them:") 85 | logging.info("\n".join(doc_files)) 86 | self.doc_sents = [] 87 | for df in doc_files: 88 | doc_id = os.path.basename(df) 89 | document = read_generic_file(df) 90 | dsents = [] 91 | # for line in document: 92 | # dsents.extend(tokenize.sent_tokenize(line)) 93 | dsents = tokenize.sent_tokenize(" ".join(document)) 94 | idx_start = 0 95 | for dsent in dsents: 96 | if dsent != "...": # this is a exception 97 | self.doc_sents.append({'documentFile': doc_id, 'docSentCharIdx': idx_start, 98 | 'docSentText': dsent}) 99 | 100 | idx_start = idx_start + len(dsent) + 1 # 1 for the space charater between sentences 101 | 102 | 103 | def calc_metric_precompute(self): 104 | scus = [] 105 | # for s in self.summ_sents: 106 | # scus.extend(generate_scu(s, max_scus=100)) 107 | scus.extend(generate_scu_oie_multiSent(self.summ_sents, doc_summ='summ')) 108 | refs = [] 109 | cands = [] 110 | ids = [] 111 | for s in scus: 112 | refs.extend([x['docSentText'] for x in self.doc_sents]) 113 | cands.extend([s['scuText'] for _ in range(len(self.doc_sents))]) 114 | ids.extend([s['summaryFile'] + s['scuText'] + x['documentFile'] + x['docSentText'] for x in self.doc_sents]) 115 | rouge1_p, bert_p, ent = calculate_metric_scores(cands, refs) 116 | 117 | for idx, key in enumerate(ids): 118 | self.metrics_data[hashhex(key)] = {'rouge1_p': rouge1_p[idx], 'bert_p': bert_p[idx], 'ent': ent[idx]} 119 | 120 | 121 | 122 | 123 | 124 | def save_predictions(self): 125 | if self.metric_precompute: 126 | with open(f'data/final_data/metric_compute_{self.mode}_oie.json', 'w') as f: 127 | json.dump(self.metrics_data, f) 128 | 129 | else: 130 | ## save the predictions into a csv file 131 | with open(os.path.join(self.output_file,'dev.csv'), 'w') as f: 132 | csvwriter = csv.writer(f, delimiter=',') 133 | header = ['database', 'topic', 'summaryFile', 'scuSentCharIdx', 'scuSentence', 'scuOffsets', 'scuText', 134 | 'documentFile', 'docSentCharIdx', 'docSentText', 'docSpanOffsets', 'summarySpanOffsets', 135 | 'docSpanText', 'summarySpanText'] 136 | csvwriter.writerow(header) 137 | for ind, row in enumerate(self.final_alignments): 138 | data = [] 139 | # from lists to string format for csv 140 | row['scuOffsets'] = ';'.join(', '.join(map(str, offset)) for offset in row['scuOffsets']) 141 | row['docSpanOffsets'] = ';'.join(', '.join(map(str, offset)) for offset in row['docSpanOffsets']) 142 | row['summarySpanOffsets'] = ';'.join(', '.join(map(str, offset)) for offset in row['summarySpanOffsets']) 143 | for key in header: 144 | if type(row[key]) is tuple: 145 | data.append(f"{row[key][0]}, {row[key][1]}") 146 | else: 147 | data.append(row[key]) 148 | csvwriter.writerow(data) 149 | 150 | 151 | 152 | 153 | 154 | 155 | def add_scu_doc_span_pairs(self, scu, cand_doc_sents): 156 | return 157 | 158 | def scu_span_aligner(self): 159 | """ Module which align scu and sentence 160 | in the document given a summary and document 161 | """ 162 | ## generate SCUs 163 | scus = [] 164 | scus.extend(generate_scu_oie_multiSent(self.summ_sents, doc_summ='summ')) 165 | 166 | if self.docSentsOIE: 167 | doc_spans = [] 168 | doc_spans.extend(generate_scu_oie_multiSent(self.doc_sents, doc_summ='doc')) 169 | 170 | 171 | ## create candidate pool for sentences in 172 | ## the document for each scu 173 | for scu in scus: 174 | if self.docSentsOIE: 175 | self.add_scu_doc_span_pairs(scu, doc_spans) 176 | 177 | def metric_filter(self, scu, use_precompute_metrics=True): 178 | """ this module finds the candidate sentences 179 | that are close to the given scu using metric based 180 | filtering 181 | """ 182 | refs = [x['docSentText'] for x in self.doc_sents] 183 | cands = [scu['scuText'] for _ in range(len(self.doc_sents))] 184 | ids = [scu['summaryFile'] + scu['scuText'] + x['documentFile'] + x['docSentText'] for x in self.doc_sents] 185 | if use_precompute_metrics: 186 | # global metrics_data 187 | rouge1_p = [] 188 | bert_p = [] 189 | ent = [] 190 | for idx, key in enumerate(ids): 191 | rouge1_p.append(self.metrics_data[hashhex(key)]['rouge1_p']) 192 | bert_p.append(self.metrics_data[hashhex(key)]['bert_p']) 193 | ent.append(self.metrics_data[hashhex(key)]['ent']) 194 | else: 195 | rouge1_p, bert_p, ent = calculate_metric_scores(cands, refs) 196 | cands = [] 197 | scores = [] 198 | for ind in range(len(refs)): 199 | preds = 0 200 | if rouge1_p[ind] > 0.2: # rouge1-p 201 | preds = 1 202 | if bert_p[ind] > 0.88: # BERT-p 203 | preds = 1 204 | if ent[ind] < 0.001: 205 | preds = 0 206 | if rouge1_p[ind] < 0.2: 207 | preds = 0 208 | if rouge1_p[ind] > 0.25 and bert_p[ind] < 0.85: 209 | preds = 0 210 | if preds == 1: 211 | tmp = copy.deepcopy(self.doc_sents[ind]) 212 | tmp['score'] = rouge1_p[ind] * bert_p[ind] * ent[ind] 213 | cands.append(tmp) 214 | 215 | cands = sorted(cands, key=lambda x: x['score'], reverse=True) 216 | 217 | return cands [:3] 218 | 219 | 220 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SuperPAL 2 | 3 | Data, Code and Model for the paper "[Summary-Source Proposition-level Alignment: Task, Datasets and Supervised Baseline](https://aclanthology.org/2021.conll-1.25.pdf)". 4 | 5 | If you find the code useful, please cite the following paper. 6 | ``` 7 | @inproceedings{ernst-etal-2021-summary, 8 | title = "Summary-Source Proposition-level Alignment: Task, Datasets and Supervised Baseline", 9 | author = "Ernst, Ori and Shapira, Ori and Pasunuru, Ramakanth and Lepioshkin, Michael and Goldberger, Jacob and Bansal, Mohit and Dagan, Ido", booktitle = "Proceedings of the 25th Conference on Computational Natural Language Learning", month = nov, year = "2021", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/2021.conll-1.25", pages = "310--322",} 10 | ``` 11 | 12 | You can use our huggingface model or check our demo [here](https://huggingface.co/biu-nlp/superpal). 13 | 14 | 15 | `run_glue.py` script was forked from [huggingface](https://github.com/huggingface/transformers) v2.5.1, and edited for our purpose. 16 | 17 | `supervised_oie_wrapper` directory is a wrapper over AllenNLP's (v0.9.0) pretrained Open IE model that was implemented by Gabriel Stanovsky. It was forked from [here](https://github.com/gabrielStanovsky/supervised_oie_wrapper), and edited for our purpose. 18 | 19 | In this repository we used python-3.6. Please refer to `environment_superPAL.yml` for other requirements. 20 | 21 | 22 | ## Manual Datasets ## 23 | 24 | All manual datasets are under `manual_datasets` repository, including crowdsourced dev and test sets, and Pyramid-based train set. 25 | 26 | As DUC-based datasets are limited to LDC agreement, we provide here only the character index of all propositions or sentences. 27 | 28 | To restore the text alignments please use: 29 | ``` 30 | python manual_datasets/restore_alignments.py -indx_csv_path -documents_path -summaries_path -output_file 31 | ``` 32 | If you have any issue regarding the DUC alignment regeneration, please contact via email. 33 | 34 | 35 | MultiNews alignments are released in full. 36 | 37 | 38 | 39 | ## Data generation ## 40 | 41 | Predicted alignments of MultiNews and CNN/DailyMail train and val datasets can be found [here](https://drive.google.com/drive/folders/1JnRrdbENzBLpbae5ZIKmil1fuZhm2toc?usp=sharing). 42 | 43 | ## Alignment model ## 44 | To apply aligment model on your own data, follow the following steps: 45 | 1. Download the trained model [here](https://drive.google.com/drive/folders/1kTaZQVxUm-RWbF71QpOue5xDuV7-IP2i?usp=sharing). 46 | 47 | 2. Run 48 | ``` 49 | python main_predict.py -data_path -output_path -alignment_model_path 50 | ``` 51 | `` should contain the following structure where a summary and its related document directory share the same name: 52 | 53 | - 54 | - summaries 55 | - A.txt 56 | - B.txt 57 | - ... 58 | - A 59 | - doc_A1 60 | - doc_A2 61 | - ... 62 | - B 63 | - doc_B1 64 | - doc_B2 65 | - ... 66 | 67 | 3. It will create two files in ``: 68 | 69 | - 'dev.tsv' - contains all alignment candidate pairs. 70 | 71 | - a '.csv' file - contains all predicted aligned pairs with their classification score. 72 | 73 | 4. To use the alignment model with your own data with different properties, you can inherent from the docSum2MRPC_Aligner class and overload the relevant functions. 74 | -------------------------------------------------------------------------------- /annotation2MRPC_Aligner.py: -------------------------------------------------------------------------------- 1 | from Aligner import Aligner 2 | from utils import * 3 | import pandas as pd 4 | 5 | 6 | 7 | class annotation2MRPC_Aligner(Aligner): 8 | 9 | 10 | """ 11 | gets gold crowdsourced alignments in a csv format 12 | and convert them to a .tsv file that fits the huggingface 'transformers' MRPC format (a classification paraphrasing model) 13 | 14 | """ 15 | 16 | 17 | def __init__(self, data_path='.', mode='dev', 18 | log_file='results/dev_log.txt', metric_precompute=True, output_file = './dev.tsv', 19 | database='duc2004,duc2007,MultiNews'): 20 | super().__init__(data_path=data_path, mode=mode, 21 | log_file=log_file, metric_precompute=metric_precompute, output_file = output_file, 22 | database=database) 23 | self.filter_data = False 24 | self.use_stored_alignment_database = False 25 | self.alignment_database_list = [] 26 | self.docSentsOIE = True 27 | self.alignment_database = pd.DataFrame(columns=['Quality', '#1 ID', '#2 ID', '#1 String', '#2 String','database', 'topic', 28 | 'summaryFile', 'scuSentCharIdx', 'scuSentence', 'documentFile', 'docSentCharIdx', 29 | 'docSentText', 'docSpanOffsets', 'summarySpanOffsets', 'docSpanText', 'summarySpanText']) 30 | 31 | 32 | 33 | 34 | def add_scu_doc_span_pairs(self, scu, doc_spans): 35 | 36 | scu_offset_str = offset_list2str(scu['scuOffsets']) 37 | id_scu = scu['topic'] + '_' + scu_offset_str 38 | 39 | for doc_span in doc_spans: 40 | doc_offset_str = offset_list2str(doc_span['docScuOffsets']) 41 | id_doc_sent = scu['topic'] + '_' + doc_span['documentFile'] + '_' + doc_offset_str 42 | label = 0 #label =0 for all. positive samples' label would be changed later 43 | 44 | 45 | self.alignment_database_list.append([label, id_scu, id_doc_sent, 46 | scu['scuText'], 47 | doc_span['docScuText'], scu['database'], 48 | scu['topic'], scu['summaryFile'], 49 | scu['scuSentCharIdx'], 50 | scu['scuSentence'], 51 | doc_span['documentFile'], 52 | doc_span['docSentCharIdx'], 53 | doc_span['docSentText'], 54 | offset_list2str( 55 | doc_span['docScuOffsets']), 56 | offset_list2str(scu['scuOffsets']), 57 | doc_span['docScuText'], scu['scuText']]) 58 | 59 | 60 | def metric_filter(self, scu): 61 | if self.filter_data: 62 | return super().metric_filter(scu) 63 | return self.doc_sents 64 | 65 | def scu_span_aligner(self): 66 | if self.use_stored_alignment_database: 67 | if self.mode == 'dev': 68 | self.alignment_database = pd.read_pickle("./span_alignment_database_dev.pkl") 69 | else: 70 | self.alignment_database = pd.read_pickle("./span_alignment_database_test.pkl") 71 | else: 72 | super().scu_span_aligner() 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | def update_positive_labels(self): 82 | if self.mode == 'dev': 83 | self.annotation_file = pd.read_csv('SCUdataGenerator/finalAlignmentDataset_dev_cleaned_wo_duplications.csv') 84 | else: 85 | self.annotation_file = pd.read_csv('SCUdataGenerator/finalAlignmentDataset_test_cleaned_wo_duplications.csv') 86 | 87 | for index, row in self.annotation_file.iterrows(): 88 | # row = self.annotation_file.sample().iloc[0] # random row for debug. 89 | documentFile = row['documentFile'] 90 | topic = row['topic'] 91 | summarySpanOffsets = offset_str2list(row['summarySpanOffsets']) 92 | docSpanOffsets = offset_str2list(row['docSpanOffsets']) 93 | cands_df = self.alignment_database[(self.alignment_database['documentFile']==documentFile).values & 94 | (self.alignment_database['topic']==topic).values] 95 | scu_cands_offset = np.unique(cands_df['summarySpanOffsets']) 96 | doc_cands_offset = np.unique(cands_df['docSpanOffsets']) 97 | self.updateAlignment(summarySpanOffsets, scu_cands_offset, docSpanOffsets, doc_cands_offset, documentFile, topic) 98 | 99 | # print(row['summarySpanText']) 100 | # print(row['scuText']) 101 | # DEBUG_print_k_max_match(summarySpanOffsets, scu_cands_offset, documentFile, 3, self.alignment_database,topic) 102 | 103 | 104 | 105 | 106 | 107 | def save_predictions(self): 108 | if self.metric_precompute: 109 | super().save_predictions() 110 | 111 | self.alignment_database = pd.DataFrame(self.alignment_database_list, 112 | columns=['Quality', '#1 ID', '#2 ID', '#1 String', '#2 String', 113 | 'database', 'topic', 114 | 'summaryFile', 'scuSentCharIdx', 'scuSentence', 'documentFile', 115 | 'docSentCharIdx', 116 | 'docSentText', 'docSpanOffsets', 'summarySpanOffsets', 117 | 'docSpanText', 'summarySpanText']) 118 | # self.alignment_database.to_pickle("./span_alignment_database_test_filtered.pkl") 119 | self.update_positive_labels() 120 | self.alignment_database.to_csv(os.path.join(self.output_file,'dev.tsv'), index=False, sep='\t') 121 | 122 | 123 | 124 | 125 | 126 | def updateAlignment(self, summarySpanOffsets, scu_cands_offset, docSpanOffsets, doc_cands_offset, documentFile, topic): 127 | INTERSECTION_RATIO_THRESH = 0.25 128 | 129 | summary_match_arr = np.array( 130 | [intersectionOverUnion(summarySpanOffsets, offset_str2list(scu_cand_offset)) for scu_cand_offset in scu_cands_offset]) 131 | 132 | matches_summary_scus = np.argwhere(summary_match_arr > INTERSECTION_RATIO_THRESH)#[[np.argmax(summary_match_arr)]]# 133 | 134 | doc_match_arr = np.array( 135 | [intersectionOverUnion(docSpanOffsets, offset_str2list(doc_cand_offset)) for doc_cand_offset in doc_cands_offset]) 136 | 137 | matches_doc_spans = np.argwhere(doc_match_arr > INTERSECTION_RATIO_THRESH)#[[np.argmax(doc_match_arr)]]# 138 | 139 | 140 | 141 | for summ_cand_idx in matches_summary_scus: 142 | for doc_cand_idx in matches_doc_spans: 143 | matched_row = self.alignment_database[(self.alignment_database['documentFile'] == documentFile).values & 144 | (self.alignment_database['topic'] == topic).values & 145 | (self.alignment_database['summarySpanOffsets'] == 146 | scu_cands_offset[summ_cand_idx[0]]).values & 147 | (self.alignment_database['docSpanOffsets'] == 148 | doc_cands_offset[doc_cand_idx[0]]).values] 149 | 150 | if len(matched_row.index) > 0: 151 | assert(len(matched_row.index) == 1) 152 | self.alignment_database.iloc[matched_row.index[0]]['Quality'] = 1 153 | 154 | 155 | 156 | 157 | 158 | def DEBUG_print_k_max_match(summarySpanOffsets, scu_cands_offset, documentFile, k, alignment_database,topic): 159 | match_arr = np.array([intersectionOverUnion(summarySpanOffsets, offset_str2list(scu_cand_offset)) for scu_cand_offset in scu_cands_offset]) 160 | 161 | max_index = match_arr.argsort()[-k:][::-1] 162 | 163 | for i in range(k): 164 | string = alignment_database[(alignment_database['documentFile'] == documentFile).values & 165 | (alignment_database['topic'] == topic).values & 166 | (alignment_database['summarySpanOffsets'] == 167 | scu_cands_offset[max_index[i]]).values].iloc[0]['#1 String'] 168 | score = match_arr[max_index[i]] 169 | 170 | print(score, string) 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /createSubDatasets.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import json 4 | import argparse 5 | from os.path import join 6 | 7 | 8 | def extract_salienceSpans(alignments): 9 | 10 | print('number of alignments: ', len(alignments)) 11 | 12 | doc_alignments = alignments[['topic', 'documentFile', 'docSentCharIdx', 13 | 'docSentText', 'docSpanOffsets','docSpanText']] 14 | doc_alignments.to_csv(join(args.out_dir_path,"salience.csv"), index=False) 15 | 16 | 17 | print('number of salient IUs: ',len(doc_alignments.drop_duplicates())) 18 | 19 | 20 | def extract_clusters(alignments): 21 | if 'scuText' in list(alignments.columns): #if annotation data 22 | summSpansLabel = 'scuText' 23 | summSpansOffsetsLabel = 'scuOffsets' 24 | elif 'summarySpanOieText' in list(alignments.columns): #if train data that uses openIE 25 | summSpansLabel = 'summarySpanOieText' 26 | summSpansOffsetsLabel = 'summarySpanOieOffsets' 27 | else: 28 | summSpansLabel = 'summarySpanText' 29 | summSpansOffsetsLabel = 'summarySpanOffsets' 30 | 31 | clusters_num = 0 32 | clusters_dict = {'data':[]} 33 | scu2clusterIdx = {} 34 | alignmentsPerClusterList = [] 35 | for topic in set(alignments['topic'].values): 36 | df_topic = alignments[alignments['topic'] == topic] 37 | clusters_dict['data'].append({'topic':str(topic), 'clusters':[]}) 38 | for scuText in set(df_topic[summSpansLabel].values): 39 | clusters_num += 1 40 | df_topic_scu = df_topic[df_topic[summSpansLabel] == scuText] 41 | alignmentsPerClusterList.append(len(df_topic_scu)) 42 | scu2clusterIdx[scuText] = clusters_num 43 | clusters_dict['data'][-1]['clusters'].append({'title':scuText, 'clusterID': clusters_num, 44 | 'scuSentCharIdx': str(df_topic_scu.iloc[0]['scuSentCharIdx']), 45 | 'scuSentence': df_topic_scu.iloc[0]['scuSentence'], 46 | summSpansOffsetsLabel: df_topic_scu.iloc[0][summSpansOffsetsLabel], 'spans':[]}) 47 | 48 | for index, row in df_topic_scu.iterrows(): 49 | clusters_dict['data'][-1]['clusters'][-1]['spans'].append({'documentFile': str(row['documentFile']), 50 | 'docSentCharIdx': str(row['docSentCharIdx']), 51 | 'docSentText': row['docSentText'], 52 | 'docSpanOffsets': row['docSpanOffsets'], 53 | 'docSpanText': row['docSpanText']}) 54 | 55 | 56 | 57 | print ('clusters number: ', clusters_num) 58 | print('Num of alignments per cluster: ', np.mean(alignmentsPerClusterList), '(',np.std(alignmentsPerClusterList),')') 59 | 60 | with open(join(args.out_dir_path,"clustering.json"), "w") as outfile: 61 | json.dump(clusters_dict, outfile) 62 | 63 | return clusters_dict, scu2clusterIdx 64 | 65 | 66 | 67 | 68 | def extract_textPlanning(alignments, scu2clusterIdx): 69 | if 'scuText' in list(alignments.columns): #if annotation data 70 | summSpansLabel = 'scuText' 71 | elif 'summarySpanOieText' in list(alignments.columns): #if train data that uses openIE 72 | summSpansLabel = 'summarySpanOieText' 73 | else: 74 | summSpansLabel = 'summarySpanText' 75 | sentGeneration_dict = {'data': []} 76 | sentences_num = 0 77 | clustersPerSentenceList = [] 78 | for topic in set(alignments['topic'].values): 79 | df_topic = alignments[alignments['topic'] == topic] 80 | sentGeneration_dict['data'].append({'topic': str(topic), 'sentences': []}) 81 | scuSentenceList = list(set(zip(df_topic['scuSentence'].values, df_topic['scuSentCharIdx'].values))) 82 | scuSentenceList.sort(key=lambda x: x[1]) 83 | sent_order = 0 84 | for scuSentence, scuSentCharIdx in scuSentenceList: 85 | df_topic_sent = df_topic[df_topic['scuSentence']==scuSentence] 86 | cluster_idx_list = [scu2clusterIdx[scuText] for scuText in set(df_topic_sent[summSpansLabel].values)] 87 | clustersPerSentenceList.append(len(cluster_idx_list)) 88 | sentGeneration_dict['data'][-1]['sentences'].append({'clusters': cluster_idx_list, 'sentence': scuSentence, 'order': sent_order}) 89 | sent_order += 1 90 | sentences_num += 1 91 | 92 | print('Num of sentence generation samples: ', sentences_num) 93 | print('Num of clusters per sentence: ',np.mean(clustersPerSentenceList), '(',np.std(clustersPerSentenceList),')') 94 | with open(join(args.out_dir_path,"generation.json"), "w") as outfile: 95 | json.dump(sentGeneration_dict, outfile) 96 | 97 | 98 | 99 | 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('-alignments_path', type=str, required=True) 102 | parser.add_argument('-out_dir_path', type=str, default='.') 103 | args = parser.parse_args() 104 | 105 | if __name__ == "__main__": 106 | 107 | 108 | alignments = pd.read_csv(args.alignments_path) 109 | extract_salienceSpans(alignments) 110 | clusters_dict, scu2clusterIdx = extract_clusters(alignments) 111 | extract_textPlanning(alignments, scu2clusterIdx) 112 | -------------------------------------------------------------------------------- /docSum2MRPC_Aligner.py: -------------------------------------------------------------------------------- 1 | from annotation2MRPC_Aligner import annotation2MRPC_Aligner 2 | from utils import * 3 | import pandas as pd 4 | 5 | 6 | 7 | class docSum2MRPC_Aligner(annotation2MRPC_Aligner): 8 | """ 9 | finds document-summary alignment pair candidates 10 | and write them in a .tsv file that fits the huggingface 'transformers' MRPC format (a classification paraphrasing model) 11 | 12 | """ 13 | 14 | 15 | def read_and_split(self, dataset, sfile): 16 | ### process the summary 17 | logging.info(f"Evaluating the Alignment for Summary: {sfile}") 18 | sfile_basename = os.path.basename(sfile) 19 | doc_name = sfile_basename.split('.')[0]#.lower()+'t' 20 | 21 | summary = read_generic_file(sfile) 22 | s_sents = tokenize.sent_tokenize(" ".join(summary)) 23 | self.summ_sents = [] 24 | idx_start = 0 25 | for sent in s_sents: 26 | self.summ_sents.append({'summaryFile': sfile_basename, 'scuSentCharIdx': idx_start, 27 | 'scuSentence': sent, 'database': dataset, 'topic': doc_name}) 28 | idx_start = idx_start + len(sent) + 1 # 1 for the space character 29 | 30 | ## process all the documents files 31 | doc_files = glob.glob(f"{self.data_path}/{doc_name}/*") 32 | 33 | logging.info(f"Following documents have been found for them:") 34 | logging.info("\n".join(doc_files)) 35 | self.doc_sents = [] 36 | for df in doc_files: 37 | doc_id = os.path.basename(df) 38 | document = read_generic_file(df) 39 | dsents = tokenize.sent_tokenize(" ".join(document)) 40 | idx_start = 0 41 | for dsent in dsents: 42 | if dsent != "...": # this is an exception 43 | self.doc_sents.append({'documentFile': doc_id, 'docSentCharIdx': idx_start, 44 | 'docSentText': dsent}) 45 | 46 | idx_start = idx_start + len(dsent) + 1 # 1 for the space charater between sentences 47 | 48 | 49 | 50 | 51 | def save_predictions(self): 52 | #if self.metric_precompute: 53 | # super().save_predictions() 54 | 55 | self.alignment_database = pd.DataFrame(self.alignment_database_list, 56 | columns=['Quality', '#1 ID', '#2 ID', '#1 String', '#2 String', 57 | 'database', 'topic', 58 | 'summaryFile', 'scuSentCharIdx', 'scuSentence', 'documentFile', 59 | 'docSentCharIdx', 60 | 'docSentText', 'docSpanOffsets', 'summarySpanOffsets', 61 | 'docSpanText', 'summarySpanText']) 62 | 63 | self.alignment_database.to_csv(os.path.join(self.output_file,'dev.tsv'), index=False, sep='\t') 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /environment_superPAL.yml: -------------------------------------------------------------------------------- 1 | name: superPAL 2 | channels: 3 | - pytorch 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - blas=1.0 8 | - ca-certificates=2021.10.26 9 | - certifi=2021.10.8 10 | - cpuonly=2.0 11 | # - cudatoolkit=11.0.3 12 | - freetype=2.10.4 13 | - intel-openmp=2021.4.0 14 | - jpeg=9d 15 | - libpng=1.6.37 16 | - libtiff=4.2.0 17 | - libuv=1.40.0 18 | - libwebp=1.2.0 19 | - lz4-c=1.9.3 20 | - mkl=2021.4.0 21 | - mkl-service=2.4.0 22 | - mkl_fft=1.3.1 23 | - mkl_random=1.2.2 24 | - numpy=1.21.2 25 | - numpy-base=1.21.2 26 | - olefile=0.46 27 | - openssl=1.1.1l 28 | - pillow=8.4.0 29 | - pip=21.2.4 30 | - python=3.7.11 31 | - pytorch-mutex=1.0 32 | - setuptools=58.0.4 33 | - six=1.16.0 34 | - sqlite=3.37.0 35 | - tk=8.6.11 36 | - typing_extensions=3.10.0.2 37 | - wheel=0.37.0 38 | - xz=5.2.5 39 | - zlib=1.2.11 40 | - zstd=1.4.9 41 | - pip: 42 | - aiohttp==3.8.1 43 | - aiosignal==1.2.0 44 | - alabaster==0.7.12 45 | - allennlp==0.9.0 46 | #- allennlp-models==2.8.0 47 | - argcomplete==1.12.3 48 | - argon2-cffi==21.3.0 49 | - argon2-cffi-bindings==21.2.0 50 | - async-timeout==4.0.2 51 | - asynctest==0.13.0 52 | - atomicwrites==1.4.0 53 | - attrs==21.4.0 54 | - babel==2.9.1 55 | - backcall==0.2.0 56 | - backports-csv==1.0.7 57 | - base58==2.1.1 58 | - beautifulsoup4==4.10.0 59 | - bert-score==0.2.2 60 | - bleach==4.1.0 61 | - blis==0.2.4 62 | - boto3==1.20.26 63 | - botocore==1.23.26 64 | - cached-path==0.3.2 65 | - cached-property==1.5.2 66 | - cachetools==4.2.4 67 | - catalogue==2.0.6 68 | - cffi==1.15.0 69 | - chardet==4.0.0 70 | - charset-normalizer==2.0.9 71 | - cheroot==8.5.2 72 | - cherrypy==18.6.1 73 | - click==8.0.3 74 | - colorama==0.4.4 75 | - configparser==5.2.0 76 | - conllu==1.3.1 77 | - cryptography==36.0.1 78 | - cycler==0.11.0 79 | - cymem==2.0.6 80 | - datasets==1.17.0 81 | - debugpy==1.5.1 82 | - decorator==5.1.0 83 | - defusedxml==0.7.1 84 | - dill==0.3.4 85 | - docker-pycreds==0.4.0 86 | - docopt==0.6.2 87 | - docutils==0.17.1 88 | - editdistance==0.6.0 89 | - entrypoints==0.3 90 | - fairscale==0.4.0 91 | - feedparser==6.0.8 92 | - filelock==3.3.2 93 | - flaky==3.7.0 94 | - flask==2.0.2 95 | - flask-cors==3.0.10 96 | - fonttools==4.28.5 97 | - frozenlist==1.2.0 98 | - fsspec==2021.11.1 99 | - ftfy==6.0.3 100 | - future==0.18.2 101 | - gevent==21.12.0 102 | - gitdb==4.0.9 103 | - gitpython==3.1.24 104 | - google-api-core==2.3.2 105 | - google-auth==2.3.3 106 | - google-cloud-core==2.2.1 107 | - google-cloud-storage==1.43.0 108 | - google-crc32c==1.3.0 109 | - google-resumable-media==2.1.0 110 | - googleapis-common-protos==1.54.0 111 | - greenlet==1.1.2 112 | - h5py==3.6.0 113 | - huggingface-hub==0.1.2 114 | - idna==3.3 115 | - imagesize==1.3.0 116 | - importlib-metadata==4.10.0 117 | - importlib-resources==5.4.0 118 | - iniconfig==1.1.1 119 | - ipykernel==6.6.0 120 | - ipython==7.30.1 121 | - ipython-genutils==0.2.0 122 | - ipywidgets==7.6.5 123 | - iso-639==0.4.5 124 | - itsdangerous==2.0.1 125 | - jaraco-classes==3.2.1 126 | - jaraco-collections==3.4.0 127 | - jaraco-functools==3.5.0 128 | - jaraco-text==3.6.0 129 | - jedi==0.18.1 130 | - jinja2==3.0.3 131 | - jmespath==0.10.0 132 | - joblib==1.1.0 133 | - jsonnet==0.14.0 134 | - jsonpickle==2.0.0 135 | - jsonschema==4.3.2 136 | - jupyter==1.0.0 137 | - jupyter-client==7.1.0 138 | - jupyter-console==6.4.0 139 | - jupyter-core==4.9.1 140 | - jupyterlab-pygments==0.1.2 141 | - jupyterlab-widgets==1.0.2 142 | - kiwisolver==1.3.2 143 | - lmdb==1.2.1 144 | - lxml==4.7.1 145 | - markupsafe==2.0.1 146 | - matplotlib==3.5.1 147 | - matplotlib-inline==0.1.3 148 | - mistune==0.8.4 149 | - more-itertools==8.12.0 150 | - multidict==5.2.0 151 | - multiprocess==0.70.12.2 152 | - munch==2.5.0 153 | - murmurhash==1.0.6 154 | - nbclient==0.5.9 155 | - nbconvert==6.3.0 156 | - nbformat==5.1.3 157 | - nest-asyncio==1.5.4 158 | - nltk==3.6.3 159 | - notebook==6.4.6 160 | - numpydoc==1.1.0 161 | - overrides==3.1.0 162 | - packaging==21.3 163 | - pandas==1.3.5 164 | - pandocfilters==1.5.0 165 | - parsimonious==0.8.1 166 | - parso==0.8.3 167 | - pathtools==0.1.2 168 | - pathy==0.6.1 169 | - patternfork-nosql==3.6 170 | - pdfminer-six==20211012 171 | - pickleshare==0.7.5 172 | - plac==0.9.6 173 | - pluggy==1.0.0 174 | - portend==3.1.0 175 | - preshed==2.0.1 176 | - prometheus-client==0.12.0 177 | - promise==2.3 178 | - prompt-toolkit==3.0.24 179 | - protobuf==3.19.1 180 | - psutil==5.8.0 181 | - py==1.11.0 182 | - py-rouge==1.1 183 | - pyarrow==6.0.1 184 | - pyasn1==0.4.8 185 | - pyasn1-modules==0.2.8 186 | - pycparser==2.21 187 | - pydantic==1.8.2 188 | - pygments==2.10.0 189 | - pyparsing==3.0.6 190 | - pyrsistent==0.18.0 191 | - pytest==6.2.5 192 | - python-dateutil==2.8.2 193 | - python-docx==0.8.11 194 | - pytorch-pretrained-bert==0.6.2 195 | - pytorch-transformers==1.1.0 196 | - pytz==2021.3 197 | - pyyaml==6.0 198 | - pyzmq==22.3.0 199 | - qtconsole==5.2.2 200 | - qtpy==2.0.0 201 | - regex==2021.11.10 202 | - requests==2.26.0 203 | - responses==0.16.0 204 | - rouge==1.0.1 205 | - rsa==4.8 206 | - s3transfer==0.5.0 207 | - sacremoses==0.0.46 208 | - scikit-learn==1.0.2 209 | - scipy==1.7.3 210 | - send2trash==1.8.0 211 | - sentencepiece==0.1.96 212 | - sentry-sdk==1.5.1 213 | - sgmllib3k==1.0.0 214 | - shortuuid==1.0.8 215 | - smart-open==5.2.1 216 | - smmap==5.0.0 217 | - snowballstemmer==2.2.0 218 | - soupsieve==2.3.1 219 | - spacy==2.1.9 220 | - spacy-legacy==3.0.8 221 | - sphinx==4.3.2 222 | - sphinxcontrib-applehelp==1.0.2 223 | - sphinxcontrib-devhelp==1.0.2 224 | - sphinxcontrib-htmlhelp==2.0.0 225 | - sphinxcontrib-jsmath==1.0.1 226 | - sphinxcontrib-qthelp==1.0.3 227 | - sphinxcontrib-serializinghtml==1.1.5 228 | - sqlitedict==1.7.0 229 | - sqlparse==0.4.2 230 | - srsly==1.0.5 231 | - subprocess32==3.5.4 232 | - tempora==4.1.2 233 | - tensorboardx==2.4.1 234 | - termcolor==1.1.0 235 | - terminado==0.12.1 236 | - testpath==0.5.0 237 | - thinc==7.0.8 238 | - threadpoolctl==3.0.0 239 | - tokenizers==0.5.2 240 | - toml==0.10.2 241 | - tornado==6.1 242 | - tqdm==4.62.3 243 | - traitlets==5.1.1 244 | - transformers==2.5.1 245 | - typer==0.4.0 246 | - unidecode==1.3.2 247 | - urllib3==1.26.7 248 | - wandb==0.12.9 249 | - wasabi==0.9.0 250 | - wcwidth==0.2.5 251 | - webencodings==0.5.1 252 | - werkzeug==2.0.2 253 | - widgetsnbextension==3.5.2 254 | - word2number==1.1 255 | - xxhash==2.0.2 256 | - yarl==1.7.2 257 | - yaspin==2.1.0 258 | - zc-lockfile==2.0 259 | - zipp==3.6.0 260 | - zope-event==4.5.0 261 | - zope-interface==5.4.0 262 | prefix: C:\Users\Anaconda3\envs\superPAL 263 | -------------------------------------------------------------------------------- /filterContained.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | def checkPairContained(containedCandidateOffset_list, containOffset_list): 5 | containedList = [] 6 | for containedCandidate in containedCandidateOffset_list: 7 | contained = False 8 | for offset in containOffset_list: 9 | contained_start, contained_end = containedCandidate 10 | start, end = offset 11 | if contained_start >= start and contained_end <= end: 12 | contained = True 13 | containedList.append(contained) 14 | 15 | notContained = not(all(containedList)) #if all spans are contained 16 | return notContained 17 | 18 | 19 | 20 | def checkContained(scuOffsetDict,sentenceText, sentenceOffset = 0): 21 | notContainedDict = {} 22 | for containedCandidate, containedCandidateOffset_list in scuOffsetDict.items(): 23 | notContainedList = [] 24 | for contain, containOffset_list in scuOffsetDict.items(): 25 | if contain == containedCandidate: 26 | continue 27 | 28 | #if one of scus is the full sentence, don't filter the other scus. 29 | full_sent_scu = True if containOffset_list[0][0] - sentenceOffset == 0 and\ 30 | containOffset_list[0][1] - sentenceOffset > 0.95*(len(sentenceText) - 1) else False 31 | if full_sent_scu: 32 | continue 33 | notContained = checkPairContained(containedCandidateOffset_list, containOffset_list) 34 | notContainedList.append(notContained) 35 | # if not notContained: 36 | # print(containedCandidate) 37 | # print (contain) 38 | 39 | notContainedDict[containedCandidate] = all(notContainedList) 40 | 41 | return notContainedDict 42 | 43 | 44 | -------------------------------------------------------------------------------- /finalAlignmentPred.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import sys 4 | import numpy as np 5 | import pickle 6 | 7 | 8 | def calc_final_alignments(csv_path, model_path, preds, preds_prob): 9 | 10 | df = pd.read_csv(os.path.join(csv_path,'dev.tsv'), sep='\t') 11 | positiveAlignments = df#[preds==1] 12 | positiveAlignments = positiveAlignments[['database', 'topic','summaryFile', 'scuSentCharIdx', 'scuSentence', 13 | 'documentFile', 'docSentCharIdx', 'docSentText', 'docSpanOffsets', 14 | 'summarySpanOffsets', 'docSpanText', 'summarySpanText','Quality']] 15 | positiveAlignments['pred_prob'] = preds_prob#[preds==1] 16 | pred_file_name = csv_path[:-1].split('/')[-1] + '_' + model_path[:-1].split('/')[-1] + '_negative' + '.csv' 17 | pred_out_path = os.path.join(csv_path, pred_file_name) 18 | positiveAlignments.to_csv(pred_out_path, index=False) 19 | 20 | 21 | 22 | positiveAlignments = df[preds==1] 23 | positiveAlignments = positiveAlignments[['database', 'topic','summaryFile', 'scuSentCharIdx', 'scuSentence', 24 | 'documentFile', 'docSentCharIdx', 'docSentText', 'docSpanOffsets', 25 | 'summarySpanOffsets', 'docSpanText', 'summarySpanText','Quality']] 26 | positiveAlignments['pred_prob'] = preds_prob[preds==1] 27 | pred_file_name = csv_path[:-1].split('/')[-1] + '_' + model_path[:-1].split('/')[-1] + '.csv' 28 | pred_out_path = os.path.join(csv_path, pred_file_name) 29 | positiveAlignments.to_csv(pred_out_path, index=False) 30 | 31 | 32 | 33 | 34 | def calc_alignment_sim_mat(csv_path, model_path, preds_prob): 35 | OUT_PATH = os.path.join(csv_path,'sim_mats') 36 | if not os.path.exists(OUT_PATH): 37 | os.makedirs(OUT_PATH) 38 | 39 | df = pd.read_csv(os.path.join(csv_path,'dev.tsv'), sep='\t') 40 | spans_num = int(np.sqrt(len(df))) 41 | sim_mat = np.zeros((spans_num, spans_num)) 42 | sim_mat_idx = df[['sim_mat_idx']] 43 | for sim_idx, prob in zip(sim_mat_idx.values, preds_prob): 44 | sim_idx = sim_idx[0].split(',') 45 | sim_mat[int(sim_idx[0]),int(sim_idx[1])] = prob 46 | pred_file_name = 'SupAligner' + '_' + model_path[:-1].split('/')[-1] + '_' + df['topic'].iloc[0] + '.pickle' 47 | pred_out_path = os.path.join(OUT_PATH, pred_file_name) 48 | 49 | with open(pred_out_path, 'wb') as handle: 50 | pickle.dump(sim_mat, handle) 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /main_predict.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from docSum2MRPC_Aligner import docSum2MRPC_Aligner 3 | 4 | import run_glue 5 | 6 | import contextlib 7 | @contextlib.contextmanager 8 | def redirect_argv(num): 9 | sys._argv = sys.argv[:] 10 | sys.argv = str(num).split() 11 | yield 12 | sys.argv = sys._argv 13 | 14 | 15 | 16 | 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-data_path', type=str, required=True) 21 | parser.add_argument('-mode', type=str, default='dev') 22 | parser.add_argument('-log_file', type=str, default='results/dev_log.txt') 23 | parser.add_argument('-output_path', type=str, required=True) 24 | parser.add_argument('-alignment_model_path', type=str, required=True) 25 | parser.add_argument('-database', type=str, default='None') 26 | args = parser.parse_args() 27 | 28 | 29 | 30 | 31 | aligner = docSum2MRPC_Aligner(data_path=args.data_path, mode=args.mode, 32 | log_file=args.log_file, output_file = args.output_path, 33 | database=args.database) 34 | logging.info(f'output_file_path: {args.output_path}') 35 | 36 | summary_files = glob.glob(f"{args.data_path}/summaries/*") 37 | for sfile in summary_files: 38 | print ('Starting with summary {}'.format(sfile)) 39 | aligner.read_and_split(args.database, sfile) 40 | aligner.scu_span_aligner() 41 | aligner.save_predictions() 42 | with redirect_argv('python --model_type roberta --model_name_or_path roberta-large-mnli --task_name MRPC --do_eval' 43 | f' --calc_final_alignments --weight_decay 0.1 --data_dir {args.output_path}' 44 | ' --max_seq_length 128 --per_gpu_train_batch_size 16 --per_gpu_eval_batch_size 16 --learning_rate 2e-6' 45 | ' --logging_steps 500 --num_train_epochs 2.0 --evaluate_during_training --overwrite_cache' 46 | f' --output_dir {args.alignment_model_path}'): 47 | run_glue.main() 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /manual_datasets/dev_DUC_index_only.csv: -------------------------------------------------------------------------------- 1 | database,topic,summaryFile,scuSentCharIdx,scuOffsets,documentFile,docSentCharIdx,docSpanOffsets,summarySpanOffsets 2 | duc2007,D0718,D0718.M.250.D.A,950,"950, 1022;1039, 1068",APW19990630.0028,1930,"1930, 2033","950, 1021;1039, 1068" 3 | duc2007,D0718,D0718.M.250.D.A,950,"953, 972;1069, 1141",APW19990630.0028,2096,"2109, 2281","1039, 1141" 4 | duc2007,D0718,D0718.M.250.D.A,950,"953, 972;1069, 1141",APW19990701.0333,1197,"1210, 1382","1003, 1141" 5 | duc2007,D0712,D0712.M.250.C.C,449,"449, 543",XIE19960313.0144,1584,"1584, 1623;1658, 1775","449, 543" 6 | duc2007,D0712,D0712.M.250.C.C,449,"449, 474;544, 604",XIE19960313.0144,1584,"1584, 1623;1726, 1775","449, 474;548, 604" 7 | duc2007,D0712,D0712.M.250.C.C,0,"0, 92",APW19990122.0181,1221,"1221, 1302","0, 92" 8 | duc2007,D0712,D0712.M.250.C.C,0,"93, 155",APW19981010.0301,607,"702, 751","101, 155" 9 | duc2007,D0712,D0712.M.250.C.C,0,"0, 92",XIE19960313.0144,712,"712, 871","0, 92" 10 | duc2007,D0712,D0712.M.250.C.C,0,"93, 155",XIE19980330.0017,0,"184, 212","93, 155" 11 | duc2007,D0712,D0712.M.250.C.C,0,"0, 92",APW19981010.0301,1819,"1856, 1895","9, 92" 12 | duc2007,D0712,D0712.M.250.C.C,0,"93, 155",XIE19980330.0017,437,"540, 636","101, 155" 13 | duc2007,D0712,D0712.M.250.C.C,0,"93, 155",XIE19960313.0144,3798,"3920, 3962","101, 155" 14 | duc2007,D0712,D0712.M.250.C.C,0,"93, 155",NYT19980923.0400,1026,"1026, 1121","93, 155" 15 | duc2007,D0712,D0712.M.250.C.C,0,"93, 155",APW19990122.0181,1221,"1365, 1413","101, 155" 16 | duc2007,D0712,D0712.M.250.C.C,0,"0, 92",APW19981010.0301,607,"607, 692","0, 92" 17 | duc2007,D0712,D0712.M.250.C.C,0,"0, 92",XIE19980330.0017,437,"437, 535","0, 92" 18 | duc2007,D0712,D0712.M.250.C.C,0,"0, 92",NYT19980923.0400,1026,"1075, 1086;1126, 1206","0, 92" 19 | duc2007,D0712,D0712.M.250.C.C,0,"0, 92",APW19990122.0181,143,"143, 150;208, 251","0, 92" 20 | duc2007,D0719,D0719.M.250.E.D,1673,"1673, 1726",XIE19971230.0110,321,"353, 409","1673, 1717" 21 | duc2004,D30001,D30001.M.100.T.A.html,303,"303, 354",APW19981027.0491,0,"45, 90","303, 354" 22 | duc2004,D30001,D30001.M.100.T.A.html,303,"355, 413",APW19981026.0220,763,"763, 787;909, 958","361, 413" 23 | duc2004,D30001,D30001.M.100.T.A.html,303,"355, 413",APW19981026.0220,959,"959, 1105","361, 413" 24 | duc2004,D30001,D30001.M.100.T.A.html,303,"303, 354",APW19981027.0491,158,"198, 229;249, 396","303, 354" 25 | duc2004,D30001,D30001.M.100.T.A.html,303,"355, 413",APW19981026.0220,0,"21, 123","361, 413" 26 | duc2007,D0712,D0712.M.250.C.C,920,"1019, 1041",XIE19960313.0144,2843,"2843, 2853;2905, 2922","988, 997;1019, 1041" 27 | duc2007,D0712,D0712.M.250.C.C,920,"1019, 1041",XIE19960313.0144,2432,"2532, 2584","1019, 1041" 28 | duc2007,D0712,D0712.M.250.C.C,920,"1019, 1041",XIE19960313.0144,272,"272, 278;311, 332","988, 997;1019, 1041" 29 | duc2007,D0712,D0712.M.250.C.C,920,"1019, 1041",XIE19960313.0144,392,"480, 490;560, 594","988, 997;1019, 1041" 30 | duc2004,D30001,D30001.M.100.T.A.html,414,"430, 476",APW19981118.0276,1873,"1884, 2081","430, 544" 31 | duc2004,D30001,D30001.M.100.T.A.html,414,"430, 476",APW19981113.0251,2256,"2337, 2399","430, 471" 32 | duc2004,D30001,D30001.M.100.T.A.html,414,"430, 476",APW19981124.0267,0,"0, 21;33, 51;121, 177","430, 506" 33 | duc2004,D30001,D30001.M.100.T.A.html,414,"510, 544",APW19981124.0267,571,"617, 681;689, 743","430, 484;507, 520" 34 | duc2004,D30001,D30001.M.100.T.A.html,414,"430, 476",APW19981113.0251,0,"21, 38;78, 194","430, 506" 35 | duc2007,D0719,D0719.M.250.E.D,82,"82, 164",XIE19980204.0022,182,"252, 342","137, 164" 36 | duc2004,D31009,D31009.M.100.T.B.html,157,"157, 225",NYT19981202.0315,1221,"1221, 1262","157, 163;190, 225" 37 | duc2004,D31009,D31009.M.100.T.B.html,157,"157, 225",APW19981219.0504,297,"297, 303;324, 355","157, 163;190, 225" 38 | duc2004,D31009,D31009.M.100.T.B.html,157,"157, 225",APW19981202.0880,0,"0, 53;64, 91","157, 225" 39 | duc2004,D31009,D31009.M.100.T.B.html,157,"157, 225",APW19981203.0322,157,"157, 203","157, 163;190, 225" 40 | duc2004,D31009,D31009.M.100.T.B.html,157,"157, 225",NYT19981202.0315,0,"106, 111;175, 215","157, 163;190, 225" 41 | duc2004,D31009,D31009.M.100.T.B.html,157,"157, 225",APW19981219.0504,837,"837, 849;859, 889","157, 163;190, 225" 42 | duc2004,D31009,D31009.M.100.T.B.html,157,"157, 225",NYT19981202.0315,216,"216, 252;279, 342","157, 225" 43 | duc2007,D0719,D0719.M.250.E.D,791,"871, 940",XIE19970214.0294,243,"243, 451","875, 940" 44 | duc2007,D0719,D0719.M.250.E.D,791,"871, 940",XIE19970214.0294,0,"30, 45;124, 229","791, 797;803, 809;875, 886;920, 940" 45 | duc2007,D0719,D0719.M.250.E.D,791,"791, 870",XIE19970214.0294,0,"30, 118","791, 797;803, 869" 46 | duc2007,D0719,D0719.M.250.E.D,791,"871, 940",XIE19970214.0294,890,"890, 902;977, 1029","791, 797;803, 809;875, 915" 47 | duc2004,D30001,D30001.M.100.T.A.html,248,"248, 302",APW19981022.0269,347,"347, 465;523, 540","248, 302" 48 | duc2004,D30001,D30001.M.100.T.A.html,248,"248, 302",APW19981022.0269,0,"0, 97","248, 302" 49 | duc2007,D0718,D0718.M.250.D.A,1229,"1229, 1325",NYT20000215.0074,142,"164, 296","1246, 1325" 50 | duc2007,D0718,D0718.M.250.D.A,1229,"1229, 1325",NYT20000215.0074,0,"16, 94","1246, 1325" 51 | duc2007,D0718,D0718.M.250.D.A,0,"0, 54;72, 94;112, 124",APW19990701.0333,1031,"1080, 1116","0, 79;112, 125" 52 | duc2007,D0718,D0718.M.250.D.A,0,"0, 54;72, 108",APW19990701.0333,1495,"1495, 1545","0, 54" 53 | duc2007,D0719,D0719.M.250.E.D,0,"0, 81",APW19981127.0178,1123,"1123, 1233","0, 81" 54 | duc2007,D0719,D0719.M.250.E.D,0,"0, 81",APW19981127.0178,109,"162, 234","0, 51" 55 | duc2007,D0719,D0719.M.250.E.D,0,"0, 81",APW19981208.0421,1270,"1388, 1460","0, 51" 56 | duc2007,D0712,D0712.M.250.C.C,605,"655, 723",APW19990122.0181,1416,"1501, 1563","656, 723" 57 | duc2007,D0712,D0712.M.250.C.C,605,"655, 723",APW19981010.0293,533,"569, 665","656, 723" 58 | duc2007,D0712,D0712.M.250.C.C,605,"605, 654",XIE19960313.0144,1157,"1166, 1185;1276, 1332","613, 654" 59 | duc2007,D0712,D0712.M.250.C.C,605,"655, 723",APW19981010.0301,851,"851, 899;937, 1001","656, 723" 60 | duc2007,D0712,D0712.M.250.C.C,605,"655, 723",APW19981010.0301,1700,"1711, 1818","656, 723" 61 | duc2007,D0712,D0712.M.250.C.C,605,"655, 723",XIE19960313.0144,2432,"2432, 2523","656, 723" 62 | duc2007,D0719,D0719.M.250.E.D,165,"165, 211",XIE19980204.0022,398,"398, 482","165, 211" 63 | duc2004,D31009,D31009.M.100.T.B.html,408,"442, 499",APW19981203.0322,157,"157, 203","429, 499" 64 | duc2004,D31009,D31009.M.100.T.B.html,408,"442, 499",APW19981202.0880,0,"0, 53;56, 99;102, 142","429, 499" 65 | duc2004,D31009,D31009.M.100.T.B.html,408,"442, 499",NYT19981202.0315,0,"66, 92;106, 173;188, 215","408, 499" 66 | duc2004,D31009,D31009.M.100.T.B.html,408,"408, 465",NYT19981202.0315,0,"66, 92;106, 215","408, 499" 67 | duc2007,D0712,D0712.M.250.C.C,1403,"1481, 1518",XIE19960313.0144,219,"219, 271","1481, 1546" 68 | duc2007,D0712,D0712.M.250.C.C,1403,"1519, 1546",XIE19960313.0144,2843,"2843, 2853;2904, 2922","1481, 1546" 69 | duc2007,D0712,D0712.M.250.C.C,1403,"1481, 1518",NYT19980923.0400,1560,"1674, 1724","1481, 1518" 70 | duc2007,D0712,D0712.M.250.C.C,1403,"1519, 1546",XIE19960313.0144,392,"392, 437;464, 473;480, 490;560, 580","1421, 1546" 71 | duc2007,D0712,D0712.M.250.C.C,1403,"1481, 1518",APW19980926.0674,161,"167, 226","1481, 1518" 72 | duc2007,D0712,D0712.M.250.C.C,724,"724, 807",NYT19980923.0400,775,"900, 989","724, 807" 73 | duc2004,D31009,D31009.M.100.T.B.html,294,"347, 407",APW19981219.0504,0,"65, 175","347, 407" 74 | duc2004,D31009,D31009.M.100.T.B.html,294,"294, 346",APW19981219.0504,615,"615, 670","294, 345" 75 | duc2004,D31009,D31009.M.100.T.B.html,294,"347, 407",APW19981219.0504,176,"229, 280","347, 407" 76 | duc2007,D0719,D0719.M.250.E.D,982,"1124, 1147",XIE19970214.0294,243,"316, 387","1010, 1016;1067, 1106;1129, 1147" 77 | duc2007,D0719,D0719.M.250.E.D,982,"1104, 1124",XIE19970214.0294,2033,"2046, 2070;2106, 2161","1010, 1016;1066, 1124" 78 | duc2007,D0719,D0719.M.250.E.D,388,"388, 443",XIE19960628.0198,161,"161, 256;290, 321","388, 443" 79 | duc2007,D0719,D0719.M.250.E.D,388,"388, 443",XIE19960628.0198,1173,"1173, 1246","388, 443" 80 | duc2007,D0719,D0719.M.250.E.D,388,"388, 443",XIE19960628.0198,0,"34, 158","388, 443" 81 | duc2007,D0712,D0712.M.250.C.C,1042,"1042, 1149",APW19981010.0301,1700,"1711, 1818","1042, 1149" 82 | duc2007,D0712,D0712.M.250.C.C,1042,"1042, 1149",APW19981010.0293,533,"544, 665","1042, 1149" 83 | duc2007,D0712,D0712.M.250.C.C,1042,"1042, 1149",APW19990122.0181,1416,"1428, 1450;1501, 1563","1042, 1077;1094, 1149" 84 | duc2007,D0718,D0718.M.250.D.A,497,"497, 539;553, 565;575, 600",NYT19990607.0306,2731,"2731, 2806","497, 600" 85 | duc2007,D0718,D0718.M.250.D.A,497,"578, 629",NYT19990607.0306,2521,"2541, 2544;2551, 2598","578, 630" 86 | duc2007,D0718,D0718.M.250.D.A,497,"497, 551;575, 600",NYT19990607.0306,2731,"2731, 2806","497, 630" 87 | duc2007,D0718,D0718.M.250.D.A,497,"578, 629",NYT19990607.0306,0,"89, 179","578, 630" 88 | duc2007,D0718,D0718.M.250.D.A,410,"410, 443",NYT19990112.0449,4097,"4104, 4165","418, 443" 89 | duc2007,D0718,D0718.M.250.D.A,410,"444, 496",NYT19990112.0449,543,"578, 587;621, 674","418, 427;469, 496" 90 | duc2007,D0718,D0718.M.250.D.A,410,"410, 443",NYT19990112.0449,2587,"2587, 2621","418, 443" 91 | duc2007,D0719,D0719.M.250.E.D,444,"444, 484;505, 514",XIE19960628.0198,0,"6, 12;34, 92","471, 515" 92 | duc2007,D0719,D0719.M.250.E.D,444,"471, 484;516, 525;560, 579",APW19981208.0421,837,"884, 950","444, 455;471, 504;516, 579" 93 | duc2007,D0719,D0719.M.250.E.D,444,"489, 504;516, 555",APW19981208.0421,837,"884, 950","471, 504;516, 555" 94 | duc2007,D0719,D0719.M.250.E.D,444,"444, 484;505, 514",XIE19960628.0198,161,"161, 183;257, 289","471, 514" 95 | duc2007,D0719,D0719.M.250.E.D,444,"489, 504;516, 525;560, 579",APW19981208.0421,837,"884, 950","471, 504;516, 579" 96 | duc2007,D0719,D0719.M.250.E.D,444,"444, 469;489, 514",APW19981208.0421,404,"417, 449","471, 514" 97 | duc2007,D0719,D0719.M.250.E.D,444,"489, 504;516, 555",APW19981208.0421,0,"13, 30;116, 208","471, 504;516, 555" 98 | duc2007,D0719,D0719.M.250.E.D,444,"444, 484;505, 514",APW19981208.0421,209,"209, 277","471, 514" 99 | duc2007,D0719,D0719.M.250.E.D,444,"471, 484;516, 555",APW19981208.0421,0,"13, 31;61, 79;116, 208","471, 504;516, 555" 100 | duc2007,D0719,D0719.M.250.E.D,444,"444, 469;489, 514",APW19981208.0421,0,"13, 79","471, 514" 101 | duc2007,D0719,D0719.M.250.E.D,1509,"1567, 1584",APW19981208.0421,1021,"1042, 1188","1509, 1585" 102 | duc2007,D0719,D0719.M.250.E.D,1509,"1567, 1584",APW19981208.0421,642,"642, 696;705, 793","1509, 1537;1570, 1584" 103 | duc2004,D30001,D30001.M.100.T.A.html,87,"159, 226",APW19981026.0220,149,"218, 333","159, 226" 104 | duc2004,D30001,D30001.M.100.T.A.html,87,"159, 226",APW19981022.0269,1904,"2006, 2093","159, 226" 105 | duc2004,D30001,D30001.M.100.T.A.html,87,"159, 202;228, 247",APW19981026.0220,2170,"2170, 2279","159, 202;228, 247" 106 | duc2004,D30001,D30001.M.100.T.A.html,87,"87, 152",APW19981016.0240,690,"690, 829","87, 152" 107 | duc2004,D30001,D30001.M.100.T.A.html,87,"87, 152",APW19981016.0240,0,"0, 24;35, 101","87, 152" 108 | duc2004,D30001,D30001.M.100.T.A.html,87,"159, 202;228, 247",APW19981016.0240,397,"397, 455;604, 656","159, 202;228, 247" 109 | duc2004,D30001,D30001.M.100.T.A.html,87,"159, 202;228, 247",APW19981016.0240,0,"44, 101","159, 202;228, 247" 110 | duc2004,D30001,D30001.M.100.T.A.html,87,"159, 202;228, 247",APW19981022.0269,1904,"2006, 2161","159, 202;228, 247" 111 | duc2004,D31009,D31009.M.100.T.B.html,0,"0, 156",APW19981202.0880,2878,"2905, 2995","0, 30;85, 156" 112 | duc2004,D31009,D31009.M.100.T.B.html,0,"0, 156",APW19981219.0504,465,"465, 488;536, 614","85, 156" 113 | duc2004,D31009,D31009.M.100.T.B.html,0,"0, 156",APW19981219.0504,297,"381, 464","0, 79" 114 | duc2004,D31009,D31009.M.100.T.B.html,0,"0, 156",APW19981111.0309,0,"22, 49;105, 222","0, 30;80, 156" 115 | duc2004,D31009,D31009.M.100.T.B.html,0,"0, 156",APW19981119.0529,410,"410, 548","85, 156" 116 | duc2004,D30005,D30005.M.100.T.A.html,218,"302, 376",NYT19981202.0428,2886,"2896, 2974","231, 239;318, 376" 117 | duc2007,D0718,D0718.M.250.D.A,126,"126, 242",NYT19980605.0223,844,"867, 999","126, 141;150, 312" 118 | duc2007,D0718,D0718.M.250.D.A,126,"126, 242",NYT19980605.0223,1432,"1461, 1556","126, 141;150, 242" 119 | duc2004,D30001,D30001.M.100.T.A.html,545,"597, 657",APW19981118.0276,710,"842, 930","597, 657" 120 | duc2004,D30001,D30001.M.100.T.A.html,545,"597, 657",APW19981118.0276,0,"17, 91","597, 657" 121 | duc2007,D0719,D0719.M.250.E.D,322,"322, 387",APW19981208.0421,209,"209, 229;324, 386","322, 387" 122 | duc2004,D30005,D30005.M.100.T.A.html,487,"577, 655",APW19981111.0288,1227,"1243, 1305","577, 655" 123 | duc2004,D30005,D30005.M.100.T.A.html,487,"577, 655",APW19981111.0288,2033,"2035, 2064","626, 655" 124 | duc2004,D30005,D30005.M.100.T.A.html,487,"577, 655",APW19981120.0887,621,"621, 676","626, 655" 125 | duc2007,D0718,D0718.M.250.D.A,1142,"1142, 1228",NYT19991119.0370,1957,"2004, 2084","1142, 1228" 126 | duc2004,D31009,D31009.M.100.T.B.html,226,"226, 293",APW19981219.0504,837,"837, 929;952, 1000","226, 293" 127 | duc2004,D31009,D31009.M.100.T.B.html,226,"226, 293",APW19981209.0696,324,"324, 361;396, 518","226, 293" 128 | duc2004,D30005,D30005.M.100.T.A.html,377,"377, 486",NYT19981202.0428,3016,"3048, 3226","377, 413;440, 486" 129 | duc2004,D30005,D30005.M.100.T.A.html,377,"377, 486",NYT19981202.0428,3747,"3747, 3922","377, 414;440, 486" 130 | duc2004,D30005,D30005.M.100.T.A.html,377,"377, 486",NYT19981202.0428,0,"124, 256","377, 413;440, 486" 131 | duc2004,D30005,D30005.M.100.T.A.html,377,"377, 486",NYT19981202.0428,866,"866, 970","377, 413;440, 486" 132 | duc2004,D30005,D30005.M.100.T.A.html,377,"377, 486",APW19981003.0517,843,"861, 948","377, 413;440, 486" 133 | duc2007,D0712,D0712.M.250.C.C,1150,"1150, 1189",APW19981010.0301,0,"20, 133","1150, 1225" 134 | duc2007,D0712,D0712.M.250.C.C,1150,"1190, 1225",APW19981010.0301,851,"851, 899;1006, 1072","1152, 1225" 135 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",NYT19981201.0444,1400,"1496, 1575","0, 50" 136 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",NYT19981113.0410,2406,"2428, 2432;2439, 2484","6, 50" 137 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",NYT19981113.0410,3007,"3101, 3150","6, 50" 138 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",NYT19981202.0428,0,"57, 97","6, 50" 139 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",APW19981120.0887,0,"33, 80","6, 50" 140 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",NYT19981201.0444,465,"576, 618","6, 50" 141 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",NYT19981202.0428,2657,"2689, 2708","6, 50" 142 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",APW19981111.0288,231,"248, 251;259, 300","6, 50" 143 | duc2004,D30005,D30005.M.100.T.A.html,0,"51, 217",NYT19981113.0410,3415,"3469, 3593","51, 217" 144 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",NYT19981201.0444,3901,"4025, 4052","0, 35" 145 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",NYT19981202.0428,677,"729, 764","6, 50" 146 | duc2004,D30005,D30005.M.100.T.A.html,0,"51, 217",APW19981111.0288,375,"375, 494","51, 180" 147 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",APW19981129.0665,0,"142, 168","6, 50" 148 | duc2004,D30005,D30005.M.100.T.A.html,0,"0, 50",APW19981003.0517,1148,"1255, 1319","6, 50" 149 | -------------------------------------------------------------------------------- /manual_datasets/restore_alignments.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | from utils import read_generic_file, offset_str2list 4 | import os 5 | from nltk import sent_tokenize 6 | import numpy as np 7 | 8 | 9 | def clean_duc_documents(doc_text_lines): 10 | #removes special tokens from DUC documents 11 | 12 | readLines = False 13 | textExtracts = [] 14 | docId = '' 15 | datetime = '' 16 | for line in doc_text_lines: 17 | lineStripped = line.strip() 18 | if not readLines: 19 | if lineStripped.startswith(''): 20 | docId = lineStripped[7:-8].strip() # example: APW19980818.0980 21 | elif lineStripped.startswith(''): 22 | datetime = lineStripped[11:-12].strip() # example: 08/18/1998 15:32:00 23 | elif lineStripped.startswith(''): 24 | readLines = True 25 | else: 26 | if lineStripped.startswith('') or lineStripped.startswith(''): 27 | break 28 | elif lineStripped.startswith('

'): 29 | continue 30 | elif lineStripped.startswith('

'): 31 | textExtracts.append('\n\n') # skip line for new paragraph 32 | continue 33 | else: 34 | textExtracts.append(lineStripped) 35 | allText = ' '.join(textExtracts) 36 | return allText 37 | 38 | def add_doc_sent_in_file_idx(alignments, data_path): 39 | doc_sent_idx = np.zeros(len(alignments), dtype=int) 40 | 41 | alignments['original_idx'] = range(len(alignments)) 42 | for topic_dir in os.listdir(data_path): 43 | print(topic_dir) 44 | if topic_dir == 'summaries': 45 | continue 46 | 47 | topic_files = os.listdir(os.path.join(data_path, topic_dir)) 48 | for file_idx, file in enumerate(topic_files): 49 | alignments_file = alignments[alignments['documentFile']==file] 50 | text = read_generic_file(os.path.join(data_path, topic_dir, file)) 51 | document = " ".join(text) 52 | doc_sents = sent_tokenize(document) 53 | doc_sent_char_idx = 0 54 | for sent_idx, doc_sent in enumerate(doc_sents): 55 | alignments_topic_file_sent_original_idx = (alignments_file['original_idx'][alignments_file['docSentCharIdx'] == doc_sent_char_idx]).values 56 | doc_sent_idx[alignments_topic_file_sent_original_idx] = sent_idx 57 | doc_sent_char_idx += len(doc_sent) + 1 # 1 for space between sents 58 | 59 | alignments['inFile_doc_sentIdx'] = doc_sent_idx 60 | return alignments 61 | 62 | 63 | def add_summ_sent_in_file_idx(alignments, data_path): 64 | doc_sent_idx = np.zeros(len(alignments), dtype=int) 65 | 66 | alignments['original_idx'] = range(len(alignments)) 67 | for summaryFile in os.listdir(data_path): 68 | print(summaryFile) 69 | 70 | alignments_file = alignments[alignments['summaryFile']==summaryFile] 71 | text = read_generic_file(os.path.join(data_path, summaryFile)) 72 | document = " ".join(text) 73 | doc_sents = sent_tokenize(document) 74 | doc_sent_char_idx = 0 75 | for sent_idx, doc_sent in enumerate(doc_sents): 76 | alignments_topic_file_sent_original_idx = (alignments_file['original_idx'][alignments_file['scuSentCharIdx'] == doc_sent_char_idx]).values 77 | doc_sent_idx[alignments_topic_file_sent_original_idx] = sent_idx 78 | doc_sent_char_idx += len(doc_sent) + 1 # 1 for space between sents 79 | 80 | alignments['inFile_summ_sentIdx'] = doc_sent_idx 81 | return alignments 82 | 83 | def add_sentence(text, indx_csv_summaryFile, indx_csv, mode='summ'): 84 | if mode=='summ': 85 | KEY_SENT = 'scuSentence' 86 | KEY_SENT_IDX = 'inFile_summ_sentIdx' 87 | else: 88 | KEY_SENT = 'docSentText' 89 | KEY_SENT_IDX = 'inFile_doc_sentIdx' 90 | 91 | sents = sent_tokenize(text) 92 | idx2sent = {idx: sent for idx, sent in enumerate(sents)} 93 | indx_csv_summaryFile[KEY_SENT] = indx_csv_summaryFile[KEY_SENT_IDX].apply(lambda x: idx2sent[x]) 94 | summaryFile_index_list = indx_csv_summaryFile.index.to_list() 95 | indx_csv[KEY_SENT].loc[summaryFile_index_list] = indx_csv_summaryFile[KEY_SENT].to_list() 96 | 97 | return indx_csv 98 | 99 | 100 | def add_span(text, indx_csv_summaryFile, indx_csv, mode='summ'): 101 | def read_span(text, offset_list): 102 | span = text[offset_list[0][0]: offset_list[0][1]] 103 | for offset_pair in offset_list[1:]: 104 | span += '...'+text[offset_pair[0]: offset_pair[1]] 105 | return span 106 | 107 | 108 | 109 | 110 | if mode == 'summ': 111 | KEY_SPAN = 'summarySpanText' 112 | KEY_OFFSETS = 'summarySpanOffsets' 113 | else: 114 | KEY_SPAN = 'docSpanText' 115 | KEY_OFFSETS = 'docSpanOffsets' 116 | 117 | 118 | 119 | 120 | indx_csv_summaryFile[KEY_OFFSETS] = indx_csv_summaryFile[KEY_OFFSETS].apply(offset_str2list) 121 | indx_csv_summaryFile[KEY_SPAN] = indx_csv_summaryFile[KEY_OFFSETS].apply(lambda x: read_span(text, x)) 122 | 123 | summaryFile_index_list = indx_csv_summaryFile.index.to_list() 124 | indx_csv[KEY_SPAN].loc[summaryFile_index_list] = indx_csv_summaryFile[KEY_SPAN].to_list() 125 | 126 | return indx_csv 127 | 128 | 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('-indx_csv_path', type=str, required=True) 131 | parser.add_argument('-documents_path', type=str, required=True) 132 | parser.add_argument('-summaries_path', type=str, required=True) 133 | 134 | parser.add_argument('-output_file', type=str, required=True) 135 | 136 | args = parser.parse_args() 137 | 138 | if __name__ == "__main__": 139 | indx_csv = pd.read_csv(args.indx_csv_path) 140 | indx_csv = indx_csv[indx_csv['database']=='duc2004'] 141 | 142 | #initialize columns 143 | indx_csv['scuSentence'] = None 144 | indx_csv['summarySpanText'] = None 145 | indx_csv['docSentText'] = None 146 | indx_csv['docSpanText'] = None 147 | 148 | 149 | #handle summaries 150 | indx_csv = add_summ_sent_in_file_idx(indx_csv, args.summaries_path) 151 | for summaryFile in indx_csv['summaryFile'].drop_duplicates(): 152 | indx_csv_summaryFile = indx_csv[indx_csv['summaryFile']==summaryFile] 153 | 154 | 155 | 156 | summary = ' '.join(read_generic_file(os.path.join(args.summaries_path, summaryFile))) 157 | 158 | indx_csv = add_sentence(summary, indx_csv_summaryFile, indx_csv, mode='summ') 159 | indx_csv = add_span(summary, indx_csv_summaryFile, indx_csv, mode='summ') 160 | 161 | 162 | # handle documents 163 | indx_csv = add_doc_sent_in_file_idx(indx_csv, args.documents_path) 164 | for topic in indx_csv['topic'].drop_duplicates(): 165 | indx_csv_topic = indx_csv[indx_csv['topic']==topic] 166 | for documentFile in indx_csv_topic['documentFile'].drop_duplicates(): 167 | indx_csv_documentFile = indx_csv_topic[indx_csv_topic['documentFile'] == documentFile] 168 | 169 | doc = read_generic_file(os.path.join(args.documents_path,topic.lower(), documentFile)) 170 | doc = clean_duc_documents(doc) 171 | 172 | indx_csv = add_sentence(doc, indx_csv_documentFile, indx_csv, mode='doc') 173 | indx_csv = add_span(doc, indx_csv_documentFile, indx_csv, mode='doc') 174 | 175 | 176 | 177 | 178 | 179 | 180 | indx_csv.to_csv(args.output_file) 181 | -------------------------------------------------------------------------------- /manual_datasets/test_DUC_index_only.csv: -------------------------------------------------------------------------------- 1 | database,topic,summaryFile,scuSentCharIdx,scuOffsets,documentFile,docSentCharIdx,docSpanOffsets,summarySpanOffsets 2 | duc2007,D0702,D0702.M.250.A.A,321,"458, 542",XIE19990817.0162,0,"31, 132","458, 542" 3 | duc2007,D0702,D0702.M.250.A.A,321,"458, 542",XIE19990817.0162,189,"225, 364;379, 389","458, 542" 4 | duc2007,D0701,D0701.M.250.A.A,0,"0, 76",APW20000831.0201,317,"324, 386","0, 68" 5 | duc2007,D0701,D0701.M.250.A.A,0,"0, 11;81, 142",NYT19990304.0376,9193,"9193, 9197;9232, 9305","0, 11;81, 142" 6 | duc2007,D0701,D0701.M.250.A.A,0,"0, 76",NYT19990304.0376,1034,"1040, 1091","0, 76" 7 | duc2007,D0701,D0701.M.250.A.A,0,"0, 76",NYT20000828.0403,310,"328, 378","0, 68" 8 | duc2007,D0701,D0701.M.250.A.A,0,"0, 11;81, 142",APW20000915.0189,1263,"1263, 1334","0, 11;81, 142" 9 | duc2007,D0701,D0701.M.250.A.A,749,"749, 849",NYT20000828.0403,5710,"5737, 5778;5806, 5826","749, 849" 10 | duc2007,D0701,D0701.M.250.A.A,749,"749, 849",NYT19990304.0376,1034,"1056, 1083;1098, 1145","749, 849" 11 | duc2007,D0701,D0701.M.250.A.A,749,"749, 849",XIE19980304.0061,1381,"1381, 1412;1468, 1498;1526, 1551","749, 849" 12 | duc2007,D0701,D0701.M.250.A.A,749,"749, 849",NYT19990304.0376,5058,"5080, 5182","749, 849" 13 | duc2007,D0707,D0707.M.250.B.B,777,"777, 831",XIE19980101.0003,487,"507, 543","794, 831" 14 | duc2007,D0707,D0707.M.250.B.B,777,"777, 831",APW19981209.0493,402,"460, 519","794, 831" 15 | duc2007,D0707,D0707.M.250.B.B,777,"833, 874",APW19981105.0450,206,"271, 359","794, 817;837, 874" 16 | duc2007,D0707,D0707.M.250.B.B,777,"833, 874",APW19981105.0450,982,"1037, 1047;1093, 1145","794, 816;837, 874" 17 | duc2007,D0707,D0707.M.250.B.B,777,"833, 874",APW19981105.0450,0,"69, 136","794, 816;837, 874" 18 | duc2007,D0707,D0707.M.250.B.B,777,"777, 831",APW19981105.0450,1148,"1161, 1221","794, 831" 19 | duc2004,D30002,D30002.M.100.T.A.html,579,"579, 655",APW19981104.0539,168,"206, 304","579, 655" 20 | duc2004,D30002,D30002.M.100.T.A.html,579,"579, 655",APW19981104.0539,0,"0, 34","579, 655" 21 | duc2004,D30017,D30017.M.100.T.A.html,598,"598, 656",APW19981124.0251,170,"230, 360","598, 656" 22 | duc2004,D30017,D30017.M.100.T.A.html,598,"598, 656",APW19981124.0251,0,"0, 113","598, 656" 23 | duc2004,D30017,D30017.M.100.T.A.html,220,"235, 241;255, 281",NYT19981209.0451,9291,"9300, 9377","220, 281" 24 | duc2004,D30017,D30017.M.100.T.A.html,220,"242, 281",APW19981119.0262,1068,"1068, 1101;1123, 1186","220, 281" 25 | duc2004,D30017,D30017.M.100.T.A.html,220,"242, 281",APW19981110.0240,1536,"1536, 1569;1591, 1654","220, 281" 26 | duc2004,D30017,D30017.M.100.T.A.html,220,"235, 241;255, 281",APW19981119.0262,1068,"1123, 1186","220, 281" 27 | duc2004,D30017,D30017.M.100.T.A.html,220,"297, 342",NYT19981209.0451,10422,"10422, 10454;10494, 10603","270, 342" 28 | duc2004,D30017,D30017.M.100.T.A.html,220,"220, 234;255, 281",APW19981221.0189,948,"948, 1112","220, 281" 29 | duc2004,D30017,D30017.M.100.T.A.html,220,"220, 234;255, 281",APW19981118.0898,185,"185, 308","220, 281" 30 | duc2004,D30017,D30017.M.100.T.A.html,220,"297, 342",NYT19981114.0099,1746,"1746, 1887","270, 342" 31 | duc2004,D30017,D30017.M.100.T.A.html,220,"235, 241;255, 281",APW19981118.0898,185,"185, 336","220, 281" 32 | duc2007,D0707,D0707.M.250.B.B,875,"875, 965",APW19981209.0493,402,"460, 565","875, 963" 33 | duc2007,D0707,D0707.M.250.B.B,875,"875, 965",XIE19980101.0003,487,"507, 638","875, 963" 34 | duc2007,D0707,D0707.M.250.B.B,875,"965, 1066",APW19981105.0450,1556,"1556, 1633","875, 928;965, 1002" 35 | duc2007,D0707,D0707.M.250.B.B,875,"875, 965",APW19981105.0450,1378,"1378, 1461","875, 963" 36 | duc2007,D0707,D0707.M.250.B.B,875,"875, 965",APW19981105.0450,1222,"1224, 1265;1324, 1345","875, 963" 37 | duc2004,D30003,D30003.M.100.T.A.html,223,"229, 271",NYT19981026.0292,788,"822, 961","229, 271" 38 | duc2004,D30003,D30003.M.100.T.A.html,223,"229, 271",APW19981019.0098,0,"0, 31","229, 271" 39 | duc2004,D30017,D30017.M.100.T.A.html,343,"384, 458",NYT19981209.0451,5461,"5489, 5530","408, 458" 40 | duc2004,D30017,D30017.M.100.T.A.html,343,"343, 380",NYT19981209.0451,4454,"4636, 4748","343, 380" 41 | duc2004,D30017,D30017.M.100.T.A.html,343,"384, 458",NYT19981209.0451,3094,"3094, 3159","385, 458" 42 | duc2004,D30017,D30017.M.100.T.A.html,343,"384, 458",APW19981118.0898,0,"91, 173","385, 458" 43 | duc2004,D30017,D30017.M.100.T.A.html,343,"384, 458",APW19981124.0251,1075,"1090, 1173","385, 458" 44 | duc2004,D30017,D30017.M.100.T.A.html,343,"384, 458",NYT19981209.0451,536,"687, 768","385, 458" 45 | duc2004,D30017,D30017.M.100.T.A.html,343,"384, 458",APW19981110.0240,679,"896, 1006","385, 458" 46 | duc2004,D30017,D30017.M.100.T.A.html,343,"384, 458",NYT19981114.0099,2233,"2325, 2391;2474, 2529","385, 458" 47 | duc2004,D30003,D30003.M.100.T.A.html,539,"539, 615",APW19981024.0192,835,"853, 878;936, 965","579, 615" 48 | duc2004,D30003,D30003.M.100.T.A.html,539,"539, 615",APW19981023.1166,1641,"1641, 1763","539, 615" 49 | duc2004,D30003,D30003.M.100.T.A.html,539,"539, 615",APW19981024.0192,0,"0, 117","539, 615" 50 | duc2007,D0701,D0701.M.250.A.A,850,"850, 922",NYT19990304.0376,3975,"3975, 4017","850, 906" 51 | duc2007,D0701,D0701.M.250.A.A,850,"850, 922",XIE19980304.0061,1381,"1381, 1412;1468, 1551","850, 906" 52 | duc2007,D0707,D0707.M.250.B.B,0,"0, 60",APW19981105.0450,0,"65, 136","8, 60" 53 | duc2007,D0707,D0707.M.250.B.B,61,"61, 127",XIE19980101.0003,487,"507, 591","75, 127" 54 | duc2007,D0707,D0707.M.250.B.B,61,"61, 127",APW19981209.0493,402,"460, 519","75, 127" 55 | duc2007,D0707,D0707.M.250.B.B,61,"61, 127",APW19981105.0450,1148,"1182, 1221","75, 127" 56 | duc2004,D30003,D30003.M.100.T.A.html,380,"380, 437",APW19981022.1132,0,"0, 97","380, 437" 57 | duc2004,D30003,D30003.M.100.T.A.html,380,"380, 437",APW19981024.0192,1191,"1191, 1285","380, 437" 58 | duc2007,D0701,D0701.M.250.A.A,964,"1117, 1188",NYT19990304.0376,4386,"4417, 4480","1122, 1188" 59 | duc2007,D0701,D0701.M.250.A.A,964,"1253, 1324",NYT20000828.0403,6569,"6569, 6601;6651, 6686;6722, 6741","1253, 1324" 60 | duc2007,D0701,D0701.M.250.A.A,964,"1189, 1252",NYT20000828.0403,6281,"6281, 6340;6401, 6424","1189, 1252" 61 | duc2007,D0701,D0701.M.250.A.A,964,"1117, 1188",NYT19990304.0376,2412,"2431, 2472;2486, 2516","1122, 1188" 62 | duc2007,D0701,D0701.M.250.A.A,964,"1024, 1076",NYT19990304.0376,1034,"1093, 1197","964, 1076" 63 | duc2007,D0701,D0701.M.250.A.A,964,"1189, 1252",NYT19990304.0376,2412,"2431, 2472;2526, 2563","1194, 1252" 64 | duc2007,D0701,D0701.M.250.A.A,964,"1117, 1188",NYT20000828.0403,5979,"5979, 6023;6093, 6150","1117, 1188" 65 | duc2007,D0701,D0701.M.250.A.A,964,"1024, 1076",NYT19990304.0376,5058,"5080, 5182","978, 1075" 66 | duc2007,D0701,D0701.M.250.A.A,964,"1117, 1188",NYT19990304.0376,4786,"4811, 4822;4900, 4931","1122, 1188" 67 | duc2007,D0701,D0701.M.250.A.A,143,"273, 345",NYT20000828.0403,5710,"5737, 5778;5811, 5826","143, 167;278, 345" 68 | duc2007,D0701,D0701.M.250.A.A,143,"165, 193",NYT19980715.0137,5347,"5347, 5351;5380, 5435","143, 193" 69 | duc2007,D0701,D0701.M.250.A.A,143,"165, 193",XIE19980304.0061,1552,"1566, 1637","143, 193" 70 | duc2007,D0701,D0701.M.250.A.A,143,"254, 274",NYT19990304.0376,11003,"11003, 11066","143, 167;255, 273" 71 | duc2007,D0701,D0701.M.250.A.A,143,"273, 345",NYT20000828.0403,0,"23, 92;121, 134","143, 167;278, 345" 72 | duc2007,D0701,D0701.M.250.A.A,143,"273, 345",XIE19980304.0061,1381,"1381, 1412;1468, 1499;1526, 1551","143, 167;278, 345" 73 | duc2004,D30002,D30002.M.100.T.A.html,0,"0, 96",APW19981028.1120,0,"0, 15;82, 126","53, 96" 74 | duc2004,D30002,D30002.M.100.T.A.html,0,"0, 96",APW19981027.0241,3642,"3642, 3680","0, 35" 75 | duc2004,D30002,D30002.M.100.T.A.html,0,"0, 96",APW19981027.0241,471,"471, 556","0, 26;53, 96" 76 | duc2007,D0707,D0707.M.250.B.B,1421,"1497, 1579",NYT20000114.0010,1295,"1295, 1419","1497, 1579" 77 | duc2007,D0707,D0707.M.250.B.B,1421,"1421, 1495",NYT20000114.0010,0,"62, 162","1437, 1495" 78 | duc2007,D0702,D0702.M.250.A.A,908,"951, 1090",APW20000616.0049,567,"586, 677","952, 1090" 79 | duc2007,D0701,D0701.M.250.A.A,640,"640, 748",NYT19990304.0376,11286,"11286, 11378","640, 748" 80 | duc2004,D30002,D30002.M.100.T.A.html,279,"279, 319",APW19981101.0843,472,"472, 501;532, 567","279, 319" 81 | duc2007,D0707,D0707.M.250.B.B,318,"335, 356;383, 417",APW20000923.0107,781,"781, 841","335, 417" 82 | duc2007,D0707,D0707.M.250.B.B,318,"318, 381",XIE19980101.0003,487,"507, 591","335, 381" 83 | duc2007,D0707,D0707.M.250.B.B,318,"335, 356;383, 417",APW20000923.0107,630,"642, 712","390, 417" 84 | duc2007,D0707,D0707.M.250.B.B,318,"335, 356;383, 417",APW19981105.0450,1222,"1222, 1377","383, 417" 85 | duc2007,D0707,D0707.M.250.B.B,318,"318, 381",APW19981105.0450,457,"457, 546","318, 381" 86 | duc2007,D0707,D0707.M.250.B.B,318,"335, 356;383, 398;422, 447",XIE19980101.0003,487,"507, 619;643, 669","335, 398;422, 449" 87 | duc2007,D0707,D0707.M.250.B.B,318,"335, 356;383, 398;422, 447",APW19981209.0493,402,"460, 539;570, 603","335, 398;422, 449" 88 | duc2007,D0707,D0707.M.250.B.B,318,"335, 356;383, 417",XIE19980101.0003,487,"507, 638","335, 417" 89 | duc2007,D0707,D0707.M.250.B.B,318,"335, 356;383, 417",APW19981209.0493,402,"460, 565","335, 417" 90 | duc2004,D30003,D30003.M.100.T.A.html,180,"180, 222",APW19981022.1132,996,"1047, 1092","180, 222" 91 | duc2004,D30003,D30003.M.100.T.A.html,180,"180, 222",APW19981019.0098,1818,"1818, 1870","180, 222" 92 | duc2004,D30003,D30003.M.100.T.A.html,180,"180, 222",APW19981020.0241,1277,"1277, 1322","180, 222" 93 | duc2004,D30003,D30003.M.100.T.A.html,180,"180, 222",APW19981022.1132,0,"0, 77","180, 222" 94 | duc2004,D30003,D30003.M.100.T.A.html,180,"180, 222",APW19981019.0098,174,"174, 226","180, 222" 95 | duc2004,D30017,D30017.M.100.T.A.html,102,"102, 219",APW19981110.0240,1967,"1967, 1996;2053, 2154","102, 219" 96 | duc2004,D30017,D30017.M.100.T.A.html,102,"102, 219",APW19981221.0189,948,"1017, 1138","110, 219" 97 | duc2007,D0701,D0701.M.250.A.A,486,"525, 598",NYT19990304.0376,11003,"11003, 11103","486, 598" 98 | duc2007,D0701,D0701.M.250.A.A,486,"604, 639",NYT19990304.0376,11104,"11104, 11114;11177, 11212","494, 524;604, 639" 99 | duc2007,D0701,D0701.M.250.A.A,486,"525, 598",NYT19990304.0376,11213,"11213, 11259","494, 598" 100 | duc2007,D0702,D0702.M.250.A.A,632,"769, 788",APW19990704.0094,165,"165, 186;254, 278","703, 787" 101 | duc2007,D0702,D0702.M.250.A.A,632,"769, 788",APW20000616.0049,0,"75, 193","703, 787" 102 | duc2007,D0702,D0702.M.250.A.A,632,"700, 768",NYT19991123.0372,1901,"1958, 2116","703, 775;791, 859" 103 | duc2007,D0702,D0702.M.250.A.A,632,"700, 768",APW20000616.0049,264,"264, 399","703, 787" 104 | duc2007,D0702,D0702.M.250.A.A,632,"769, 788",APW20000616.0049,1798,"1798, 1904","703, 787" 105 | duc2007,D0707,D0707.M.250.B.B,1254,"1379, 1419",APW19981105.0450,1378,"1378, 1461","1379, 1420" 106 | duc2007,D0707,D0707.M.250.B.B,1254,"1379, 1419",APW19981209.0493,402,"520, 565","1379, 1420" 107 | duc2007,D0707,D0707.M.250.B.B,1254,"1379, 1419",APW20000923.0107,630,"642, 712","1386, 1420" 108 | duc2007,D0707,D0707.M.250.B.B,1254,"1379, 1419",APW20000923.0107,781,"781, 841","1379, 1420" 109 | duc2007,D0707,D0707.M.250.B.B,1254,"1305, 1337",APW19981209.0493,402,"460, 519","1305, 1337" 110 | duc2007,D0707,D0707.M.250.B.B,1254,"1379, 1419",XIE19961026.0044,654,"654, 718","1379, 1420" 111 | duc2007,D0707,D0707.M.250.B.B,1254,"1379, 1419",APW19981105.0450,1222,"1222, 1265;1328, 1377","1379, 1420" 112 | duc2007,D0707,D0707.M.250.B.B,1254,"1271, 1303",APW19981105.0450,0,"65, 136","1271, 1303" 113 | duc2007,D0707,D0707.M.250.B.B,1254,"1379, 1419",XIE19980101.0003,487,"593, 638","1379, 1420" 114 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",NYT19981209.0451,1092,"1189, 1294","485, 597" 115 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",NYT19981209.0451,536,"661, 768","459, 597" 116 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",NYT19981114.0099,2233,"2368, 2529","459, 597" 117 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",NYT19981209.0451,3094,"3094, 3197","459, 597" 118 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",APW19981124.0251,1075,"1123, 1220","459, 597" 119 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",APW19981124.0251,170,"207, 360","485, 597" 120 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",NYT19981209.0451,0,"0, 288","459, 597" 121 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",NYT19981209.0451,769,"769, 795;869, 1091","459, 597" 122 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",APW19981110.0240,679,"679, 769;896, 1006","459, 597" 123 | duc2004,D30017,D30017.M.100.T.A.html,459,"459, 597",APW19981124.0251,921,"921, 1074","459, 597" 124 | duc2004,D30003,D30003.M.100.T.A.html,0,"77, 144",APW19981019.0098,439,"534, 665","77, 144" 125 | duc2004,D30003,D30003.M.100.T.A.html,0,"0, 94",APW19981024.0192,550,"550, 616","0, 94" 126 | duc2004,D30003,D30003.M.100.T.A.html,0,"77, 144",APW19981019.0098,4507,"4507, 4525;4536, 4639","77, 144" 127 | duc2004,D30003,D30003.M.100.T.A.html,0,"77, 144",APW19981024.0192,689,"689, 725;779, 834","77, 144" 128 | duc2004,D30003,D30003.M.100.T.A.html,0,"77, 144",NYT19981026.0292,1533,"1600, 1734","77, 144" 129 | duc2004,D30003,D30003.M.100.T.A.html,0,"0, 94",APW19981019.0098,439,"439, 471;519, 555","17, 94" 130 | duc2004,D30003,D30003.M.100.T.A.html,0,"0, 94",APW19981023.1166,331,"331, 426","0, 94" 131 | duc2004,D30003,D30003.M.100.T.A.html,0,"0, 94",APW19981022.1132,363,"363, 446","0, 94" 132 | duc2004,D30003,D30003.M.100.T.A.html,0,"0, 94",APW19981020.0241,703,"719, 782","17, 94" 133 | duc2004,D30003,D30003.M.100.T.A.html,0,"0, 94",NYT19981026.0292,3839,"3868, 3938","0, 73" 134 | duc2004,D30002,D30002.M.100.T.A.html,428,"428, 577",APW19981103.0526,0,"25, 43;55, 125","428, 578" 135 | duc2004,D30002,D30002.M.100.T.A.html,428,"428, 577",APW19981106.0869,970,"970, 995;1092, 1159","428, 578" 136 | duc2004,D30002,D30002.M.100.T.A.html,428,"428, 577",APW19981027.0241,2241,"2241, 2336","479, 485;530, 535;550, 564" 137 | duc2007,D0702,D0702.M.250.A.A,0,"0, 113",XIE19990817.0162,0,"31, 186","0, 6;48, 113" 138 | duc2007,D0702,D0702.M.250.A.A,0,"0, 113",APW20000616.0049,889,"889, 902;959, 1029","0, 113" 139 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",APW19981119.0262,1068,"1068, 1079;1102, 1186","0, 101" 140 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",APW19981110.0240,192,"268, 311","0, 92" 141 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",NYT19981209.0451,0,"0, 45;101, 152","0, 101" 142 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",APW19981118.0898,185,"271, 336","0, 92" 143 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",NYT19981114.0099,0,"161, 237","0, 92" 144 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",APW19981110.0240,1536,"1536, 1547;1575, 1654","0, 101" 145 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",APW19981110.0240,0,"0, 67","0, 92" 146 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",APW19981118.0898,0,"0, 84","0, 92" 147 | duc2004,D30017,D30017.M.100.T.A.html,0,"0, 101",APW19981221.0189,948,"948, 1015","0, 92" 148 | duc2007,D0707,D0707.M.250.B.B,450,"450, 653",APW19981105.0450,1222,"1222, 1377","462, 527;598, 653" 149 | duc2007,D0707,D0707.M.250.B.B,450,"450, 653",APW19981209.0493,402,"402, 603","450, 527;567, 653" 150 | duc2007,D0707,D0707.M.250.B.B,450,"450, 653",XIE19980101.0003,487,"487, 670","450, 527;567, 653" 151 | duc2007,D0707,D0707.M.250.B.B,1067,"1097, 1150",APW19981105.0450,607,"677, 759","1085, 1150" 152 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" 17 | 18 | 19 | import argparse 20 | import glob 21 | import json 22 | import logging 23 | import os 24 | import random 25 | 26 | import numpy as np 27 | import torch 28 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 29 | from torch.utils.data.distributed import DistributedSampler 30 | from tqdm import tqdm, trange 31 | from scipy.special import softmax 32 | 33 | from transformers import ( 34 | WEIGHTS_NAME, 35 | AdamW, 36 | AlbertConfig, 37 | AlbertForSequenceClassification, 38 | AlbertTokenizer, 39 | BertConfig, 40 | BertForSequenceClassification, 41 | BertTokenizer, 42 | DistilBertConfig, 43 | DistilBertForSequenceClassification, 44 | DistilBertTokenizer, 45 | FlaubertConfig, 46 | FlaubertForSequenceClassification, 47 | FlaubertTokenizer, 48 | RobertaConfig, 49 | RobertaForSequenceClassification, 50 | RobertaTokenizer, 51 | XLMConfig, 52 | XLMForSequenceClassification, 53 | XLMRobertaConfig, 54 | XLMRobertaForSequenceClassification, 55 | XLMRobertaTokenizer, 56 | XLMTokenizer, 57 | XLNetConfig, 58 | XLNetForSequenceClassification, 59 | XLNetTokenizer, 60 | get_linear_schedule_with_warmup, 61 | ) 62 | from transformers import glue_compute_metrics as compute_metrics 63 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 64 | from transformers import glue_output_modes as output_modes 65 | from transformers import glue_processors as processors 66 | import finalAlignmentPred 67 | 68 | 69 | from transformers.modeling_roberta import RobertaClassificationHead 70 | 71 | 72 | try: 73 | from torch.utils.tensorboard import SummaryWriter 74 | except ImportError: 75 | from tensorboardX import SummaryWriter 76 | 77 | 78 | logger = logging.getLogger(__name__) 79 | 80 | ALL_MODELS = sum( 81 | ( 82 | tuple(conf.pretrained_config_archive_map.keys()) 83 | for conf in ( 84 | BertConfig, 85 | XLNetConfig, 86 | XLMConfig, 87 | RobertaConfig, 88 | DistilBertConfig, 89 | AlbertConfig, 90 | XLMRobertaConfig, 91 | FlaubertConfig, 92 | ) 93 | ), 94 | (), 95 | ) 96 | 97 | MODEL_CLASSES = { 98 | "bert": (BertConfig, BertForSequenceClassification, BertTokenizer), 99 | "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 100 | "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 101 | "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 102 | "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer), 103 | "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer), 104 | "xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer), 105 | "flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer), 106 | } 107 | 108 | 109 | def set_seed(args): 110 | random.seed(args.seed) 111 | np.random.seed(args.seed) 112 | torch.manual_seed(args.seed) 113 | if args.n_gpu > 0: 114 | torch.cuda.manual_seed_all(args.seed) 115 | 116 | 117 | def train(args, train_dataset, model, tokenizer): 118 | """ Train the model """ 119 | if args.local_rank in [-1, 0]: 120 | tb_writer = SummaryWriter() 121 | 122 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 123 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 124 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 125 | 126 | if args.max_steps > 0: 127 | t_total = args.max_steps 128 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 129 | else: 130 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 131 | 132 | # Prepare optimizer and schedule (linear warmup and decay) 133 | no_decay = ["bias", "LayerNorm.weight"] 134 | optimizer_grouped_parameters = [ 135 | { 136 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 137 | "weight_decay": args.weight_decay, 138 | }, 139 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 140 | ] 141 | 142 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 143 | scheduler = get_linear_schedule_with_warmup( 144 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 145 | ) 146 | 147 | # Check if saved optimizer or scheduler states exist 148 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 149 | os.path.join(args.model_name_or_path, "scheduler.pt") 150 | ): 151 | # Load in optimizer and scheduler states 152 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 153 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 154 | 155 | if args.fp16: 156 | try: 157 | from apex import amp 158 | except ImportError: 159 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 160 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 161 | 162 | # multi-gpu training (should be after apex fp16 initialization) 163 | if args.n_gpu > 1: 164 | model = torch.nn.DataParallel(model) 165 | 166 | # Distributed training (should be after apex fp16 initialization) 167 | if args.local_rank != -1: 168 | model = torch.nn.parallel.DistributedDataParallel( 169 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, 170 | ) 171 | 172 | # Train! 173 | logger.info("***** Running training *****") 174 | logger.info(" Num examples = %d", len(train_dataset)) 175 | logger.info(" Num Epochs = %d", args.num_train_epochs) 176 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 177 | logger.info( 178 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 179 | args.train_batch_size 180 | * args.gradient_accumulation_steps 181 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 182 | ) 183 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 184 | logger.info(" Total optimization steps = %d", t_total) 185 | 186 | global_step = 0 187 | epochs_trained = 0 188 | steps_trained_in_current_epoch = 0 189 | # Check if continuing training from a checkpoint 190 | if os.path.exists(args.model_name_or_path): 191 | # set global_step to gobal_step of last saved checkpoint from model path 192 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) 193 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 194 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 195 | 196 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 197 | logger.info(" Continuing training from epoch %d", epochs_trained) 198 | logger.info(" Continuing training from global step %d", global_step) 199 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 200 | 201 | tr_loss, logging_loss = 0.0, 0.0 202 | model.zero_grad() 203 | train_iterator = trange( 204 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], 205 | ) 206 | set_seed(args) # Added here for reproductibility 207 | for _ in train_iterator: 208 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 209 | for step, batch in enumerate(epoch_iterator): 210 | 211 | # Skip past any already trained steps if resuming training 212 | if steps_trained_in_current_epoch > 0: 213 | steps_trained_in_current_epoch -= 1 214 | continue 215 | 216 | model.train() 217 | batch = tuple(t.to(args.device) for t in batch) 218 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} 219 | if args.model_type != "distilbert": 220 | inputs["token_type_ids"] = ( 221 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None 222 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 223 | outputs = model(**inputs) 224 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 225 | 226 | if args.n_gpu > 1: 227 | loss = loss.mean() # mean() to average on multi-gpu parallel training 228 | if args.gradient_accumulation_steps > 1: 229 | loss = loss / args.gradient_accumulation_steps 230 | 231 | if args.fp16: 232 | with amp.scale_loss(loss, optimizer) as scaled_loss: 233 | scaled_loss.backward() 234 | else: 235 | loss.backward() 236 | 237 | tr_loss += loss.item() 238 | if (step + 1) % args.gradient_accumulation_steps == 0: 239 | if args.fp16: 240 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 241 | else: 242 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 243 | 244 | optimizer.step() 245 | scheduler.step() # Update learning rate schedule 246 | model.zero_grad() 247 | global_step += 1 248 | 249 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 250 | logs = {} 251 | if ( 252 | args.local_rank == -1 and args.evaluate_during_training 253 | ): # Only evaluate when single GPU otherwise metrics may not average well 254 | results = evaluate(args, model, tokenizer) 255 | for key, value in results.items(): 256 | eval_key = "eval_{}".format(key) 257 | logs[eval_key] = value 258 | 259 | loss_scalar = (tr_loss - logging_loss) / args.logging_steps 260 | learning_rate_scalar = scheduler.get_lr()[0] 261 | logs["learning_rate"] = learning_rate_scalar 262 | logs["loss"] = loss_scalar 263 | logging_loss = tr_loss 264 | 265 | for key, value in logs.items(): 266 | tb_writer.add_scalar(key, value, global_step) 267 | print(json.dumps({**logs, **{"step": global_step}})) 268 | 269 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 270 | # Save model checkpoint 271 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 272 | if not os.path.exists(output_dir): 273 | os.makedirs(output_dir) 274 | model_to_save = ( 275 | model.module if hasattr(model, "module") else model 276 | ) # Take care of distributed/parallel training 277 | model_to_save.save_pretrained(output_dir) 278 | tokenizer.save_pretrained(output_dir) 279 | 280 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 281 | logger.info("Saving model checkpoint to %s", output_dir) 282 | 283 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 284 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 285 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 286 | 287 | if args.max_steps > 0 and global_step > args.max_steps: 288 | epoch_iterator.close() 289 | break 290 | if args.max_steps > 0 and global_step > args.max_steps: 291 | train_iterator.close() 292 | break 293 | 294 | if args.local_rank in [-1, 0]: 295 | tb_writer.close() 296 | 297 | return global_step, tr_loss / global_step 298 | 299 | 300 | def evaluate(args, model, tokenizer, prefix="", finalEval=False): 301 | # Loop to handle MNLI double evaluation (matched, mis-matched) 302 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) 303 | eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,) 304 | 305 | results = {} 306 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 307 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) 308 | 309 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 310 | os.makedirs(eval_output_dir) 311 | 312 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 313 | # Note that DistributedSampler samples randomly 314 | eval_sampler = SequentialSampler(eval_dataset) 315 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 316 | 317 | # multi-gpu eval 318 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 319 | model = torch.nn.DataParallel(model) 320 | 321 | # Eval! 322 | logger.info("***** Running evaluation {} *****".format(prefix)) 323 | logger.info(" Num examples = %d", len(eval_dataset)) 324 | logger.info(" Batch size = %d", args.eval_batch_size) 325 | eval_loss = 0.0 326 | nb_eval_steps = 0 327 | preds = None 328 | out_label_ids = None 329 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 330 | model.eval() 331 | batch = tuple(t.to(args.device) for t in batch) 332 | 333 | with torch.no_grad(): 334 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} 335 | if args.model_type != "distilbert": 336 | inputs["token_type_ids"] = ( 337 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None 338 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 339 | outputs = model(**inputs) 340 | tmp_eval_loss, logits = outputs[:2] 341 | 342 | eval_loss += tmp_eval_loss.mean().item() 343 | nb_eval_steps += 1 344 | if preds is None: 345 | preds = logits.detach().cpu().numpy() 346 | out_label_ids = inputs["labels"].detach().cpu().numpy() 347 | else: 348 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 349 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 350 | 351 | eval_loss = eval_loss / nb_eval_steps 352 | if args.output_mode == "classification": 353 | preds_prob = softmax(preds,axis=1)[:,1]#np.max(softmax(preds,axis=1), axis=1) 354 | preds = np.argmax(preds, axis=1) 355 | elif args.output_mode == "regression": 356 | preds = np.squeeze(preds) 357 | result = compute_metrics(eval_task, preds, out_label_ids) 358 | result['loss'] = eval_loss 359 | results.update(result) 360 | 361 | if args.calc_final_alignments and finalEval: 362 | finalAlignmentPred.calc_final_alignments(args.data_dir,args.output_dir, preds, preds_prob) 363 | 364 | if args.calc_alignment_sim_mat and finalEval: 365 | finalAlignmentPred.calc_alignment_sim_mat(args.data_dir,args.output_dir, preds_prob) 366 | 367 | 368 | output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") 369 | with open(output_eval_file, "w") as writer: 370 | logger.info("***** Eval results {} *****".format(prefix)) 371 | for key in sorted(result.keys()): 372 | logger.info(" %s = %s", key, str(result[key])) 373 | writer.write("%s = %s\n" % (key, str(result[key]))) 374 | 375 | return results 376 | 377 | 378 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 379 | if args.local_rank not in [-1, 0] and not evaluate: 380 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 381 | 382 | processor = processors[task]() 383 | output_mode = output_modes[task] 384 | # Load data features from cache or dataset file 385 | cached_features_file = os.path.join( 386 | args.data_dir, 387 | "cached_{}_{}_{}_{}".format( 388 | "dev" if evaluate else "train", 389 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 390 | str(args.max_seq_length), 391 | str(task), 392 | ), 393 | ) 394 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 395 | logger.info("Loading features from cached file %s", cached_features_file) 396 | features = torch.load(cached_features_file) 397 | else: 398 | logger.info("Creating features from dataset file at %s", args.data_dir) 399 | label_list = processor.get_labels() 400 | if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]: 401 | # HACK(label indices are swapped in RoBERTa pretrained model) 402 | label_list[1], label_list[2] = label_list[2], label_list[1] 403 | examples = ( 404 | processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) 405 | ) 406 | features = convert_examples_to_features( 407 | examples, 408 | tokenizer, 409 | label_list=label_list, 410 | max_length=args.max_seq_length, 411 | output_mode=output_mode, 412 | pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet 413 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 414 | pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, 415 | ) 416 | if args.local_rank in [-1, 0]: 417 | logger.info("Saving features into cached file %s", cached_features_file) 418 | torch.save(features, cached_features_file) 419 | 420 | if args.local_rank == 0 and not evaluate: 421 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 422 | 423 | # Convert to Tensors and build dataset 424 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 425 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 426 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 427 | if output_mode == "classification": 428 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 429 | elif output_mode == "regression": 430 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 431 | 432 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 433 | return dataset 434 | 435 | 436 | def main(): 437 | parser = argparse.ArgumentParser() 438 | 439 | # Required parameters 440 | parser.add_argument( 441 | "--data_dir", 442 | default=None, 443 | type=str, 444 | required=True, 445 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.", 446 | ) 447 | parser.add_argument( 448 | "--model_type", 449 | default=None, 450 | type=str, 451 | required=True, 452 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 453 | ) 454 | parser.add_argument( 455 | "--model_name_or_path", 456 | default=None, 457 | type=str, 458 | required=True, 459 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), 460 | ) 461 | parser.add_argument( 462 | "--task_name", 463 | default=None, 464 | type=str, 465 | required=True, 466 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys()), 467 | ) 468 | parser.add_argument( 469 | "--output_dir", 470 | default=None, 471 | type=str, 472 | required=True, 473 | help="The output directory where the model predictions and checkpoints will be written.", 474 | ) 475 | 476 | # Other parameters 477 | parser.add_argument( 478 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name", 479 | ) 480 | parser.add_argument( 481 | "--tokenizer_name", 482 | default="", 483 | type=str, 484 | help="Pretrained tokenizer name or path if not the same as model_name", 485 | ) 486 | parser.add_argument( 487 | "--cache_dir", 488 | default="", 489 | type=str, 490 | help="Where do you want to store the pre-trained models downloaded from s3", 491 | ) 492 | parser.add_argument( 493 | "--max_seq_length", 494 | default=128, 495 | type=int, 496 | help="The maximum total input sequence length after tokenization. Sequences longer " 497 | "than this will be truncated, sequences shorter will be padded.", 498 | ) 499 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 500 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") 501 | parser.add_argument( 502 | "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.", 503 | ) 504 | parser.add_argument( 505 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", 506 | ) 507 | 508 | parser.add_argument( 509 | "--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.", 510 | ) 511 | parser.add_argument( 512 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.", 513 | ) 514 | parser.add_argument( 515 | "--gradient_accumulation_steps", 516 | type=int, 517 | default=1, 518 | help="Number of updates steps to accumulate before performing a backward/update pass.", 519 | ) 520 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 521 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 522 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 523 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 524 | parser.add_argument( 525 | "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.", 526 | ) 527 | parser.add_argument( 528 | "--max_steps", 529 | default=-1, 530 | type=int, 531 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 532 | ) 533 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 534 | 535 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") 536 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 537 | parser.add_argument( 538 | "--eval_all_checkpoints", 539 | action="store_true", 540 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", 541 | ) 542 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 543 | parser.add_argument( 544 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory", 545 | ) 546 | parser.add_argument( 547 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets", 548 | ) 549 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 550 | 551 | parser.add_argument( 552 | "--fp16", 553 | action="store_true", 554 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 555 | ) 556 | parser.add_argument( 557 | "--fp16_opt_level", 558 | type=str, 559 | default="O1", 560 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 561 | "See details at https://nvidia.github.io/apex/amp.html", 562 | ) 563 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 564 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 565 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 566 | 567 | parser.add_argument( 568 | "--calc_final_alignments", action="store_true", help="Set this flag if you want to calculate final aligments.", 569 | ) 570 | 571 | parser.add_argument( 572 | "--calc_alignment_sim_mat", action="store_true", help="Set this flag if you want to calculate alignment_sim_mat.", 573 | ) 574 | 575 | args = parser.parse_args() 576 | 577 | if ( 578 | os.path.exists(args.output_dir) 579 | and os.listdir(args.output_dir) 580 | and args.do_train 581 | and not args.overwrite_output_dir 582 | ): 583 | raise ValueError( 584 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 585 | args.output_dir 586 | ) 587 | ) 588 | 589 | # Setup distant debugging if needed 590 | if args.server_ip and args.server_port: 591 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 592 | import ptvsd 593 | 594 | print("Waiting for debugger attach") 595 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 596 | ptvsd.wait_for_attach() 597 | 598 | # Setup CUDA, GPU & distributed training 599 | if args.local_rank == -1 or args.no_cuda: 600 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 601 | args.n_gpu = torch.cuda.device_count() 602 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 603 | torch.cuda.set_device(args.local_rank) 604 | device = torch.device("cuda", args.local_rank) 605 | torch.distributed.init_process_group(backend="nccl") 606 | args.n_gpu = 1 607 | args.device = device 608 | 609 | # Setup logging 610 | logging.basicConfig( 611 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 612 | datefmt="%m/%d/%Y %H:%M:%S", 613 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 614 | ) 615 | logger.warning( 616 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 617 | args.local_rank, 618 | device, 619 | args.n_gpu, 620 | bool(args.local_rank != -1), 621 | args.fp16, 622 | ) 623 | 624 | # Set seed 625 | set_seed(args) 626 | 627 | # Prepare GLUE task 628 | args.task_name = args.task_name.lower() 629 | if args.task_name not in processors: 630 | raise ValueError("Task not found: %s" % (args.task_name)) 631 | processor = processors[args.task_name]() 632 | args.output_mode = output_modes[args.task_name] 633 | label_list = processor.get_labels() 634 | num_labels = len(label_list) 635 | 636 | # Load pretrained model and tokenizer 637 | if args.local_rank not in [-1, 0]: 638 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 639 | 640 | args.model_type = args.model_type.lower() 641 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 642 | if args.model_name_or_path != 'roberta-large-mnli': 643 | config = config_class.from_pretrained( 644 | args.config_name if args.config_name else args.model_name_or_path, 645 | num_labels=num_labels, 646 | finetuning_task=args.task_name, 647 | cache_dir=args.cache_dir if args.cache_dir else None, 648 | ) 649 | tokenizer = tokenizer_class.from_pretrained( 650 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 651 | do_lower_case=args.do_lower_case, 652 | cache_dir=args.cache_dir if args.cache_dir else None, 653 | ) 654 | model = model_class.from_pretrained( 655 | args.model_name_or_path, 656 | from_tf=bool(".ckpt" in args.model_name_or_path), 657 | config=config, 658 | cache_dir=args.cache_dir if args.cache_dir else None, 659 | ) 660 | 661 | 662 | 663 | 664 | 665 | # for num_labels(mnli) 666 | if args.model_name_or_path == 'roberta-large-mnli': 667 | num_labels_old = config_class.from_pretrained(args.model_name_or_path)._num_labels 668 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 669 | num_labels=num_labels_old, 670 | finetuning_task=args.task_name, 671 | cache_dir=args.cache_dir if args.cache_dir else None) 672 | tokenizer = tokenizer_class.from_pretrained( 673 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 674 | do_lower_case=args.do_lower_case, 675 | cache_dir=args.cache_dir if args.cache_dir else None) 676 | if num_labels != num_labels_old: 677 | config.num_labels = num_labels_old 678 | model = model_class.from_pretrained(args.model_name_or_path, 679 | from_tf=bool('.ckpt' in args.model_name_or_path), 680 | config=config, 681 | cache_dir=args.cache_dir if args.cache_dir else None) 682 | config.num_labels = num_labels 683 | logger.info('Reintializing model classifier layer...') 684 | model.num_labels = num_labels 685 | model.classifier = RobertaClassificationHead(config) 686 | 687 | else: 688 | model = model_class.from_pretrained(args.model_name_or_path, 689 | from_tf=bool('.ckpt' in args.model_name_or_path), 690 | config=config, 691 | cache_dir=args.cache_dir if args.cache_dir else None) 692 | 693 | 694 | 695 | # # num_added_toks = tokenizer.add_tokens(['[START]', '[END]']) 696 | # num_added_toks = tokenizer.add_special_tokens({'additional_special_tokens':['[START]', '[END]']}) 697 | # # num_added_toks = tokenizer.add_special_tokens({'additional_special_tokens': ['[CONTEXT]']}) 698 | # print('We have added', num_added_toks, 'tokens') 699 | # model.resize_token_embeddings(len(tokenizer)) 700 | 701 | 702 | 703 | 704 | 705 | 706 | if args.local_rank == 0: 707 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 708 | 709 | model.to(args.device) 710 | 711 | logger.info("Training/evaluation parameters %s", args) 712 | 713 | # Training 714 | if args.do_train: 715 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 716 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 717 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 718 | 719 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 720 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 721 | # Create output directory if needed 722 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 723 | os.makedirs(args.output_dir) 724 | 725 | logger.info("Saving model checkpoint to %s", args.output_dir) 726 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 727 | # They can then be reloaded using `from_pretrained()` 728 | model_to_save = ( 729 | model.module if hasattr(model, "module") else model 730 | ) # Take care of distributed/parallel training 731 | model_to_save.save_pretrained(args.output_dir) 732 | tokenizer.save_pretrained(args.output_dir) 733 | 734 | # Good practice: save your training arguments together with the trained model 735 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 736 | 737 | # Load a trained model and vocabulary that you have fine-tuned 738 | model = model_class.from_pretrained(args.output_dir) 739 | tokenizer = tokenizer_class.from_pretrained(args.output_dir) 740 | model.to(args.device) 741 | 742 | # Evaluation 743 | results = {} 744 | if args.do_eval and args.local_rank in [-1, 0]: 745 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 746 | checkpoints = [args.output_dir] 747 | if args.eval_all_checkpoints: 748 | checkpoints = list( 749 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) 750 | ) 751 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 752 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 753 | for checkpoint in checkpoints: 754 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" 755 | prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" 756 | 757 | model = model_class.from_pretrained(checkpoint) 758 | model.to(args.device) 759 | result = evaluate(args, model, tokenizer, prefix=prefix, finalEval = True) 760 | result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) 761 | results.update(result) 762 | 763 | return results 764 | 765 | 766 | if __name__ == "__main__": 767 | main() 768 | -------------------------------------------------------------------------------- /supervised_oie_wrapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oriern/SuperPAL/cb91393edb5af5dc73618378fa864ee2659bab45/supervised_oie_wrapper/__init__.py -------------------------------------------------------------------------------- /supervised_oie_wrapper/format_oie.py: -------------------------------------------------------------------------------- 1 | """ Usage: 2 | --in=INPUT_FILE --out=OUTPUT_FILE [--debug] 3 | """ 4 | # External imports 5 | import logging 6 | from pprint import pprint 7 | from pprint import pformat 8 | from docopt import docopt 9 | import json 10 | from collections import defaultdict 11 | from tqdm import tqdm 12 | from allennlp.pretrained import open_information_extraction_stanovsky_2018 13 | from allennlp.predictors.open_information_extraction import consolidate_predictions 14 | from allennlp.predictors.open_information_extraction import join_mwp 15 | from allennlp.predictors.open_information_extraction import make_oie_string 16 | from allennlp.predictors.open_information_extraction import get_predicate_text 17 | 18 | # Local imports 19 | 20 | #=---- 21 | 22 | class Mock_token: 23 | """ 24 | Spacy token imitation 25 | """ 26 | def __init__(self, tok_str): 27 | self.text = tok_str 28 | 29 | def __str__(self): 30 | return self.text 31 | 32 | def get_oie_frame(tokens, tags) -> str: 33 | """ 34 | Converts a list of model outputs (i.e., a list of lists of bio tags, each 35 | pertaining to a single word), returns an inline bracket representation of 36 | the prediction. 37 | """ 38 | frame = defaultdict(list) 39 | chunk = [] 40 | words = [token.text for token in tokens] 41 | 42 | for (token, tag) in zip(words, tags): 43 | if tag.startswith("I-") or tag.startswith("B-"): 44 | frame[tag[2:]].append(token) 45 | 46 | return dict(frame) 47 | 48 | 49 | def get_frame_str(oie_frame) -> str: 50 | """ 51 | Convert and oie frame dictionary to string. 52 | """ 53 | dummy_dict = dict([(k if k != "V" else "ARG01", v) 54 | for (k, v) in oie_frame.items()]) 55 | 56 | sorted_roles = sorted(dummy_dict) 57 | 58 | frame_str = [] 59 | for role in sorted_roles: 60 | if role == "ARG01": 61 | role = "V" 62 | arg = " ".join(oie_frame[role]) 63 | frame_str.append(f"{role}:{arg}") 64 | 65 | return "\t".join(frame_str) 66 | 67 | 68 | def format_extractions(sent_tokens, sent_predictions): 69 | """ 70 | Convert token-level raw predictions to clean extractions. 71 | """ 72 | # Consolidate predictions 73 | if not (len(set(map(len, sent_predictions))) == 1): 74 | raise AssertionError 75 | assert len(sent_tokens) == len(sent_predictions[0]) 76 | sent_str = " ".join(map(str, sent_tokens)) 77 | 78 | pred_dict = consolidate_predictions(sent_predictions, sent_tokens) 79 | 80 | # Build and return output dictionary 81 | results = [] 82 | all_tags = [] 83 | results_dict = {'verbs':[], 'words': [str(token) for token in sent_tokens]} 84 | 85 | for tags in pred_dict.values(): 86 | # Join multi-word predicates 87 | tags = join_mwp(tags) 88 | all_tags.append(tags) 89 | 90 | # Create description text 91 | oie_frame = get_oie_frame(sent_tokens, tags) 92 | 93 | # Add a predicate prediction to outputs. 94 | results.append("\t".join([sent_str, get_frame_str(oie_frame)])) 95 | results_dict['verbs'].append({'tags': tags}) 96 | 97 | return results, all_tags, results_dict -------------------------------------------------------------------------------- /supervised_oie_wrapper/run_oie.py: -------------------------------------------------------------------------------- 1 | """ Usage: 2 | --in=INPUT_FILE --batch-size=BATCH-SIZE --out=OUTPUT_FILE [--cuda-device=CUDA_DEVICE] [--debug] 3 | """ 4 | # External imports 5 | import logging 6 | from pprint import pprint 7 | from pprint import pformat 8 | from docopt import docopt 9 | import json 10 | import pdb 11 | from tqdm import tqdm 12 | from allennlp.pretrained import open_information_extraction_stanovsky_2018 13 | from collections import defaultdict 14 | from operator import itemgetter 15 | import functools 16 | import operator 17 | import torch 18 | 19 | model_oie = open_information_extraction_stanovsky_2018() 20 | model_oie._model.cuda(0) 21 | 22 | 23 | # Local imports 24 | from supervised_oie_wrapper.format_oie import format_extractions, Mock_token 25 | #=----- 26 | 27 | def chunks(l, n): 28 | """ 29 | Yield successive n-sized chunks from l. 30 | """ 31 | for i in range(0, len(l), n): 32 | yield l[i:i + n] 33 | 34 | def create_instances(model, sent): 35 | """ 36 | Convert a sentence into a list of instances. 37 | """ 38 | sent_tokens = model._tokenizer.tokenize(sent) 39 | 40 | # Find all verbs in the input sentence 41 | pred_ids = [i for (i, t) in enumerate(sent_tokens) 42 | if t.pos_ == "VERB" or t.pos_ == "AUX"] 43 | 44 | # Create instances 45 | instances = [{"sentence": sent_tokens, 46 | "predicate_index": pred_id} 47 | for pred_id in pred_ids] 48 | 49 | return instances 50 | 51 | def get_confidence(model, tag_per_token, class_probs): 52 | """ 53 | Get the confidence of a given model in a token list, using the class probabilities 54 | associated with this prediction. 55 | """ 56 | token_indexes = [model._model.vocab.get_token_index(tag, namespace = "labels") for tag in tag_per_token] 57 | 58 | # Get probability per tag 59 | probs = [class_prob[token_index] for token_index, class_prob in zip(token_indexes, class_probs)] 60 | 61 | # Combine (product) 62 | prod_prob = functools.reduce(operator.mul, probs) 63 | 64 | return prod_prob 65 | 66 | def run_oie(lines, batch_size=64, cuda_device=-1, debug=False): 67 | """ 68 | Run the OIE model and process the output. 69 | """ 70 | 71 | if debug: 72 | logging.basicConfig(level = logging.DEBUG) 73 | else: 74 | logging.basicConfig(level = logging.INFO) 75 | 76 | # Init OIE 77 | model = model_oie 78 | # model = open_information_extraction_stanovsky_2018() 79 | # 80 | # 81 | # # Move model to gpu, if requested 82 | # if cuda_device >= 0: 83 | # model._model.cuda(cuda_device) 84 | 85 | 86 | # process sentences 87 | logging.info("Processing sentences") 88 | oie_lines = [] 89 | oie_lines_dict = [] 90 | for chunk in tqdm(chunks(lines, batch_size)): 91 | oie_inputs = [] 92 | sentTokensList = [] 93 | for sent_idx ,sent in enumerate(chunk): 94 | # if len(sent) > 20000: #if sentence is too long for memory (probably garbage sentence) 95 | # sent = '' 96 | pred_instance = create_instances(model, sent) 97 | oie_inputs.extend(pred_instance) 98 | 99 | sentTokensList.append(" ".join([str(token) for token in model._tokenizer.tokenize(sent)])) 100 | 101 | 102 | # Run oie on sents 103 | sent_preds = [] 104 | if oie_inputs: 105 | sent_preds = model.predict_batch_json(oie_inputs) 106 | 107 | # Collect outputs in batches 108 | predictions_by_sent = defaultdict(list) 109 | for outputs in sent_preds: 110 | sent_tokens = outputs["words"] 111 | tags = outputs["tags"] 112 | sent_str = " ".join(sent_tokens) 113 | assert(len(sent_tokens) == len(tags)) 114 | predictions_by_sent[sent_str].append((outputs["tags"], outputs["class_probabilities"])) 115 | 116 | 117 | # Create extractions by sentence 118 | for sent_tokens in sentTokensList: 119 | if sent_tokens not in predictions_by_sent: # handle sentences without predicate 120 | oie_lines.extend([None]) 121 | oie_lines_dict.extend([None]) 122 | continue 123 | predictions_for_sent = predictions_by_sent[sent_tokens] 124 | raw_tags = list(map(itemgetter(0), predictions_for_sent)) 125 | class_probs = list(map(itemgetter(1), predictions_for_sent)) 126 | 127 | # Compute confidence per extraction 128 | confs = [get_confidence(model, tag_per_token, class_prob) 129 | for tag_per_token, class_prob in zip(raw_tags, class_probs)] 130 | 131 | extractions, tags, results_dict = format_extractions([Mock_token(tok) for tok in sent_tokens.split(" ")], raw_tags) 132 | 133 | oie_lines.extend([extraction + f"\t{conf}" for extraction, conf in zip(extractions, confs)]) 134 | oie_lines_dict.extend([results_dict]) 135 | 136 | logging.info("DONE") 137 | return oie_lines, oie_lines_dict 138 | 139 | if __name__ == "__main__": 140 | # Parse command line arguments 141 | args = docopt(__doc__) 142 | inp_fn = args["--in"] 143 | batch_size = int(args["--batch-size"]) 144 | out_fn = args["--out"] 145 | cuda_device = int(args["--cuda-device"]) if (args["--cuda-device"] is not None) \ 146 | else -1 147 | debug = args["--debug"] 148 | 149 | lines = [line.strip() 150 | for line in open(inp_fn, encoding = "utf8")] 151 | 152 | oie_lines = run_oie(lines, batch_size, cuda_device, debug) 153 | 154 | # Write to file 155 | logging.info(f"Writing output to {out_fn}") 156 | with open(out_fn, "w", encoding = "utf8") as fout: 157 | fout.write("\n".join(oie_lines)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | #sys.path.append('/home/nlp/ernstor1/rouge/SummEval_referenceSubsets/code_score_extraction/') 3 | 4 | from allennlp.predictors.predictor import Predictor 5 | import csv 6 | import argparse 7 | import subprocess 8 | from nltk import tokenize 9 | from nltk.parse import CoreNLPParser 10 | from rouge import Rouge 11 | from bert_score import score 12 | import requests 13 | # import ipdb 14 | import ast 15 | import glob 16 | import os 17 | import logging 18 | import copy 19 | import hashlib 20 | import json 21 | from supervised_oie_wrapper.run_oie import run_oie 22 | # import createRougeDataset 23 | # import calculateRouge 24 | import numpy as np 25 | import shutil 26 | from filterContained import * 27 | from tqdm import tqdm 28 | from itertools import chain 29 | from collections import defaultdict 30 | 31 | 32 | def str2bool(v): 33 | return v.lower() in ('true') 34 | 35 | 36 | def hashhex(s): 37 | """Returns a heximal formated SHA1 hash of the input string.""" 38 | s = s.encode('utf-8') 39 | h = hashlib.sha1() 40 | h.update(s) 41 | return h.hexdigest() 42 | 43 | 44 | # metrics_data = {} 45 | # 46 | predictor = Predictor.from_path("https://s3-us-west-2.amazonaws.com/allennlp/models/bert-base-srl-2019.06.17.tar.gz") 47 | nlp_parser = CoreNLPParser() # (url='http://nlp3.cs.unc.edu:9000') 48 | # rouge = Rouge() 49 | # 50 | # DATASETS = ['duc2004', 'duc2007', 'MultiNews'] 51 | 52 | 53 | 54 | def read_csv_data(csv_file): 55 | """ Reader to parse the csv file""" 56 | data = [] 57 | with open(args.input_file_path, encoding='utf-8', errors='ignore') as f: 58 | csv_reader = csv.reader(f, delimiter=',') 59 | for ind, row in enumerate(csv_reader): 60 | if ind == 0: 61 | header = row 62 | else: 63 | data.append(row) 64 | return header, data 65 | 66 | 67 | def read_generic_file(filepath): 68 | """ reads any generic text file into 69 | list containing one line as element 70 | """ 71 | text = [] 72 | with open(filepath, 'r') as f: 73 | for line in f.read().splitlines(): 74 | text.append(line.strip()) 75 | return text 76 | 77 | 78 | def calculate_metric_scores(cands, refs): 79 | """ calculate Rouge-1 precision, Bert precision 80 | and Entailment Scores 81 | """ 82 | # calculate rouge-1 precision 83 | rouge = Rouge() 84 | rouge1_p = [] 85 | for r, c in tqdm(zip(refs, cands)): 86 | r = " ".join(list(nlp_parser.tokenize(r))).lower() 87 | c = " ".join(list(nlp_parser.tokenize(c))).lower() 88 | scores = rouge.get_scores(c, r)[0] 89 | rouge1_p.append(round(scores['rouge-1']['p'], 4)) 90 | # calculate bert precision 91 | P, _, _ = score(cands, refs, lang='en', verbose=True) 92 | P = [round(x, 4) for x in P.tolist()] 93 | ## calculate entaiment score 94 | url = 'http://localhost:5003/roberta_mnli_classifier' # 'http://nlp1.cs.unc.edu:5003/roberta_mnli_classifier' 95 | mnli_data = [] 96 | for p, h in zip(refs, cands): 97 | mnli_data.append({'premise': p, 'hypo': h}) 98 | r = requests.post(url, json=mnli_data) 99 | results = r.json() 100 | ent_scores = [] 101 | for ind, d in enumerate(results): 102 | ent_scores.append(float(d['entailment'])) 103 | 104 | return rouge1_p, P, ent_scores 105 | 106 | 107 | 108 | 109 | 110 | def generate_scu(sentence, max_scus=5): 111 | """ Given a scu sentence retrieve SCUs""" 112 | 113 | srl = predictor.predict(sentence=sentence['scuSentence']) 114 | # ipdb.set_trace() 115 | scus = srl['verbs'] 116 | scu_list = [] 117 | tokens = srl['words'] 118 | for scu in scus: 119 | tags = scu['tags'] 120 | words = [] 121 | if not ("B-ARG1" in tags or "B-ARG2" in tags or "B-ARG0" in tags): 122 | continue 123 | scu_start_offset = None 124 | for ind, tag in enumerate(tags): 125 | # if "ARG0" in tag or "ARG1" in tag or "V" in tag: 126 | if "O" not in tag: 127 | if scu_start_offset is None: 128 | if ind == 0: 129 | scu_start_offset = sentence['scuSentCharIdx'] + ind 130 | else: 131 | scu_start_offset = sentence['scuSentCharIdx'] + len(" ".join(tokens[:ind])) 132 | else: 133 | scu_end_offset = sentence['scuSentCharIdx'] + len(" ".join(tokens[:ind + 1])) 134 | words.append(tokens[ind]) 135 | 136 | if len(words) <= 4: 137 | continue 138 | tmp = copy.deepcopy(sentence) 139 | tmp['scuText'] = " ".join(words) 140 | tmp['scuOffsets'] = (scu_start_offset, scu_end_offset) 141 | scu_list.append(tmp) 142 | # select the best SCU 143 | # sort SCUs based on their length and select middle one 144 | scu_list = sorted(scu_list, key=lambda x: len(x['scuText'].split()), reverse=True) 145 | # print(f"Best SCU:::{scu_list[int(len(scu_list)/2)]}") 146 | # return scu_list[int(len(scu_list)/2)] 147 | return scu_list[:max_scus] 148 | 149 | 150 | def generate_scu_oie(sentence, max_scus=5, doc_summ='summ'): 151 | """ Given a scu sentence retrieve SCUs""" 152 | 153 | if doc_summ=='summ': 154 | KEY_sent = 'scuSentence' 155 | KEY_sent_char_idx = 'scuSentCharIdx' 156 | KEY_scu_text = 'scuText' 157 | KEY_scu_offset = 'scuOffsets' 158 | else: 159 | KEY_sent = 'docSentText' 160 | KEY_sent_char_idx = 'docSentCharIdx' 161 | KEY_scu_text = 'docScuText' 162 | KEY_scu_offset = 'docScuOffsets' 163 | 164 | _, oie = run_oie([sentence[KEY_sent]]) 165 | 166 | # ipdb.set_trace() 167 | if not oie: #if list is empty 168 | return oie 169 | else: 170 | oie = oie[0] 171 | scus = oie['verbs'] 172 | scu_list = [] 173 | tokens = oie['words'] 174 | for scu in scus: 175 | tags = scu['tags'] 176 | words = [] 177 | if not ("B-ARG1" in tags or "B-ARG2" in tags or "B-ARG0" in tags): 178 | continue 179 | scu_start_offset = None 180 | for ind, tag in enumerate(tags): 181 | # if "ARG0" in tag or "ARG1" in tag or "V" in tag: 182 | if "O" not in tag: 183 | if scu_start_offset is None: 184 | if ind == 0: 185 | scu_start_offset = sentence[KEY_sent_char_idx] + ind 186 | else: 187 | scu_start_offset = sentence[KEY_sent_char_idx] + len(" ".join(tokens[:ind])) 188 | else: 189 | scu_end_offset = sentence[KEY_sent_char_idx] + len(" ".join(tokens[:ind + 1])) 190 | words.append(tokens[ind]) 191 | 192 | # if len(words) <= 3: 193 | # continue 194 | tmp = copy.deepcopy(sentence) 195 | tmp[KEY_scu_text] = " ".join(words) 196 | tmp[KEY_scu_offset] = (scu_start_offset, scu_end_offset) 197 | scu_list.append(tmp) 198 | # select the best SCU 199 | # sort SCUs based on their length and select middle one 200 | scu_list = sorted(scu_list, key=lambda x: len(x[KEY_scu_text].split()), reverse=True) 201 | # print(f"Best SCU:::{scu_list[int(len(scu_list)/2)]}") 202 | # return scu_list[int(len(scu_list)/2)] 203 | return scu_list[:max_scus] 204 | 205 | def generate_scu_oie_multiSent(sentences, doc_summ='summ'): 206 | """ Given a scu sentence retrieve SCUs (OIEs) 207 | 208 | The input should be a list of dictionaries with the following fields: 209 | 'scuSentence' #sentence text 210 | 'scuSentCharIdx' # character offset of the beginning of the sentence w.r.t the beginning of the document 211 | 'scuText' # The OIE text would be written here. 212 | 'scuOffsets' # The character offset of the OIE w.r.t the beginning of the document would be written here 213 | 214 | """ 215 | 216 | if doc_summ=='summ': 217 | KEY_sent = 'scuSentence' 218 | KEY_sent_char_idx = 'scuSentCharIdx' 219 | KEY_scu_text = 'scuText' 220 | KEY_scu_offset = 'scuOffsets' 221 | else: 222 | KEY_sent = 'docSentText' 223 | KEY_sent_char_idx = 'docSentCharIdx' 224 | KEY_scu_text = 'docScuText' 225 | KEY_scu_offset = 'docScuOffsets' 226 | 227 | _, oies = run_oie([sentence[KEY_sent] for sentence in sentences], cuda_device = 0) 228 | #adaptation for srl 229 | # oies = [] 230 | # for sentence in sentences: 231 | # oies.append(predictor.predict(sentence = sentence[KEY_sent] )) 232 | 233 | 234 | scu_list = [] 235 | assert(len(sentences) == len(oies)) 236 | for sentence ,oie in zip(sentences,oies): 237 | sentence[KEY_sent] = sentence[KEY_sent].replace(u'\u00a0', ' ') 238 | # ipdb.set_trace() 239 | if not oie: # if list is empty 240 | continue 241 | 242 | # if sentence[KEY_sent] =='Johnson\'s new TV show, ``The Magic Hour,\'\' is just one aspect of a busy life: -- HIS HEALTH: While by no means cured, he owes the appearance of remarkable health to a Spartan lifestyle and modern medicine.': 243 | # print('here') 244 | scus = oie['verbs'] 245 | in_sentence_scu_dict = {} 246 | tokens = oie['words'] 247 | for scu in scus: 248 | tags = scu['tags'] 249 | words = [] 250 | if not ("B-ARG1" in tags or "B-ARG2" in tags or "B-ARG0" in tags): 251 | continue 252 | sub_scu_offsets = [] 253 | scu_start_offset = None 254 | offset = 0 255 | initialSpace = 0 256 | while sentence[KEY_sent][offset + initialSpace] == ' ': 257 | initialSpace += 1 ## add space if exists, so 'offset' would start from next token and not from space 258 | offset += initialSpace 259 | for ind, tag in enumerate(tags): 260 | # if "ARG0" in tag or "ARG1" in tag or "V" in tag: 261 | assert (sentence[KEY_sent][offset] == tokens[ind][0]) 262 | if "O" not in tag: 263 | if scu_start_offset is None: 264 | scu_start_offset = sentence[KEY_sent_char_idx] + offset 265 | 266 | assert(sentence[KEY_sent][offset] == tokens[ind][0]) 267 | 268 | words.append(tokens[ind]) 269 | else: #if sub-scu is finished (we get 'O' tag) 270 | if scu_start_offset is not None: 271 | spaceBeforeToken = 0 272 | while sentence[KEY_sent][offset-1-spaceBeforeToken] == ' ': 273 | spaceBeforeToken += 1## add space if exists 274 | if sentence[KEY_sent][offset] == '.' or sentence[KEY_sent][offset] == '?': 275 | dotAfter = 1 + spaceAfterToken 276 | dotTest = 1 277 | else: 278 | dotAfter = 0 279 | dotTest = 0 280 | scu_end_offset = sentence[KEY_sent_char_idx] + offset - spaceBeforeToken + dotAfter 281 | 282 | if dotTest: 283 | assert (sentence[KEY_sent][offset - spaceBeforeToken + dotAfter -1] == tokens[ind-1+ dotTest][0]) #check only the dot, the start of the token 284 | else: 285 | assert (sentence[KEY_sent][offset - spaceBeforeToken + dotAfter - 1] == tokens[ind - 1 + dotTest][-1]) #check end of token 286 | sub_scu_offsets.append([scu_start_offset, scu_end_offset]) 287 | scu_start_offset = None 288 | 289 | 290 | ## update offset 291 | 292 | offset += len(tokens[ind]) 293 | if ind < len(tags) - 1: #if not last token 294 | spaceAfterToken = 0 295 | while sentence[KEY_sent][offset + spaceAfterToken] == ' ': 296 | spaceAfterToken += 1## add space after token if exists, so 'offset' would start from next token and not from space 297 | offset += spaceAfterToken 298 | 299 | if scu_start_offset is not None: #end of sentence 300 | scu_end_offset = sentence[KEY_sent_char_idx] + offset 301 | sub_scu_offsets.append([scu_start_offset, scu_end_offset]) 302 | scu_start_offset = None 303 | 304 | 305 | 306 | # if len(words) <= 3: 307 | # continue 308 | scuText = "...".join([sentence[KEY_sent][strt_end_indx[0] - sentence[KEY_sent_char_idx]:strt_end_indx[1] - sentence[KEY_sent_char_idx]] for strt_end_indx in sub_scu_offsets]) 309 | #assert(scuText==" ".join([sentence[KEY_sent][strt_end_indx[0]:strt_end_indx[1]] for strt_end_indx in sub_scu_offsets])) 310 | in_sentence_scu_dict[scuText] = sub_scu_offsets 311 | 312 | notContainedDict = checkContained(in_sentence_scu_dict, sentence[KEY_sent], sentence[KEY_sent_char_idx]) 313 | 314 | 315 | for scuText, binaryNotContained in notContainedDict.items(): 316 | scu_offsets = in_sentence_scu_dict[scuText] 317 | if binaryNotContained: 318 | tmp = copy.deepcopy(sentence) 319 | tmp[KEY_scu_text] = scuText 320 | tmp[KEY_scu_offset] = scu_offsets 321 | scu_list.append(tmp) 322 | # select the best SCU 323 | # sort SCUs based on their length and select middle one 324 | # scu_list = sorted(scu_list, key=lambda x: len(x[KEY_scu_text].split()), reverse=True) 325 | # print(f"Best SCU:::{scu_list[int(len(scu_list)/2)]}") 326 | # return scu_list[int(len(scu_list)/2)] 327 | return scu_list 328 | 329 | 330 | 331 | 332 | def word_aligner(sent1, sent2): 333 | """ wrapper which calls the monolingual 334 | word aligner and gives the alignment scores between 335 | sent1 and sent2 336 | """ 337 | ## tokenize 338 | sent1_tok = " ".join(list(nlp_parser.tokenize(sent1))) 339 | sent2_tok = " ".join(list(nlp_parser.tokenize(sent2))) 340 | 341 | ## create a subprocess to call the word aligner 342 | process = subprocess.Popen(['python2', 'predict_align.py', '--s1', sent1_tok, 343 | '--s2', sent2_tok], stdout=subprocess.PIPE, 344 | cwd='/ssd-playpen/home/ram/monolingual-word-aligner') 345 | output, error = process.communicate() 346 | ## parse the output 347 | output = output.decode('utf-8') 348 | output = output.split('\n') 349 | return ast.literal_eval(output[0]), ast.literal_eval(output[1]), sent1_tok, sent2_tok 350 | 351 | 352 | 353 | def write_doc_scus(doc_sents, doc_sent_dir): 354 | 355 | 356 | if not os.path.exists(doc_sent_dir): 357 | os.makedirs(doc_sent_dir) 358 | for sent_idx, sentence in enumerate(doc_sents): 359 | html_path = os.path.join(doc_sent_dir, 'D061.M.250.J.' + str(sent_idx)+'.html') 360 | with open(html_path, 'w') as f: 361 | f.write(sentence) 362 | 363 | return len(doc_sents) 364 | 365 | def write_summ_scus(summ_sents, summ_sent_dir): 366 | for sent_idx, sentence in enumerate(summ_sents): 367 | sent_dir = os.path.join(summ_sent_dir, str(sent_idx)) 368 | if not os.path.exists(sent_dir): 369 | os.makedirs(sent_dir) 370 | html_path = os.path.join(sent_dir, 'D061.M.250.J.A' + '.html') 371 | with open(html_path, 'w') as f: 372 | f.write(sentence) 373 | return len(summ_sents) 374 | 375 | 376 | 377 | 378 | # def calc_rouge_mat(summ_scus, doc_scus): 379 | # DOC_SENT_DIR = '/home/nlp/ernstor1/tmp/doc_sent_dir' 380 | # SUMM_SENT_DIR = '/home/nlp/ernstor1/tmp/summ_sent_dir' 381 | # 382 | # num_doc_scus = write_doc_scus(doc_scus, DOC_SENT_DIR) 383 | # num_summ_scus = write_summ_scus(summ_scus, SUMM_SENT_DIR) 384 | # rouge_mat = np.zeros((num_summ_scus, num_doc_scus)) 385 | # 386 | # for summ_dir in os.listdir(SUMM_SENT_DIR): 387 | # INPUTS = [(calculateRouge.COMPARE_SAME_LEN, os.path.join(SUMM_SENT_DIR,summ_dir),DOC_SENT_DIR, 388 | # None, None, calculateRouge.REMOVE_STOP_WORDS)] 389 | # 390 | # compareType, refFolder, sysFolder, outputPath, ducVersion, stopWordsRemoval = INPUTS[0] 391 | # 392 | # # get the different options: 393 | # taskNames, systemNames, summaryLengths = calculateRouge.getComparisonOptions(sysFolder, refFolder) 394 | # # get ROUGE scores: 395 | # allData = calculateRouge.runRougeCombinations(compareType, sysFolder, refFolder, systemNames, 396 | # summaryLengths, 397 | # ducVersion, stopWordsRemoval) 398 | # # calculate R1,R2,RL average 399 | # rouge_vec = createRougeDataset.extractRouge(allData, systemNames, summaryLengths) 400 | # rouge_mat[int(summ_dir),:] = rouge_vec 401 | # 402 | # 403 | # # remove tmp dirs 404 | # shutil.rmtree(DOC_SENT_DIR) 405 | # shutil.rmtree(SUMM_SENT_DIR) 406 | # 407 | # return rouge_mat 408 | 409 | def saveSCUsToCsv(scus, outputFilePath): 410 | # Outputs the selected SCUs to the output CSV path specified 411 | with open(outputFilePath, mode='w', newline='') as outFile: 412 | csvWriter = csv.writer(outFile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 413 | csvWriter.writerow(['db', 'topic', 'sentCharIdx', 'sentence', 'offsets', 'scu']) 414 | # print in order of hitID (so that the sentence order is kept in the output CSV): 415 | for scu in scus: 416 | db = scu['database'] 417 | topic = scu['summaryFile'] 418 | sentCharIdx = scu['scuSentCharIdx'] 419 | # sentId= annoPerHIT[hitId]['sentCharIdx'] 420 | sentence = scu['scuSentence'] 421 | scu_text = scu['scuText'] 422 | offsetsStr = ';'.join( 423 | ', '.join(map(str, offset)) for offset in [scu['scuOffsets']]) 424 | csvWriter.writerow([db, topic, sentCharIdx, sentence, offsetsStr, scu_text]) 425 | 426 | 427 | def saveSCU_SentFilteredPairsToCsv(scu_sent_pairs, outputCsvFilepath): 428 | # output fields: 429 | # db, topic, summaryFile, scuSentCharIdx, scuSentence, scuOffsets, documentFile, docSentCharIdx, scuText, docSentText, isAligned 430 | 431 | with open(outputCsvFilepath, 'w', newline='') as fOut: 432 | csvWriter = csv.writer(fOut, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 433 | csvWriter.writerow( 434 | ['db', 'topic', 'summaryFile', 'scuSentCharIdx', 'scuSentence', 'scuOffsets', 'documentFile', 'docSentCharIdx', 435 | 'scuText', 'docSentText', 'isAligned']) 436 | 437 | for scu, doc_sent in scu_sent_pairs: 438 | db = scu['database'] 439 | topic = scu['topic'] 440 | summaryFile = scu['summaryFile'] 441 | scuSentCharIdx = scu['scuSentCharIdx'] 442 | # sentId= annoPerHIT[hitId]['sentCharIdx'] 443 | scuSentence = scu['scuSentence'] 444 | scuText = scu['scuText'] 445 | scuOffsets = ';'.join( 446 | ', '.join(map(str, offset)) for offset in [scu['scuOffsets']]) 447 | 448 | documentFile = doc_sent['documentFile'] 449 | docSentCharIdx = doc_sent['docSentCharIdx'] 450 | docSentText = doc_sent['docSentText'] 451 | answer = 1 452 | 453 | csvWriter.writerow([db, topic, summaryFile, scuSentCharIdx, scuSentence, scuOffsets, documentFile, 454 | docSentCharIdx, scuText, docSentText, answer]) 455 | 456 | 457 | def intersectionOverUnion(offset1, offset2): 458 | ranges1 = [range(marking[0], marking[1]) for marking in offset1] 459 | ranges1 = set(chain(*ranges1)) 460 | ranges2 = [range(marking[0], marking[1]) for marking in offset2] 461 | ranges2 = set(chain(*ranges2)) 462 | return len(ranges1 & ranges2) / len(ranges1 | ranges2) 463 | 464 | 465 | 466 | def Union(offsets, sentOffsets): 467 | ranges_tmp = set([]) 468 | for offset, sentOffset in zip(offsets, sentOffsets): 469 | offset = offset_str2list(offset) 470 | offset = offset_decreaseSentOffset(sentOffset, offset) 471 | ranges = [range(marking[0], marking[1]) for marking in offset] 472 | ranges = set(chain(*ranges)) 473 | ranges_tmp = ranges_tmp | ranges 474 | return ranges_tmp 475 | 476 | 477 | 478 | 479 | 480 | 481 | def offset_str2list(offset): 482 | return [[int(start_end) for start_end in offset.split(',')] for offset in offset.split(';')] 483 | 484 | def offset_list2str(list): 485 | return ';'.join(', '.join(map(str, offset)) for offset in list) 486 | 487 | def offset_decreaseSentOffset(sentOffset, scu_offsets): 488 | return [[start_end[0] - sentOffset, start_end[1] - sentOffset] for start_end in scu_offsets] 489 | 490 | def chunks_new(lst, n): 491 | """Yield successive n-sized chunks from lst.""" 492 | for i in range(0, len(lst), n): 493 | yield lst[i:i + n] 494 | 495 | 496 | 497 | 498 | --------------------------------------------------------------------------------