├── .gitignore ├── LICENSE ├── README.md ├── analyze.py ├── conll.py ├── evaluate.py ├── experiments.conf ├── higher_order.py ├── metrics.py ├── model.py ├── predict.py ├── preprocess.py ├── requirements.txt ├── run.py ├── setup_data.sh ├── tensorize.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .pyc 4 | tmp.py 5 | conll-2012 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Liyan Xu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Coreference Resolution with Different Higher-Order Inference Methods 2 | 3 | This repository contains the implementation of the paper: [Revealing the Myth of Higher-Order Inference in Coreference Resolution](https://www.aclweb.org/anthology/2020.emnlp-main.686.pdf). 4 | 5 | ## Architecture 6 | 7 | The basic end-to-end coreference model is a PyTorch re-implementation based on the TensorFlow model following similar preprocessing (see this [repository](https://github.com/mandarjoshi90/coref)). 8 | 9 | There are four higher-order inference (HOI) methods experimented: **Attended Antecedent**, **Entity Equalization**, **Span Clustering**, and **Cluster Merging**. All are included here except for Entity Equalization which is experimented in the equivalent TensorFlow environment (see this separate [repository](https://github.com/lxucs/coref-ee)). 10 | 11 | **Files**: 12 | * [run.py](run.py): training and evaluation 13 | * [model.py](model.py): the coreference model 14 | * [higher_order.py](higher_order.py): higher-order inference modules 15 | * [predict.py](predict.py): script for prediction on custom input 16 | * [analyze.py](analyze.py): result analysis 17 | * [preprocess.py](preprocess.py): converting CoNLL files to examples 18 | * [tensorize.py](tensorize.py): tensorizing example 19 | * [conll.py](conll.py), [metrics.py](metrics.py): same CoNLL-related files from the [repository](https://github.com/mandarjoshi90/coref) 20 | * [experiments.conf](experiments.conf): different model configurations 21 | 22 | ## Basic Setup 23 | Set up environment and data for training and evaluation: 24 | * Install Python3 dependencies: `pip install -r requirements.txt` 25 | * Create a directory for data that will contain all data files, models and log files; set `data_dir = /path/to/data/dir` in [experiments.conf](experiments.conf) 26 | * Prepare dataset (requiring [OntoNotes 5.0](https://catalog.ldc.upenn.edu/LDC2013T19) corpus): `./setup_data.sh /path/to/ontonotes /path/to/data/dir` 27 | 28 | For SpanBERT, download the pretrained weights from this [repository](https://github.com/facebookresearch/SpanBERT), and rename it `/path/to/data/dir/spanbert_base` or `/path/to/data/dir/spanbert_large` accordingly. 29 | 30 | ## Evaluation 31 | Provided trained models: 32 | * SpanBERT + no HOI: [FILE](https://drive.google.com/file/d/1fjHrRT98XzvNSrzhJnydQBgKW0QodvGK) 33 | * SpanBERT + Attended Antecedent: [FILE](https://drive.google.com/file/d/1qTrTYM2aEocvrO-cq2kcL64NFZmRUn6Z) 34 | * SpanBERT + Span Clustering: [FILE](https://drive.google.com/file/d/1NAXbCbJBPtPYBj3lttPnlioBoLdxjKoc) 35 | * SpanBERT + Cluster Merging: [FILE](https://drive.google.com/file/d/1ZdT9QjIwJxCGZjj7utFQdtjvFlwwdnP2) 36 | * SpanBERT + Entity Equalization: see [repository](https://github.com/lxucs/coref-ee) 37 | 38 | The name of each directory corresponds with a **configuration** in [experiments.conf](experiments.conf). Each directory has two trained models inside. 39 | 40 | If you want to use the official evaluator, download and unzip [conll 2012 scorer](https://drive.google.com/file/d/1UeDIAFFNpJXfSH-PvOvacA60mC-XRDk5) under this directory. 41 | 42 | Evaluate a model on the dev/test set: 43 | * Download the corresponding model directory and unzip it under `data_dir` 44 | * `python evaluate.py [config] [model_id] [gpu_id]` 45 | * e.g. Attended Antecedent:`python evaluate.py train_spanbert_large_ml0_d2 May08_12-38-29_58000 0` 46 | 47 | ## Prediction 48 | Prediction on custom input: see `python predict.py -h` 49 | * Interactive user input: `python predict.py --config_name=[config] --model_identifier=[model_id] --gpu_id=[gpu_id]` 50 | * E.g. `python predict.py --config_name=train_spanbert_large_ml0_d1 --model_identifier=May10_03-28-49_54000 --gpu_id=0` 51 | * Input from file (jsonlines file of this [format](https://github.com/mandarjoshi90/coref#batched-prediction-instructions)): `python predict.py --config_name=[config] --model_identifier=[model_id] --gpu_id=[gpu_id] --jsonlines_path=[input_path] --output_path=[output_path]` 52 | ## Training 53 | `python run.py [config] [gpu_id]` 54 | 55 | * [config] can be any **configuration** in [experiments.conf](experiments.conf) 56 | * Log file will be saved at `your_data_dir/[config]/log_XXX.txt` 57 | * Models will be saved at `your_data_dir/[config]/model_XXX.bin` 58 | * Tensorboard is available at `your_data_dir/tensorboard` 59 | 60 | 61 | ## Configurations 62 | Some important configurations in [experiments.conf](experiments.conf): 63 | * `data_dir`: the full path to the directory containing dataset, models, log files 64 | * `coref_depth` and `higher_order`: controlling the higher-order inference module 65 | * `bert_pretrained_name_or_path`: the name/path of the pretrained BERT model ([HuggingFace BERT models](https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained)) 66 | * `max_training_sentences`: the maximum segments to use when document is too long; for BERT-Large and SpanBERT-Large, set to `3` for 32GB GPU or `2` for 24GB GPU 67 | 68 | ## Citation 69 | ``` 70 | @inproceedings{xu-choi-2020-revealing, 71 | title = "Revealing the Myth of Higher-Order Inference in Coreference Resolution", 72 | author = "Xu, Liyan and Choi, Jinho D.", 73 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 74 | month = nov, 75 | year = "2020", 76 | publisher = "Association for Computational Linguistics", 77 | url = "https://www.aclweb.org/anthology/2020.emnlp-main.686", 78 | pages = "8527--8533" 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /analyze.py: -------------------------------------------------------------------------------- 1 | from run import Runner 2 | import util 3 | import json 4 | import pickle 5 | from os.path import join 6 | import os 7 | from collections import defaultdict 8 | 9 | singular_pronouns = ['i', 'me', 'my', 'mine', 'myself', 'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', 'it', 'its', 'itself', 'yourself'] 10 | plural_pronouns = ['they', 'them', 'their', 'theirs', 'themselves', 'we', 'us', 'our', 'ours', 'ourselves', 'yourselves'] 11 | ambiguous_pronouns = ['you', 'your', 'yours'] 12 | valid_pronouns = singular_pronouns + plural_pronouns + ambiguous_pronouns 13 | 14 | 15 | def get_prediction_path(config, config_name, saved_suffix, suffix=''): 16 | dir_analysis = join(config['data_dir'], 'analysis') 17 | os.makedirs(dir_analysis, exist_ok=True) 18 | 19 | name = f'pred_{config_name}_{saved_suffix}{suffix}.bin' 20 | path = join(dir_analysis, name) 21 | return path 22 | 23 | 24 | def get_prediction(config_name, saved_suffix, gpu_id): 25 | runner = Runner(config_name, gpu_id) 26 | conf = runner.config 27 | 28 | path = get_prediction_path(conf, config_name, saved_suffix) 29 | if os.path.exists(path): 30 | # Load if saved 31 | with open(path, 'rb') as f: 32 | prediction = pickle.load(f) 33 | print('Loaded prediction from %s' % path) 34 | else: 35 | # Get prediction 36 | model = runner.initialize_model(saved_suffix) 37 | examples_train, examples_dev, examples_test = runner.data.get_tensor_examples() 38 | stored_info = runner.data.get_stored_info() 39 | 40 | samples_test = [example[1] for example in examples_test] 41 | predicted_clusters, predicted_spans, predicted_antecedents = runner.predict(model, samples_test) 42 | prediction = (predicted_clusters, predicted_spans, predicted_antecedents) 43 | 44 | # Save 45 | with open(path, 'wb') as f: 46 | pickle.dump(prediction, f) 47 | print('Prediction saved in %s' % path) 48 | 49 | return prediction 50 | 51 | 52 | def get_prediction_wo_hoi(config_name, saved_suffix, gpu_id): 53 | runner = Runner(config_name, gpu_id) 54 | conf = runner.config 55 | 56 | suffix = '_noHOI' 57 | path = get_prediction_path(conf, config_name, saved_suffix, suffix) 58 | if os.path.exists(path): 59 | # Load if saved 60 | with open(path, 'rb') as f: 61 | prediction = pickle.load(f) 62 | print('Loaded prediction from %s' % path) 63 | else: 64 | # Get prediction 65 | model = runner.initialize_model(saved_suffix) 66 | examples_train, examples_dev, examples_test = runner.data.get_tensor_examples() 67 | stored_info = runner.data.get_stored_info() 68 | 69 | # Turn off HOI after model initialization 70 | if '_cm' in config_name: 71 | conf['coref_depth'] = 1 72 | conf['higher_order'] = 'attended_antecedent' 73 | elif '_d2' in config_name or '_sc' in config_name or '_ee' in config_name: 74 | conf['coref_depth'] = 1 75 | 76 | samples_test = [example[1] for example in examples_test] 77 | predicted_clusters, predicted_spans, predicted_antecedents = runner.predict(model, samples_test) 78 | prediction = (predicted_clusters, predicted_spans, predicted_antecedents) 79 | 80 | # Save 81 | with open(path, 'wb') as f: 82 | pickle.dump(prediction, f) 83 | print('Prediction saved in %s' % path) 84 | 85 | return prediction 86 | 87 | 88 | def get_original_samples(config, split='tst'): 89 | samples = [] 90 | paths = { 91 | 'trn': join(config['data_dir'], f'train.english.{config["max_segment_len"]}.jsonlines'), 92 | 'dev': join(config['data_dir'], f'dev.english.{config["max_segment_len"]}.jsonlines'), 93 | 'tst': join(config['data_dir'], f'test.english.{config["max_segment_len"]}.jsonlines') 94 | } 95 | with open(paths[split]) as fin: 96 | for line in fin.readlines(): 97 | data = json.loads(line) 98 | samples.append(data) 99 | return samples 100 | 101 | 102 | def get_gold_to_cluster_id(example_list): 103 | gold_to_cluster_id = [] # 0 means not in cluster 104 | non_anaphoric = [] # Firstly appeared mention in a cluster 105 | for i, example in enumerate(example_list): 106 | gold_to_cluster_id.append(defaultdict(int)) 107 | non_anaphoric.append(set()) 108 | 109 | clusters = example['clusters'] 110 | clusters = [sorted(cluster) for cluster in clusters] # Sort mention 111 | for c_i, c in enumerate(clusters): 112 | non_anaphoric[i].add(tuple(c[0])) 113 | for m in c: 114 | gold_to_cluster_id[i][tuple(m)] = c_i + 1 115 | return gold_to_cluster_id, non_anaphoric 116 | 117 | 118 | def check_singular_plural_cluster(cluster): 119 | """ Cluster with text """ 120 | singular, plural, contain_ambiguous = False, False, False 121 | for m in cluster: 122 | if singular and plural: 123 | break 124 | m = m.lower() 125 | if not singular: 126 | singular = (m in singular_pronouns) 127 | if not plural: 128 | plural = (m in plural_pronouns) 129 | for m in cluster: 130 | m = m.lower() 131 | if m in ambiguous_pronouns: 132 | contain_ambiguous = True 133 | break 134 | return singular, plural, contain_ambiguous 135 | 136 | 137 | def analyze(config_name, saved_suffix, gpu_id): 138 | runner = Runner(config_name, gpu_id) 139 | conf = runner.config 140 | 141 | # Get gold clusters 142 | example_list = get_original_samples(conf) 143 | gold_to_cluster_id, non_anaphoric = get_gold_to_cluster_id(example_list) 144 | 145 | # Get prediction 146 | predicted_clusters, predicted_spans, predicted_antecedents = get_prediction(config_name, saved_suffix, gpu_id) 147 | 148 | # Get cluster text 149 | cluster_list = [] 150 | subtoken_list = [] 151 | for i, example in enumerate(example_list): 152 | subtokens = util.flatten(example['sentences']) 153 | subtoken_list.append(subtokens) 154 | cluster_list.append([[' '.join(subtokens[m[0]: m[1] + 1]) for m in c] for c in predicted_clusters[i]]) 155 | 156 | # Get cluster stats 157 | num_clusters, num_singular_clusters, num_plural_clusters, num_mixed_clusters, num_mixed_ambiguous = 0, 0, 0, 0, 0 158 | for clusters in cluster_list: 159 | # print(clusters) 160 | for c in clusters: 161 | singular, plural, contain_ambiguous = check_singular_plural_cluster(c) 162 | num_clusters += 1 163 | if singular and plural: 164 | num_mixed_clusters += 1 165 | if contain_ambiguous: 166 | num_mixed_ambiguous += 1 167 | if singular: 168 | num_singular_clusters += 1 169 | if plural: 170 | num_plural_clusters += 1 171 | 172 | # Get antecedent stats 173 | fl, fn, wl, correct = 0, 0, 0, 0 # False Link, False New, Wrong Link 174 | s_to_p, p_to_s = 0, 0 175 | num_non_gold, num_total_spans = 0, 0 176 | for i, antecedents in enumerate(predicted_antecedents): 177 | antecedents = [(-1, -1) if a == -1 else predicted_spans[i][a] for a in antecedents] 178 | for j, antecedent in enumerate(antecedents): 179 | span = predicted_spans[i][j] 180 | span_cluster_id = gold_to_cluster_id[i][span] 181 | num_total_spans += 1 182 | 183 | if antecedent == (-1, -1): 184 | continue 185 | 186 | # Only look at stats of pronouns 187 | span_text = ' '.join(subtoken_list[i][span[0]: span[1] + 1]).lower() 188 | antecedent_text = ' '.join(subtoken_list[i][antecedent[0]: antecedent[1] + 1]).lower() 189 | if span_text not in valid_pronouns or antecedent_text not in valid_pronouns: 190 | continue 191 | 192 | if span_text in singular_pronouns and antecedent_text in plural_pronouns: 193 | s_to_p += 1 194 | elif span_text in plural_pronouns and antecedent_text in singular_pronouns: 195 | p_to_s += 1 196 | 197 | if span_cluster_id == 0: # Non-gold span 198 | num_non_gold += 1 199 | if antecedent == (-1, -1): 200 | correct += 1 201 | else: 202 | fl += 1 203 | elif span in non_anaphoric[i]: # Non-anaphoric span 204 | if antecedent == (-1, -1): 205 | correct += 1 206 | else: 207 | fl += 1 208 | else: 209 | if antecedent == (-1, -1): 210 | fn += 1 211 | elif span_cluster_id != gold_to_cluster_id[i][antecedent]: 212 | wl += 1 213 | else: 214 | correct += 1 215 | 216 | return num_clusters, num_singular_clusters, num_plural_clusters, num_mixed_clusters, num_mixed_ambiguous, fl, fn, wl, correct, \ 217 | num_non_gold, num_total_spans, s_to_p, p_to_s 218 | 219 | 220 | def analyze2(config_name, saved_suffix, gpu_id): 221 | runner = Runner(config_name, gpu_id) 222 | conf = runner.config 223 | 224 | # Get gold clusters 225 | example_list = get_original_samples(conf) 226 | gold_to_cluster_id, non_anaphoric = get_gold_to_cluster_id(example_list) 227 | 228 | # Get info 229 | named_entities, pronouns = [], [] 230 | for example in example_list: 231 | named_entities.append(util.flatten(example['named_entities'])) 232 | pronouns.append(util.flatten(example['pronouns'])) 233 | 234 | # Get normal prediction 235 | predicted_clusters, predicted_spans, predicted_antecedents = get_prediction(config_name, saved_suffix, gpu_id) 236 | # Get prediction turning off HOI 237 | predicted_clusters_nohoi, predicted_spans_nohoi, predicted_antecedents_nohoi = get_prediction_wo_hoi(config_name, saved_suffix, gpu_id) 238 | # predicted_spans and predicted_spans_nohoi should be almost identical 239 | 240 | # Check wrong->correct and correct->wrong links after turning off HOI 241 | f2t, t2f, t2t, f2f = [[],[],[]], [[],[],[]], [[],[],[]], [[],[],[]] 242 | f2t_pct, t2f_pct, t2t_pct, f2f_pct = [], [], [], [] 243 | link_status_wo_hoi = get_link_status(predicted_spans_nohoi, predicted_antecedents_nohoi, gold_to_cluster_id, non_anaphoric) 244 | link_status_w_hoi = get_link_status(predicted_spans, predicted_antecedents, gold_to_cluster_id, non_anaphoric) 245 | for doc_i in range(len(link_status_wo_hoi)): 246 | f2t_doc, t2f_doc, t2t_doc, f2f_doc = [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0] 247 | status_dict_wo_hoi = link_status_wo_hoi[doc_i] 248 | status_dict_w_hoi = link_status_w_hoi[doc_i] 249 | for span, link_wo_hoi in status_dict_wo_hoi.items(): 250 | link_w_hoi = status_dict_w_hoi.get(span, None) 251 | if link_w_hoi is None: 252 | continue # Only look at gold mentions in both prediction 253 | 254 | span_type = identify_span_type(named_entities[doc_i], pronouns[doc_i], span) 255 | 256 | if link_wo_hoi: 257 | if link_w_hoi: 258 | t2t_doc[span_type] += 1 259 | else: 260 | t2f_doc[span_type] += 1 261 | else: 262 | if link_w_hoi: 263 | f2t_doc[span_type] += 1 264 | else: 265 | f2f_doc[span_type] += 1 266 | total_link = sum(f2t_doc) + sum(t2f_doc) + sum(t2t_doc) + sum(f2f_doc) 267 | if total_link == 0: 268 | print('Zero gold mention; should not happen often') 269 | continue 270 | for span_type in range(3): 271 | f2t[span_type].append(f2t_doc[span_type]) 272 | for span_type in range(3): 273 | t2f[span_type].append(t2f_doc[span_type]) 274 | for span_type in range(3): 275 | t2t[span_type].append(t2t_doc[span_type]) 276 | for span_type in range(3): 277 | f2f[span_type].append(f2f_doc[span_type]) 278 | f2t_pct.append(sum(f2t_doc) * 100 / total_link) 279 | t2f_pct.append(sum(t2f_doc) * 100 / total_link) 280 | t2t_pct.append(sum(t2t_doc) * 100 / total_link) 281 | f2f_pct.append(sum(f2f_doc) * 100 / total_link) 282 | 283 | f2t_total, t2f_total, t2t_total, f2f_total = 0, 0, 0, 0 284 | f2t_type_pct, t2f_type_pct, t2t_type_pct, f2f_type_pct = [[], [], []], [[], [], []], [[], [], []], [[], [], []] 285 | for doc_i in range(len(f2t[0])): 286 | f2t_doc_sum = f2t[0][doc_i] + f2t[1][doc_i] + f2t[2][doc_i] 287 | t2f_doc_sum = t2f[0][doc_i] + t2f[1][doc_i] + t2f[2][doc_i] 288 | t2t_doc_sum = t2t[0][doc_i] + t2t[1][doc_i] + t2t[2][doc_i] 289 | f2f_doc_sum = f2f[0][doc_i] + f2f[1][doc_i] + f2f[2][doc_i] 290 | if f2t_doc_sum > 0: 291 | for span_type in range(3): 292 | f2t_type_pct[span_type].append(f2t[span_type][doc_i] * 100 / f2t_doc_sum) 293 | if t2f_doc_sum > 0: 294 | for span_type in range(3): 295 | t2f_type_pct[span_type].append(t2f[span_type][doc_i] * 100 / t2f_doc_sum) 296 | if t2t_doc_sum > 0: 297 | for span_type in range(3): 298 | t2t_type_pct[span_type].append(t2t[span_type][doc_i] * 100 / t2t_doc_sum) 299 | if f2f_doc_sum > 0: 300 | for span_type in range(3): 301 | f2f_type_pct[span_type].append(f2f[span_type][doc_i] * 100 / f2f_doc_sum) 302 | f2t_total += f2t_doc_sum 303 | t2f_total += t2f_doc_sum 304 | t2t_total += t2t_doc_sum 305 | f2f_total += f2f_doc_sum 306 | 307 | return f2t_total, t2f_total, t2t_total, f2f_total,\ 308 | sum(f2t_pct) / len(f2t_pct), sum(t2f_pct) / len(t2f_pct), sum(t2t_pct) / len(t2t_pct), sum(f2f_pct) / len(f2f_pct), \ 309 | mean(f2t_type_pct[0]), mean(f2t_type_pct[1]), mean(f2t_type_pct[2]), \ 310 | mean(t2f_type_pct[0]), mean(t2f_type_pct[1]), mean(t2f_type_pct[2]), \ 311 | mean(t2t_type_pct[0]), mean(t2t_type_pct[1]), mean(t2t_type_pct[2]), \ 312 | mean(f2f_type_pct[0]), mean(f2f_type_pct[1]), mean(f2f_type_pct[2]) 313 | 314 | 315 | def mean(l): 316 | return sum(l) / len(l) 317 | 318 | 319 | def identify_span_type(named_entities_doc, pronouns_doc, span): 320 | """ 1: pronoun; 2: named entity; 0: other(nominal nouns) """ 321 | # Check pronoun 322 | if pronouns_doc[span[0]: span[1] + 1] == ([True] * (span[1] - span[0] + 1)): 323 | return 1 324 | # Check named entity 325 | entity_text = ''.join(named_entities_doc[span[0]: span[1] + 1]) 326 | if entity_text.count('(') == 1 and entity_text.count(')') == 1: 327 | return 2 328 | return 0 329 | 330 | 331 | def get_link_status(predicted_spans, predicted_antecedents, gold_to_cluster_id, non_anaphoric): 332 | """ 333 | :param predicted_spans: from get_prediction() 334 | :param predicted_antecedents: 335 | :param gold_to_cluster_id, non_anaphoric: from get_gold_to_cluster_id() 336 | :return: dict of gold spans indicating wrong(False) or correct(True) link 337 | """ 338 | link_status = [] 339 | for doc_i in range(len(predicted_spans)): 340 | status_dict = {} # Only for gold mentions 341 | spans = predicted_spans[doc_i] 342 | for span_i, antecedent_i in enumerate(predicted_antecedents[doc_i]): 343 | span_cluster_id = gold_to_cluster_id[doc_i][spans[span_i]] 344 | if span_cluster_id == 0: 345 | continue 346 | if antecedent_i == -1: 347 | status_dict[spans[span_i]] = (spans[span_i] in non_anaphoric[doc_i]) 348 | else: 349 | antecedent_cluster_id = gold_to_cluster_id[doc_i][spans[antecedent_i]] 350 | status_dict[spans[span_i]] = (span_cluster_id == antecedent_cluster_id) 351 | link_status.append(status_dict) 352 | return link_status 353 | 354 | 355 | if __name__ == '__main__': 356 | gpu_id = 6 357 | 358 | experiments = [('train_bert_large_ml0_d1', 'May20_10-25-13_65000'), 359 | ('train_bert_large_ml0_d1', 'May21_00-29-00_66000'), 360 | ('train_bert_large_ml0_d1', 'May21_17-04-35_50000'), 361 | ('train_bert_large_ml0_d1', 'May24_03-33-55_58000')] 362 | 363 | results_final = None 364 | for experiment in experiments: 365 | # results = analyze(*experiment, gpu_id=gpu_id) 366 | results = analyze2(*experiment, gpu_id=gpu_id) 367 | if results is None: 368 | continue 369 | 370 | if results_final is None: 371 | results_final = results 372 | else: 373 | results_final = [r + results[i] for i, r in enumerate(results_final)] 374 | 375 | # print('%s_%s: # clusters: %d; # singular clusters: %d; # plural clusters: %d; # mixed clusters: %d; ' 376 | # 'FL %d; FN: %d; WL: %d; CORRECT %d; # gold spans: %d; # total spans: %d' % (*experiment, *results)) 377 | 378 | results_final = [r / len(experiments) for r in results_final] 379 | 380 | # Analyze 381 | # print('Avg: # clusters: %.3f; # singular clusters: %.3f; # plural clusters: %.3f; # mixed clusters: %.3f; # mixed with ambiguous: %.3f; ' 382 | # 'FL %.3f; FN: %.3f; WL: %.3f; CORRECT %.3f; # gold spans: %.3f; # total spans: %.3f; # S to P: %.3f; # P to S: %.3f' % (*results_final,)) 383 | 384 | # Analyze2 385 | print('f2t, t2f, t2t, f2f: %.2f, %.2f, %.2f, %.2f;\t%.2f%%, %.2f%%, %.2f%%, %.2f%%;\n%.2f%%, %.2f%%, %.2f%%\n%.2f%%, %.2f%%, %.2f%%\n%.2f%%, %.2f%%, %.2f%%\n%.2f%%, %.2f%%, %.2f%%' % (*results_final,)) 386 | -------------------------------------------------------------------------------- /conll.py: -------------------------------------------------------------------------------- 1 | import re 2 | import tempfile 3 | import subprocess 4 | import operator 5 | import collections 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)") # First line at each document 11 | COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL) 12 | 13 | 14 | def get_doc_key(doc_id, part): 15 | return "{}_{}".format(doc_id, int(part)) 16 | 17 | 18 | def output_conll(input_file, output_file, predictions, subtoken_map): 19 | prediction_map = {} 20 | for doc_key, clusters in predictions.items(): 21 | start_map = collections.defaultdict(list) 22 | end_map = collections.defaultdict(list) 23 | word_map = collections.defaultdict(list) 24 | for cluster_id, mentions in enumerate(clusters): 25 | for start, end in mentions: 26 | start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end] 27 | if start == end: 28 | word_map[start].append(cluster_id) 29 | else: 30 | start_map[start].append((cluster_id, end)) 31 | end_map[end].append((cluster_id, start)) 32 | for k,v in start_map.items(): 33 | start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)] 34 | for k,v in end_map.items(): 35 | end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)] 36 | prediction_map[doc_key] = (start_map, end_map, word_map) 37 | 38 | word_index = 0 39 | for line in input_file.readlines(): 40 | row = line.split() 41 | if len(row) == 0: 42 | output_file.write("\n") 43 | elif row[0].startswith("#"): 44 | begin_match = re.match(BEGIN_DOCUMENT_REGEX, line) 45 | if begin_match: 46 | doc_key = get_doc_key(begin_match.group(1), begin_match.group(2)) 47 | start_map, end_map, word_map = prediction_map[doc_key] 48 | word_index = 0 49 | output_file.write(line) 50 | output_file.write("\n") 51 | else: 52 | assert get_doc_key(row[0], row[1]) == doc_key 53 | coref_list = [] 54 | if word_index in end_map: 55 | for cluster_id in end_map[word_index]: 56 | coref_list.append("{})".format(cluster_id)) 57 | if word_index in word_map: 58 | for cluster_id in word_map[word_index]: 59 | coref_list.append("({})".format(cluster_id)) 60 | if word_index in start_map: 61 | for cluster_id in start_map[word_index]: 62 | coref_list.append("({}".format(cluster_id)) 63 | 64 | if len(coref_list) == 0: 65 | row[-1] = "-" 66 | else: 67 | row[-1] = "|".join(coref_list) 68 | 69 | output_file.write(" ".join(row)) 70 | output_file.write("\n") 71 | word_index += 1 72 | 73 | 74 | def official_conll_eval(gold_path, predicted_path, metric, official_stdout=True): 75 | cmd = ["conll-2012/scorer/v8.01/scorer.pl", metric, gold_path, predicted_path, "none"] 76 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 77 | stdout, stderr = process.communicate() 78 | process.wait() 79 | 80 | stdout = stdout.decode("utf-8") 81 | if stderr is not None: 82 | logger.error(stderr) 83 | 84 | if official_stdout: 85 | logger.info("Official result for {}".format(metric)) 86 | logger.info(stdout) 87 | 88 | coref_results_match = re.match(COREF_RESULTS_REGEX, stdout) 89 | recall = float(coref_results_match.group(1)) 90 | precision = float(coref_results_match.group(2)) 91 | f1 = float(coref_results_match.group(3)) 92 | return {"r": recall, "p": precision, "f": f1} 93 | 94 | 95 | def evaluate_conll(gold_path, predictions, subtoken_maps, official_stdout=True): 96 | with tempfile.NamedTemporaryFile(delete=True, mode="w") as prediction_file: 97 | with open(gold_path, "r") as gold_file: 98 | output_conll(gold_file, prediction_file, predictions, subtoken_maps) 99 | # logger.info("Predicted conll file: {}".format(prediction_file.name)) 100 | results = {m: official_conll_eval(gold_file.name, prediction_file.name, m, official_stdout) for m in ("muc", "bcub", "ceafe") } 101 | return results 102 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from run import Runner 2 | import sys 3 | 4 | 5 | def evaluate(config_name, gpu_id, saved_suffix): 6 | runner = Runner(config_name, gpu_id) 7 | model = runner.initialize_model(saved_suffix) 8 | 9 | examples_train, examples_dev, examples_test = runner.data.get_tensor_examples() 10 | stored_info = runner.data.get_stored_info() 11 | 12 | # runner.evaluate(model, examples_dev, stored_info, 0, official=True, conll_path=runner.config['conll_eval_path']) # Eval dev 13 | # print('=================================') 14 | runner.evaluate(model, examples_test, stored_info, 0, official=True, conll_path=runner.config['conll_test_path']) # Eval test 15 | 16 | 17 | if __name__ == '__main__': 18 | config_name, saved_suffix, gpu_id = sys.argv[1], sys.argv[2], int(sys.argv[3]) 19 | evaluate(config_name, gpu_id, saved_suffix) 20 | -------------------------------------------------------------------------------- /experiments.conf: -------------------------------------------------------------------------------- 1 | best { 2 | data_dir = /path/to/data/dir # Edit this 3 | 4 | # Computation limits. 5 | max_top_antecedents = 50 6 | max_training_sentences = 5 7 | top_span_ratio = 0.4 8 | max_num_extracted_spans = 3900 9 | max_num_speakers = 20 10 | max_segment_len = 256 11 | 12 | # Learning 13 | bert_learning_rate = 1e-5 14 | task_learning_rate = 2e-4 15 | loss_type = marginalized # {marginalized, hinge} 16 | mention_loss_coef = 0 17 | false_new_delta = 1.5 # For loss_type = hinge 18 | adam_eps = 1e-6 19 | adam_weight_decay = 1e-2 20 | warmup_ratio = 0.1 21 | max_grad_norm = 1 # Set 0 to disable clipping 22 | gradient_accumulation_steps = 1 23 | 24 | # Model hyperparameters. 25 | coref_depth = 1 # when 1: no higher order (except for cluster_merging) 26 | higher_order = attended_antecedent # {attended_antecedent, max_antecedent, entity_equalization, span_clustering, cluster_merging} 27 | coarse_to_fine = true 28 | fine_grained = true 29 | dropout_rate = 0.3 30 | ffnn_size = 1000 31 | ffnn_depth = 1 32 | cluster_ffnn_size = 1000 # For cluster_merging 33 | cluster_reduce = mean # For cluster_merging 34 | easy_cluster_first = false # For cluster_merging 35 | cluster_dloss = false # cluster_merging 36 | num_epochs = 24 37 | feature_emb_size = 20 38 | max_span_width = 30 39 | use_metadata = true 40 | use_features = true 41 | use_segment_distance = true 42 | model_heads = true 43 | use_width_prior = true # For mention score 44 | use_distance_prior = true # For mention-ranking score 45 | 46 | # Other. 47 | conll_eval_path = ${best.data_dir}/dev.english.v4_gold_conll # gold_conll file for dev 48 | conll_test_path = ${best.data_dir}/test.english.v4_gold_conll # gold_conll file for test 49 | genres = ["bc", "bn", "mz", "nw", "pt", "tc", "wb"] 50 | eval_frequency = 1000 51 | report_frequency = 100 52 | log_root = ${best.data_dir} 53 | } 54 | 55 | bert_base = ${best}{ 56 | num_docs = 2802 57 | bert_learning_rate = 1e-05 58 | task_learning_rate = 2e-4 59 | max_segment_len = 128 60 | ffnn_size = 3000 61 | cluster_ffnn_size = 3000 62 | max_training_sentences = 11 63 | bert_tokenizer_name = bert-base-cased 64 | bert_pretrained_name_or_path = bert-base-cased 65 | } 66 | 67 | train_bert_base = ${bert_base}{ 68 | } 69 | 70 | train_bert_base_ml0_d1 = ${train_bert_base}{ 71 | mention_loss_coef = 0 72 | coref_depth = 1 73 | } 74 | 75 | train_bert_base_ml0_d2 = ${train_bert_base}{ 76 | mention_loss_coef = 0 77 | coref_depth = 2 78 | } 79 | 80 | bert_large = ${best}{ 81 | num_docs = 2802 82 | bert_learning_rate = 1e-05 83 | task_learning_rate = 2e-4 84 | max_segment_len = 384 85 | ffnn_size = 3000 86 | cluster_ffnn_size = 3000 87 | max_training_sentences = 3 88 | bert_tokenizer_name = bert-base-cased 89 | bert_pretrained_name_or_path = bert-large-cased 90 | } 91 | 92 | train_bert_large = ${bert_large}{ 93 | } 94 | 95 | train_bert_large_ml0_d1 = ${train_bert_large}{ 96 | mention_loss_coef = 0 97 | coref_depth = 1 98 | } 99 | 100 | train_bert_large_ml0_d2 = ${train_bert_large}{ 101 | mention_loss_coef = 0 102 | coref_depth = 2 103 | } 104 | 105 | spanbert_base = ${best}{ 106 | num_docs = 2802 107 | bert_learning_rate = 2e-05 108 | task_learning_rate = 0.0001 109 | max_segment_len = 384 110 | ffnn_size = 3000 111 | cluster_ffnn_size = 3000 112 | max_training_sentences = 3 113 | bert_tokenizer_name = bert-base-cased 114 | bert_pretrained_name_or_path = ${best.data_dir}/spanbert_base 115 | } 116 | 117 | train_spanbert_base = ${spanbert_base}{ 118 | } 119 | 120 | debug_spanbert_base = ${train_spanbert_base}{ 121 | } 122 | 123 | train_spanbert_base_ml0_d1 = ${train_spanbert_base}{ 124 | mention_loss_coef = 0 125 | coref_depth = 1 126 | } 127 | 128 | train_spanbert_base_ml0_lr2e-4_d1 = ${train_spanbert_base}{ 129 | mention_loss_coef = 0 130 | task_learning_rate = 2e-4 131 | coref_depth = 1 132 | } 133 | 134 | spanbert_large = ${best}{ 135 | num_docs = 2802 136 | bert_learning_rate = 1e-05 137 | task_learning_rate = 0.0003 138 | max_segment_len = 512 139 | ffnn_size = 3000 140 | cluster_ffnn_size = 3000 141 | max_training_sentences = 3 142 | bert_tokenizer_name = bert-base-cased 143 | bert_pretrained_name_or_path = ${best.data_dir}/spanbert_large 144 | } 145 | 146 | train_spanbert_large = ${spanbert_large}{ 147 | } 148 | 149 | train_spanbert_large_ml0_d1 = ${train_spanbert_large}{ 150 | mention_loss_coef = 0 151 | coref_depth = 1 152 | } 153 | 154 | train_spanbert_large_ml0_lr_d1 = ${train_spanbert_large}{ 155 | mention_loss_coef = 0 156 | bert_learning_rate = 2e-05 157 | task_learning_rate = 2e-4 158 | coref_depth = 1 159 | } 160 | 161 | train_spanbert_large_ml0_d2 = ${train_spanbert_large}{ 162 | mention_loss_coef = 0 163 | coref_depth = 2 164 | } 165 | 166 | train_spanbert_large_ml0_lr_d2 = ${train_spanbert_large}{ 167 | mention_loss_coef = 0 168 | bert_learning_rate = 2e-05 169 | task_learning_rate = 2e-4 170 | coref_depth = 2 171 | } 172 | 173 | train_spanbert_large_ml0_sc = ${train_spanbert_large}{ 174 | mention_loss_coef = 0 175 | coref_depth = 2 176 | higher_order = span_clustering 177 | } 178 | 179 | train_spanbert_large_ml0_cm_fn1000 = ${train_spanbert_large}{ 180 | mention_loss_coef = 0 181 | coref_depth = 1 182 | higher_order = cluster_merging 183 | cluster_ffnn_size = 1000 184 | } 185 | 186 | train_spanbert_large_ml0_cm_fn1000_dloss = ${train_spanbert_large}{ 187 | mention_loss_coef = 0 188 | coref_depth = 1 189 | higher_order = cluster_merging 190 | cluster_ffnn_size = 1000 191 | cluster_dloss = true 192 | } 193 | 194 | train_spanbert_large_ml0_cm_fn1000_e1st = ${train_spanbert_large}{ 195 | mention_loss_coef = 0 196 | coref_depth = 1 197 | higher_order = cluster_merging 198 | cluster_ffnn_size = 1000 199 | easy_cluster_first = true 200 | } 201 | 202 | train_spanbert_large_ml0_cm_fn1000_e1st_dloss = ${train_spanbert_large}{ 203 | mention_loss_coef = 0 204 | coref_depth = 1 205 | higher_order = cluster_merging 206 | cluster_ffnn_size = 1000 207 | easy_cluster_first = true 208 | cluster_dloss = true 209 | } 210 | 211 | train_spanbert_large_ml0_cm_fn1000_max = ${train_spanbert_large}{ 212 | mention_loss_coef = 0 213 | coref_depth = 1 214 | higher_order = cluster_merging 215 | cluster_ffnn_size = 1000 216 | cluster_reduce = max 217 | } 218 | 219 | train_spanbert_large_ml0_cm_fn1000_max_dloss = ${train_spanbert_large}{ 220 | mention_loss_coef = 0 221 | coref_depth = 1 222 | higher_order = cluster_merging 223 | cluster_ffnn_size = 1000 224 | cluster_reduce = max 225 | cluster_dloss = true 226 | } 227 | 228 | train_spanbert_large_ml0_cm_fn1000_max_e1st = ${train_spanbert_large}{ 229 | mention_loss_coef = 0 230 | coref_depth = 1 231 | higher_order = cluster_merging 232 | cluster_ffnn_size = 1000 233 | cluster_reduce = max 234 | easy_cluster_first = true 235 | } 236 | 237 | train_spanbert_large_ml0_cm_fn1000_max_e1st_dloss = ${train_spanbert_large}{ 238 | mention_loss_coef = 0 239 | coref_depth = 1 240 | higher_order = cluster_merging 241 | cluster_ffnn_size = 1000 242 | cluster_reduce = max 243 | easy_cluster_first = true 244 | cluster_dloss = true 245 | } 246 | 247 | train_spanbert_large_ml1_d1 = ${train_spanbert_large}{ 248 | mention_loss_coef = 1 249 | coref_depth = 1 250 | } 251 | 252 | -------------------------------------------------------------------------------- /higher_order.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import util 4 | 5 | 6 | def attended_antecedent(top_span_emb, top_antecedent_emb, top_antecedent_scores, device): 7 | num_top_spans = top_span_emb.shape[0] 8 | top_antecedent_weights = torch.cat([torch.zeros(num_top_spans, 1, device=device), top_antecedent_scores], dim=1) 9 | top_antecedent_weights = nn.functional.softmax(top_antecedent_weights, dim=1) 10 | top_antecedent_emb = torch.cat([torch.unsqueeze(top_span_emb, 1), top_antecedent_emb], dim=1) 11 | refined_span_emb = torch.sum(torch.unsqueeze(top_antecedent_weights, 2) * top_antecedent_emb, dim=1) # [num top spans, span emb size] 12 | return refined_span_emb 13 | 14 | 15 | def max_antecedent(top_span_emb, top_antecedent_emb, top_antecedent_scores, device): 16 | num_top_spans = top_span_emb.shape[0] 17 | top_antecedent_weights = torch.cat([torch.zeros(num_top_spans, 1, device=device), top_antecedent_scores], dim=1) 18 | top_antecedent_emb = torch.cat([torch.unsqueeze(top_span_emb, 1), top_antecedent_emb], dim=1) 19 | max_antecedent_idx = torch.argmax(top_antecedent_weights, dim=1, keepdim=True) 20 | refined_span_emb = util.batch_select(top_antecedent_emb, max_antecedent_idx, device=device).squeeze(1) # [num top spans, span emb size] 21 | return refined_span_emb 22 | 23 | 24 | def entity_equalization(top_span_emb, top_antecedent_emb, top_antecedent_idx, top_antecedent_scores, device): 25 | # Use TF implementation in another repo 26 | pass 27 | 28 | def span_clustering(top_span_emb, top_antecedent_idx, top_antecedent_scores, span_attn_ffnn, device): 29 | # Get predicted antecedents 30 | num_top_spans, max_top_antecedents = top_antecedent_idx.shape[0], top_antecedent_idx.shape[1] 31 | predicted_antecedents = [] 32 | top_antecedent_scores = torch.cat([torch.zeros(num_top_spans, 1, device=device), top_antecedent_scores], dim=1) 33 | for i, idx in enumerate((torch.argmax(top_antecedent_scores, axis=1) - 1).tolist()): 34 | if idx < 0: 35 | predicted_antecedents.append(-1) 36 | else: 37 | predicted_antecedents.append(top_antecedent_idx[i, idx].item()) 38 | # Get predicted clusters 39 | predicted_clusters = [] 40 | span_to_cluster_id = [-1] * num_top_spans 41 | for i, predicted_idx in enumerate(predicted_antecedents): 42 | if predicted_idx < 0: 43 | continue 44 | assert i > predicted_idx, f'span idx: {i}; antecedent idx: {predicted_idx}' 45 | # Check antecedent's cluster 46 | antecedent_cluster_id = span_to_cluster_id[predicted_idx] 47 | if antecedent_cluster_id == -1: 48 | antecedent_cluster_id = len(predicted_clusters) 49 | predicted_clusters.append([predicted_idx]) 50 | span_to_cluster_id[predicted_idx] = antecedent_cluster_id 51 | # Add mention to cluster 52 | predicted_clusters[antecedent_cluster_id].append(i) 53 | span_to_cluster_id[i] = antecedent_cluster_id 54 | if len(predicted_clusters) == 0: 55 | return top_span_emb 56 | 57 | # Pad clusters 58 | max_cluster_size = max([len(c) for c in predicted_clusters]) 59 | cluster_sizes = [] 60 | for cluster in predicted_clusters: 61 | cluster_sizes.append(len(cluster)) 62 | cluster += [0] * (max_cluster_size - len(cluster)) 63 | predicted_clusters_mask = torch.arange(0, max_cluster_size, device=device).repeat(len(predicted_clusters), 1) 64 | predicted_clusters_mask = predicted_clusters_mask < torch.tensor(cluster_sizes, device=device).unsqueeze(1) # [num clusters, max cluster size] 65 | # Get cluster repr 66 | predicted_clusters = torch.tensor(predicted_clusters, device=device) 67 | cluster_emb = top_span_emb[predicted_clusters] # [num clusters, max cluster size, emb size] 68 | span_attn = torch.squeeze(span_attn_ffnn(cluster_emb), 2) 69 | span_attn += torch.log(predicted_clusters_mask.to(torch.float)) 70 | span_attn = nn.functional.softmax(span_attn, dim=1) 71 | cluster_emb = torch.sum(cluster_emb * torch.unsqueeze(span_attn, 2), dim=1) # [num clusters, emb size] 72 | # Get refined span 73 | refined_span_emb = [] 74 | for i, cluster_idx in enumerate(span_to_cluster_id): 75 | if cluster_idx < 0: 76 | refined_span_emb.append(top_span_emb[i]) 77 | else: 78 | refined_span_emb.append(cluster_emb[cluster_idx]) 79 | refined_span_emb = torch.stack(refined_span_emb, dim=0) 80 | return refined_span_emb 81 | 82 | 83 | def cluster_merging(top_span_emb, top_antecedent_idx, top_antecedent_scores, emb_cluster_size, cluster_score_ffnn, cluster_transform, dropout, device, reduce='mean', easy_cluster_first=False): 84 | num_top_spans, max_top_antecedents = top_antecedent_idx.shape[0], top_antecedent_idx.shape[1] 85 | span_emb_size = top_span_emb.shape[-1] 86 | max_num_clusters = num_top_spans 87 | 88 | span_to_cluster_id = torch.zeros(num_top_spans, dtype=torch.long, device=device) # id 0 as dummy cluster 89 | cluster_emb = torch.zeros(max_num_clusters, span_emb_size, dtype=torch.float, device=device) # [max num clusters, emb size] 90 | num_clusters = 1 # dummy cluster 91 | cluster_sizes = torch.ones(max_num_clusters, dtype=torch.long, device=device) 92 | 93 | merge_order = torch.arange(0, num_top_spans) 94 | if easy_cluster_first: 95 | max_antecedent_scores, _ = torch.max(top_antecedent_scores, dim=1) 96 | merge_order = torch.argsort(max_antecedent_scores, descending=True) 97 | cluster_merging_scores = [None] * num_top_spans 98 | 99 | for i in merge_order.tolist(): 100 | # Get cluster scores 101 | antecedent_cluster_idx = span_to_cluster_id[top_antecedent_idx[i]] 102 | antecedent_cluster_emb = cluster_emb[antecedent_cluster_idx] 103 | # antecedent_cluster_emb = dropout(cluster_transform(antecedent_cluster_emb)) 104 | 105 | antecedent_cluster_size = cluster_sizes[antecedent_cluster_idx] 106 | antecedent_cluster_size = util.bucket_distance(antecedent_cluster_size) 107 | cluster_size_emb = dropout(emb_cluster_size(antecedent_cluster_size)) 108 | 109 | span_emb = top_span_emb[i].unsqueeze(0).repeat(max_top_antecedents, 1) 110 | similarity_emb = span_emb * antecedent_cluster_emb 111 | pair_emb = torch.cat([span_emb, antecedent_cluster_emb, similarity_emb, cluster_size_emb], dim=1) # [max top antecedents, pair emb size] 112 | cluster_scores = torch.squeeze(cluster_score_ffnn(pair_emb), 1) 113 | cluster_scores_mask = (antecedent_cluster_idx > 0).to(torch.float) 114 | cluster_scores *= cluster_scores_mask 115 | cluster_merging_scores[i] = cluster_scores 116 | 117 | # Get predicted antecedent 118 | antecedent_scores = top_antecedent_scores[i] + cluster_scores 119 | max_score, max_score_idx = torch.max(antecedent_scores, dim=0) 120 | if max_score < 0: 121 | continue # Dummy antecedent 122 | max_antecedent_idx = top_antecedent_idx[i, max_score_idx] 123 | 124 | if not easy_cluster_first: # Always add span to antecedent's cluster 125 | # Create antecedent cluster if needed 126 | antecedent_cluster_id = span_to_cluster_id[max_antecedent_idx] 127 | if antecedent_cluster_id == 0: 128 | antecedent_cluster_id = num_clusters 129 | span_to_cluster_id[max_antecedent_idx] = antecedent_cluster_id 130 | cluster_emb[antecedent_cluster_id] = top_span_emb[max_antecedent_idx] 131 | num_clusters += 1 132 | # Add span to cluster 133 | span_to_cluster_id[i] = antecedent_cluster_id 134 | _merge_span_to_cluster(cluster_emb, cluster_sizes, antecedent_cluster_id, top_span_emb[i], reduce=reduce) 135 | else: # current span can be in cluster already 136 | antecedent_cluster_id = span_to_cluster_id[max_antecedent_idx] 137 | curr_span_cluster_id = span_to_cluster_id[i] 138 | if antecedent_cluster_id > 0 and curr_span_cluster_id > 0: 139 | # Merge two clusters 140 | span_to_cluster_id[max_antecedent_idx] = curr_span_cluster_id 141 | _merge_clusters(cluster_emb, cluster_sizes, antecedent_cluster_id, curr_span_cluster_id, reduce=reduce) 142 | elif curr_span_cluster_id > 0: 143 | # Merge antecedent to span's cluster 144 | span_to_cluster_id[max_antecedent_idx] = curr_span_cluster_id 145 | _merge_span_to_cluster(cluster_emb, cluster_sizes, curr_span_cluster_id, top_span_emb[max_antecedent_idx], reduce=reduce) 146 | else: 147 | # Create antecedent cluster if needed 148 | if antecedent_cluster_id == 0: 149 | antecedent_cluster_id = num_clusters 150 | span_to_cluster_id[max_antecedent_idx] = antecedent_cluster_id 151 | cluster_emb[antecedent_cluster_id] = top_span_emb[max_antecedent_idx] 152 | num_clusters += 1 153 | # Add span to cluster 154 | span_to_cluster_id[i] = antecedent_cluster_id 155 | _merge_span_to_cluster(cluster_emb, cluster_sizes, antecedent_cluster_id, top_span_emb[i], reduce=reduce) 156 | 157 | cluster_merging_scores = torch.stack(cluster_merging_scores, dim=0) 158 | return cluster_merging_scores 159 | 160 | 161 | def _merge_span_to_cluster(cluster_emb, cluster_sizes, cluster_to_merge_id, span_emb, reduce): 162 | cluster_size = cluster_sizes[cluster_to_merge_id].item() 163 | if reduce == 'mean': 164 | cluster_emb[cluster_to_merge_id] = (cluster_emb[cluster_to_merge_id] * cluster_size + span_emb) / (cluster_size + 1) 165 | elif reduce == 'max': 166 | cluster_emb[cluster_to_merge_id], _ = torch.max(torch.stack([cluster_emb[cluster_to_merge_id], span_emb]), dim=0) 167 | else: 168 | raise ValueError('reduce value is invalid: %s' % reduce) 169 | cluster_sizes[cluster_to_merge_id] += 1 170 | 171 | 172 | def _merge_clusters(cluster_emb, cluster_sizes, cluster1_id, cluster2_id, reduce): 173 | """ Merge cluster1 to cluster2 """ 174 | cluster1_size, cluster2_size = cluster_sizes[cluster1_id].item(), cluster_sizes[cluster2_id].item() 175 | if reduce == 'mean': 176 | cluster_emb[cluster2_id] = (cluster_emb[cluster1_id] * cluster1_size + cluster_emb[cluster2_id] * cluster2_size) / (cluster1_size + cluster2_size) 177 | elif reduce == 'max': 178 | cluster_emb[cluster2_id] = torch.max(cluster_emb[cluster1_id], cluster_emb[cluster2_id]) 179 | else: 180 | raise ValueError('reduce value is invalid: %s' % reduce) 181 | cluster_sizes[cluster2_id] += cluster_sizes[cluster1_id] 182 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from collections import Counter 7 | from sklearn.utils.linear_assignment_ import linear_assignment 8 | from scipy.optimize import linear_sum_assignment 9 | 10 | 11 | def f1(p_num, p_den, r_num, r_den, beta=1): 12 | p = 0 if p_den == 0 else p_num / float(p_den) 13 | r = 0 if r_den == 0 else r_num / float(r_den) 14 | return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) 15 | 16 | 17 | class CorefEvaluator(object): 18 | def __init__(self): 19 | self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] 20 | 21 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 22 | for e in self.evaluators: 23 | e.update(predicted, gold, mention_to_predicted, mention_to_gold) 24 | 25 | def get_f1(self): 26 | return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) 27 | 28 | def get_recall(self): 29 | return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) 30 | 31 | def get_precision(self): 32 | return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) 33 | 34 | def get_prf(self): 35 | return self.get_precision(), self.get_recall(), self.get_f1() 36 | 37 | 38 | class Evaluator(object): 39 | def __init__(self, metric, beta=1): 40 | self.p_num = 0 41 | self.p_den = 0 42 | self.r_num = 0 43 | self.r_den = 0 44 | self.metric = metric 45 | self.beta = beta 46 | 47 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 48 | if self.metric == ceafe: 49 | pn, pd, rn, rd = self.metric(predicted, gold) 50 | else: 51 | pn, pd = self.metric(predicted, mention_to_gold) 52 | rn, rd = self.metric(gold, mention_to_predicted) 53 | self.p_num += pn 54 | self.p_den += pd 55 | self.r_num += rn 56 | self.r_den += rd 57 | 58 | def get_f1(self): 59 | return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) 60 | 61 | def get_recall(self): 62 | return 0 if self.r_num == 0 else self.r_num / float(self.r_den) 63 | 64 | def get_precision(self): 65 | return 0 if self.p_num == 0 else self.p_num / float(self.p_den) 66 | 67 | def get_prf(self): 68 | return self.get_precision(), self.get_recall(), self.get_f1() 69 | 70 | def get_counts(self): 71 | return self.p_num, self.p_den, self.r_num, self.r_den 72 | 73 | 74 | def evaluate_documents(documents, metric, beta=1): 75 | evaluator = Evaluator(metric, beta=beta) 76 | for document in documents: 77 | evaluator.update(document) 78 | return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1() 79 | 80 | 81 | def b_cubed(clusters, mention_to_gold): 82 | num, dem = 0, 0 83 | 84 | for c in clusters: 85 | if len(c) == 1: 86 | continue 87 | 88 | gold_counts = Counter() 89 | correct = 0 90 | for m in c: 91 | if m in mention_to_gold: 92 | gold_counts[tuple(mention_to_gold[m])] += 1 93 | for c2, count in gold_counts.items(): 94 | if len(c2) != 1: 95 | correct += count * count 96 | 97 | num += correct / float(len(c)) 98 | dem += len(c) 99 | 100 | return num, dem 101 | 102 | 103 | def muc(clusters, mention_to_gold): 104 | tp, p = 0, 0 105 | for c in clusters: 106 | p += len(c) - 1 107 | tp += len(c) 108 | linked = set() 109 | for m in c: 110 | if m in mention_to_gold: 111 | linked.add(mention_to_gold[m]) 112 | else: 113 | tp -= 1 114 | tp -= len(linked) 115 | return tp, p 116 | 117 | 118 | def phi4(c1, c2): 119 | return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) 120 | 121 | 122 | def ceafe(clusters, gold_clusters): 123 | clusters = [c for c in clusters if len(c) != 1] 124 | scores = np.zeros((len(gold_clusters), len(clusters))) 125 | for i in range(len(gold_clusters)): 126 | for j in range(len(clusters)): 127 | scores[i, j] = phi4(gold_clusters[i], clusters[j]) 128 | matching = linear_assignment(-scores) 129 | # matching2 = linear_sum_assignment(-scores) 130 | # matching2 = np.transpose(np.asarray(matching2)) 131 | similarity = sum(scores[matching[:, 0], matching[:, 1]]) 132 | return similarity, len(clusters), similarity, len(gold_clusters) 133 | 134 | 135 | def lea(clusters, mention_to_gold): 136 | num, dem = 0, 0 137 | 138 | for c in clusters: 139 | if len(c) == 1: 140 | continue 141 | 142 | common_links = 0 143 | all_links = len(c) * (len(c) - 1) / 2.0 144 | for i, m in enumerate(c): 145 | if m in mention_to_gold: 146 | for m2 in c[i + 1:]: 147 | if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: 148 | common_links += 1 149 | 150 | num += len(c) * common_links / float(all_links) 151 | dem += len(c) 152 | 153 | return num, dem 154 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel 4 | import util 5 | import logging 6 | from collections import Iterable 7 | import numpy as np 8 | import torch.nn.init as init 9 | import higher_order as ho 10 | 11 | 12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 13 | datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger() 16 | 17 | 18 | class CorefModel(nn.Module): 19 | def __init__(self, config, device, num_genres=None): 20 | super().__init__() 21 | self.config = config 22 | self.device = device 23 | 24 | self.num_genres = num_genres if num_genres else len(config['genres']) 25 | self.max_seg_len = config['max_segment_len'] 26 | self.max_span_width = config['max_span_width'] 27 | assert config['loss_type'] in ['marginalized', 'hinge'] 28 | if config['coref_depth'] > 1 or config['higher_order'] == 'cluster_merging': 29 | assert config['fine_grained'] # Higher-order is in slow fine-grained scoring 30 | 31 | # Model 32 | self.dropout = nn.Dropout(p=config['dropout_rate']) 33 | self.bert = BertModel.from_pretrained(config['bert_pretrained_name_or_path']) 34 | 35 | self.bert_emb_size = self.bert.config.hidden_size 36 | self.span_emb_size = self.bert_emb_size * 3 37 | if config['use_features']: 38 | self.span_emb_size += config['feature_emb_size'] 39 | self.pair_emb_size = self.span_emb_size * 3 40 | if config['use_metadata']: 41 | self.pair_emb_size += 2 * config['feature_emb_size'] 42 | if config['use_features']: 43 | self.pair_emb_size += config['feature_emb_size'] 44 | if config['use_segment_distance']: 45 | self.pair_emb_size += config['feature_emb_size'] 46 | 47 | self.emb_span_width = self.make_embedding(self.max_span_width) if config['use_features'] else None 48 | self.emb_span_width_prior = self.make_embedding(self.max_span_width) if config['use_width_prior'] else None 49 | self.emb_antecedent_distance_prior = self.make_embedding(10) if config['use_distance_prior'] else None 50 | self.emb_genre = self.make_embedding(self.num_genres) 51 | self.emb_same_speaker = self.make_embedding(2) if config['use_metadata'] else None 52 | self.emb_segment_distance = self.make_embedding(config['max_training_sentences']) if config['use_segment_distance'] else None 53 | self.emb_top_antecedent_distance = self.make_embedding(10) 54 | self.emb_cluster_size = self.make_embedding(10) if config['higher_order'] == 'cluster_merging' else None 55 | 56 | self.mention_token_attn = self.make_ffnn(self.bert_emb_size, 0, output_size=1) if config['model_heads'] else None 57 | self.span_emb_score_ffnn = self.make_ffnn(self.span_emb_size, [config['ffnn_size']] * config['ffnn_depth'], output_size=1) 58 | self.span_width_score_ffnn = self.make_ffnn(config['feature_emb_size'], [config['ffnn_size']] * config['ffnn_depth'], output_size=1) if config['use_width_prior'] else None 59 | self.coarse_bilinear = self.make_ffnn(self.span_emb_size, 0, output_size=self.span_emb_size) 60 | self.antecedent_distance_score_ffnn = self.make_ffnn(config['feature_emb_size'], 0, output_size=1) if config['use_distance_prior'] else None 61 | self.coref_score_ffnn = self.make_ffnn(self.pair_emb_size, [config['ffnn_size']] * config['ffnn_depth'], output_size=1) if config['fine_grained'] else None 62 | 63 | self.gate_ffnn = self.make_ffnn(2 * self.span_emb_size, 0, output_size=self.span_emb_size) if config['coref_depth'] > 1 else None 64 | self.span_attn_ffnn = self.make_ffnn(self.span_emb_size, 0, output_size=1) if config['higher_order'] == 'span_clustering' else None 65 | self.cluster_score_ffnn = self.make_ffnn(3 * self.span_emb_size + config['feature_emb_size'], [config['cluster_ffnn_size']] * config['ffnn_depth'], output_size=1) if config['higher_order'] == 'cluster_merging' else None 66 | 67 | self.update_steps = 0 # Internal use for debug 68 | self.debug = True 69 | 70 | def make_embedding(self, dict_size, std=0.02): 71 | emb = nn.Embedding(dict_size, self.config['feature_emb_size']) 72 | init.normal_(emb.weight, std=std) 73 | return emb 74 | 75 | def make_linear(self, in_features, out_features, bias=True, std=0.02): 76 | linear = nn.Linear(in_features, out_features, bias) 77 | init.normal_(linear.weight, std=std) 78 | if bias: 79 | init.zeros_(linear.bias) 80 | return linear 81 | 82 | def make_ffnn(self, feat_size, hidden_size, output_size): 83 | if hidden_size is None or hidden_size == 0 or hidden_size == [] or hidden_size == [0]: 84 | return self.make_linear(feat_size, output_size) 85 | 86 | if not isinstance(hidden_size, Iterable): 87 | hidden_size = [hidden_size] 88 | ffnn = [self.make_linear(feat_size, hidden_size[0]), nn.ReLU(), self.dropout] 89 | for i in range(1, len(hidden_size)): 90 | ffnn += [self.make_linear(hidden_size[i-1], hidden_size[i]), nn.ReLU(), self.dropout] 91 | ffnn.append(self.make_linear(hidden_size[-1], output_size)) 92 | return nn.Sequential(*ffnn) 93 | 94 | def get_params(self, named=False): 95 | bert_based_param, task_param = [], [] 96 | for name, param in self.named_parameters(): 97 | if name.startswith('bert'): 98 | to_add = (name, param) if named else param 99 | bert_based_param.append(to_add) 100 | else: 101 | to_add = (name, param) if named else param 102 | task_param.append(to_add) 103 | return bert_based_param, task_param 104 | 105 | def forward(self, *input): 106 | return self.get_predictions_and_loss(*input) 107 | 108 | def get_predictions_and_loss(self, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, 109 | is_training, gold_starts=None, gold_ends=None, gold_mention_cluster_map=None): 110 | """ Model and input are already on the device """ 111 | device = self.device 112 | conf = self.config 113 | 114 | do_loss = False 115 | if gold_mention_cluster_map is not None: 116 | assert gold_starts is not None 117 | assert gold_ends is not None 118 | do_loss = True 119 | 120 | # Get token emb 121 | mention_doc, _ = self.bert(input_ids, attention_mask=input_mask) # [num seg, num max tokens, emb size] 122 | input_mask = input_mask.to(torch.bool) 123 | mention_doc = mention_doc[input_mask] 124 | speaker_ids = speaker_ids[input_mask] 125 | num_words = mention_doc.shape[0] 126 | 127 | # Get candidate span 128 | sentence_indices = sentence_map # [num tokens] 129 | candidate_starts = torch.unsqueeze(torch.arange(0, num_words, device=device), 1).repeat(1, self.max_span_width) 130 | candidate_ends = candidate_starts + torch.arange(0, self.max_span_width, device=device) 131 | candidate_start_sent_idx = sentence_indices[candidate_starts] 132 | candidate_end_sent_idx = sentence_indices[torch.min(candidate_ends, torch.tensor(num_words - 1, device=device))] 133 | candidate_mask = (candidate_ends < num_words) & (candidate_start_sent_idx == candidate_end_sent_idx) 134 | candidate_starts, candidate_ends = candidate_starts[candidate_mask], candidate_ends[candidate_mask] # [num valid candidates] 135 | num_candidates = candidate_starts.shape[0] 136 | 137 | # Get candidate labels 138 | if do_loss: 139 | same_start = (torch.unsqueeze(gold_starts, 1) == torch.unsqueeze(candidate_starts, 0)) 140 | same_end = (torch.unsqueeze(gold_ends, 1) == torch.unsqueeze(candidate_ends, 0)) 141 | same_span = (same_start & same_end).to(torch.long) 142 | candidate_labels = torch.matmul(torch.unsqueeze(gold_mention_cluster_map, 0).to(torch.float), same_span.to(torch.float)) 143 | candidate_labels = torch.squeeze(candidate_labels.to(torch.long), 0) # [num candidates]; non-gold span has label 0 144 | 145 | # Get span embedding 146 | span_start_emb, span_end_emb = mention_doc[candidate_starts], mention_doc[candidate_ends] 147 | candidate_emb_list = [span_start_emb, span_end_emb] 148 | if conf['use_features']: 149 | candidate_width_idx = candidate_ends - candidate_starts 150 | candidate_width_emb = self.emb_span_width(candidate_width_idx) 151 | candidate_width_emb = self.dropout(candidate_width_emb) 152 | candidate_emb_list.append(candidate_width_emb) 153 | # Use attended head or avg token 154 | candidate_tokens = torch.unsqueeze(torch.arange(0, num_words, device=device), 0).repeat(num_candidates, 1) 155 | candidate_tokens_mask = (candidate_tokens >= torch.unsqueeze(candidate_starts, 1)) & (candidate_tokens <= torch.unsqueeze(candidate_ends, 1)) 156 | if conf['model_heads']: 157 | token_attn = torch.squeeze(self.mention_token_attn(mention_doc), 1) 158 | else: 159 | token_attn = torch.ones(num_words, dtype=torch.float, device=device) # Use avg if no attention 160 | candidate_tokens_attn_raw = torch.log(candidate_tokens_mask.to(torch.float)) + torch.unsqueeze(token_attn, 0) 161 | candidate_tokens_attn = nn.functional.softmax(candidate_tokens_attn_raw, dim=1) 162 | head_attn_emb = torch.matmul(candidate_tokens_attn, mention_doc) 163 | candidate_emb_list.append(head_attn_emb) 164 | candidate_span_emb = torch.cat(candidate_emb_list, dim=1) # [num candidates, new emb size] 165 | 166 | # Get span score 167 | candidate_mention_scores = torch.squeeze(self.span_emb_score_ffnn(candidate_span_emb), 1) 168 | if conf['use_width_prior']: 169 | width_score = torch.squeeze(self.span_width_score_ffnn(self.emb_span_width_prior.weight), 1) 170 | candidate_width_score = width_score[candidate_width_idx] 171 | candidate_mention_scores += candidate_width_score 172 | 173 | # Extract top spans 174 | candidate_idx_sorted_by_score = torch.argsort(candidate_mention_scores, descending=True).tolist() 175 | candidate_starts_cpu, candidate_ends_cpu = candidate_starts.tolist(), candidate_ends.tolist() 176 | num_top_spans = int(min(conf['max_num_extracted_spans'], conf['top_span_ratio'] * num_words)) 177 | selected_idx_cpu = self._extract_top_spans(candidate_idx_sorted_by_score, candidate_starts_cpu, candidate_ends_cpu, num_top_spans) 178 | assert len(selected_idx_cpu) == num_top_spans 179 | selected_idx = torch.tensor(selected_idx_cpu, device=device) 180 | top_span_starts, top_span_ends = candidate_starts[selected_idx], candidate_ends[selected_idx] 181 | top_span_emb = candidate_span_emb[selected_idx] 182 | top_span_cluster_ids = candidate_labels[selected_idx] if do_loss else None 183 | top_span_mention_scores = candidate_mention_scores[selected_idx] 184 | 185 | # Coarse pruning on each mention's antecedents 186 | max_top_antecedents = min(num_top_spans, conf['max_top_antecedents']) 187 | top_span_range = torch.arange(0, num_top_spans, device=device) 188 | antecedent_offsets = torch.unsqueeze(top_span_range, 1) - torch.unsqueeze(top_span_range, 0) 189 | antecedent_mask = (antecedent_offsets >= 1) 190 | pairwise_mention_score_sum = torch.unsqueeze(top_span_mention_scores, 1) + torch.unsqueeze(top_span_mention_scores, 0) 191 | source_span_emb = self.dropout(self.coarse_bilinear(top_span_emb)) 192 | target_span_emb = self.dropout(torch.transpose(top_span_emb, 0, 1)) 193 | pairwise_coref_scores = torch.matmul(source_span_emb, target_span_emb) 194 | pairwise_fast_scores = pairwise_mention_score_sum + pairwise_coref_scores 195 | pairwise_fast_scores += torch.log(antecedent_mask.to(torch.float)) 196 | if conf['use_distance_prior']: 197 | distance_score = torch.squeeze(self.antecedent_distance_score_ffnn(self.dropout(self.emb_antecedent_distance_prior.weight)), 1) 198 | bucketed_distance = util.bucket_distance(antecedent_offsets) 199 | antecedent_distance_score = distance_score[bucketed_distance] 200 | pairwise_fast_scores += antecedent_distance_score 201 | top_pairwise_fast_scores, top_antecedent_idx = torch.topk(pairwise_fast_scores, k=max_top_antecedents) 202 | top_antecedent_mask = util.batch_select(antecedent_mask, top_antecedent_idx, device) # [num top spans, max top antecedents] 203 | top_antecedent_offsets = util.batch_select(antecedent_offsets, top_antecedent_idx, device) 204 | 205 | # Slow mention ranking 206 | if conf['fine_grained']: 207 | same_speaker_emb, genre_emb, seg_distance_emb, top_antecedent_distance_emb = None, None, None, None 208 | if conf['use_metadata']: 209 | top_span_speaker_ids = speaker_ids[top_span_starts] 210 | top_antecedent_speaker_id = top_span_speaker_ids[top_antecedent_idx] 211 | same_speaker = torch.unsqueeze(top_span_speaker_ids, 1) == top_antecedent_speaker_id 212 | same_speaker_emb = self.emb_same_speaker(same_speaker.to(torch.long)) 213 | genre_emb = self.emb_genre(genre) 214 | genre_emb = torch.unsqueeze(torch.unsqueeze(genre_emb, 0), 0).repeat(num_top_spans, max_top_antecedents, 1) 215 | if conf['use_segment_distance']: 216 | num_segs, seg_len = input_ids.shape[0], input_ids.shape[1] 217 | token_seg_ids = torch.arange(0, num_segs, device=device).unsqueeze(1).repeat(1, seg_len) 218 | token_seg_ids = token_seg_ids[input_mask] 219 | top_span_seg_ids = token_seg_ids[top_span_starts] 220 | top_antecedent_seg_ids = token_seg_ids[top_span_starts[top_antecedent_idx]] 221 | top_antecedent_seg_distance = torch.unsqueeze(top_span_seg_ids, 1) - top_antecedent_seg_ids 222 | top_antecedent_seg_distance = torch.clamp(top_antecedent_seg_distance, 0, self.config['max_training_sentences'] - 1) 223 | seg_distance_emb = self.emb_segment_distance(top_antecedent_seg_distance) 224 | if conf['use_features']: # Antecedent distance 225 | top_antecedent_distance = util.bucket_distance(top_antecedent_offsets) 226 | top_antecedent_distance_emb = self.emb_top_antecedent_distance(top_antecedent_distance) 227 | 228 | for depth in range(conf['coref_depth']): 229 | top_antecedent_emb = top_span_emb[top_antecedent_idx] # [num top spans, max top antecedents, emb size] 230 | feature_list = [] 231 | if conf['use_metadata']: # speaker, genre 232 | feature_list.append(same_speaker_emb) 233 | feature_list.append(genre_emb) 234 | if conf['use_segment_distance']: 235 | feature_list.append(seg_distance_emb) 236 | if conf['use_features']: # Antecedent distance 237 | feature_list.append(top_antecedent_distance_emb) 238 | feature_emb = torch.cat(feature_list, dim=2) 239 | feature_emb = self.dropout(feature_emb) 240 | target_emb = torch.unsqueeze(top_span_emb, 1).repeat(1, max_top_antecedents, 1) 241 | similarity_emb = target_emb * top_antecedent_emb 242 | pair_emb = torch.cat([target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2) 243 | top_pairwise_slow_scores = torch.squeeze(self.coref_score_ffnn(pair_emb), 2) 244 | top_pairwise_scores = top_pairwise_slow_scores + top_pairwise_fast_scores 245 | if conf['higher_order'] == 'cluster_merging': 246 | cluster_merging_scores = ho.cluster_merging(top_span_emb, top_antecedent_idx, top_pairwise_scores, self.emb_cluster_size, self.cluster_score_ffnn, None, self.dropout, 247 | device=device, reduce=conf['cluster_reduce'], easy_cluster_first=conf['easy_cluster_first']) 248 | break 249 | elif depth != conf['coref_depth'] - 1: 250 | if conf['higher_order'] == 'attended_antecedent': 251 | refined_span_emb = ho.attended_antecedent(top_span_emb, top_antecedent_emb, top_pairwise_scores, device) 252 | elif conf['higher_order'] == 'max_antecedent': 253 | refined_span_emb = ho.max_antecedent(top_span_emb, top_antecedent_emb, top_pairwise_scores, device) 254 | elif conf['higher_order'] == 'entity_equalization': 255 | refined_span_emb = ho.entity_equalization(top_span_emb, top_antecedent_emb, top_antecedent_idx, top_pairwise_scores, device) 256 | elif conf['higher_order'] == 'span_clustering': 257 | refined_span_emb = ho.span_clustering(top_span_emb, top_antecedent_idx, top_pairwise_scores, self.span_attn_ffnn, device) 258 | 259 | gate = self.gate_ffnn(torch.cat([top_span_emb, refined_span_emb], dim=1)) 260 | gate = torch.sigmoid(gate) 261 | top_span_emb = gate * refined_span_emb + (1 - gate) * top_span_emb # [num top spans, span emb size] 262 | else: 263 | top_pairwise_scores = top_pairwise_fast_scores # [num top spans, max top antecedents] 264 | 265 | if not do_loss: 266 | if conf['fine_grained'] and conf['higher_order'] == 'cluster_merging': 267 | top_pairwise_scores += cluster_merging_scores 268 | top_antecedent_scores = torch.cat([torch.zeros(num_top_spans, 1, device=device), top_pairwise_scores], dim=1) # [num top spans, max top antecedents + 1] 269 | return candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedent_idx, top_antecedent_scores 270 | 271 | # Get gold labels 272 | top_antecedent_cluster_ids = top_span_cluster_ids[top_antecedent_idx] 273 | top_antecedent_cluster_ids += (top_antecedent_mask.to(torch.long) - 1) * 100000 # Mask id on invalid antecedents 274 | same_gold_cluster_indicator = (top_antecedent_cluster_ids == torch.unsqueeze(top_span_cluster_ids, 1)) 275 | non_dummy_indicator = torch.unsqueeze(top_span_cluster_ids > 0, 1) 276 | pairwise_labels = same_gold_cluster_indicator & non_dummy_indicator 277 | dummy_antecedent_labels = torch.logical_not(pairwise_labels.any(dim=1, keepdims=True)) 278 | top_antecedent_gold_labels = torch.cat([dummy_antecedent_labels, pairwise_labels], dim=1) 279 | 280 | # Get loss 281 | top_antecedent_scores = torch.cat([torch.zeros(num_top_spans, 1, device=device), top_pairwise_scores], dim=1) 282 | if conf['loss_type'] == 'marginalized': 283 | log_marginalized_antecedent_scores = torch.logsumexp(top_antecedent_scores + torch.log(top_antecedent_gold_labels.to(torch.float)), dim=1) 284 | log_norm = torch.logsumexp(top_antecedent_scores, dim=1) 285 | loss = torch.sum(log_norm - log_marginalized_antecedent_scores) 286 | elif conf['loss_type'] == 'hinge': 287 | top_antecedent_mask = torch.cat([torch.ones(num_top_spans, 1, dtype=torch.bool, device=device), top_antecedent_mask], dim=1) 288 | top_antecedent_scores += torch.log(top_antecedent_mask.to(torch.float)) 289 | highest_antecedent_scores, highest_antecedent_idx = torch.max(top_antecedent_scores, dim=1) 290 | gold_antecedent_scores = top_antecedent_scores + torch.log(top_antecedent_gold_labels.to(torch.float)) 291 | highest_gold_antecedent_scores, highest_gold_antecedent_idx = torch.max(gold_antecedent_scores, dim=1) 292 | slack_hinge = 1 + highest_antecedent_scores - highest_gold_antecedent_scores 293 | # Calculate delta 294 | highest_antecedent_is_gold = (highest_antecedent_idx == highest_gold_antecedent_idx) 295 | mistake_false_new = (highest_antecedent_idx == 0) & torch.logical_not(dummy_antecedent_labels.squeeze()) 296 | delta = ((3 - conf['false_new_delta']) / 2) * torch.ones(num_top_spans, dtype=torch.float, device=device) 297 | delta -= (1 - conf['false_new_delta']) * mistake_false_new.to(torch.float) 298 | delta *= torch.logical_not(highest_antecedent_is_gold).to(torch.float) 299 | loss = torch.sum(slack_hinge * delta) 300 | 301 | # Add mention loss 302 | if conf['mention_loss_coef']: 303 | gold_mention_scores = top_span_mention_scores[top_span_cluster_ids > 0] 304 | non_gold_mention_scores = top_span_mention_scores[top_span_cluster_ids == 0] 305 | loss_mention = -torch.sum(torch.log(torch.sigmoid(gold_mention_scores))) * conf['mention_loss_coef'] 306 | loss_mention += -torch.sum(torch.log(1 - torch.sigmoid(non_gold_mention_scores))) * conf['mention_loss_coef'] 307 | loss += loss_mention 308 | 309 | if conf['higher_order'] == 'cluster_merging': 310 | top_pairwise_scores += cluster_merging_scores 311 | top_antecedent_scores = torch.cat([torch.zeros(num_top_spans, 1, device=device), top_pairwise_scores], dim=1) 312 | log_marginalized_antecedent_scores2 = torch.logsumexp(top_antecedent_scores + torch.log(top_antecedent_gold_labels.to(torch.float)), dim=1) 313 | log_norm2 = torch.logsumexp(top_antecedent_scores, dim=1) # [num top spans] 314 | loss_cm = torch.sum(log_norm2 - log_marginalized_antecedent_scores2) 315 | if conf['cluster_dloss']: 316 | loss += loss_cm 317 | else: 318 | loss = loss_cm 319 | 320 | # Debug 321 | if self.debug: 322 | if self.update_steps % 20 == 0: 323 | logger.info('---------debug step: %d---------' % self.update_steps) 324 | # logger.info('candidates: %d; antecedents: %d' % (num_candidates, max_top_antecedents)) 325 | logger.info('spans/gold: %d/%d; ratio: %.2f' % (num_top_spans, (top_span_cluster_ids > 0).sum(), (top_span_cluster_ids > 0).sum()/num_top_spans)) 326 | if conf['mention_loss_coef']: 327 | logger.info('mention loss: %.4f' % loss_mention) 328 | if conf['loss_type'] == 'marginalized': 329 | logger.info('norm/gold: %.4f/%.4f' % (torch.sum(log_norm), torch.sum(log_marginalized_antecedent_scores))) 330 | else: 331 | logger.info('loss: %.4f' % loss) 332 | self.update_steps += 1 333 | 334 | return [candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedent_idx, top_antecedent_scores], loss 335 | 336 | def _extract_top_spans(self, candidate_idx_sorted, candidate_starts, candidate_ends, num_top_spans): 337 | """ Keep top non-cross-overlapping candidates ordered by scores; compute on CPU because of loop """ 338 | selected_candidate_idx = [] 339 | start_to_max_end, end_to_min_start = {}, {} 340 | for candidate_idx in candidate_idx_sorted: 341 | if len(selected_candidate_idx) >= num_top_spans: 342 | break 343 | # Perform overlapping check 344 | span_start_idx = candidate_starts[candidate_idx] 345 | span_end_idx = candidate_ends[candidate_idx] 346 | cross_overlap = False 347 | for token_idx in range(span_start_idx, span_end_idx + 1): 348 | max_end = start_to_max_end.get(token_idx, -1) 349 | if token_idx > span_start_idx and max_end > span_end_idx: 350 | cross_overlap = True 351 | break 352 | min_start = end_to_min_start.get(token_idx, -1) 353 | if token_idx < span_end_idx and 0 <= min_start < span_start_idx: 354 | cross_overlap = True 355 | break 356 | if not cross_overlap: 357 | # Pass check; select idx and update dict stats 358 | selected_candidate_idx.append(candidate_idx) 359 | max_end = start_to_max_end.get(span_start_idx, -1) 360 | if span_end_idx > max_end: 361 | start_to_max_end[span_start_idx] = span_end_idx 362 | min_start = end_to_min_start.get(span_end_idx, -1) 363 | if min_start == -1 or span_start_idx < min_start: 364 | end_to_min_start[span_end_idx] = span_start_idx 365 | # Sort selected candidates by span idx 366 | selected_candidate_idx = sorted(selected_candidate_idx, key=lambda idx: (candidate_starts[idx], candidate_ends[idx])) 367 | if len(selected_candidate_idx) < num_top_spans: # Padding 368 | selected_candidate_idx += ([selected_candidate_idx[0]] * (num_top_spans - len(selected_candidate_idx))) 369 | return selected_candidate_idx 370 | 371 | def get_predicted_antecedents(self, antecedent_idx, antecedent_scores): 372 | """ CPU list input """ 373 | predicted_antecedents = [] 374 | for i, idx in enumerate(np.argmax(antecedent_scores, axis=1) - 1): 375 | if idx < 0: 376 | predicted_antecedents.append(-1) 377 | else: 378 | predicted_antecedents.append(antecedent_idx[i][idx]) 379 | return predicted_antecedents 380 | 381 | def get_predicted_clusters(self, span_starts, span_ends, antecedent_idx, antecedent_scores): 382 | """ CPU list input """ 383 | # Get predicted antecedents 384 | predicted_antecedents = self.get_predicted_antecedents(antecedent_idx, antecedent_scores) 385 | 386 | # Get predicted clusters 387 | mention_to_cluster_id = {} 388 | predicted_clusters = [] 389 | for i, predicted_idx in enumerate(predicted_antecedents): 390 | if predicted_idx < 0: 391 | continue 392 | assert i > predicted_idx, f'span idx: {i}; antecedent idx: {predicted_idx}' 393 | # Check antecedent's cluster 394 | antecedent = (int(span_starts[predicted_idx]), int(span_ends[predicted_idx])) 395 | antecedent_cluster_id = mention_to_cluster_id.get(antecedent, -1) 396 | if antecedent_cluster_id == -1: 397 | antecedent_cluster_id = len(predicted_clusters) 398 | predicted_clusters.append([antecedent]) 399 | mention_to_cluster_id[antecedent] = antecedent_cluster_id 400 | # Add mention to cluster 401 | mention = (int(span_starts[i]), int(span_ends[i])) 402 | predicted_clusters[antecedent_cluster_id].append(mention) 403 | mention_to_cluster_id[mention] = antecedent_cluster_id 404 | 405 | predicted_clusters = [tuple(c) for c in predicted_clusters] 406 | return predicted_clusters, mention_to_cluster_id, predicted_antecedents 407 | 408 | def update_evaluator(self, span_starts, span_ends, antecedent_idx, antecedent_scores, gold_clusters, evaluator): 409 | predicted_clusters, mention_to_cluster_id, _ = self.get_predicted_clusters(span_starts, span_ends, antecedent_idx, antecedent_scores) 410 | mention_to_predicted = {m: predicted_clusters[cluster_idx] for m, cluster_idx in mention_to_cluster_id.items()} 411 | gold_clusters = [tuple(tuple(m) for m in cluster) for cluster in gold_clusters] 412 | mention_to_gold = {m: cluster for cluster in gold_clusters for m in cluster} 413 | evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) 414 | return predicted_clusters 415 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | from spacy.lang.en import English 3 | from preprocess import get_document 4 | import argparse 5 | import util 6 | from tensorize import CorefDataProcessor 7 | from run import Runner 8 | import logging 9 | logging.getLogger().setLevel(logging.CRITICAL) 10 | 11 | 12 | def create_spacy_tokenizer(): 13 | nlp = English() 14 | sentencizer = nlp.create_pipe('sentencizer') 15 | nlp.add_pipe(sentencizer) 16 | 17 | 18 | def get_document_from_string(string, seg_len, bert_tokenizer, spacy_tokenizer, genre='nw'): 19 | doc_key = genre # See genres in experiment config 20 | doc_lines = [] 21 | 22 | # Build doc_lines 23 | for token in spacy_tokenizer(string): 24 | cols = [genre] + ['-'] * 11 25 | cols[3] = token.text 26 | doc_lines.append('\t'.join(cols)) 27 | if token.is_sent_end: 28 | doc_lines.append('\n') 29 | 30 | doc = get_document(doc_key, doc_lines, 'english', seg_len, bert_tokenizer) 31 | return doc 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--config_name', type=str, required=True, 37 | help='Configuration name in experiments.conf') 38 | parser.add_argument('--model_identifier', type=str, required=True, 39 | help='Model identifier to load') 40 | parser.add_argument('--gpu_id', type=int, default=None, 41 | help='GPU id; CPU by default') 42 | parser.add_argument('--seg_len', type=int, default=512) 43 | parser.add_argument('--jsonlines_path', type=str, default=None, 44 | help='Path to custom input from file; input from console by default') 45 | parser.add_argument('--output_path', type=str, default=None, 46 | help='Path to save output') 47 | args = parser.parse_args() 48 | 49 | runner = Runner(args.config_name, args.gpu_id) 50 | model = runner.initialize_model(args.model_identifier) 51 | data_processor = CorefDataProcessor(runner.config) 52 | 53 | if args.jsonlines_path: 54 | # Input from file 55 | with open(args.jsonlines_path, 'r') as f: 56 | lines = f.readlines() 57 | docs = [json.loads(line) for line in lines] 58 | tensor_examples, stored_info = data_processor.get_tensor_examples_from_custom_input(docs) 59 | predicted_clusters, _, _ = runner.predict(model, tensor_examples) 60 | 61 | if args.output_path: 62 | with open(args.output_path, 'w') as f: 63 | for i, doc in enumerate(docs): 64 | doc['predicted_clusters'] = predicted_clusters[i] 65 | f.write(json.dumps(doc)) 66 | print(f'Saved prediction in {args.output_path}') 67 | else: 68 | # Interactive input 69 | model.to(model.device) 70 | nlp = English() 71 | nlp.add_pipe(nlp.create_pipe('sentencizer')) 72 | while True: 73 | input_str = str(input('Input document:')) 74 | bert_tokenizer, spacy_tokenizer = data_processor.tokenizer, nlp 75 | doc = get_document_from_string(input_str, args.seg_len, bert_tokenizer, nlp) 76 | tensor_examples, stored_info = data_processor.get_tensor_examples_from_custom_input([doc]) 77 | predicted_clusters, _, _ = runner.predict(model, tensor_examples) 78 | 79 | subtokens = util.flatten(doc['sentences']) 80 | print('---Predicted clusters:') 81 | for cluster in predicted_clusters[0]: 82 | mentions_str = [' '.join(subtokens[m[0]:m[1]+1]) for m in cluster] 83 | mentions_str = [m.replace(' ##', '') for m in mentions_str] 84 | mentions_str = [m.replace('##', '') for m in mentions_str] 85 | print(mentions_str) # Print out strings 86 | # print(cluster) # Print out indices 87 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import re 5 | import collections 6 | import json 7 | import conll 8 | import util 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def skip_doc(doc_key): 16 | return False 17 | 18 | 19 | def normalize_word(word, language): 20 | if language == "arabic": 21 | word = word[:word.find("#")] 22 | if word == "/." or word == "/?": 23 | return word[1:] 24 | else: 25 | return word 26 | 27 | 28 | def get_sentence_map(segments, sentence_end): 29 | assert len(sentence_end) == sum([len(seg) - 2 for seg in segments]) # of subtokens in all segments 30 | sent_map = [] 31 | sent_idx, subtok_idx = 0, 0 32 | for segment in segments: 33 | sent_map.append(sent_idx) # [CLS] 34 | for i in range(len(segment) - 2): 35 | sent_map.append(sent_idx) 36 | sent_idx += int(sentence_end[subtok_idx]) 37 | subtok_idx += 1 38 | sent_map.append(sent_idx) # [SEP] 39 | return sent_map 40 | 41 | 42 | class DocumentState(object): 43 | def __init__(self, key): 44 | self.doc_key = key 45 | self.tokens = [] 46 | 47 | # Linear list mapped to subtokens without CLS, SEP 48 | self.subtokens = [] 49 | self.subtoken_map = [] 50 | self.token_end = [] 51 | self.sentence_end = [] 52 | self.info = [] # Only non-none for the first subtoken of each word 53 | 54 | # Linear list mapped to subtokens with CLS, SEP 55 | self.sentence_map = [] 56 | 57 | # Segments (mapped to subtokens with CLS, SEP) 58 | self.segments = [] 59 | self.segment_subtoken_map = [] 60 | self.segment_info = [] # Only non-none for the first subtoken of each word 61 | self.speakers = [] 62 | 63 | # Doc-level attributes 64 | self.pronouns = [] 65 | self.clusters = collections.defaultdict(list) # {cluster_id: [(first_subtok_idx, last_subtok_idx) for each mention]} 66 | self.coref_stacks = collections.defaultdict(list) 67 | 68 | def finalize(self): 69 | """ Extract clusters; fill other info e.g. speakers, pronouns """ 70 | # Populate speakers from info 71 | subtoken_idx = 0 72 | for seg_info in self.segment_info: 73 | speakers = [] 74 | for i, subtoken_info in enumerate(seg_info): 75 | if i == 0 or i == len(seg_info) - 1: 76 | speakers.append('[SPL]') 77 | elif subtoken_info is not None: # First subtoken of each word 78 | speakers.append(subtoken_info[9]) 79 | # if subtoken_info[4] == 'PRP': # Uncomment if needed 80 | # self.pronouns.append(subtoken_idx) 81 | else: 82 | speakers.append(speakers[-1]) 83 | subtoken_idx += 1 84 | self.speakers += [speakers] 85 | 86 | # Populate cluster 87 | first_subtoken_idx = 0 # Subtoken idx across segments 88 | subtokens_info = util.flatten(self.segment_info) 89 | while first_subtoken_idx < len(subtokens_info): 90 | subtoken_info = subtokens_info[first_subtoken_idx] 91 | coref = subtoken_info[-2] if subtoken_info is not None else '-' 92 | if coref != '-': 93 | last_subtoken_idx = first_subtoken_idx + subtoken_info[-1] - 1 94 | for part in coref.split('|'): 95 | if part[0] == '(': 96 | if part[-1] == ')': 97 | cluster_id = int(part[1:-1]) 98 | self.clusters[cluster_id].append((first_subtoken_idx, last_subtoken_idx)) 99 | else: 100 | cluster_id = int(part[1:]) 101 | self.coref_stacks[cluster_id].append(first_subtoken_idx) 102 | else: 103 | cluster_id = int(part[:-1]) 104 | start = self.coref_stacks[cluster_id].pop() 105 | self.clusters[cluster_id].append((start, last_subtoken_idx)) 106 | first_subtoken_idx += 1 107 | 108 | # Merge clusters if any clusters have common mentions 109 | merged_clusters = [] 110 | for cluster in self.clusters.values(): 111 | existing = None 112 | for mention in cluster: 113 | for merged_cluster in merged_clusters: 114 | if mention in merged_cluster: 115 | existing = merged_cluster 116 | break 117 | if existing is not None: 118 | break 119 | if existing is not None: 120 | print("Merging clusters (shouldn't happen very often)") 121 | existing.update(cluster) 122 | else: 123 | merged_clusters.append(set(cluster)) 124 | 125 | merged_clusters = [list(cluster) for cluster in merged_clusters] 126 | all_mentions = util.flatten(merged_clusters) 127 | sentence_map = get_sentence_map(self.segments, self.sentence_end) 128 | subtoken_map = util.flatten(self.segment_subtoken_map) 129 | 130 | # Sanity check 131 | assert len(all_mentions) == len(set(all_mentions)) # Each mention unique 132 | # Below should have length: # all subtokens with CLS, SEP in all segments 133 | num_all_seg_tokens = len(util.flatten(self.segments)) 134 | assert num_all_seg_tokens == len(util.flatten(self.speakers)) 135 | assert num_all_seg_tokens == len(subtoken_map) 136 | assert num_all_seg_tokens == len(sentence_map) 137 | 138 | return { 139 | "doc_key": self.doc_key, 140 | "tokens": self.tokens, 141 | "sentences": self.segments, 142 | "speakers": self.speakers, 143 | "constituents": [], 144 | "ner": [], 145 | "clusters": merged_clusters, 146 | 'sentence_map': sentence_map, 147 | "subtoken_map": subtoken_map, 148 | 'pronouns': self.pronouns 149 | } 150 | 151 | 152 | def split_into_segments(document_state: DocumentState, max_seg_len, constraints1, constraints2, tokenizer): 153 | """ Split into segments. 154 | Add subtokens, subtoken_map, info for each segment; add CLS, SEP in the segment subtokens 155 | Input document_state: tokens, subtokens, token_end, sentence_end, utterance_end, subtoken_map, info 156 | """ 157 | curr_idx = 0 # Index for subtokens 158 | prev_token_idx = 0 159 | while curr_idx < len(document_state.subtokens): 160 | # Try to split at a sentence end point 161 | end_idx = min(curr_idx + max_seg_len - 1 - 2, len(document_state.subtokens) - 1) # Inclusive 162 | while end_idx >= curr_idx and not constraints1[end_idx]: 163 | end_idx -= 1 164 | if end_idx < curr_idx: 165 | logger.info(f'{document_state.doc_key}: no sentence end found; split at token end') 166 | # If no sentence end point, try to split at token end point 167 | end_idx = min(curr_idx + max_seg_len - 1 - 2, len(document_state.subtokens) - 1) 168 | while end_idx >= curr_idx and not constraints2[end_idx]: 169 | end_idx -= 1 170 | if end_idx < curr_idx: 171 | logger.error('Cannot split valid segment: no sentence end or token end') 172 | 173 | segment = [tokenizer.cls_token] + document_state.subtokens[curr_idx: end_idx + 1] + [tokenizer.sep_token] 174 | document_state.segments.append(segment) 175 | 176 | subtoken_map = document_state.subtoken_map[curr_idx: end_idx + 1] 177 | document_state.segment_subtoken_map.append([prev_token_idx] + subtoken_map + [subtoken_map[-1]]) 178 | 179 | document_state.segment_info.append([None] + document_state.info[curr_idx: end_idx + 1] + [None]) 180 | 181 | curr_idx = end_idx + 1 182 | prev_token_idx = subtoken_map[-1] 183 | 184 | 185 | def get_document(doc_key, doc_lines, language, seg_len, tokenizer): 186 | """ Process raw input to finalized documents """ 187 | document_state = DocumentState(doc_key) 188 | word_idx = -1 189 | 190 | # Build up documents 191 | for line in doc_lines: 192 | row = line.split() # Columns for each token 193 | if len(row) == 0: 194 | document_state.sentence_end[-1] = True 195 | else: 196 | assert len(row) >= 12 197 | word_idx += 1 198 | word = normalize_word(row[3], language) 199 | subtokens = tokenizer.tokenize(word) 200 | document_state.tokens.append(word) 201 | document_state.token_end += [False] * (len(subtokens) - 1) + [True] 202 | for idx, subtoken in enumerate(subtokens): 203 | document_state.subtokens.append(subtoken) 204 | info = None if idx != 0 else (row + [len(subtokens)]) 205 | document_state.info.append(info) 206 | document_state.sentence_end.append(False) 207 | document_state.subtoken_map.append(word_idx) 208 | 209 | # Split documents 210 | constraits1 = document_state.sentence_end if language != 'arabic' else document_state.token_end 211 | split_into_segments(document_state, seg_len, constraits1, document_state.token_end, tokenizer) 212 | document = document_state.finalize() 213 | return document 214 | 215 | 216 | def minimize_partition(partition, extension, args, tokenizer): 217 | input_path = os.path.join(args.input_dir, f'{partition}.{args.language}.{extension}') 218 | output_path = os.path.join(args.output_dir, f'{partition}.{args.language}.{args.seg_len}.jsonlines') 219 | doc_count = 0 220 | logger.info(f'Minimizing {input_path}...') 221 | 222 | # Read documents 223 | documents = [] # [(doc_key, lines)] 224 | with open(input_path, 'r') as input_file: 225 | for line in input_file.readlines(): 226 | begin_document_match = re.match(conll.BEGIN_DOCUMENT_REGEX, line) 227 | if begin_document_match: 228 | doc_key = conll.get_doc_key(begin_document_match.group(1), begin_document_match.group(2)) 229 | documents.append((doc_key, [])) 230 | elif line.startswith('#end document'): 231 | continue 232 | else: 233 | documents[-1][1].append(line) 234 | 235 | # Write documents 236 | with open(output_path, 'w') as output_file: 237 | for doc_key, doc_lines in documents: 238 | if skip_doc(doc_key): 239 | continue 240 | document = get_document(doc_key, doc_lines, args.language, args.seg_len, tokenizer) 241 | output_file.write(json.dumps(document)) 242 | output_file.write('\n') 243 | doc_count += 1 244 | logger.info(f'Processed {doc_count} documents to {output_path}') 245 | 246 | 247 | def minimize_language(args): 248 | tokenizer = util.get_tokenizer(args.tokenizer_name) 249 | 250 | minimize_partition('dev', 'v4_gold_conll', args, tokenizer) 251 | minimize_partition('test', 'v4_gold_conll', args, tokenizer) 252 | minimize_partition('train', 'v4_gold_conll', args, tokenizer) 253 | 254 | 255 | if __name__ == '__main__': 256 | parser = argparse.ArgumentParser() 257 | parser.add_argument('--tokenizer_name', type=str, default='bert-base-cased', 258 | help='Name or path of the tokenizer/vocabulary') 259 | parser.add_argument('--input_dir', type=str, required=True, 260 | help='Input directory that contains conll files') 261 | parser.add_argument('--output_dir', type=str, required=True, 262 | help='Output directory') 263 | parser.add_argument('--seg_len', type=int, default=128, 264 | help='Segment length: 128, 256, 384, 512') 265 | parser.add_argument('--language', type=str, default='english', 266 | help='english, chinese, arabic') 267 | # parser.add_argument('--lower_case', action='store_true', 268 | # help='Do lower case on input') 269 | 270 | args = parser.parse_args() 271 | logger.info(args) 272 | os.makedirs(args.output_dir, exist_ok=True) 273 | 274 | minimize_language(args) 275 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision==0.5.0 3 | transformers==2.4.1 4 | numpy 5 | scikit-learn==0.22.1 6 | pyhocon 7 | graphviz 8 | tensorboard 9 | spacy -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | from transformers import AdamW 7 | from torch.optim import Adam 8 | from tensorize import CorefDataProcessor 9 | import util 10 | import time 11 | from os.path import join 12 | from metrics import CorefEvaluator 13 | from datetime import datetime 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from model import CorefModel 16 | import conll 17 | import sys 18 | 19 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 20 | datefmt='%m/%d/%Y %H:%M:%S', 21 | level=logging.INFO) 22 | logger = logging.getLogger() 23 | 24 | 25 | class Runner: 26 | def __init__(self, config_name, gpu_id=0, seed=None): 27 | self.name = config_name 28 | self.name_suffix = datetime.now().strftime('%b%d_%H-%M-%S') 29 | self.gpu_id = gpu_id 30 | self.seed = seed 31 | 32 | # Set up config 33 | self.config = util.initialize_config(config_name) 34 | 35 | # Set up logger 36 | log_path = join(self.config['log_dir'], 'log_' + self.name_suffix + '.txt') 37 | logger.addHandler(logging.FileHandler(log_path, 'a')) 38 | logger.info('Log file path: %s' % log_path) 39 | 40 | # Set up seed 41 | if seed: 42 | util.set_seed(seed) 43 | 44 | # Set up device 45 | self.device = torch.device('cpu' if gpu_id is None else f'cuda:{gpu_id}') 46 | 47 | # Set up data 48 | self.data = CorefDataProcessor(self.config) 49 | 50 | def initialize_model(self, saved_suffix=None): 51 | model = CorefModel(self.config, self.device) 52 | if saved_suffix: 53 | self.load_model_checkpoint(model, saved_suffix) 54 | return model 55 | 56 | def train(self, model): 57 | conf = self.config 58 | logger.info(conf) 59 | epochs, grad_accum = conf['num_epochs'], conf['gradient_accumulation_steps'] 60 | 61 | model.to(self.device) 62 | logger.info('Model parameters:') 63 | for name, param in model.named_parameters(): 64 | logger.info('%s: %s' % (name, tuple(param.shape))) 65 | 66 | # Set up tensorboard 67 | tb_path = join(conf['tb_dir'], self.name + '_' + self.name_suffix) 68 | tb_writer = SummaryWriter(tb_path, flush_secs=30) 69 | logger.info('Tensorboard summary path: %s' % tb_path) 70 | 71 | # Set up data 72 | examples_train, examples_dev, examples_test = self.data.get_tensor_examples() 73 | stored_info = self.data.get_stored_info() 74 | 75 | # Set up optimizer and scheduler 76 | total_update_steps = len(examples_train) * epochs // grad_accum 77 | optimizers = self.get_optimizer(model) 78 | schedulers = self.get_scheduler(optimizers, total_update_steps) 79 | 80 | # Get model parameters for grad clipping 81 | bert_param, task_param = model.get_params() 82 | 83 | # Start training 84 | logger.info('*******************Training*******************') 85 | logger.info('Num samples: %d' % len(examples_train)) 86 | logger.info('Num epochs: %d' % epochs) 87 | logger.info('Gradient accumulation steps: %d' % grad_accum) 88 | logger.info('Total update steps: %d' % total_update_steps) 89 | 90 | loss_during_accum = [] # To compute effective loss at each update 91 | loss_during_report = 0.0 # Effective loss during logging step 92 | loss_history = [] # Full history of effective loss; length equals total update steps 93 | max_f1 = 0 94 | start_time = time.time() 95 | model.zero_grad() 96 | for epo in range(epochs): 97 | random.shuffle(examples_train) # Shuffle training set 98 | for doc_key, example in examples_train: 99 | # Forward pass 100 | model.train() 101 | example_gpu = [d.to(self.device) for d in example] 102 | _, loss = model(*example_gpu) 103 | 104 | # Backward; accumulate gradients and clip by grad norm 105 | if grad_accum > 1: 106 | loss /= grad_accum 107 | loss.backward() 108 | if conf['max_grad_norm']: 109 | torch.nn.utils.clip_grad_norm_(bert_param, conf['max_grad_norm']) 110 | torch.nn.utils.clip_grad_norm_(task_param, conf['max_grad_norm']) 111 | loss_during_accum.append(loss.item()) 112 | 113 | # Update 114 | if len(loss_during_accum) % grad_accum == 0: 115 | for optimizer in optimizers: 116 | optimizer.step() 117 | model.zero_grad() 118 | for scheduler in schedulers: 119 | scheduler.step() 120 | 121 | # Compute effective loss 122 | effective_loss = np.sum(loss_during_accum).item() 123 | loss_during_accum = [] 124 | loss_during_report += effective_loss 125 | loss_history.append(effective_loss) 126 | 127 | # Report 128 | if len(loss_history) % conf['report_frequency'] == 0: 129 | # Show avg loss during last report interval 130 | avg_loss = loss_during_report / conf['report_frequency'] 131 | loss_during_report = 0.0 132 | end_time = time.time() 133 | logger.info('Step %d: avg loss %.2f; steps/sec %.2f' % 134 | (len(loss_history), avg_loss, conf['report_frequency'] / (end_time - start_time))) 135 | start_time = end_time 136 | 137 | tb_writer.add_scalar('Training_Loss', avg_loss, len(loss_history)) 138 | tb_writer.add_scalar('Learning_Rate_Bert', schedulers[0].get_last_lr()[0], len(loss_history)) 139 | tb_writer.add_scalar('Learning_Rate_Task', schedulers[1].get_last_lr()[-1], len(loss_history)) 140 | 141 | # Evaluate 142 | if len(loss_history) > 0 and len(loss_history) % conf['eval_frequency'] == 0: 143 | f1, _ = self.evaluate(model, examples_dev, stored_info, len(loss_history), official=False, conll_path=self.config['conll_eval_path'], tb_writer=tb_writer) 144 | if f1 > max_f1: 145 | max_f1 = f1 146 | self.save_model_checkpoint(model, len(loss_history)) 147 | logger.info('Eval max f1: %.2f' % max_f1) 148 | start_time = time.time() 149 | 150 | logger.info('**********Finished training**********') 151 | logger.info('Actual update steps: %d' % len(loss_history)) 152 | 153 | # Wrap up 154 | tb_writer.close() 155 | return loss_history 156 | 157 | def evaluate(self, model, tensor_examples, stored_info, step, official=False, conll_path=None, tb_writer=None): 158 | logger.info('Step %d: evaluating on %d samples...' % (step, len(tensor_examples))) 159 | model.to(self.device) 160 | evaluator = CorefEvaluator() 161 | doc_to_prediction = {} 162 | 163 | model.eval() 164 | for i, (doc_key, tensor_example) in enumerate(tensor_examples): 165 | gold_clusters = stored_info['gold'][doc_key] 166 | tensor_example = tensor_example[:7] # Strip out gold 167 | example_gpu = [d.to(self.device) for d in tensor_example] 168 | with torch.no_grad(): 169 | _, _, _, span_starts, span_ends, antecedent_idx, antecedent_scores = model(*example_gpu) 170 | span_starts, span_ends = span_starts.tolist(), span_ends.tolist() 171 | antecedent_idx, antecedent_scores = antecedent_idx.tolist(), antecedent_scores.tolist() 172 | predicted_clusters = model.update_evaluator(span_starts, span_ends, antecedent_idx, antecedent_scores, gold_clusters, evaluator) 173 | doc_to_prediction[doc_key] = predicted_clusters 174 | 175 | p, r, f = evaluator.get_prf() 176 | metrics = {'Eval_Avg_Precision': p * 100, 'Eval_Avg_Recall': r * 100, 'Eval_Avg_F1': f * 100} 177 | for name, score in metrics.items(): 178 | logger.info('%s: %.2f' % (name, score)) 179 | if tb_writer: 180 | tb_writer.add_scalar(name, score, step) 181 | 182 | if official: 183 | conll_results = conll.evaluate_conll(conll_path, doc_to_prediction, stored_info['subtoken_maps']) 184 | official_f1 = sum(results["f"] for results in conll_results.values()) / len(conll_results) 185 | logger.info('Official avg F1: %.4f' % official_f1) 186 | 187 | return f * 100, metrics 188 | 189 | def predict(self, model, tensor_examples): 190 | logger.info('Predicting %d samples...' % len(tensor_examples)) 191 | model.to(self.device) 192 | predicted_spans, predicted_antecedents, predicted_clusters = [], [], [] 193 | 194 | model.eval() 195 | for i, (doc_key, tensor_example) in enumerate(tensor_examples): 196 | tensor_example = tensor_example[:7] 197 | example_gpu = [d.to(self.device) for d in tensor_example] 198 | with torch.no_grad(): 199 | _, _, _, span_starts, span_ends, antecedent_idx, antecedent_scores = model(*example_gpu) 200 | span_starts, span_ends = span_starts.tolist(), span_ends.tolist() 201 | antecedent_idx, antecedent_scores = antecedent_idx.tolist(), antecedent_scores.tolist() 202 | clusters, mention_to_cluster_id, antecedents = model.get_predicted_clusters(span_starts, span_ends, antecedent_idx, antecedent_scores) 203 | 204 | spans = [(span_start, span_end) for span_start, span_end in zip(span_starts, span_ends)] 205 | predicted_spans.append(spans) 206 | predicted_antecedents.append(antecedents) 207 | predicted_clusters.append(clusters) 208 | 209 | return predicted_clusters, predicted_spans, predicted_antecedents 210 | 211 | def get_optimizer(self, model): 212 | no_decay = ['bias', 'LayerNorm.weight'] 213 | bert_param, task_param = model.get_params(named=True) 214 | grouped_bert_param = [ 215 | { 216 | 'params': [p for n, p in bert_param if not any(nd in n for nd in no_decay)], 217 | 'lr': self.config['bert_learning_rate'], 218 | 'weight_decay': self.config['adam_weight_decay'] 219 | }, { 220 | 'params': [p for n, p in bert_param if any(nd in n for nd in no_decay)], 221 | 'lr': self.config['bert_learning_rate'], 222 | 'weight_decay': 0.0 223 | } 224 | ] 225 | optimizers = [ 226 | AdamW(grouped_bert_param, lr=self.config['bert_learning_rate'], eps=self.config['adam_eps']), 227 | Adam(model.get_params()[1], lr=self.config['task_learning_rate'], eps=self.config['adam_eps'], weight_decay=0) 228 | ] 229 | return optimizers 230 | # grouped_parameters = [ 231 | # { 232 | # 'params': [p for n, p in bert_param if not any(nd in n for nd in no_decay)], 233 | # 'lr': self.config['bert_learning_rate'], 234 | # 'weight_decay': self.config['adam_weight_decay'] 235 | # }, { 236 | # 'params': [p for n, p in bert_param if any(nd in n for nd in no_decay)], 237 | # 'lr': self.config['bert_learning_rate'], 238 | # 'weight_decay': 0.0 239 | # }, { 240 | # 'params': [p for n, p in task_param if not any(nd in n for nd in no_decay)], 241 | # 'lr': self.config['task_learning_rate'], 242 | # 'weight_decay': self.config['adam_weight_decay'] 243 | # }, { 244 | # 'params': [p for n, p in task_param if any(nd in n for nd in no_decay)], 245 | # 'lr': self.config['task_learning_rate'], 246 | # 'weight_decay': 0.0 247 | # } 248 | # ] 249 | # optimizer = AdamW(grouped_parameters, lr=self.config['task_learning_rate'], eps=self.config['adam_eps']) 250 | # return optimizer 251 | 252 | def get_scheduler(self, optimizers, total_update_steps): 253 | # Only warm up bert lr 254 | warmup_steps = int(total_update_steps * self.config['warmup_ratio']) 255 | 256 | def lr_lambda_bert(current_step): 257 | if current_step < warmup_steps: 258 | return float(current_step) / float(max(1, warmup_steps)) 259 | return max( 260 | 0.0, float(total_update_steps - current_step) / float(max(1, total_update_steps - warmup_steps)) 261 | ) 262 | 263 | def lr_lambda_task(current_step): 264 | return max(0.0, float(total_update_steps - current_step) / float(max(1, total_update_steps))) 265 | 266 | schedulers = [ 267 | LambdaLR(optimizers[0], lr_lambda_bert), 268 | LambdaLR(optimizers[1], lr_lambda_task) 269 | ] 270 | return schedulers 271 | # return LambdaLR(optimizer, [lr_lambda_bert, lr_lambda_bert, lr_lambda_task, lr_lambda_task]) 272 | 273 | def save_model_checkpoint(self, model, step): 274 | if step < 30000: 275 | return # Debug 276 | path_ckpt = join(self.config['log_dir'], f'model_{self.name_suffix}_{step}.bin') 277 | torch.save(model.state_dict(), path_ckpt) 278 | logger.info('Saved model to %s' % path_ckpt) 279 | 280 | def load_model_checkpoint(self, model, suffix): 281 | path_ckpt = join(self.config['log_dir'], f'model_{suffix}.bin') 282 | model.load_state_dict(torch.load(path_ckpt, map_location=torch.device('cpu')), strict=False) 283 | logger.info('Loaded model from %s' % path_ckpt) 284 | 285 | 286 | if __name__ == '__main__': 287 | config_name, gpu_id = sys.argv[1], int(sys.argv[2]) 288 | runner = Runner(config_name, gpu_id) 289 | model = runner.initialize_model() 290 | 291 | runner.train(model) 292 | -------------------------------------------------------------------------------- /setup_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ontonotes_path=$1 4 | data_dir=$2 5 | 6 | dlx() { 7 | wget -P $data_dir $1/$2 8 | tar -xvzf $data_dir/$2 -C $data_dir 9 | rm $data_dir/$2 10 | } 11 | 12 | conll_url=http://conll.cemantix.org/2012/download 13 | dlx $conll_url conll-2012-train.v4.tar.gz 14 | dlx $conll_url conll-2012-development.v4.tar.gz 15 | dlx $conll_url/test conll-2012-test-key.tar.gz 16 | dlx $conll_url/test conll-2012-test-official.v9.tar.gz 17 | 18 | dlx $conll_url conll-2012-scripts.v3.tar.gz 19 | # dlx http://conll.cemantix.org/download reference-coreference-scorers.v8.01.tar.gz 20 | 21 | bash $data_dir/conll-2012/v3/scripts/skeleton2conll.sh -D $ontonotes_path/data/files/data $data_dir/conll-2012 22 | 23 | function compile_partition() { 24 | rm -f $2.$5.$3$4 25 | cat $data_dir/conll-2012/$3/data/$1/data/$5/annotations/*/*/*/*.$3$4 >> $data_dir/$2.$5.$3$4 26 | } 27 | 28 | function compile_language() { 29 | compile_partition development dev v4 _gold_conll $1 30 | compile_partition train train v4 _gold_conll $1 31 | compile_partition test test v4 _gold_conll $1 32 | } 33 | 34 | compile_language english 35 | #compile_language chinese 36 | #compile_language arabic 37 | 38 | python preprocess.py --input_dir $data_dir --output_dir $data_dir --seg_len 384 39 | python preprocess.py --input_dir $data_dir --output_dir $data_dir --seg_len 512 40 | -------------------------------------------------------------------------------- /tensorize.py: -------------------------------------------------------------------------------- 1 | import util 2 | import numpy as np 3 | import random 4 | import os 5 | from os.path import join 6 | import json 7 | import pickle 8 | import logging 9 | import torch 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class CorefDataProcessor: 15 | def __init__(self, config, language='english'): 16 | self.config = config 17 | self.language = language 18 | 19 | self.max_seg_len = config['max_segment_len'] 20 | self.max_training_seg = config['max_training_sentences'] 21 | self.data_dir = config['data_dir'] 22 | 23 | self.tokenizer = util.get_tokenizer(config['bert_tokenizer_name']) 24 | self.tensor_samples, self.stored_info = None, None # For dataset samples; lazy loading 25 | 26 | def get_tensor_examples_from_custom_input(self, samples): 27 | """ For interactive samples; no caching """ 28 | tensorizer = Tensorizer(self.config, self.tokenizer) 29 | tensor_samples = [tensorizer.tensorize_example(sample, False) for sample in samples] 30 | tensor_samples = [(doc_key, self.convert_to_torch_tensor(*tensor)) for doc_key, tensor in tensor_samples] 31 | return tensor_samples, tensorizer.stored_info 32 | 33 | def get_tensor_examples(self): 34 | """ For dataset samples """ 35 | cache_path = self.get_cache_path() 36 | if os.path.exists(cache_path): 37 | # Load cached tensors if exists 38 | with open(cache_path, 'rb') as f: 39 | self.tensor_samples, self.stored_info = pickle.load(f) 40 | logger.info('Loaded tensorized examples from cache') 41 | else: 42 | # Generate tensorized samples 43 | self.tensor_samples = {} 44 | tensorizer = Tensorizer(self.config, self.tokenizer) 45 | paths = { 46 | 'trn': join(self.data_dir, f'train.{self.language}.{self.max_seg_len}.jsonlines'), 47 | 'dev': join(self.data_dir, f'dev.{self.language}.{self.max_seg_len}.jsonlines'), 48 | 'tst': join(self.data_dir, f'test.{self.language}.{self.max_seg_len}.jsonlines') 49 | } 50 | for split, path in paths.items(): 51 | logger.info('Tensorizing examples from %s; results will be cached)' % path) 52 | is_training = (split == 'trn') 53 | with open(path, 'r') as f: 54 | samples = [json.loads(line) for line in f.readlines()] 55 | tensor_samples = [tensorizer.tensorize_example(sample, is_training) for sample in samples] 56 | self.tensor_samples[split] = [(doc_key, self.convert_to_torch_tensor(*tensor)) for doc_key, tensor 57 | in tensor_samples] 58 | self.stored_info = tensorizer.stored_info 59 | # Cache tensorized samples 60 | with open(cache_path, 'wb') as f: 61 | pickle.dump((self.tensor_samples, self.stored_info), f) 62 | return self.tensor_samples['trn'], self.tensor_samples['dev'], self.tensor_samples['tst'] 63 | 64 | def get_stored_info(self): 65 | return self.stored_info 66 | 67 | @classmethod 68 | def convert_to_torch_tensor(cls, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, 69 | is_training, gold_starts, gold_ends, gold_mention_cluster_map): 70 | input_ids = torch.tensor(input_ids, dtype=torch.long) 71 | input_mask = torch.tensor(input_mask, dtype=torch.long) 72 | speaker_ids = torch.tensor(speaker_ids, dtype=torch.long) 73 | sentence_len = torch.tensor(sentence_len, dtype=torch.long) 74 | genre = torch.tensor(genre, dtype=torch.long) 75 | sentence_map = torch.tensor(sentence_map, dtype=torch.long) 76 | is_training = torch.tensor(is_training, dtype=torch.bool) 77 | gold_starts = torch.tensor(gold_starts, dtype=torch.long) 78 | gold_ends = torch.tensor(gold_ends, dtype=torch.long) 79 | gold_mention_cluster_map = torch.tensor(gold_mention_cluster_map, dtype=torch.long) 80 | return input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, \ 81 | is_training, gold_starts, gold_ends, gold_mention_cluster_map, 82 | 83 | def get_cache_path(self): 84 | cache_path = join(self.data_dir, f'cached.tensors.{self.language}.{self.max_seg_len}.{self.max_training_seg}.bin') 85 | return cache_path 86 | 87 | 88 | class Tensorizer: 89 | def __init__(self, config, tokenizer): 90 | self.config = config 91 | self.tokenizer = tokenizer 92 | 93 | # Will be used in evaluation 94 | self.stored_info = {} 95 | self.stored_info['tokens'] = {} # {doc_key: ...} 96 | self.stored_info['subtoken_maps'] = {} # {doc_key: ...}; mapping back to tokens 97 | self.stored_info['gold'] = {} # {doc_key: ...} 98 | self.stored_info['genre_dict'] = {genre: idx for idx, genre in enumerate(config['genres'])} 99 | 100 | def _tensorize_spans(self, spans): 101 | if len(spans) > 0: 102 | starts, ends = zip(*spans) 103 | else: 104 | starts, ends = [], [] 105 | return np.array(starts), np.array(ends) 106 | 107 | def _tensorize_span_w_labels(self, spans, label_dict): 108 | if len(spans) > 0: 109 | starts, ends, labels = zip(*spans) 110 | else: 111 | starts, ends, labels = [], [], [] 112 | return np.array(starts), np.array(ends), np.array([label_dict[label] for label in labels]) 113 | 114 | def _get_speaker_dict(self, speakers): 115 | speaker_dict = {'UNK': 0, '[SPL]': 1} 116 | for speaker in speakers: 117 | if len(speaker_dict) > self.config['max_num_speakers']: 118 | pass # 'break' to limit # speakers 119 | if speaker not in speaker_dict: 120 | speaker_dict[speaker] = len(speaker_dict) 121 | return speaker_dict 122 | 123 | def tensorize_example(self, example, is_training): 124 | # Mentions and clusters 125 | clusters = example['clusters'] 126 | gold_mentions = sorted(tuple(mention) for mention in util.flatten(clusters)) 127 | gold_mention_map = {mention: idx for idx, mention in enumerate(gold_mentions)} 128 | gold_mention_cluster_map = np.zeros(len(gold_mentions)) # 0: no cluster 129 | for cluster_id, cluster in enumerate(clusters): 130 | for mention in cluster: 131 | gold_mention_cluster_map[gold_mention_map[tuple(mention)]] = cluster_id + 1 132 | 133 | # Speakers 134 | speakers = example['speakers'] 135 | speaker_dict = self._get_speaker_dict(util.flatten(speakers)) 136 | 137 | # Sentences/segments 138 | sentences = example['sentences'] # Segments 139 | sentence_map = example['sentence_map'] 140 | num_words = sum([len(s) for s in sentences]) 141 | max_sentence_len = self.config['max_segment_len'] 142 | sentence_len = np.array([len(s) for s in sentences]) 143 | 144 | # Bert input 145 | input_ids, input_mask, speaker_ids = [], [], [] 146 | for idx, (sent_tokens, sent_speakers) in enumerate(zip(sentences, speakers)): 147 | sent_input_ids = self.tokenizer.convert_tokens_to_ids(sent_tokens) 148 | sent_input_mask = [1] * len(sent_input_ids) 149 | sent_speaker_ids = [speaker_dict[speaker] for speaker in sent_speakers] 150 | while len(sent_input_ids) < max_sentence_len: 151 | sent_input_ids.append(0) 152 | sent_input_mask.append(0) 153 | sent_speaker_ids.append(0) 154 | input_ids.append(sent_input_ids) 155 | input_mask.append(sent_input_mask) 156 | speaker_ids.append(sent_speaker_ids) 157 | input_ids = np.array(input_ids) 158 | input_mask = np.array(input_mask) 159 | speaker_ids = np.array(speaker_ids) 160 | assert num_words == np.sum(input_mask), (num_words, np.sum(input_mask)) 161 | 162 | # Keep info to store 163 | doc_key = example['doc_key'] 164 | self.stored_info['subtoken_maps'][doc_key] = example.get('subtoken_map', None) 165 | self.stored_info['gold'][doc_key] = example['clusters'] 166 | # self.stored_info['tokens'][doc_key] = example['tokens'] 167 | 168 | # Construct example 169 | genre = self.stored_info['genre_dict'].get(doc_key[:2], 0) 170 | gold_starts, gold_ends = self._tensorize_spans(gold_mentions) 171 | example_tensor = (input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, is_training, 172 | gold_starts, gold_ends, gold_mention_cluster_map) 173 | 174 | if is_training and len(sentences) > self.config['max_training_sentences']: 175 | return doc_key, self.truncate_example(*example_tensor) 176 | else: 177 | return doc_key, example_tensor 178 | 179 | def truncate_example(self, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, is_training, 180 | gold_starts, gold_ends, gold_mention_cluster_map, sentence_offset=None): 181 | max_sentences = self.config["max_training_sentences"] 182 | num_sentences = input_ids.shape[0] 183 | assert num_sentences > max_sentences 184 | 185 | sent_offset = sentence_offset 186 | if sent_offset is None: 187 | sent_offset = random.randint(0, num_sentences - max_sentences) 188 | word_offset = sentence_len[:sent_offset].sum() 189 | num_words = sentence_len[sent_offset: sent_offset + max_sentences].sum() 190 | 191 | input_ids = input_ids[sent_offset: sent_offset + max_sentences, :] 192 | input_mask = input_mask[sent_offset: sent_offset + max_sentences, :] 193 | speaker_ids = speaker_ids[sent_offset: sent_offset + max_sentences, :] 194 | sentence_len = sentence_len[sent_offset: sent_offset + max_sentences] 195 | 196 | sentence_map = sentence_map[word_offset: word_offset + num_words] 197 | gold_spans = (gold_starts < word_offset + num_words) & (gold_ends >= word_offset) 198 | gold_starts = gold_starts[gold_spans] - word_offset 199 | gold_ends = gold_ends[gold_spans] - word_offset 200 | gold_mention_cluster_map = gold_mention_cluster_map[gold_spans] 201 | 202 | return input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, \ 203 | is_training, gold_starts, gold_ends, gold_mention_cluster_map 204 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import numpy as np 4 | import pyhocon 5 | import logging 6 | import torch 7 | import random 8 | from transformers import BertTokenizer 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def flatten(l): 15 | return [item for sublist in l for item in sublist] 16 | 17 | 18 | def get_tokenizer(bert_tokenizer_name): 19 | return BertTokenizer.from_pretrained(bert_tokenizer_name) 20 | 21 | 22 | def initialize_config(config_name): 23 | logger.info("Running experiment: {}".format(config_name)) 24 | 25 | config = pyhocon.ConfigFactory.parse_file("experiments.conf")[config_name] 26 | config['log_dir'] = join(config["log_root"], config_name) 27 | makedirs(config['log_dir'], exist_ok=True) 28 | 29 | config['tb_dir'] = join(config['log_root'], 'tensorboard') 30 | makedirs(config['tb_dir'], exist_ok=True) 31 | 32 | logger.info(pyhocon.HOCONConverter.convert(config, "hocon")) 33 | return config 34 | 35 | 36 | def set_seed(seed, set_gpu=True): 37 | random.seed(seed) 38 | np.random.seed(seed) 39 | torch.manual_seed(seed) 40 | if set_gpu and torch.cuda.is_available(): 41 | # Necessary for reproducibility; lower performance 42 | torch.backends.cudnn.deterministic = True 43 | torch.backends.cudnn.benchmark = False 44 | torch.cuda.manual_seed_all(seed) 45 | logger.info('Random seed is set to %d' % seed) 46 | 47 | 48 | def bucket_distance(offsets): 49 | """ offsets: [num spans1, num spans2] """ 50 | # 10 semi-logscale bin: 0, 1, 2, 3, 4, (5-7)->5, (8-15)->6, (16-31)->7, (32-63)->8, (64+)->9 51 | logspace_distance = torch.log2(offsets.to(torch.float)).to(torch.long) + 3 52 | identity_mask = (offsets <= 4).to(torch.long) 53 | combined_distance = identity_mask * offsets + (1 - identity_mask) * logspace_distance 54 | combined_distance = torch.clamp(combined_distance, 0, 9) 55 | return combined_distance 56 | 57 | 58 | def batch_select(tensor, idx, device=torch.device('cpu')): 59 | """ Do selection per row (first axis). """ 60 | assert tensor.shape[0] == idx.shape[0] # Same size of first dim 61 | dim0_size, dim1_size = tensor.shape[0], tensor.shape[1] 62 | 63 | tensor = torch.reshape(tensor, [dim0_size * dim1_size, -1]) 64 | idx_offset = torch.unsqueeze(torch.arange(0, dim0_size, device=device) * dim1_size, 1) 65 | new_idx = idx + idx_offset 66 | selected = tensor[new_idx] 67 | 68 | if tensor.shape[-1] == 1: # If selected element is scalar, restore original dim 69 | selected = torch.squeeze(selected, -1) 70 | 71 | return selected 72 | --------------------------------------------------------------------------------