├── LICENSE ├── README.md ├── bert_modelling.py ├── blanc.py ├── conll.py ├── evaluate.py ├── experiments.conf ├── fig ├── illustration.png └── structured_span_selector_cky_chart.pdf ├── greedy_mp.py ├── metrics.py ├── minimize.py ├── model.py ├── outside_mp.py ├── run.py ├── sss.yml ├── tensorize.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [Tianyu Liu] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A Structured Span Selector 2 | 3 | https://arxiv.org/pdf/2205.03977.pdf 4 | 5 | This repository contains the open-sourced official implementation of our **structured span selector** paper: 6 | 7 | [A Structured Span Selector](https://arxiv.org/abs/2205.03977) (NAACL 2022). 8 | _Tianyu Liu, Yuchen Eleanor Jiang, Ryan Cotterell, and Mrinmaya Sachan_ 9 | 10 | 11 | ## Overall idea 12 | 13 | For all **span selection** tasks (e.g. coreference resolution, semantic role labelling, question answering), we learn the latent **context-free grammar** of the spans of interest. The search space of spans $O(n^2)$ is reduced to the space of nonterminals $O(n)$. 14 | 15 | 16 | 17 | 18 | ## Installation 19 | 20 | First of all: 21 | ```bash 22 | git clone https://github.com/lyutyuh/structured-span-selector.git 23 | cd structured-span-selector 24 | ``` 25 | 26 | 1. Create a virtual environment with Conda 27 | ```bash 28 | conda env create -f sss.yml 29 | ``` 30 | 31 | 2. Activate the new environment 32 | ```bash 33 | conda activate sss 34 | ``` 35 | 36 | 3. **Install genbmm with [inside-outside algorithm extension](https://github.com/lyutyuh/genbmm)** 37 | ```bash 38 | pip install git+https://github.com/lyutyuh/genbmm 39 | ``` 40 | 41 | 42 | ## Obtaining the CoNLL-2012 data 43 | 44 | Please follow https://github.com/mandarjoshi90/coref and especially https://github.com/mandarjoshi90/coref/blob/master/setup_training.sh to obtain the {train, dev, test}.english.v4_gold_conll. There are 2802, 343, 348 documents in the {train, dev, test} datasets respectively. 45 | 46 | The MD5 values are: 47 | ```bash 48 | md5sum dev.english.v4_gold_conll 49 | >>> bde418ea4bbec119b3a43b43933ec2ae 50 | md5sum test.english.v4_gold_conll 51 | >>> 6e64b649a039b4320ad32780db3abfa1 52 | md5sum train.english.v4_gold_conll 53 | >>> 9f92a664298dc78600fd50813246aa77 54 | ``` 55 | 56 | Then, run 57 | ```bash 58 | python minimize.py ./data_dir/ ./data_dir/ false 59 | ``` 60 | and get the jsonlines files. 61 | 62 | ## Training 63 | 64 | ```bash 65 | python run.py spanbert_large 0 66 | ``` 67 | 68 | ## Evaluating 69 | 70 | ```bash 71 | python evaluate.py spanbert_large 0 72 | ``` 73 | 74 | 75 | 76 | ## Citing 77 | 78 | If you find this repo helpful, please cite the following version of the paper: 79 | ```tex 80 | @inproceedings{liu-etal-2022-structured, 81 | title = "A Structured Span Selector", 82 | author = "Liu, Tianyu and 83 | Jiang, Yuchen and 84 | Cotterell, Ryan and 85 | Sachan, Mrinmaya", 86 | booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", 87 | month = jul, 88 | year = "2022", 89 | address = "Seattle, United States", 90 | publisher = "Association for Computational Linguistics", 91 | url = "https://aclanthology.org/2022.naacl-main.189", 92 | pages = "2629--2641", 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /blanc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing as ty 3 | 4 | import numpy as np 5 | 6 | def trace(cluster: ty.Set, partition: ty.Iterable[ty.Set]) -> ty.Iterable[ty.Set]: 7 | r""" 8 | Return the partition of `#cluster` induced by `#partition`, that is 9 | ```math 10 | \{C∩A|A∈P\} ∪ \{\{x\}|x∈C∖∪P\} 11 | ``` 12 | Where `$C$` is `#cluster` and `$P$` is `#partition`. 13 | 14 | This assume that the elements of `#partition` are indeed pairwise disjoint. 15 | """ 16 | remaining = set(cluster) 17 | for a in partition: 18 | common = remaining.intersection(a) 19 | if common: 20 | remaining.difference_update(common) 21 | yield common 22 | for x in sorted(remaining): 23 | yield set((x,)) 24 | 25 | 26 | class RemapClusteringsReturn(ty.NamedTuple): 27 | clusterings: ty.Sequence[ty.Sequence[ty.Sequence[int]]] 28 | elts_map: ty.Dict[ty.Hashable, int] 29 | 30 | 31 | def remap_clusterings( 32 | clusterings: ty.Sequence[ty.Sequence[ty.Set[ty.Hashable]]], 33 | ) -> RemapClusteringsReturn: 34 | """Remap clusterings of arbitrary elements to clusterings of integers.""" 35 | elts = set(e for clusters in clusterings for c in clusters for e in c) 36 | elts_map = {e: i for i, e in enumerate(elts)} 37 | res = [] 38 | for clusters in clusterings: 39 | remapped_clusters = [] 40 | for c in clusters: 41 | remapped_c = [elts_map[e] for e in c] 42 | remapped_clusters.append(remapped_c) 43 | res.append(remapped_clusters) 44 | return RemapClusteringsReturn(res, elts_map) 45 | 46 | 47 | 48 | # COMBAK: Check the numeric stability 49 | def blanc( 50 | key: ty.Sequence[ty.Set], response: ty.Sequence[ty.Set], fast=True, 51 | ) -> ty.Tuple[float, float, float]: 52 | r""" 53 | Return the BLANC `$(R, P, F)$` scores for a `#response` clustering given a `#key` clustering. 54 | 55 | ## Notes 56 | 57 | - Mention identifiers have to be comparable 58 | - To ensure the compliance with the reference implementation, the edge cases results are 59 | those from Recasens and Hovy (2011) rather than from the more recent Luo et al. (2014) when 60 | those two disagree. This has an effect for the N-6 testcase, where according to Luo et al. 61 | (2014), BLANC should be `$\frac{0+F_n}{2}$` since `$C_k=∅$` and `$C_r≠∅$`, but according to 62 | Recasens and Hovy (2011), BLANC should be `$F_n$`. 63 | """ 64 | if fast: 65 | C_score, N_score = fast_detailed_blanc(key, response) 66 | else: 67 | C_score, N_score = detailed_blanc(key, response) 68 | if C_score is None: 69 | assert N_score is not None # nosec:B101 70 | return N_score 71 | if N_score is None: 72 | assert C_score is not None # nosec:B101 73 | return C_score 74 | return C_score, N_score 75 | 76 | 77 | def links_from_clusters( 78 | clusters: ty.Iterable[ty.Set], 79 | ) -> ty.Tuple[ 80 | ty.Set[ty.Tuple[ty.Hashable, ty.Hashable]], 81 | ty.Set[ty.Tuple[ty.Hashable, ty.Hashable]], 82 | ]: 83 | r""" 84 | Return a `(coreference_links, non-coreference_links)` tuple corresponding to a clustering. 85 | 86 | The links are given as sorted couples for uniqueness 87 | """ 88 | clusters_lst = [list(c) for c in clusters] 89 | C = set() 90 | N = set() 91 | for i, c in enumerate(clusters_lst[:-1]): 92 | for j, e in enumerate(c[:-1]): 93 | # Since the links are symmetric, we only add the links between `e` and 94 | # the following mentions 95 | for f in c[j + 1 :]: 96 | C.add((e, f) if e <= f else (f, e)) 97 | for other in clusters_lst[i + 1 :]: 98 | for e in c: 99 | for f in other: 100 | N.add((e, f) if e <= f else (f, e)) 101 | # We missed the coreference links for the last cluster, add them here 102 | last_cluster = clusters_lst[-1] 103 | for j, e in enumerate(last_cluster): 104 | for f in last_cluster[j + 1 :]: 105 | C.add((e, f) if e <= f else (f, e)) 106 | return C, N 107 | 108 | 109 | def detailed_blanc( 110 | key: ty.Sequence[ty.Set], response: ty.Sequence[ty.Set] 111 | ) -> ty.Tuple[ 112 | ty.Union[ty.Tuple[float, float, float], None], 113 | ty.Union[ty.Tuple[float, float, float], None], 114 | ]: 115 | """Return BLANC `$(R, P, F)$` scores for coreference and non-coreference respectively.""" 116 | 117 | # Edge case : a single mention in both `key` and `response` clusters 118 | # in that case, `C_k`, `C_r`, `N_k` and `N_r` are all empty, so we need a separate examination 119 | # of the mentions to know if we are very good or very bad. 120 | if len(key) == len(response) == 1 and len(key[0]) == len(response[0]) == 1: 121 | if key[0] == response[0]: 122 | return ((1.0, 1.0, 1.0), (1.0, 1.0, 1.0)) 123 | else: 124 | return ((0.0, 0.0, 0.0), (0.0, 0.0, 0.0)) 125 | 126 | C_k, N_k = links_from_clusters(key) 127 | C_r, N_r = links_from_clusters(response) 128 | 129 | tp_c = len(C_k.intersection(C_r)) 130 | tp_n = len(N_k.intersection(N_r)) 131 | c_k, n_k = len(C_k), len(N_k) 132 | c_r, n_r = len(C_r), len(N_r) 133 | 134 | if not c_k and not c_r: 135 | R_c, P_c, F_c = (1.0, 1.0, 1.0) 136 | elif not c_k or not c_r: 137 | R_c, P_c, F_c = (0.0, 0.0, 0.0) 138 | else: 139 | R_c, P_c = tp_c / c_k, tp_c / c_r 140 | F_c = 2 * tp_c / (c_k + c_r) 141 | 142 | if not n_k and not n_r: 143 | R_n, P_n, F_n = (1.0, 1.0, 1.0) 144 | elif not n_k or not n_r: 145 | R_n, P_n, F_n = (0.0, 0.0, 0.0) 146 | else: 147 | R_n, P_n = tp_n / n_k, tp_n / n_r 148 | F_n = 2 * tp_n / (n_k + n_r) 149 | 150 | # Edge cases 151 | if not c_k: 152 | return (None, (R_n, P_n, F_n)) 153 | if not n_k: 154 | return ((R_c, P_c, F_c), None) 155 | 156 | return ((R_c, P_c, F_c), (R_n, P_n, F_n)) 157 | 158 | 159 | class AdjacencyReturn(ty.NamedTuple): 160 | """Represents a clustering of integers as an adjacency matrix and a presence mask""" 161 | 162 | adjacency: np.ndarray 163 | presence: np.ndarray 164 | 165 | 166 | def adjacency(clusters: ty.List[ty.List[int]], num_elts: int) -> AdjacencyReturn: 167 | adjacency = np.zeros((num_elts, num_elts), dtype=np.bool) 168 | presence = np.zeros(num_elts, dtype=np.bool) 169 | # **Note** The nested loop makes the complexity of this `$∑|c|²$` but we are only doing memory 170 | # access, which is really fast, so this is not really an issue. In comparison, doing it by 171 | # computing the Gram matrix one-hot elt-cluster attribution matrix was making `fast_blanc` 3× 172 | # slower than the naïve version. 173 | for c in clusters: 174 | # Note: don't be clever and use numpy array indicing here, see 175 | # 176 | # for why it would be slower. If you want to get C loops here, cythonize it instead (nut 177 | # it's probably not worth it) 178 | for e in c: 179 | presence[e] = True 180 | for f in c: 181 | if f != e: 182 | adjacency[e, f] = True 183 | return AdjacencyReturn(adjacency, presence) 184 | 185 | 186 | def fast_detailed_blanc( 187 | key: ty.Sequence[ty.Set], response: ty.Sequence[ty.Set] 188 | ) -> ty.Tuple[ 189 | ty.Union[ty.Tuple[float, float, float], None], 190 | ty.Union[ty.Tuple[float, float, float], None], 191 | ]: 192 | """Return BLANC `$(R, P, F)$` scores for coreference and non-coreference respectively.""" 193 | 194 | # Edge case : a single mention in both `key` and `response` clusters 195 | # in that case, `C_k`, `C_r`, `N_k` and `N_r` are all empty, so we need a separate examination 196 | # of the mentions to know if we are very good or very bad. 197 | if len(key) == len(response) == 1 and len(key[0]) == len(response[0]) == 1: 198 | if key[0] == response[0]: 199 | return ((1.0, 1.0, 1.0), (1.0, 1.0, 1.0)) 200 | else: 201 | return ((0.0, 0.0, 0.0), (0.0, 0.0, 0.0)) 202 | 203 | (key, response), mentions_map = remap_clusterings([key, response]) 204 | num_mentions = len(mentions_map) 205 | 206 | key_coref_links, key_presence = adjacency(key, num_mentions) 207 | response_coref_links, response_presence = adjacency(response, num_mentions) 208 | 209 | tp_c = np.logical_and(key_coref_links, response_coref_links).sum() // 2 210 | c_k = key_coref_links.sum() // 2 211 | c_r = response_coref_links.sum() // 2 212 | 213 | # Headache ahead 214 | common_links = np.logical_and( 215 | np.outer(key_presence, key_presence), 216 | np.outer(response_presence, response_presence), 217 | ) 218 | # There is no link between a mention and itself 219 | np.fill_diagonal(common_links, False) 220 | tp_n = ( 221 | np.logical_and( 222 | common_links, 223 | np.logical_not(np.logical_or(key_coref_links, response_coref_links)), 224 | ).sum() 225 | / 2 226 | ) 227 | num_key_mentions = key_presence.sum() 228 | n_k = (num_key_mentions * (num_key_mentions - 1)) // 2 - c_k 229 | num_response_mentions = response_presence.sum() 230 | n_r = (num_response_mentions * (num_response_mentions - 1)) // 2 - c_r 231 | 232 | if not c_k and not c_r: 233 | R_c, P_c, F_c = (1.0, 1.0, 1.0) 234 | elif not c_k or not c_r: 235 | R_c, P_c, F_c = (0.0, 0.0, 0.0) 236 | else: 237 | R_c, P_c = tp_c / c_k, tp_c / c_r 238 | F_c = 2 * tp_c / (c_k + c_r) 239 | 240 | if not n_k and not n_r: 241 | R_n, P_n, F_n = (1.0, 1.0, 1.0) 242 | elif not n_k or not n_r: 243 | R_n, P_n, F_n = (0.0, 0.0, 0.0) 244 | else: 245 | R_n, P_n = tp_n / n_k, tp_n / n_r 246 | F_n = 2 * tp_n / (n_k + n_r) 247 | 248 | # # Edge cases 249 | # if not c_k: 250 | # return (None, (R_n, P_n, F_n)) 251 | # if not n_k: 252 | # return ((R_c, P_c, F_c), None) 253 | 254 | return ((tp_c, c_k,c_r), (tp_n, n_k, n_r)) 255 | # return ((R_c, P_c, F_c), (R_n, P_n, F_n)) 256 | 257 | def tuple_to_metric(c_tuple, n_tuple): 258 | (tp_c, c_k,c_r), (tp_n, n_k, n_r) = c_tuple, n_tuple 259 | 260 | if not c_k and not c_r: 261 | R_c, P_c, F_c = (1.0, 1.0, 1.0) 262 | elif not c_k or not c_r: 263 | R_c, P_c, F_c = (0.0, 0.0, 0.0) 264 | else: 265 | R_c, P_c = tp_c / c_k, tp_c / c_r 266 | F_c = 2 * tp_c / (c_k + c_r) 267 | 268 | if not n_k and not n_r: 269 | R_n, P_n, F_n = (1.0, 1.0, 1.0) 270 | elif not n_k or not n_r: 271 | R_n, P_n, F_n = (0.0, 0.0, 0.0) 272 | else: 273 | R_n, P_n = tp_n / n_k, tp_n / n_r 274 | F_n = 2 * tp_n / (n_k + n_r) 275 | return ((P_c, R_c, F_c), (P_n, R_n, F_n)) -------------------------------------------------------------------------------- /conll.py: -------------------------------------------------------------------------------- 1 | import re 2 | import tempfile 3 | import subprocess 4 | import operator 5 | import collections 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)") # First line at each document 11 | COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL) 12 | 13 | 14 | def get_doc_key(doc_id, part): 15 | return "{}_{}".format(doc_id, int(part)) 16 | 17 | 18 | def output_conll(input_file, output_file, predictions, subtoken_map): 19 | prediction_map = {} 20 | for doc_key, clusters in predictions.items(): 21 | start_map = collections.defaultdict(list) 22 | end_map = collections.defaultdict(list) 23 | word_map = collections.defaultdict(list) 24 | for cluster_id, mentions in enumerate(clusters): 25 | for start, end in mentions: 26 | start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end] 27 | if start == end: 28 | word_map[start].append(cluster_id) 29 | else: 30 | start_map[start].append((cluster_id, end)) 31 | end_map[end].append((cluster_id, start)) 32 | for k,v in start_map.items(): 33 | start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)] 34 | for k,v in end_map.items(): 35 | end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)] 36 | prediction_map[doc_key] = (start_map, end_map, word_map) 37 | 38 | word_index = 0 39 | for line in input_file.readlines(): 40 | row = line.split() 41 | if len(row) == 0: 42 | output_file.write("\n") 43 | elif row[0].startswith("#"): 44 | begin_match = re.match(BEGIN_DOCUMENT_REGEX, line) 45 | if begin_match: 46 | doc_key = get_doc_key(begin_match.group(1), begin_match.group(2)) 47 | start_map, end_map, word_map = prediction_map[doc_key] 48 | word_index = 0 49 | output_file.write(line) 50 | output_file.write("\n") 51 | else: 52 | assert get_doc_key(row[0], row[1]) == doc_key 53 | coref_list = [] 54 | if word_index in end_map: 55 | for cluster_id in end_map[word_index]: 56 | coref_list.append("{})".format(cluster_id)) 57 | if word_index in word_map: 58 | for cluster_id in word_map[word_index]: 59 | coref_list.append("({})".format(cluster_id)) 60 | if word_index in start_map: 61 | for cluster_id in start_map[word_index]: 62 | coref_list.append("({}".format(cluster_id)) 63 | 64 | if len(coref_list) == 0: 65 | row[-1] = "-" 66 | else: 67 | row[-1] = "|".join(coref_list) 68 | 69 | output_file.write(" ".join(row)) 70 | output_file.write("\n") 71 | word_index += 1 72 | 73 | 74 | def official_conll_eval(gold_path, predicted_path, metric, official_stdout=True): 75 | cmd = ["conll-2012/scorer/v8.01/scorer.pl", metric, gold_path, predicted_path, "none"] 76 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 77 | stdout, stderr = process.communicate() 78 | process.wait() 79 | 80 | stdout = stdout.decode("utf-8") 81 | if stderr is not None: 82 | logger.error(stderr) 83 | 84 | if official_stdout: 85 | logger.info("Official result for {}".format(metric)) 86 | logger.info(stdout) 87 | 88 | coref_results_match = re.match(COREF_RESULTS_REGEX, stdout) 89 | recall = float(coref_results_match.group(1)) 90 | precision = float(coref_results_match.group(2)) 91 | f1 = float(coref_results_match.group(3)) 92 | return {"r": recall, "p": precision, "f": f1} 93 | 94 | 95 | def evaluate_conll(gold_path, predictions, subtoken_maps, official_stdout=True): 96 | with tempfile.NamedTemporaryFile(delete=True, mode="w") as prediction_file: 97 | with open(gold_path, "r") as gold_file: 98 | output_conll(gold_file, prediction_file, predictions, subtoken_maps) 99 | # logger.info("Predicted conll file: {}".format(prediction_file.name)) 100 | results = {m: official_conll_eval(gold_file.name, prediction_file.name, m, official_stdout) for m in ("muc", "bcub", "ceafe") } 101 | return results 102 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from run import Runner 2 | import sys 3 | import torch 4 | 5 | def evaluate(config_name, gpu_id, saved_suffix): 6 | runner = Runner(config_name, gpu_id) 7 | model = runner.initialize_model(saved_suffix) 8 | 9 | examples_train, examples_dev, examples_test = runner.data.get_tensor_examples() 10 | stored_info = runner.data.get_stored_info() 11 | 12 | runner.evaluate( 13 | model, examples_dev, stored_info, 0, official=False, conll_path=runner.config['conll_eval_path'], predict=True 14 | ) # Eval dev 15 | # print('=================================') 16 | # runner.evaluate(model, examples_test, stored_info, 0, official=False, conll_path=runner.config['conll_test_path']) # Eval test 17 | 18 | 19 | if __name__ == '__main__': 20 | config_name, saved_suffix, gpu_id = sys.argv[1], sys.argv[2], int(sys.argv[3]) 21 | evaluate(config_name, gpu_id, saved_suffix) 22 | print( 23 | torch.cuda.max_memory_allocated(device=gpu_id) 24 | ) 25 | -------------------------------------------------------------------------------- /experiments.conf: -------------------------------------------------------------------------------- 1 | best { 2 | data_dir = ./data_dir/ 3 | 4 | # Computation limits. 5 | max_top_antecedents = 50 6 | max_training_sentences = 5 7 | top_span_ratio = 0.4 8 | max_num_extracted_spans = 3900 9 | max_num_speakers = 20 10 | max_segment_len = 256 11 | 12 | dataset = "ontonotes" 13 | 14 | mention_sigmoid = false 15 | 16 | # Learning 17 | bert_learning_rate = 1e-5 18 | task_learning_rate = 2e-4 19 | 20 | adam_eps = 1e-8 21 | adam_weight_decay = 1e-2 22 | 23 | warmup_ratio = 0.1 24 | max_grad_norm = 1 # Set 0 to disable clipping 25 | gradient_accumulation_steps = 1 26 | 27 | # Model hyperparameters. 28 | coref_depth = 1 # when 1: no higher order (except for cluster_merging) 29 | coarse_to_fine = true 30 | fine_grained = true 31 | dropout_rate = 0.3 32 | ffnn_size = 1000 33 | ffnn_depth = 1 34 | 35 | num_epochs = 24 36 | feature_emb_size = 20 37 | max_span_width = 30 38 | use_metadata = true 39 | use_features = true 40 | use_segment_distance = true 41 | model_heads = true 42 | use_width_prior = true # For mention score 43 | use_distance_prior = true # For mention-ranking score 44 | 45 | 46 | # Other. 47 | conll_eval_path = ${best.data_dir}/dev.english.v4_gold_conll # gold_conll file for dev 48 | conll_test_path = ${best.data_dir}/test.english.v4_gold_conll # gold_conll file for test 49 | genres = ["bc", "bn", "mz", "nw", "pt", "tc", "wb"] 50 | eval_frequency = 1000 51 | report_frequency = 100 52 | log_root = ${best.data_dir} 53 | 54 | mention_proposer = outside 55 | } 56 | 57 | spanbert_base = ${best}{ 58 | num_docs = 2802 59 | bert_learning_rate = 2e-05 60 | task_learning_rate = 0.0001 61 | coref_depth = 1 62 | max_segment_len = 384 63 | ffnn_size = 3000 64 | cluster_ffnn_size = 1000 65 | max_training_sentences = 3 66 | neg_sample_rate=0.2 67 | 68 | bert_tokenizer_name = bert-base-cased 69 | bert_pretrained_name_or_path = SpanBERT/spanbert-base-cased 70 | 71 | } 72 | 73 | spanbert_base_greedy = ${best}{ 74 | num_docs = 2802 75 | bert_learning_rate = 2e-05 76 | task_learning_rate = 0.0001 77 | coref_depth = 1 78 | max_segment_len = 384 79 | ffnn_size = 3000 80 | cluster_ffnn_size = 1000 81 | max_training_sentences = 3 82 | 83 | bert_tokenizer_name = bert-base-cased 84 | bert_pretrained_name_or_path = SpanBERT/spanbert-base-cased 85 | 86 | mention_proposer = greedy 87 | } 88 | 89 | 90 | spanbert_large = ${best}{ 91 | num_docs = 2802 92 | bert_learning_rate = 1e-05 93 | task_learning_rate = 0.0003 94 | max_segment_len = 512 95 | ffnn_size = 3000 96 | cluster_ffnn_size = 3000 97 | max_training_sentences = 3 98 | 99 | neg_sample_rate=0.2 100 | 101 | bert_tokenizer_name = bert-base-cased 102 | bert_pretrained_name_or_path = SpanBERT/spanbert-large-cased 103 | 104 | } 105 | 106 | spanbert_large_greedy = ${best}{ 107 | num_docs = 2802 108 | bert_learning_rate = 1e-05 109 | task_learning_rate = 0.0003 110 | max_segment_len = 512 111 | ffnn_size = 3000 112 | cluster_ffnn_size = 3000 113 | max_training_sentences = 3 114 | 115 | bert_tokenizer_name = bert-base-cased 116 | bert_pretrained_name_or_path = SpanBERT/spanbert-large-cased 117 | 118 | mention_proposer = greedy 119 | } 120 | -------------------------------------------------------------------------------- /fig/illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyutyuh/structured-span-selector/eac7771312622cd98535b4c93bad3ee000957e59/fig/illustration.png -------------------------------------------------------------------------------- /fig/structured_span_selector_cky_chart.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyutyuh/structured-span-selector/eac7771312622cd98535b4c93bad3ee000957e59/fig/structured_span_selector_cky_chart.pdf -------------------------------------------------------------------------------- /greedy_mp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Any, Dict, List, Tuple 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.nn.parallel import DataParallel 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class GreedyMentionProposer(torch.nn.Module): 13 | 14 | def __init__( 15 | self, 16 | **kwargs 17 | ) -> None: 18 | super().__init__(**kwargs) 19 | 20 | 21 | def forward( 22 | self, 23 | spans: torch.IntTensor, 24 | span_mention_scores: torch.FloatTensor, 25 | span_mask: torch.FloatTensor, 26 | token_num: torch.IntTensor, 27 | num_spans_to_keep: int, 28 | take_top_spans_per_sentence = False, 29 | flat_span_sent_ids = None, 30 | ratio = 0.4, 31 | ): 32 | if not take_top_spans_per_sentence: 33 | top_span_indices = masked_topk_non_overlap( 34 | span_mention_scores, 35 | span_mask, 36 | num_spans_to_keep, 37 | spans 38 | ) 39 | top_spans = spans[top_span_indices] 40 | top_span_scores = span_mention_scores[top_span_indices] 41 | return top_span_scores, top_span_indices, top_spans, 0., None 42 | else: 43 | top_span_indices, top_span_scores, top_spans = [], [], [] 44 | prev_sent_id, prev_span_id = 0, 0 45 | for span_id, sent_id in enumerate(flat_span_sent_ids.tolist()): 46 | if sent_id != prev_sent_id: 47 | sent_span_indices = masked_topk_non_overlap( 48 | span_mention_scores[prev_span_id:span_id], 49 | span_mask[prev_span_id:span_id], 50 | int(ratio * (token_num[prev_sent_id])), # [CLS], [SEP] 51 | spans[prev_span_id:span_id], 52 | non_crossing=True, 53 | ) + prev_span_id 54 | 55 | top_span_indices.append(sent_span_indices) 56 | top_span_scores.append(span_mention_scores[sent_span_indices]) 57 | top_spans.append(spans[sent_span_indices]) 58 | 59 | prev_sent_id, prev_span_id = sent_id, span_id 60 | # last sentence 61 | sent_span_indices = masked_topk_non_overlap( 62 | span_mention_scores[prev_span_id:], 63 | span_mask[prev_span_id:], 64 | int(ratio * (token_num[-1])), 65 | spans[prev_span_id:], 66 | non_crossing=True, 67 | ) + prev_span_id 68 | 69 | top_span_indices.append(sent_span_indices) 70 | top_span_scores.append(span_mention_scores[sent_span_indices]) 71 | top_spans.append(spans[sent_span_indices]) 72 | 73 | num_top_spans = [x.size(0) for x in top_span_indices] 74 | max_num_top_span = max(num_top_spans) 75 | 76 | top_spans = torch.stack( 77 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), 2))], dim=0) for x in top_spans], dim=0 78 | ) 79 | top_span_masks = torch.stack( 80 | [torch.cat([x.new_ones((x.size(0), )), x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in top_span_indices], dim=0 81 | ) 82 | top_span_indices = torch.stack( 83 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in top_span_indices], dim=0 84 | ) 85 | 86 | return top_span_scores, top_span_indices, top_spans, 0., None, top_span_masks 87 | 88 | 89 | def masked_topk_non_overlap( 90 | span_scores, 91 | span_mask, 92 | num_spans_to_keep, 93 | spans, 94 | non_crossing=True 95 | ): 96 | 97 | sorted_scores, sorted_indices = torch.sort(span_scores + span_mask.log(), descending=True) 98 | sorted_indices = sorted_indices.tolist() 99 | spans = spans.tolist() 100 | 101 | if not non_crossing: 102 | selected_candidate_idx = sorted(sorted_indices[:num_spans_to_keep], key=lambda idx: (spans[idx][0], spans[idx][1])) 103 | selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long) 104 | return selected_candidate_idx 105 | 106 | selected_candidate_idx = [] 107 | start_to_max_end, end_to_min_start = {}, {} 108 | for candidate_idx in sorted_indices: 109 | if len(selected_candidate_idx) >= num_spans_to_keep: 110 | break 111 | # Perform overlapping check 112 | span_start_idx = spans[candidate_idx][0] 113 | span_end_idx = spans[candidate_idx][1] 114 | cross_overlap = False 115 | for token_idx in range(span_start_idx, span_end_idx + 1): 116 | max_end = start_to_max_end.get(token_idx, -1) 117 | if token_idx > span_start_idx and max_end > span_end_idx: 118 | cross_overlap = True 119 | break 120 | min_start = end_to_min_start.get(token_idx, -1) 121 | if token_idx < span_end_idx and 0 <= min_start < span_start_idx: 122 | cross_overlap = True 123 | break 124 | if not cross_overlap: 125 | # Pass check; select idx and update dict stats 126 | selected_candidate_idx.append(candidate_idx) 127 | max_end = start_to_max_end.get(span_start_idx, -1) 128 | if span_end_idx > max_end: 129 | start_to_max_end[span_start_idx] = span_end_idx 130 | min_start = end_to_min_start.get(span_end_idx, -1) 131 | if min_start == -1 or span_start_idx < min_start: 132 | end_to_min_start[span_end_idx] = span_start_idx 133 | # Sort selected candidates by span idx 134 | selected_candidate_idx = sorted(selected_candidate_idx, key=lambda idx: (spans[idx][0], spans[idx][1])) 135 | selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long) 136 | 137 | return selected_candidate_idx 138 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from collections import Counter, defaultdict 7 | 8 | # from scipy.optimize import linear_sum_assignment as linear_assignment 9 | # from sklearn.utils.linear_assignment_ import linear_assignment 10 | from scipy.optimize import linear_sum_assignment 11 | 12 | from blanc import blanc, tuple_to_metric 13 | 14 | 15 | def f1(p_num, p_den, r_num, r_den, beta=1): 16 | p = 0 if p_den == 0 else p_num / float(p_den) 17 | r = 0 if r_den == 0 else r_num / float(r_den) 18 | return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) 19 | 20 | 21 | class CorefEvaluator(object): 22 | def __init__(self): 23 | self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] 24 | self.all_gm = 1e-6 25 | self.recalled_gm = 1e-6 26 | 27 | self.all_gm_by_width = defaultdict((int)) 28 | self.recalled_gm_by_width = defaultdict((int)) 29 | 30 | self.all_gm_by_depth = defaultdict((int)) 31 | self.recalled_gm_by_depth = defaultdict((int)) 32 | 33 | self.c_tuple = [0,0,0] 34 | self.n_tuple = [0,0,0] 35 | 36 | def update( 37 | self, predicted, gold, mention_to_predicted, mention_to_gold, 38 | metainfo_gms, recalled_gms # all_gm, recalled_gm 39 | ): 40 | for e in self.evaluators: 41 | e.update(predicted, gold, mention_to_predicted, mention_to_gold) 42 | self.all_gm += len(metainfo_gms) 43 | self.recalled_gm += len(recalled_gms & set(metainfo_gms.keys())) 44 | 45 | for x, v in metainfo_gms.items(): 46 | self.all_gm_by_width[v['width']] += 1 47 | self.all_gm_by_depth[v['depth']] += 1 48 | if x in recalled_gms: 49 | self.recalled_gm_by_width[v['width']] += 1 50 | self.recalled_gm_by_depth[v['depth']] += 1 51 | 52 | 53 | 54 | c_tuple, n_tuple = blanc(gold, predicted) 55 | for i in range(3): 56 | self.c_tuple[i] += c_tuple[i] 57 | for i in range(3): 58 | self.n_tuple[i] += n_tuple[i] 59 | 60 | def get_all(self): 61 | all_res = {} 62 | name_dict = {0: "muc", 1: "b_cubed", 2: "ceafe"} 63 | for i, e in enumerate(self.evaluators): 64 | all_res[name_dict[i]+"_f1"] = e.get_f1() 65 | all_res[name_dict[i]+"_p"] = e.get_precision() 66 | all_res[name_dict[i]+"_r"] = e.get_recall() 67 | 68 | return all_res 69 | 70 | def get_f1(self): 71 | for e in self.evaluators: 72 | print("f:", e.get_f1()) 73 | 74 | return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) 75 | 76 | def get_recall(self): 77 | for e in self.evaluators: 78 | print("r:", e.get_recall()) 79 | return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) 80 | 81 | def get_precision(self): 82 | for e in self.evaluators: 83 | print("p:", e.get_precision()) 84 | return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) 85 | 86 | def get_prf(self): 87 | blanc_scores = tuple_to_metric(self.c_tuple, self.n_tuple) 88 | blanc_p, blanc_r, blanc_f = tuple(0.5*(a+b) for (a,b) in zip(*blanc_scores)) 89 | print("all_gm", self.all_gm) 90 | 91 | print(self.all_gm_by_width) 92 | print(self.recalled_gm_by_width) 93 | 94 | print(self.all_gm_by_depth) 95 | print(self.recalled_gm_by_depth) 96 | 97 | return self.get_precision(), self.get_recall(), self.get_f1(), self.recalled_gm / self.all_gm, (blanc_p, blanc_r, blanc_f) 98 | 99 | 100 | class Evaluator(object): 101 | def __init__(self, metric, beta=1): 102 | self.p_num = 0 103 | self.p_den = 0 104 | self.r_num = 0 105 | self.r_den = 0 106 | self.metric = metric 107 | self.beta = beta 108 | 109 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 110 | if self.metric == ceafe: 111 | pn, pd, rn, rd = self.metric(predicted, gold) 112 | else: 113 | pn, pd = self.metric(predicted, mention_to_gold) 114 | rn, rd = self.metric(gold, mention_to_predicted) 115 | self.p_num += pn 116 | self.p_den += pd 117 | self.r_num += rn 118 | self.r_den += rd 119 | 120 | def get_f1(self): 121 | return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) 122 | 123 | def get_recall(self): 124 | return 0 if self.r_num == 0 else self.r_num / float(self.r_den) 125 | 126 | def get_precision(self): 127 | return 0 if self.p_num == 0 else self.p_num / float(self.p_den) 128 | 129 | def get_prf(self): 130 | return self.get_precision(), self.get_recall(), self.get_f1() 131 | 132 | def get_counts(self): 133 | return self.p_num, self.p_den, self.r_num, self.r_den 134 | 135 | 136 | def evaluate_documents(documents, metric, beta=1): 137 | evaluator = Evaluator(metric, beta=beta) 138 | for document in documents: 139 | evaluator.update(document) 140 | return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1() 141 | 142 | 143 | def b_cubed(clusters, mention_to_gold): 144 | num, dem = 0, 0 145 | 146 | for c in clusters: 147 | # if len(c) == 1: 148 | # continue 149 | 150 | gold_counts = Counter() 151 | correct = 0 152 | for m in c: 153 | if m in mention_to_gold: 154 | gold_counts[tuple(mention_to_gold[m])] += 1 155 | for c2, count in gold_counts.items(): 156 | # if len(c2) != 1: 157 | correct += count * count 158 | 159 | num += correct / float(len(c)) 160 | dem += len(c) 161 | 162 | return num, dem 163 | 164 | 165 | def muc(clusters, mention_to_gold): 166 | tp, p = 0, 0 167 | for c in clusters: 168 | p += len(c) - 1 169 | tp += len(c) 170 | linked = set() 171 | for m in c: 172 | if m in mention_to_gold: 173 | linked.add(mention_to_gold[m]) 174 | else: 175 | tp -= 1 176 | tp -= len(linked) 177 | return tp, p 178 | 179 | 180 | def phi4(c1, c2): 181 | return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) 182 | 183 | 184 | def ceafe(clusters, gold_clusters): 185 | clusters = [c for c in clusters] # if len(c) != 1] 186 | scores = np.zeros((len(gold_clusters), len(clusters))) 187 | for i in range(len(gold_clusters)): 188 | for j in range(len(clusters)): 189 | scores[i, j] = phi4(gold_clusters[i], clusters[j]) 190 | matching = linear_sum_assignment(-scores) 191 | matching = np.transpose(np.asarray(matching)) 192 | similarity = sum(scores[matching[:,0], matching[:,1]]) 193 | return similarity, len(clusters), similarity, len(gold_clusters) 194 | 195 | 196 | def lea(clusters, mention_to_gold): 197 | num, dem = 0, 0 198 | 199 | for c in clusters: 200 | if len(c) == 1: 201 | continue 202 | 203 | common_links = 0 204 | all_links = len(c) * (len(c) - 1) / 2.0 205 | for i, m in enumerate(c): 206 | if m in mention_to_gold: 207 | for m2 in c[i + 1:]: 208 | if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: 209 | common_links += 1 210 | 211 | num += len(c) * common_links / float(all_links) 212 | dem += len(c) 213 | 214 | return num, dem 215 | -------------------------------------------------------------------------------- /minimize.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import re 6 | import os 7 | import sys 8 | import json 9 | import tempfile 10 | import subprocess 11 | import collections 12 | from nltk import Tree 13 | 14 | import conll 15 | from transformers import AutoTokenizer 16 | 17 | import copy 18 | import nltk 19 | 20 | SPEAKER_START = '[unused19]' 21 | SPEAKER_END = '[unused73]' 22 | GENRE_DICT = {"bc":"[unused20]", "bn":"[unused21]", "mz":"[unused22]", "nw":"[unused23]", "pt":"[unused24]", "tc":"[unused25]", "wb":"[unused26]"} 23 | 24 | def recur_list_tree(parse_tree, offset=0): 25 | span_to_tag = {} 26 | for node in parse_tree: 27 | if type(node) == type(parse_tree): 28 | span_to_tag.update(recur_list_tree(node, offset)) 29 | n_leaves = len(node.leaves()) 30 | start, end = offset, offset+n_leaves-1 31 | span_to_tag[(start,end)] = node.label() 32 | offset += n_leaves 33 | else: 34 | offset += 1 35 | pass 36 | return span_to_tag 37 | 38 | def get_chunks_from_tree(parse_tree,chunk_types,offset=0,depth=0): 39 | chunks = {} 40 | root = parse_tree 41 | for node in parse_tree: 42 | if type(node) == type(parse_tree): 43 | n_leaves = len(node.leaves()) 44 | start, end = offset, offset+n_leaves-1 45 | if node.label() not in chunk_types: 46 | chunks.update(get_chunks_from_tree(node,chunk_types, offset, depth)) 47 | else: 48 | if depth == 0: 49 | chunks[(start,end)] = node.label() 50 | else: 51 | next_level = get_chunks_from_tree(node,chunk_types, offset, depth-1) 52 | if len(next_level) > 0: 53 | chunks.update(next_level) 54 | else: 55 | chunks[(start,end)] = node.label() 56 | offset += n_leaves 57 | else: # leaf node 58 | offset += 1 59 | pass 60 | return chunks 61 | 62 | def keep_coreferable(root): 63 | coreferable_types = ["NP", "VB", "PRP", "NML","CD","NNP"] 64 | processed_types = ["MNP", "V", "PRP", "NML", "NNP"] 65 | result_tree = copy.deepcopy(root) 66 | 67 | coreferable_spans = {} 68 | 69 | def recur(node, parent, parent_span, span, index_in_siblings): 70 | child_left_ind = span[0] 71 | res = [] 72 | if True: 73 | for (i, child) in enumerate(node): 74 | if type(child) == nltk.tree.Tree: 75 | child_span = (child_left_ind, child_left_ind+len(child.leaves())) 76 | child_left_ind += len(child.leaves()) 77 | processed_node = recur(child, node, span, child_span, i) 78 | 79 | res += processed_node 80 | else: 81 | if node.label() not in ["CD", "NML","NNP","PRP"]: 82 | res += [node] 83 | else: 84 | res += [child] 85 | if node.label() == "NP": 86 | if True: # keeping all NPs parent.label() not in ["NP",]: 87 | node.set_label("MNP") 88 | else: 89 | if node[-1].label() in {"POS"}: 90 | node.set_label("MNP") 91 | pass 92 | elif any([x.label() in ("VP", "CC","HYPH") for x in parent]): 93 | node.set_label("MNP") 94 | pass 95 | elif sum([(x.label() in ("NP", "MNP")) for x in parent]) > 1: 96 | node.set_label("MNP") 97 | pass 98 | 99 | elif node.label().startswith("PRP"): 100 | node.set_label("PRP") 101 | elif node.label() == "CD": 102 | # exclude a lot of CDs 103 | if index_in_siblings > 0 and parent[index_in_siblings-1].label() == "DT": 104 | pass 105 | elif index_in_siblings < len(parent)-1 and parent[index_in_siblings+1].label() in ("NN", "NNS", ): 106 | pass 107 | else: 108 | node.set_label("-CD") 109 | elif node.label() in {"NN", "NNP"}: 110 | # exclude a lot of NNPs 111 | if index_in_siblings < len(parent)-1 and parent[index_in_siblings+1].label() == "POS": 112 | # node.set_label("-NNP") 113 | pass 114 | elif any([(x.label() in ("CC", "HYPH")) for x in parent]): 115 | node.set_label("NNP") 116 | pass 117 | elif index_in_siblings > 0 and parent[index_in_siblings-1].label() == "DT": 118 | pass 119 | elif index_in_siblings < len(parent)-1 and parent[index_in_siblings+1].label() in ("NN", "NNS", ): 120 | pass 121 | elif index_in_siblings == 0: 122 | pass 123 | else: 124 | node.set_label("-NNP") 125 | pass 126 | elif node.label().startswith("VB"): 127 | node.set_label("-V") 128 | pass 129 | if node.label() in processed_types: 130 | coreferable_spans[(span[0], span[1]-1)] = node.label() 131 | res = [nltk.Tree(node.label(), res)] 132 | else: 133 | res = res 134 | return res 135 | result_tree = recur(result_tree, None, None, (0, len(result_tree.leaves())), 0) 136 | result_tree = nltk.Tree("TOP", result_tree) 137 | return result_tree, coreferable_spans 138 | 139 | class DocumentState(object): 140 | def __init__(self, key): 141 | self.doc_key = key 142 | self.sentence_end = [] 143 | self.token_end = [] 144 | self.tokens = [] 145 | self.subtokens = [] 146 | self.info = [] 147 | self.segments = [] 148 | self.subtoken_map = [] 149 | self.segment_subtoken_map = [] 150 | self.sentence_map = [] 151 | self.pronouns = [] 152 | self.clusters = collections.defaultdict(list) 153 | self.coref_stacks = collections.defaultdict(list) 154 | self.speakers = [] 155 | self.segment_info = [] 156 | 157 | self.chunk_tags = [] # corresponding to subtokens 158 | self.constituents = [] 159 | self.coreferables = [] 160 | self.pos_tags = [] 161 | 162 | 163 | def finalize(self): 164 | # finalized: segments, segment_subtoken_map 165 | # populate speakers from info 166 | subtoken_idx = 0 167 | for segment in self.segment_info: 168 | speakers = [] 169 | for i, tok_info in enumerate(segment): 170 | if tok_info is None and (i == 0 or i == len(segment) - 1): 171 | speakers.append('[SPL]') 172 | elif tok_info is None: 173 | speakers.append(speakers[-1]) 174 | else: 175 | speakers.append(tok_info[9]) 176 | if tok_info[4] == 'PRP': 177 | self.pronouns.append(subtoken_idx) 178 | subtoken_idx += 1 179 | self.speakers += [speakers] 180 | # populate sentence map 181 | 182 | # populate clusters 183 | first_subtoken_index = -1 184 | for seg_idx, segment in enumerate(self.segment_info): 185 | speakers = [] 186 | for i, tok_info in enumerate(segment): 187 | first_subtoken_index += 1 188 | coref = tok_info[-2] if tok_info is not None else '-' 189 | if coref != "-": 190 | last_subtoken_index = first_subtoken_index + tok_info[-1] - 1 191 | for part in coref.split("|"): 192 | if part[0] == "(": 193 | if part[-1] == ")": 194 | cluster_id = int(part[1:-1]) 195 | self.clusters[cluster_id].append((first_subtoken_index, last_subtoken_index)) 196 | else: 197 | cluster_id = int(part[1:]) 198 | self.coref_stacks[cluster_id].append(first_subtoken_index) 199 | else: 200 | cluster_id = int(part[:-1]) 201 | start = self.coref_stacks[cluster_id].pop() 202 | self.clusters[cluster_id].append((start, last_subtoken_index)) 203 | # merge clusters 204 | merged_clusters = [] 205 | for c1 in self.clusters.values(): 206 | existing = None 207 | for m in c1: 208 | for c2 in merged_clusters: 209 | if m in c2: 210 | existing = c2 211 | break 212 | if existing is not None: 213 | break 214 | if existing is not None: 215 | print("Merging clusters (shouldn't happen very often.)") 216 | existing.update(c1) 217 | else: 218 | merged_clusters.append(set(c1)) 219 | merged_clusters = [list(c) for c in merged_clusters] 220 | 221 | flattened_sentences = flatten(self.segments) 222 | all_mentions = flatten(merged_clusters) 223 | sentence_map = get_sentence_map(self.segments, self.sentence_end) 224 | subtoken_map = flatten(self.segment_subtoken_map) 225 | 226 | 227 | for cluster in merged_clusters: 228 | for mention in cluster: 229 | if subtoken_map[mention[0]] == subtoken_map[mention[1]]: 230 | if self.pos_tags[subtoken_map[mention[0]]].startswith("V"): 231 | self.coreferables.append(mention) 232 | 233 | 234 | assert len(all_mentions) == len(set(all_mentions)) 235 | chunk_tags = flatten(self.chunk_tags) 236 | num_words = len(flattened_sentences) 237 | assert num_words == len(flatten(self.speakers)) 238 | # assert num_words == len(chunk_tags), (num_words, len(chunk_tags)) 239 | assert num_words == len(subtoken_map), (num_words, len(subtoken_map)) 240 | assert num_words == len(sentence_map), (num_words, len(sentence_map)) 241 | def mapper(x): 242 | if x == "NP": 243 | return 1 244 | else: 245 | return 2 246 | return { 247 | "doc_key": self.doc_key, 248 | "sentences": self.segments, 249 | "speakers": self.speakers, 250 | "constituents": [x[0] for x in self.constituents], # 251 | "constituent_type": [x[1] for x in self.constituents], # 252 | # "coreferables": self.coreferables, 253 | "ner": [], 254 | "clusters": merged_clusters, 255 | 'sentence_map':sentence_map, 256 | "subtoken_map": subtoken_map, 257 | 'pronouns': self.pronouns, 258 | "chunk_tags": self.chunk_tags 259 | } 260 | 261 | 262 | def normalize_word(word, language): 263 | if language == "arabic": 264 | word = word[:word.find("#")] 265 | if word == "/." or word == "/?": 266 | return word[1:] 267 | else: 268 | return word 269 | 270 | # first try to satisfy constraints1, and if not possible, constraints2. 271 | def split_into_segments(document_state, max_segment_len, constraints1, constraints2): 272 | current = 0 273 | previous_token = 0 274 | final_chunk_tags = [] 275 | index_mapping_dict = {} 276 | cur_seg_ind = 1 277 | while current < len(document_state.subtokens): 278 | # -3 for 3 additional special tokens 279 | end = min(current + max_segment_len - 1 - 3, len(document_state.subtokens) - 1) 280 | while end >= current and not constraints1[end]: 281 | end -= 1 282 | if end < current: 283 | end = min(current + max_segment_len - 1 - 3, len(document_state.subtokens) - 1) 284 | while end >= current and not constraints2[end]: 285 | end -= 1 286 | if end < current: 287 | raise Exception("Can't find valid segment") 288 | 289 | for i in range(current, end+1): 290 | index_mapping_dict[i] = i + 3*cur_seg_ind - 1 291 | 292 | 293 | genre = document_state.doc_key[:2] 294 | genre_text = GENRE_DICT[genre] 295 | document_state.tokens.append(genre_text) 296 | 297 | document_state.segments.append(['[CLS]', genre_text] + document_state.subtokens[current:end + 1] + ['[SEP]']) 298 | 299 | subtoken_map = document_state.subtoken_map[current : end + 1] 300 | document_state.segment_subtoken_map.append([previous_token, previous_token] + subtoken_map + [subtoken_map[-1]]) 301 | info = document_state.info[current : end + 1] 302 | document_state.segment_info.append([None, None] + info + [None]) 303 | current = end + 1 304 | cur_seg_ind += 1 305 | previous_token = subtoken_map[-1] 306 | 307 | document_state.chunk_tags = final_chunk_tags 308 | return index_mapping_dict 309 | 310 | def get_sentence_map(segments, sentence_end): 311 | current = 0 312 | sent_map = [] 313 | sent_end_idx = 0 314 | assert len(sentence_end) == sum([len(s)-3 for s in segments]) 315 | for segment in segments: 316 | sent_map.append(current) 317 | sent_map.append(current) 318 | for i in range(len(segment) - 3): 319 | sent_map.append(current) 320 | current += int(sentence_end[sent_end_idx]) 321 | sent_end_idx += 1 322 | sent_map.append(current) 323 | return sent_map 324 | 325 | def get_document(document_lines, tokenizer, language, segment_len): 326 | document_state = DocumentState(document_lines[0]) 327 | word_idx = -1 328 | parse_pieces = [] 329 | cur_sent_offset = 0 330 | cur_sent_len = 0 331 | 332 | current_speaker = None 333 | 334 | for line in document_lines[1]: 335 | row = line.split() 336 | sentence_end = len(row) == 0 337 | if not sentence_end: 338 | assert len(row) >= 12 339 | 340 | if current_speaker is None or current_speaker != row[9]: 341 | added_speaker_head = True 342 | # insert speaker 343 | word_idx += 1 344 | current_speaker = row[9] 345 | speaker_text = tokenizer.tokenize(current_speaker) 346 | parse_pieces.append(f"(PSEUDO {' '.join([SPEAKER_START] + speaker_text + [SPEAKER_END])}") 347 | document_state.tokens.append(current_speaker) 348 | document_state.pos_tags.append("SPEAKER") 349 | for sidx, subtoken in enumerate([SPEAKER_START] + speaker_text + [SPEAKER_END]): 350 | cur_sent_len += 1 351 | document_state.subtokens.append(subtoken) 352 | info = None 353 | document_state.info.append(info) 354 | document_state.sentence_end.append(False) 355 | document_state.subtoken_map.append(word_idx) 356 | 357 | word_idx += 1 358 | word = normalize_word(row[3], language) 359 | 360 | parse_piece = row[5] 361 | pos_tag = row[4] 362 | if pos_tag == "(": 363 | pos_tag = "-LRB-" 364 | if pos_tag == ")": 365 | pos_tag = "-RRB-" 366 | 367 | (left_brackets, right_hand_side) = parse_piece.split("*") 368 | right_brackets = right_hand_side.count(")") * ")" 369 | 370 | subtokens = tokenizer.tokenize(word) 371 | document_state.tokens.append(word) 372 | document_state.pos_tags.append(pos_tag) 373 | 374 | document_state.token_end += ([False] * (len(subtokens) - 1)) + [True] 375 | for sidx, subtoken in enumerate(subtokens): 376 | cur_sent_len += 1 377 | document_state.subtokens.append(subtoken) 378 | info = None if sidx != 0 else (row + [len(subtokens)]) 379 | document_state.info.append(info) 380 | document_state.sentence_end.append(False) 381 | document_state.subtoken_map.append(word_idx) 382 | new_core = " ".join(subtokens) 383 | 384 | parse_piece = f"{left_brackets} {new_core} {right_brackets}" 385 | # parse_piece = f"{left_brackets} {new_core} {right_brackets}" 386 | parse_pieces.append(parse_piece) 387 | else: 388 | if added_speaker_head: 389 | parse_pieces.append(")") 390 | added_speaker_head = False 391 | 392 | parse_tree = Tree.fromstring("".join(parse_pieces)) 393 | chunk_dict = get_chunks_from_tree(parse_tree, ["NP"]) 394 | constituent_dict = recur_list_tree(parse_tree) 395 | 396 | coreferable_spans_dict = keep_coreferable(parse_tree)[1] 397 | 398 | document_state.coreferables += [[x[0]+cur_sent_offset,x[1]+cur_sent_offset] for x,y in coreferable_spans_dict.items()] 399 | document_state.constituents += [[[x[0]+cur_sent_offset,x[1]+cur_sent_offset],y] for x,y in constituent_dict.items()] 400 | sent_chunk_tags = ["O" for i in range(cur_sent_len)] 401 | if pos_tag==".": 402 | sent_chunk_tags[-1] = "." 403 | for (chunk_l, chunk_r), chunk_tag in chunk_dict.items(): 404 | chunk_len = chunk_r - chunk_l + 1 405 | if chunk_len == 1: 406 | sent_chunk_tags[chunk_r] = "U-" + chunk_tag 407 | else: 408 | sent_chunk_tags[chunk_l] = "B-" + chunk_tag 409 | sent_chunk_tags[chunk_r] = "L-" + chunk_tag 410 | for i in range(chunk_l + 1, chunk_r): 411 | sent_chunk_tags[i] = "I-" + chunk_tag 412 | document_state.sentence_end[-1] = True 413 | cur_sent_offset += cur_sent_len 414 | cur_sent_len, parse_pieces = 0, [] 415 | document_state.chunk_tags += sent_chunk_tags 416 | 417 | constraints1 = document_state.sentence_end if language != 'arabic' else document_state.token_end 418 | index_mapping_dict = split_into_segments(document_state, segment_len, constraints1, document_state.token_end) 419 | 420 | for x in document_state.constituents: 421 | x[0][0] = index_mapping_dict[x[0][0]] 422 | x[0][1] = index_mapping_dict[x[0][1]] 423 | for x in document_state.coreferables: 424 | x[0] = index_mapping_dict[x[0]] 425 | x[1] = index_mapping_dict[x[1]] 426 | 427 | stats["max_sent_len_{}".format(language)] = max(max([len(s) for s in document_state.segments]), stats["max_sent_len_{}".format(language)]) 428 | document = document_state.finalize() 429 | return document 430 | 431 | def skip(doc_key): 432 | # if doc_key in ['nw/xinhua/00/chtb_0078_0', 'wb/eng/00/eng_0004_1']: #, 'nw/xinhua/01/chtb_0194_0', 'nw/xinhua/01/chtb_0157_0']: 433 | # return True 434 | return False 435 | 436 | def minimize_partition(name, language, extension, labels, stats, tokenizer, seg_len, input_dir, output_dir): 437 | input_path = "{}/{}.{}.{}".format(input_dir, name, language, extension) 438 | output_path = "{}/{}.{}.{}.jsonlines".format(output_dir, name, language, seg_len) 439 | count = 0 440 | print("Minimizing {}".format(input_path)) 441 | documents = [] 442 | with open(input_path, "r") as input_file: 443 | for line in input_file.readlines(): 444 | begin_document_match = re.match(conll.BEGIN_DOCUMENT_REGEX, line) 445 | if begin_document_match: 446 | doc_key = conll.get_doc_key(begin_document_match.group(1), begin_document_match.group(2)) 447 | documents.append((doc_key, [])) 448 | elif line.startswith("#end document"): 449 | continue 450 | else: 451 | documents[-1][1].append(line) 452 | num_coreferables, num_words = 0, 0 453 | with open(output_path, "w") as output_file: 454 | for document_lines in documents: 455 | if skip(document_lines[0]): 456 | continue 457 | document = get_document(document_lines, tokenizer, language, seg_len) 458 | 459 | # num_coreferables += len(document["coreferables"]) 460 | num_words += len(flatten(document["sentences"])) 461 | 462 | output_file.write(json.dumps(document)) 463 | output_file.write("\n") 464 | count += 1 465 | print("Wrote {} documents to {}, with {} words".format(count, output_path, num_words)) 466 | 467 | def minimize_language(language, labels, stats, seg_len, input_dir, output_dir, do_lower_case): 468 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 469 | # tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-large-4096") 470 | 471 | minimize_partition("dev", language, "v4_gold_conll", labels, stats, tokenizer, seg_len, input_dir, output_dir) 472 | minimize_partition("train", language, "v4_gold_conll", labels, stats, tokenizer, seg_len, input_dir, output_dir) 473 | minimize_partition("test", language, "v4_gold_conll", labels, stats, tokenizer, seg_len, input_dir, output_dir) 474 | 475 | 476 | def flatten(l): 477 | return [item for sublist in l for item in sublist] 478 | 479 | # Usage: python minimize.py ./data_dir/ ./data_dir/ontonotes_speaker_encoding/ false 480 | if __name__ == "__main__": 481 | input_dir = sys.argv[1] 482 | output_dir = sys.argv[2] 483 | do_lower_case = (sys.argv[3].lower() == 'true') 484 | 485 | print("do_lower_case", do_lower_case) 486 | labels = collections.defaultdict(set) 487 | stats = collections.defaultdict(int) 488 | 489 | if not os.path.isdir(output_dir): 490 | os.mkdir(output_dir) 491 | for seg_len in [384, 512]: 492 | minimize_language("english", labels, stats, seg_len, input_dir, output_dir, do_lower_case) 493 | for k, v in labels.items(): 494 | print("{} = [{}]".format(k, ", ".join("\"{}\"".format(label) for label in v))) 495 | for k, v in stats.items(): 496 | print("{} = {}".format(k, v)) 497 | 498 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | from bert_modelling import BertModel 7 | from transformers import BertTokenizer 8 | 9 | import logging 10 | import numpy as np 11 | from collections import Iterable, defaultdict 12 | 13 | from outside_mp import CFGMentionProposer 14 | from greedy_mp import GreedyMentionProposer 15 | 16 | from util import logsumexp, log1mexp, batch_select, bucket_distance 17 | 18 | 19 | logging.basicConfig( 20 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 21 | datefmt='%m/%d/%Y %H:%M:%S', 22 | level=logging.INFO 23 | ) 24 | logger = logging.getLogger() 25 | 26 | 27 | class CorefModel(torch.nn.Module): 28 | tz = BertTokenizer.from_pretrained("bert-base-cased") 29 | def __init__(self, config, device, num_genres=None): 30 | super().__init__() 31 | self.config = config 32 | self.device = device 33 | 34 | self.cls_id = self.tz.convert_tokens_to_ids("[CLS]") 35 | self.sep_id = self.tz.convert_tokens_to_ids("[SEP]") 36 | 37 | self.num_genres = num_genres if num_genres else len(config['genres']) 38 | self.max_seg_len = config['max_segment_len'] 39 | self.max_span_width = config['max_span_width'] 40 | 41 | # Model 42 | self.dropout = nn.Dropout(p=config['dropout_rate']) 43 | self.bert = BertModel.from_pretrained(config['bert_pretrained_name_or_path']) 44 | 45 | self.bert_emb_size = self.bert.config.hidden_size 46 | self.span_emb_size = self.bert_emb_size * 3 47 | if config['use_features']: 48 | self.span_emb_size += config['feature_emb_size'] 49 | 50 | self.pair_emb_size = self.span_emb_size * 3 51 | if config['use_metadata']: 52 | self.pair_emb_size += 2 * config['feature_emb_size'] 53 | if config['use_features']: 54 | self.pair_emb_size += config['feature_emb_size'] 55 | if config['use_segment_distance']: 56 | self.pair_emb_size += config['feature_emb_size'] 57 | 58 | if config['mention_proposer'].lower() == "outside": 59 | self.mention_proposer = CFGMentionProposer(max_span_width=self.max_span_width, neg_sample_rate=config['neg_sample_rate']) 60 | elif config['mention_proposer'].lower() == "greedy": 61 | self.mention_proposer = GreedyMentionProposer() 62 | 63 | self.emb_span_width = self.make_embedding(self.max_span_width) if config['use_features'] else None 64 | self.emb_span_width_prior = self.make_embedding(self.max_span_width) if config['use_width_prior'] else None 65 | self.emb_antecedent_distance_prior = self.make_embedding(10) if config['use_distance_prior'] else None 66 | self.emb_genre = self.make_embedding(self.num_genres) 67 | self.emb_same_speaker = self.make_embedding(2) if config['use_metadata'] else None 68 | self.emb_segment_distance = self.make_embedding(config['max_training_sentences']) if config['use_segment_distance'] else None 69 | self.emb_top_antecedent_distance = self.make_embedding(10) 70 | 71 | self.mention_token_attn = self.make_ffnn(self.bert_emb_size, 0, output_size=1) if config['model_heads'] else None 72 | if type(self.mention_proposer) == CFGMentionProposer: 73 | self.span_emb_score_ffnn = self.make_ffnn(self.span_emb_size, [config['ffnn_size']] * config['ffnn_depth'], output_size=2) 74 | elif type(self.mention_proposer) == GreedyMentionProposer: 75 | self.span_emb_score_ffnn = self.make_ffnn(self.span_emb_size, [config['ffnn_size']] * config['ffnn_depth'], output_size=1) 76 | 77 | self.span_width_score_ffnn = self.make_ffnn(config['feature_emb_size'], [config['ffnn_size']] * config['ffnn_depth'], output_size=1) if config['use_width_prior'] else None 78 | self.coarse_bilinear = self.make_ffnn(self.span_emb_size, 0, output_size=self.span_emb_size) 79 | self.antecedent_distance_score_ffnn = self.make_ffnn(config['feature_emb_size'], 0, output_size=1) if config['use_distance_prior'] else None 80 | self.coref_score_ffnn = self.make_ffnn(self.pair_emb_size, [config['ffnn_size']] * config['ffnn_depth'], output_size=1) if config['fine_grained'] else None 81 | 82 | self.all_words = 0 83 | self.all_pred_men = 0 84 | self.debug = False 85 | 86 | 87 | def make_embedding(self, dict_size, std=0.02): 88 | emb = nn.Embedding(dict_size, self.config['feature_emb_size']) 89 | init.normal_(emb.weight, std=std) 90 | return emb 91 | 92 | def make_linear(self, in_features, out_features, bias=True, std=0.02): 93 | linear = nn.Linear(in_features, out_features, bias) 94 | init.normal_(linear.weight, std=std) 95 | if bias: 96 | init.zeros_(linear.bias) 97 | return linear 98 | 99 | def make_ffnn(self, feat_size, hidden_size, output_size): 100 | if hidden_size is None or hidden_size == 0 or hidden_size == [] or hidden_size == [0]: 101 | return self.make_linear(feat_size, output_size) 102 | 103 | if not isinstance(hidden_size, Iterable): 104 | hidden_size = [hidden_size] 105 | ffnn = [self.make_linear(feat_size, hidden_size[0]), nn.ReLU(), self.dropout] 106 | for i in range(1, len(hidden_size)): 107 | ffnn += [self.make_linear(hidden_size[i-1], hidden_size[i]), nn.ReLU(), self.dropout] 108 | ffnn.append(self.make_linear(hidden_size[-1], output_size)) 109 | return nn.Sequential(*ffnn) 110 | 111 | def get_params(self, named=False): 112 | bert_based_param, task_param = [], [] 113 | for name, param in self.named_parameters(): 114 | if name.startswith('bert') or name.startswith("mention_transformer"): 115 | to_add = (name, param) if named else param 116 | bert_based_param.append(to_add) 117 | else: 118 | to_add = (name, param) if named else param 119 | task_param.append(to_add) 120 | return bert_based_param, task_param 121 | 122 | def forward(self, *input): 123 | mention_doc = self.get_mention_doc(*input) 124 | return self.get_predictions_and_loss(mention_doc, *input) 125 | 126 | def get_flat_span_location_indices(self, spans, sentence_map): 127 | sentence_map = sentence_map.tolist() 128 | spans_list = spans.tolist() 129 | flat_span_location_indices = [] 130 | 131 | prev_sent_id = sentence_map[0] 132 | sentence_lengths, cur_sent_len = [], 0 133 | for i in sentence_map: 134 | if prev_sent_id == i: 135 | cur_sent_len += 1 136 | else: 137 | sentence_lengths.append(cur_sent_len) 138 | cur_sent_len = 1 139 | prev_sent_id = i 140 | 141 | sentence_lengths.append(cur_sent_len) 142 | max_sentence_len = max(sentence_lengths) 143 | 144 | sentence_offsets = np.cumsum([0] + sentence_lengths)[:-1] 145 | for (start, end) in spans_list: 146 | sent_id = sentence_map[start] - sentence_map[0] 147 | offset = sentence_offsets[sent_id] 148 | 149 | flat_id = sent_id * (max_sentence_len**2) + (start-offset)*max_sentence_len + (end-offset) 150 | flat_span_location_indices.append(flat_id) 151 | 152 | sentence_lengths = spans.new_tensor(sentence_lengths) 153 | flat_span_location_indices = spans.new_tensor(flat_span_location_indices) 154 | return flat_span_location_indices, sentence_lengths 155 | 156 | def get_mention_doc(self, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, 157 | is_training, gold_starts=None, gold_ends=None, gold_mention_cluster_map=None, 158 | coreferable_starts=None, coreferable_ends=None, 159 | constituent_starts=None, constituent_ends=None, constituent_type=None): 160 | 161 | mention_doc = self.bert(input_ids, attention_mask=input_mask) # [num seg, num max tokens, emb size] 162 | mention_doc = mention_doc["last_hidden_state"] 163 | input_mask = input_mask.bool() 164 | mention_doc = mention_doc[input_mask] 165 | return mention_doc 166 | 167 | 168 | def get_predictions_and_loss( 169 | self, mention_doc, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, 170 | is_training, gold_starts=None, gold_ends=None, gold_mention_cluster_map=None, 171 | coreferable_starts=None, coreferable_ends=None, 172 | constituent_starts=None, constituent_ends=None, constituent_type=None 173 | ): 174 | """ Model and input are already on the device """ 175 | device = self.device 176 | conf = self.config 177 | 178 | do_loss = False 179 | if gold_mention_cluster_map is not None: 180 | assert gold_starts is not None 181 | assert gold_ends is not None 182 | do_loss = True 183 | 184 | input_mask = input_mask.bool() 185 | speaker_ids = speaker_ids[input_mask] 186 | num_words = mention_doc.shape[0] 187 | 188 | self.all_words += num_words 189 | 190 | # Get candidate span 191 | sentence_indices = sentence_map # [num tokens] 192 | candidate_starts = torch.unsqueeze(torch.arange(0, num_words, device=device), 1).repeat(1, self.max_span_width) 193 | candidate_ends = candidate_starts + torch.arange(0, self.max_span_width, device=device) 194 | candidate_start_sent_idx = sentence_indices[candidate_starts] 195 | candidate_end_sent_idx = sentence_indices[torch.min(candidate_ends, torch.tensor(num_words - 1, device=device))] 196 | candidate_mask = (candidate_ends < num_words) & (candidate_start_sent_idx == candidate_end_sent_idx) 197 | candidate_mask &= (input_ids[input_mask][candidate_starts] != self.cls_id) 198 | candidate_mask &= (input_ids[input_mask][torch.clamp(candidate_ends, max=num_words-1)] != self.sep_id) 199 | 200 | candidate_starts, candidate_ends = candidate_starts[candidate_mask], candidate_ends[candidate_mask] # [num valid candidates] 201 | num_candidates = candidate_starts.shape[0] 202 | 203 | candidate_labels = None 204 | non_dummy_indicator = None 205 | # Get candidate labels 206 | if do_loss: 207 | same_start = (torch.unsqueeze(gold_starts, 1) == torch.unsqueeze(candidate_starts, 0)) 208 | same_end = (torch.unsqueeze(gold_ends, 1) == torch.unsqueeze(candidate_ends, 0)) 209 | same_span = (same_start & same_end).long() 210 | candidate_labels = torch.matmul(gold_mention_cluster_map.unsqueeze(0).type_as(mention_doc), same_span.type_as(mention_doc)) 211 | candidate_labels = candidate_labels.long().squeeze() # [num candidates]; non-gold span has label 0 212 | 213 | 214 | # Get span embedding 215 | span_start_emb, span_end_emb = mention_doc[candidate_starts], mention_doc[candidate_ends] 216 | # span_start_emb_1, span_end_emb_1 = mention_doc[candidate_starts], mention_doc[candidate_ends+1] 217 | # candidate_emb_list = [span_start_emb, span_end_emb] 218 | candidate_emb_list = [span_start_emb, span_end_emb] 219 | if conf['use_features']: 220 | candidate_width_idx = candidate_ends - candidate_starts 221 | candidate_width_emb = self.emb_span_width(candidate_width_idx) 222 | candidate_width_emb = self.dropout(candidate_width_emb) 223 | candidate_emb_list.append(candidate_width_emb) 224 | # Use attended head or avg token 225 | candidate_tokens = torch.unsqueeze(torch.arange(0, num_words, device=device), 0).repeat(num_candidates, 1) 226 | candidate_tokens_mask = (candidate_tokens >= torch.unsqueeze(candidate_starts, 1)) & (candidate_tokens <= torch.unsqueeze(candidate_ends, 1)) 227 | if conf['model_heads']: 228 | token_attn = self.mention_token_attn(mention_doc).squeeze() 229 | else: 230 | token_attn = torch.ones(num_words, dtype=mention_doc.dtype, device=device) # Use avg if no attention 231 | candidate_tokens_attn_raw = candidate_tokens_mask.log() + token_attn.unsqueeze(0) 232 | candidate_tokens_attn = F.softmax(candidate_tokens_attn_raw, dim=1) 233 | 234 | head_attn_emb = torch.matmul(candidate_tokens_attn, mention_doc) 235 | candidate_emb_list.append(head_attn_emb) 236 | candidate_span_emb = torch.cat(candidate_emb_list, dim=-1) # [num candidates, new emb size] 237 | 238 | # Get span scores 239 | candidate_mention_scores_and_parsing = self.span_emb_score_ffnn(candidate_span_emb) 240 | 241 | if type(self.mention_proposer) == CFGMentionProposer: 242 | candidate_mention_scores, candidate_mention_parsing_scores = candidate_mention_scores_and_parsing.split(1, dim=-1) 243 | candidate_mention_scores = candidate_mention_scores.squeeze(1) 244 | elif type(self.mention_proposer) == GreedyMentionProposer: 245 | candidate_mention_scores = candidate_mention_scores_and_parsing.squeeze(-1) 246 | candidate_mention_parsing_scores = candidate_mention_scores 247 | 248 | if conf['use_width_prior']: 249 | width_score = self.span_width_score_ffnn(self.emb_span_width_prior.weight).squeeze(1) 250 | candidate_mention_scores = candidate_mention_scores + width_score[candidate_width_idx] 251 | 252 | 253 | spans = torch.stack([candidate_starts, candidate_ends], dim=-1) 254 | flat_span_location_indices, sentence_lengths = self.get_flat_span_location_indices( 255 | spans, sentence_map 256 | ) 257 | num_top_spans = int(min(conf['max_num_extracted_spans'], conf['top_span_ratio'] * num_words)) 258 | non_dummy_indicator = (candidate_labels > 0) if candidate_labels is not None else None 259 | 260 | if type(self.mention_proposer) == CFGMentionProposer: 261 | top_span_p_mention, selected_idx, top_spans, mp_loss, _ = self.mention_proposer( 262 | spans, 263 | candidate_mention_parsing_scores, 264 | candidate_mask[candidate_mask], 265 | non_dummy_indicator if non_dummy_indicator is not None else None, 266 | sentence_lengths, 267 | num_top_spans, 268 | flat_span_location_indices, 269 | ) 270 | top_span_log_p_mention = top_span_p_mention.log() 271 | top_span_log_p_mention = top_span_log_p_mention[selected_idx] 272 | 273 | elif type(self.mention_proposer) == GreedyMentionProposer: 274 | _, selected_idx, top_spans, mp_loss, _ = self.mention_proposer( 275 | spans, 276 | candidate_mention_parsing_scores, 277 | candidate_mask[candidate_mask], 278 | sentence_lengths, 279 | num_top_spans, 280 | ) 281 | 282 | num_top_spans = selected_idx.size(0) 283 | 284 | top_span_starts, top_span_ends = candidate_starts[selected_idx], candidate_ends[selected_idx] 285 | top_span_emb = candidate_span_emb[selected_idx] 286 | top_span_cluster_ids = candidate_labels[selected_idx] if do_loss else None 287 | top_span_mention_scores = candidate_mention_scores[selected_idx] 288 | 289 | # Coarse pruning on each mention's antecedents 290 | max_top_antecedents = min(num_top_spans, conf['max_top_antecedents']) 291 | top_span_range = torch.arange(0, num_top_spans, device=device) 292 | antecedent_offsets = torch.unsqueeze(top_span_range, 1) - torch.unsqueeze(top_span_range, 0) 293 | antecedent_mask = (antecedent_offsets >= 1) 294 | pairwise_mention_score_sum = torch.unsqueeze(top_span_mention_scores, 1) + torch.unsqueeze(top_span_mention_scores, 0) 295 | source_span_emb = self.dropout(self.coarse_bilinear(top_span_emb)) 296 | target_span_emb = self.dropout(torch.transpose(top_span_emb, 0, 1)) 297 | pairwise_coref_scores = torch.matmul(source_span_emb, target_span_emb) 298 | 299 | pairwise_fast_scores = pairwise_mention_score_sum + pairwise_coref_scores 300 | pairwise_fast_scores += antecedent_mask.type_as(mention_doc).log() 301 | if conf['use_distance_prior']: 302 | distance_score = torch.squeeze(self.antecedent_distance_score_ffnn(self.dropout(self.emb_antecedent_distance_prior.weight)), 1) 303 | bucketed_distance = bucket_distance(antecedent_offsets) 304 | antecedent_distance_score = distance_score[bucketed_distance] 305 | pairwise_fast_scores += antecedent_distance_score 306 | # Slow mention ranking 307 | if conf['fine_grained']: 308 | top_pairwise_fast_scores, top_antecedent_idx = torch.topk(pairwise_fast_scores, k=max_top_antecedents) 309 | top_antecedent_mask = batch_select(antecedent_mask, top_antecedent_idx, device) # [num top spans, max top antecedents] 310 | top_antecedent_offsets = batch_select(antecedent_offsets, top_antecedent_idx, device) 311 | 312 | same_speaker_emb, genre_emb, seg_distance_emb, top_antecedent_distance_emb = None, None, None, None 313 | if conf['use_metadata']: 314 | top_span_speaker_ids = speaker_ids[top_span_starts] 315 | top_antecedent_speaker_id = top_span_speaker_ids[top_antecedent_idx] 316 | same_speaker = torch.unsqueeze(top_span_speaker_ids, 1) == top_antecedent_speaker_id 317 | same_speaker_emb = self.emb_same_speaker(same_speaker.long()) 318 | genre_emb = self.emb_genre(genre) 319 | genre_emb = torch.unsqueeze(torch.unsqueeze(genre_emb, 0), 0).repeat(num_top_spans, max_top_antecedents, 1) 320 | if conf['use_segment_distance']: 321 | num_segs, seg_len = input_ids.shape[0], input_ids.shape[1] 322 | token_seg_ids = torch.arange(0, num_segs, device=device).unsqueeze(1).repeat(1, seg_len) 323 | token_seg_ids = token_seg_ids[input_mask] 324 | top_span_seg_ids = token_seg_ids[top_span_starts] 325 | top_antecedent_seg_ids = token_seg_ids[top_span_starts[top_antecedent_idx]] 326 | top_antecedent_seg_distance = torch.unsqueeze(top_span_seg_ids, 1) - top_antecedent_seg_ids 327 | top_antecedent_seg_distance = torch.clamp(top_antecedent_seg_distance, 0, self.config['max_training_sentences'] - 1) 328 | seg_distance_emb = self.emb_segment_distance(top_antecedent_seg_distance) 329 | if conf['use_features']: # Antecedent distance 330 | top_antecedent_distance = bucket_distance(top_antecedent_offsets) 331 | top_antecedent_distance_emb = self.emb_top_antecedent_distance(top_antecedent_distance) 332 | 333 | top_antecedent_emb = top_span_emb[top_antecedent_idx] # [num top spans, max top antecedents, emb size] 334 | feature_list = [] 335 | if conf['use_metadata']: # speaker, genre 336 | feature_list.append(same_speaker_emb) 337 | feature_list.append(genre_emb) 338 | if conf['use_segment_distance']: 339 | feature_list.append(seg_distance_emb) 340 | if conf['use_features']: # Antecedent distance 341 | feature_list.append(top_antecedent_distance_emb) 342 | feature_emb = torch.cat(feature_list, dim=2) 343 | feature_emb = self.dropout(feature_emb) 344 | target_emb = torch.unsqueeze(top_span_emb, 1).repeat(1, max_top_antecedents, 1) 345 | # target_parent_emb = torch.unsqueeze(top_span_parent_emb, 1).repeat(1, max_top_antecedents, 1) 346 | similarity_emb = target_emb * top_antecedent_emb 347 | pair_emb = torch.cat([target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2) 348 | top_pairwise_slow_scores = self.coref_score_ffnn(pair_emb).squeeze(2) 349 | # print(pair_emb.size(), mention_doc.size(), pair_emb.size(0) / mention_doc.size()[0]) 350 | top_pairwise_scores = top_pairwise_slow_scores + top_pairwise_fast_scores 351 | else: 352 | top_pairwise_fast_scores, top_antecedent_idx = torch.topk(pairwise_fast_scores, k=pairwise_fast_scores.size(0)) 353 | top_antecedent_mask = batch_select(antecedent_mask, top_antecedent_idx, device) # [num top spans, max top antecedents] 354 | top_antecedent_offsets = batch_select(antecedent_offsets, top_antecedent_idx, device) 355 | 356 | top_pairwise_scores = top_pairwise_fast_scores # [num top spans, max top antecedents] 357 | 358 | top_antecedent_scores = torch.cat([torch.zeros(num_top_spans, 1, device=device), top_pairwise_scores], dim=1) 359 | 360 | 361 | if not do_loss: 362 | if type(self.mention_proposer) == CFGMentionProposer or self.config["mention_sigmoid"]: 363 | top_antecedent_log_p_mention = top_span_log_p_mention[top_antecedent_idx] 364 | log_norm = logsumexp(top_antecedent_scores, dim=1) 365 | # Shape: (num_spans_to_keep, max_antecedents+1) 366 | log_p_im = top_antecedent_scores - log_norm.unsqueeze(-1) + top_span_log_p_mention.unsqueeze(-1) 367 | # Shape: (num_spans_to_keep) 368 | log_p_em = torch.logaddexp( 369 | log1mexp(top_span_log_p_mention), 370 | top_span_log_p_mention - log_norm + torch.finfo(log_norm.dtype).eps 371 | ) 372 | # log probability for inference 373 | log_probs = torch.cat([log_p_em.unsqueeze(-1), log_p_im[:,1:]], dim=-1) 374 | 375 | return candidate_starts, candidate_ends, candidate_mention_parsing_scores, top_span_starts, top_span_ends, top_antecedent_idx, log_probs 376 | elif type(self.mention_proposer) == GreedyMentionProposer: 377 | return candidate_starts, candidate_ends, candidate_mention_parsing_scores, top_span_starts, top_span_ends, top_antecedent_idx, top_antecedent_scores 378 | 379 | log_norm = logsumexp(top_antecedent_scores, dim=1) 380 | if type(self.mention_proposer) == CFGMentionProposer or self.config["mention_sigmoid"]: 381 | top_antecedent_log_p_mention = top_span_log_p_mention[top_antecedent_idx] 382 | # Shape: (num_spans_to_keep, max_antecedents+1) 383 | log_p_im = top_antecedent_scores - log_norm.unsqueeze(-1) + top_span_log_p_mention.unsqueeze(-1) 384 | # Shape: (num_spans_to_keep) 385 | log_p_em = torch.logaddexp( 386 | log1mexp(top_span_log_p_mention) + torch.finfo(log_norm.dtype).eps, 387 | top_span_log_p_mention - log_norm + torch.finfo(log_norm.dtype).eps 388 | ) 389 | # log probability for inference 390 | log_probs = torch.cat([log_p_em.unsqueeze(-1), log_p_im[:,1:]], dim=-1) 391 | 392 | # Get gold labels 393 | top_antecedent_cluster_ids = top_span_cluster_ids[top_antecedent_idx] 394 | top_antecedent_cluster_ids += (top_antecedent_mask.long() - 1) * 100000 # Mask id on invalid antecedents 395 | same_gold_cluster_indicator = (top_antecedent_cluster_ids == torch.unsqueeze(top_span_cluster_ids, 1)) 396 | non_dummy_indicator = non_dummy_indicator[selected_idx] # (top_span_cluster_ids > 0).squeeze() 397 | # non_dummy_indicator is the coreferable flags 398 | pairwise_labels = same_gold_cluster_indicator & torch.unsqueeze(top_span_cluster_ids > 0, 1) 399 | dummy_antecedent_labels = torch.logical_not(pairwise_labels.any(dim=1, keepdims=True)) 400 | 401 | top_antecedent_gold_labels = torch.cat([dummy_antecedent_labels, pairwise_labels], dim=1) 402 | 403 | # Get loss 404 | if type(self.mention_proposer) == CFGMentionProposer or self.config["mention_sigmoid"]: 405 | coref_loss = -logsumexp(log_p_im + top_antecedent_gold_labels.log(), dim=-1) # for mentions 406 | loss = mp_loss + (coref_loss * non_dummy_indicator).sum() + (-log_p_em * torch.logical_not(non_dummy_indicator)).sum() 407 | return [candidate_starts, candidate_ends, candidate_mention_parsing_scores, top_span_starts, top_span_ends, top_antecedent_idx, log_probs], loss 408 | elif type(self.mention_proposer) == GreedyMentionProposer: 409 | log_marginalized_antecedent_scores = logsumexp(top_antecedent_scores + top_antecedent_gold_labels.log(), dim=1) 410 | loss = (log_norm - log_marginalized_antecedent_scores).sum() 411 | 412 | return [candidate_starts, candidate_ends, candidate_mention_parsing_scores, top_span_starts, top_span_ends, top_antecedent_idx, top_antecedent_scores], loss 413 | 414 | def _extract_top_spans(self, candidate_idx_sorted, candidate_starts, candidate_ends, num_top_spans): 415 | """ Keep top non-cross-overlapping candidates ordered by scores; compute on CPU because of loop """ 416 | selected_candidate_idx = [] 417 | start_to_max_end, end_to_min_start = {}, {} 418 | for candidate_idx in candidate_idx_sorted: 419 | if len(selected_candidate_idx) >= num_top_spans: 420 | break 421 | # Perform overlapping check 422 | span_start_idx = candidate_starts[candidate_idx] 423 | span_end_idx = candidate_ends[candidate_idx] 424 | cross_overlap = False 425 | for token_idx in range(span_start_idx, span_end_idx + 1): 426 | max_end = start_to_max_end.get(token_idx, -1) 427 | if token_idx > span_start_idx and max_end > span_end_idx: 428 | cross_overlap = True 429 | break 430 | min_start = end_to_min_start.get(token_idx, -1) 431 | if token_idx < span_end_idx and 0 <= min_start < span_start_idx: 432 | cross_overlap = True 433 | break 434 | if not cross_overlap: 435 | # Pass check; select idx and update dict stats 436 | selected_candidate_idx.append(candidate_idx) 437 | max_end = start_to_max_end.get(span_start_idx, -1) 438 | if span_end_idx > max_end: 439 | start_to_max_end[span_start_idx] = span_end_idx 440 | min_start = end_to_min_start.get(span_end_idx, -1) 441 | if min_start == -1 or span_start_idx < min_start: 442 | end_to_min_start[span_end_idx] = span_start_idx 443 | # Sort selected candidates by span idx 444 | selected_candidate_idx = sorted(selected_candidate_idx, key=lambda idx: (candidate_starts[idx], candidate_ends[idx])) 445 | if len(selected_candidate_idx) < num_top_spans: # Padding 446 | selected_candidate_idx += ([selected_candidate_idx[0]] * (num_top_spans - len(selected_candidate_idx))) 447 | return selected_candidate_idx 448 | 449 | def get_predicted_antecedents(self, antecedent_idx, antecedent_scores): 450 | """ CPU list input """ 451 | predicted_antecedents = [] 452 | for i, idx in enumerate((antecedent_scores.argmax(dim=1) - 1).tolist()): 453 | if idx < 0: 454 | predicted_antecedents.append(-1) 455 | elif idx >= len(antecedent_idx[0]): 456 | predicted_antecedents.append(-2) 457 | else: 458 | predicted_antecedents.append(antecedent_idx[i][idx]) 459 | return predicted_antecedents 460 | 461 | def get_predicted_clusters(self, span_starts, span_ends, antecedent_idx, antecedent_scores): 462 | """ CPU list input """ 463 | # Get predicted antecedents 464 | predicted_antecedents = self.get_predicted_antecedents(antecedent_idx, antecedent_scores) 465 | 466 | # Get predicted clusters 467 | mention_to_cluster_id = {} 468 | predicted_clusters = [] 469 | for i, predicted_idx in enumerate(predicted_antecedents): 470 | if predicted_idx == -1: 471 | continue 472 | elif predicted_idx == -2: 473 | cluster_id = len(predicted_clusters) 474 | predicted_clusters.append([(int(span_starts[i]), int(span_ends[i]))]) 475 | mention_to_cluster_id[(int(span_starts[i]), int(span_ends[i]))] = cluster_id 476 | continue 477 | assert i > predicted_idx, f'span idx: {i}; antecedent idx: {predicted_idx}' 478 | # Check antecedent's cluster 479 | antecedent = (int(span_starts[predicted_idx]), int(span_ends[predicted_idx])) 480 | antecedent_cluster_id = mention_to_cluster_id.get(antecedent, -1) 481 | 482 | if antecedent_cluster_id == -1: 483 | antecedent_cluster_id = len(predicted_clusters) 484 | predicted_clusters.append([antecedent]) 485 | mention_to_cluster_id[antecedent] = antecedent_cluster_id 486 | 487 | # Add mention to cluster 488 | mention = (int(span_starts[i]), int(span_ends[i])) 489 | predicted_clusters[antecedent_cluster_id].append(mention) 490 | mention_to_cluster_id[mention] = antecedent_cluster_id 491 | 492 | predicted_clusters = [tuple(c) for c in predicted_clusters] 493 | return predicted_clusters, mention_to_cluster_id, predicted_antecedents 494 | 495 | def update_evaluator(self, span_starts, span_ends, antecedent_idx, antecedent_scores, gold_clusters, evaluator): 496 | predicted_clusters, mention_to_cluster_id, _ = self.get_predicted_clusters(span_starts, span_ends, antecedent_idx, antecedent_scores) 497 | mention_to_predicted = {m: predicted_clusters[cluster_idx] for m, cluster_idx in mention_to_cluster_id.items()} 498 | gold_clusters = [tuple(tuple(m) for m in cluster) for cluster in gold_clusters] 499 | mention_to_gold = {m: cluster for cluster in gold_clusters for m in cluster} 500 | 501 | # gold mentions 502 | gms = set([x for cluster in gold_clusters for x in cluster]) 503 | # getting meta informations, e.g. nested depth, width 504 | metainfo_gms = defaultdict(lambda: defaultdict(int)) 505 | for x in gms: 506 | metainfo_gms[x]["width"] = x[1] - x[0] 507 | for y in gms: 508 | if y[0] <= x[0] and y[1] >= x[1]: 509 | metainfo_gms[x]["depth"] += 1 510 | 511 | 512 | recalled_gms = set([(int(x), int(y)) for x,y in zip(span_starts, span_ends)]) 513 | self.all_pred_men += len(recalled_gms) 514 | 515 | evaluator.update( 516 | predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold, 517 | metainfo_gms, recalled_gms 518 | ) 519 | return predicted_clusters 520 | 521 | -------------------------------------------------------------------------------- /outside_mp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Any, Dict, List, Tuple 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn.parallel import DataParallel 10 | 11 | from util import logsumexp, clip_to_01, stripe, masked_topk_non_overlap 12 | 13 | from genbmm import logbmm, logbmminside_rule 14 | 15 | logger = logging.getLogger(__name__) 16 | LARGENUMBER = 1e4 17 | 18 | class CKY(torch.nn.Module): 19 | def __init__( 20 | self, 21 | max_span_width=30, 22 | ): 23 | super().__init__() 24 | self.max_span_width = max_span_width 25 | 26 | def forward( 27 | self, 28 | span_mention_score_matrix: torch.FloatTensor, 29 | sequence_lengths: torch.IntTensor, 30 | ) -> Tuple[torch.FloatTensor]: 31 | 32 | with torch.autograd.enable_grad(): 33 | # Enable autograd during inference 34 | # For outside value computation 35 | return self.coolio(span_mention_score_matrix, sequence_lengths) 36 | 37 | def coolio( 38 | self, 39 | span_mention_score_matrix: torch.FloatTensor, 40 | sequence_lengths: torch.IntTensor, 41 | ) -> Tuple[torch.FloatTensor]: 42 | """ 43 | Parameters: 44 | span_mention_score_matrix: shape (batch_size, sent_len, sent_len, score_dim) 45 | Score of each span being a span of interest. There are batch_size number 46 | of sentences in this document. And the maximum length of sentence is 47 | sent_len. 48 | sequence_lengths: shape (batch_size, ) 49 | The actual length of each sentence. 50 | """ 51 | # faster inside-outside 52 | # requiring grad for outside algorithm 53 | # https://www.cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf 54 | span_mention_score_matrix.requires_grad_(True) 55 | 56 | batch_size, _, _, score_dim = span_mention_score_matrix.size() 57 | seq_len = sequence_lengths.max() 58 | # Shape: (batch_size, ) 59 | sequence_lengths = sequence_lengths.view(-1) 60 | 61 | rules = span_mention_score_matrix 62 | # distributive law: log(exp(s+r) + exp(s)) == s + log(exp(r) + 1) 63 | log1p_exp_rules = torch.log1p(rules.squeeze(-1).exp()) 64 | 65 | zero_rules = (rules.new_ones(seq_len, seq_len).tril(diagonal=-1))*(-LARGENUMBER) 66 | zero_rules = zero_rules.unsqueeze(0).unsqueeze(-1).repeat(batch_size,1,1,1) 67 | 68 | inside_s = torch.cat([rules.clone(), zero_rules], dim=3) 69 | inside_s = inside_s.logsumexp(dim=3) 70 | 71 | for width in range(0, seq_len-1): 72 | # Usage: https://github.com/lyutyuh/genbmm 73 | inside_s = logbmminside_rule(inside_s, log1p_exp_rules, width+1) 74 | 75 | series_batchsize = torch.arange(0, batch_size, dtype=torch.long) 76 | Z = inside_s[series_batchsize, 0, sequence_lengths-1] # (batch_size, ) 77 | 78 | marginal = torch.autograd.grad( 79 | Z.sum(), 80 | span_mention_score_matrix, 81 | create_graph=True, 82 | only_inputs=True, 83 | allow_unused=False, 84 | ) 85 | marginal = marginal[0].squeeze() 86 | return (Z.view(-1), marginal) # Shape: (batch_size, seq_len, seq_len, ) 87 | 88 | def io( 89 | self, 90 | span_mention_score_matrix: torch.FloatTensor, 91 | sequence_lengths: torch.IntTensor, 92 | ) -> Tuple[torch.FloatTensor]: 93 | """ 94 | Parameters: 95 | span_mention_score_matrix: shape (batch_size, sent_len, sent_len, score_dim) 96 | Score of each span being a span of interest. There are batch_size number 97 | of sentences in this document. And the maximum length of sentence is 98 | sent_len. 99 | sequence_lengths: shape (batch_size, ) 100 | The actual length of each sentence. 101 | """ 102 | # inside-outside 103 | span_mention_score_matrix.requires_grad_(True) 104 | 105 | batch_size, _, _, score_dim = span_mention_score_matrix.size() 106 | seq_len = sequence_lengths.max() 107 | # Shape: (batch_size, ) 108 | sequence_lengths = sequence_lengths.view(-1) 109 | 110 | # Shape: (seq_len, seq_len, score_dim, batch_size) 111 | span_mention_score_matrix = span_mention_score_matrix.permute(1, 2, 3, 0) 112 | 113 | # There should be another matrix of non-mention span scores, which is full of 0s 114 | # Shape: (seq_len, seq_len, score_dim + 1, batch_size), 2 for mention / non-mention 115 | inside_s = span_mention_score_matrix.new_zeros(seq_len, seq_len, score_dim + 1, batch_size) 116 | 117 | for width in range(0, seq_len): 118 | n = seq_len - width 119 | if width == 0: 120 | inside_s[:,:,:score_dim,:].diagonal(width).copy_( 121 | span_mention_score_matrix.diagonal(width) 122 | ) 123 | continue 124 | # [n, width, score_dim + 1, batch_size] 125 | split_1 = stripe(inside_s, n, width) 126 | split_2 = stripe(inside_s, n, width, (1, width), 0) 127 | 128 | # [n, width, batch_size] 129 | inside_s_span = logsumexp(split_1, 2) + logsumexp(split_2, 2) 130 | # [1, batch_size, n] 131 | inside_s_span = logsumexp(inside_s_span, 1, keepdim=True).permute(1, 2, 0) 132 | 133 | if width < self.max_span_width: 134 | inside_s.diagonal(width).copy_( 135 | torch.cat( 136 | [inside_s_span + span_mention_score_matrix.diagonal(width), # mention 137 | inside_s_span], # non-mention 138 | dim=0 139 | ) 140 | ) 141 | else: 142 | inside_s.diagonal(width).copy_( 143 | torch.cat( 144 | [torch.full_like(span_mention_score_matrix.diagonal(width), -LARGENUMBER), # mention 145 | inside_s_span], # non-mention 146 | dim=0 147 | ) 148 | ) 149 | 150 | inside_s = inside_s.permute(0,1,3,2) # (seq_len, seq_len, batch_size, 2), 2 for mention / non-mention 151 | series_batchsize = torch.arange(0, batch_size, dtype=torch.long) 152 | 153 | Z = logsumexp(inside_s[0, sequence_lengths-1, series_batchsize], dim=-1) # (batch_size,) 154 | 155 | marginal = torch.autograd.grad( 156 | Z.sum(), 157 | span_mention_score_matrix, 158 | create_graph=True, 159 | only_inputs=True, 160 | allow_unused=False, 161 | ) 162 | marginal = marginal[0].squeeze() 163 | return (Z.view(-1), marginal.permute(2,0,1)) # Shape: (batch_size, seq_len, seq_len, ) 164 | 165 | @staticmethod 166 | def viterbi( 167 | span_mention_score_matrix: torch.FloatTensor, 168 | sequence_lengths: torch.IntTensor, 169 | ) -> Tuple[torch.FloatTensor]: 170 | 171 | if len(span_mention_score_matrix.size()) == 4: 172 | span_mention_score_matrix, _ = span_mention_score_matrix.max(-1) 173 | # Shape: (seq_len, seq_len, batch_size) 174 | span_mention_score_matrix = span_mention_score_matrix.permute(1, 2, 0) 175 | 176 | 177 | # Shape: (batch_size, ) 178 | sequence_lengths = sequence_lengths.view(-1) 179 | # There should be another matrix of non-mention span scores, which is full of 0s 180 | seq_len, _, batch_size = span_mention_score_matrix.size() 181 | 182 | s = span_mention_score_matrix.new_zeros(seq_len, seq_len, 2, batch_size) 183 | p = sequence_lengths.new_zeros(seq_len, seq_len, 2, batch_size) # backtrack 184 | 185 | for width in range(0, seq_len): 186 | n = seq_len - width 187 | span_score = span_mention_score_matrix.diagonal(width) 188 | if width == 0: 189 | s.diagonal(0)[0, :].copy_(span_score) 190 | continue 191 | # [n, width, 2, 1, batch_size] 192 | split1 = stripe(s, n, width, ).unsqueeze(3) 193 | # [n, width, 1, 2, batch_size] 194 | split2 = stripe(s, n, width, (1, width), 0).unsqueeze(2) 195 | 196 | # [n, width, 2, 2, batch_size] 197 | s_span = split1 + split2 198 | # [batch_size, n, 2, width, 2, 2] 199 | s_span = s_span.permute(4, 0, 1, 2, 3).unsqueeze(2).repeat(1,1,2,1,1,1) 200 | 201 | s_span[:,:,0,:,:,:] += span_score.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 202 | # [batch_size, n, 2] 203 | s_span, p_span = s_span.view(batch_size, n, 2, -1).max(-1) # best split 204 | s.diagonal(width).copy_( 205 | s_span.permute(2, 0, 1) 206 | ) 207 | starts = p.new_tensor(range(n)).unsqueeze(0).unsqueeze(0) 208 | p.diagonal(width).copy_( 209 | p_span.permute(2, 0, 1) + starts * 4 210 | ) 211 | 212 | def backtrack(mat, p, i, j, c): 213 | if j == i: 214 | return p.new_tensor([(i, j, c)], dtype=torch.long) 215 | int_pijc = int(p[i][j][c]) 216 | split = int_pijc // 4 217 | ltree = backtrack(mat, p, i, split, (int_pijc // 2) % 2) 218 | rtree = backtrack(mat, p, split+1, j, int_pijc % 2) 219 | return torch.cat([p.new_tensor([(i,j,c)], dtype=torch.long), ltree, rtree], dim=0) 220 | 221 | # (batch_size, seq_len, seq_len, 2) 222 | p = p.permute(3, 0, 1, 2) 223 | 224 | span_mention_score_matrix_cpu = span_mention_score_matrix.cpu() 225 | p_cpu = p.cpu() 226 | 227 | trees = [backtrack(s[:,:,:,i], p_cpu[i], 0, int(sequence_lengths[i]-1), 228 | (int) (s[0,int(sequence_lengths[i]-1),0,i] < s[0,int(sequence_lengths[i]-1),1,i])) 229 | for i in range(batch_size)] 230 | 231 | return trees 232 | 233 | def outside( 234 | self, 235 | inside_s: torch.FloatTensor, 236 | span_mention_score_matrix: torch.FloatTensor, 237 | sequence_lengths: torch.IntTensor, 238 | ): 239 | ''' 240 | inside_s: Shape: (seq_len, seq_len, 2, batch_size), 2 for mention / non-mention 241 | span_mention_score_matrix: Shape: (seq_len, seq_len, batch_size) 242 | 243 | Return: outside_s : Shape: (seq_len, seq_len, 2, batch_size), 2 for mention / non-mention 244 | ''' 245 | seq_len = sequence_lengths.max() 246 | 247 | _, _, batch_size = span_mention_score_matrix.size() 248 | series_batchsize = torch.arange(0, batch_size, dtype=torch.long) 249 | 250 | # Shape: (seq_len, seq_len, batch_size) 251 | # outside_s = span_mention_score_matrix.new_zeros(seq_len, seq_len, batch_size) 252 | outside_s = span_mention_score_matrix.new_full((seq_len, seq_len, batch_size), fill_value=-LARGENUMBER) 253 | 254 | mask_top = span_mention_score_matrix.new_zeros(seq_len, seq_len, batch_size).bool() 255 | mask_top[0, sequence_lengths-1, series_batchsize] = 1 256 | 257 | for width in range(seq_len-1, -1, -1): 258 | n = seq_len - width 259 | if width == seq_len-1: 260 | continue 261 | outside_s[mask_top] = 0 262 | # [n, n, 1, 2, batch_size] 263 | split_1 = inside_s[:n-1,:n-1].unsqueeze(2) # using the upper triangular [:n, :n] of the inside matrix 264 | # [n, n, 2, 1, batch_size] 265 | split_2 = outside_s[:n-1, width+1:seq_len].unsqueeze(2).unsqueeze(3).repeat(1,1,2,1,1) # using a submatrix of the outside matrix [:n, width:seq_len] 266 | # [n, n, 1, batch_size] 267 | span_score_submatrix = span_mention_score_matrix[:n-1, width+1:seq_len].unsqueeze(2) 268 | # [n, n, 1, 2, batch_size] 269 | split_3 = inside_s[width+1:seq_len, width+1:seq_len].unsqueeze(2) # using the upper triangular [width:seq_len, width:seq_len] of the inside matrix 270 | 271 | # [n, n, 1, 1, 1] 272 | upp_triu_mask = torch.triu(span_score_submatrix.new_ones(n-1,n-1), diagonal=0).view(n-1,n-1,1,1,1) 273 | 274 | # [n, n, 2, 2, batch_size] 275 | # B -> CA, B, C, A \in {0,1} 276 | outside_s_span_1 = (split_1 + split_2) 277 | outside_s_span_1[:,:,0,:,:] += span_score_submatrix 278 | 279 | outside_s_span_1 += (upp_triu_mask*LARGENUMBER - LARGENUMBER) # upp_triu_mask.log() # 280 | # [batch_size, n, n, 2, 2] 281 | outside_s_span_1 = outside_s_span_1.permute(4, 1, 0, 2, 3) 282 | # [batch_size, n] 283 | outside_s_span_1 = logsumexp(outside_s_span_1.reshape(batch_size, n-1, -1), dim=-1) 284 | # outside_s_span_1.logsumexp((1,3,4)) # sum vertical, as right child 285 | 286 | # [n, n, 2, 2, batch_size] 287 | outside_s_span_2 = (split_3 + split_2) 288 | outside_s_span_2[:,:,0,:,:] += span_score_submatrix 289 | 290 | outside_s_span_2 += (upp_triu_mask*LARGENUMBER - LARGENUMBER) # upp_triu_mask.log() # 291 | # [batch_size, n, n, 2, 2] 292 | outside_s_span_2 = outside_s_span_2.permute(4, 0, 1, 2, 3) 293 | # [batch_size, n] 294 | outside_s_span_2 = logsumexp(outside_s_span_2.view(batch_size, n-1, -1), dim=-1) # sum horizontal, as left child 295 | 296 | # shift and sum 297 | outside_s_span_1 = torch.cat([outside_s_span_1.new_tensor([float(-LARGENUMBER)]*batch_size).unsqueeze(-1), outside_s_span_1], dim=-1) 298 | outside_s_span_2 = torch.cat([outside_s_span_2, outside_s_span_2.new_tensor([float(-LARGENUMBER)]*batch_size).unsqueeze(-1)], dim=-1) 299 | 300 | # [batch_size, n, 2] 301 | outside_s_span = torch.stack([outside_s_span_1, outside_s_span_2], dim=-1) 302 | 303 | # [batch_size, n] 304 | outside_s_span = logsumexp(outside_s_span, dim=-1) 305 | 306 | outside_s.diagonal(width).copy_(outside_s_span) 307 | 308 | return outside_s 309 | 310 | 311 | def get_sentence_matrix( 312 | sentence_num, 313 | max_sentence_length, 314 | unidimensional_values, 315 | span_location_indices, 316 | padding_value=0. 317 | ): 318 | total_units = sentence_num * max_sentence_length * max_sentence_length 319 | flat_matrix_by_sentence = unidimensional_values.new_full( 320 | (total_units, unidimensional_values.size(-1)), padding_value 321 | ).index_copy(0, span_location_indices, unidimensional_values.view(-1, unidimensional_values.size(-1))) 322 | 323 | return flat_matrix_by_sentence.view(sentence_num, max_sentence_length, max_sentence_length, unidimensional_values.size(-1)) 324 | 325 | 326 | class CFGMentionProposer(torch.nn.Module): 327 | def __init__( 328 | self, 329 | max_span_width=30, 330 | neg_sample_rate=0.2, 331 | **kwargs 332 | ) -> None: 333 | super().__init__(**kwargs) 334 | self.neg_sample_rate = float(neg_sample_rate) 335 | self.cky_module = CKY(max_span_width) 336 | 337 | def forward( 338 | self, # type: ignore 339 | spans: torch.IntTensor, 340 | span_mention_scores: torch.FloatTensor, 341 | span_mask: torch.FloatTensor, 342 | span_labels: torch.IntTensor, 343 | sentence_lengths: torch.IntTensor, 344 | num_spans_to_keep: int, 345 | flat_span_location_indices: torch.IntTensor, 346 | take_top_spans_per_sentence = False, 347 | flat_span_sent_ids = None, 348 | ratio = 0. 349 | ): 350 | # Shape: (batch_size, document_length, embedding_size) 351 | num_spans = spans.size(1) 352 | span_max_item = spans.max() 353 | 354 | sentence_offsets = torch.cumsum(sentence_lengths.squeeze(), 0) 355 | sentence_offsets = torch.cat( 356 | [sentence_offsets.new_zeros(1, 1), sentence_offsets.view(1, -1)], 357 | dim=-1 358 | ) 359 | span_mention_scores = span_mention_scores + (span_mask.unsqueeze(-1) * LARGENUMBER - LARGENUMBER) 360 | max_sentence_length = sentence_lengths.max() 361 | sentence_num = sentence_lengths.size(0) 362 | 363 | # We directly calculate indices of span scores in the matrices during data preparation. 364 | # The indices is 1-d (except the batch dimension) to facilitate index_copy_ 365 | # We copy the span scores into (batch_size, sentence_num, max_sentence_length, max_sentence_length) 366 | # shaped score matrices 367 | 368 | # We will do sentence-level CKY over span scores 369 | # span_mention_scores shape: (batch_size, num_spans, 2) 370 | # the first column scores are for parsing, the second column for linking 371 | 372 | span_score_matrix_by_sentence = get_sentence_matrix( 373 | sentence_num, max_sentence_length, span_mention_scores, 374 | flat_span_location_indices, padding_value=-LARGENUMBER 375 | ) 376 | valid_span_flag_matrix_by_sentence = get_sentence_matrix( 377 | sentence_num, max_sentence_length, torch.ones_like(span_mask).unsqueeze(-1), 378 | flat_span_location_indices, padding_value=0 379 | ).squeeze(-1) 380 | 381 | 382 | Z, marginal = self.cky_module( 383 | span_score_matrix_by_sentence, sentence_lengths 384 | ) 385 | span_marginal = torch.masked_select(marginal, valid_span_flag_matrix_by_sentence) 386 | 387 | if not take_top_spans_per_sentence: 388 | top_span_indices = masked_topk_non_overlap( 389 | span_marginal, 390 | span_mask, 391 | num_spans_to_keep, 392 | spans 393 | ) 394 | span_marginal = clip_to_01(span_marginal) 395 | 396 | top_ind_list = top_span_indices.tolist() 397 | all_ind_list = list(range(0, span_marginal.size(0))) 398 | neg_sample_indices = np.random.choice( 399 | list(set(all_ind_list) - set(top_ind_list)), 400 | int(self.neg_sample_rate * num_spans_to_keep), 401 | False 402 | ) 403 | neg_sample_indices = top_span_indices.new_tensor(sorted(neg_sample_indices)) 404 | else: 405 | top_span_indices, sentwise_top_span_marginal, top_spans = [], [], [] 406 | prev_sent_id, prev_span_id = 0, 0 407 | for span_id, sent_id in enumerate(flat_span_sent_ids.tolist()): 408 | if sent_id != prev_sent_id: 409 | sent_span_indices = masked_topk_non_overlap( 410 | span_marginal[prev_span_id:span_id], 411 | span_mask[prev_span_id:span_id], 412 | int(ratio * sentence_lengths[prev_sent_id]), 413 | spans[prev_span_id:span_id], 414 | non_crossing=False, 415 | ) + prev_span_id 416 | 417 | top_span_indices.append(sent_span_indices) 418 | sentwise_top_span_marginal.append(span_marginal[sent_span_indices]) 419 | top_spans.append(spans[sent_span_indices]) 420 | 421 | prev_sent_id, prev_span_id = sent_id, span_id 422 | # last sentence 423 | sent_span_indices = masked_topk_non_overlap( 424 | span_marginal[prev_span_id:], 425 | span_mask[prev_span_id:], 426 | int(ratio * sentence_lengths[-1]), 427 | spans[prev_span_id:], 428 | non_crossing=False, 429 | ) + prev_span_id 430 | 431 | top_span_indices.append(sent_span_indices) 432 | sentwise_top_span_marginal.append(span_marginal[sent_span_indices]) 433 | top_spans.append(spans[sent_span_indices]) 434 | 435 | num_top_spans = [x.size(0) for x in top_span_indices] 436 | max_num_top_span = max(num_top_spans) 437 | 438 | sentwise_top_span_marginal = torch.stack( 439 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in sentwise_top_span_marginal], dim=0 440 | ) 441 | top_spans = torch.stack( 442 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), 2))], dim=0) for x in top_spans], dim=0 443 | ) 444 | top_span_masks = torch.stack( 445 | [torch.cat([spans.new_ones((x, )), spans.new_zeros((max_num_top_span-x, ))], dim=0) for x in num_top_spans], dim=0 446 | ) 447 | 448 | flat_top_span_indices = torch.cat(top_span_indices, dim=0) 449 | top_span_indices = torch.stack( 450 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in top_span_indices], dim=0 451 | ) 452 | 453 | top_ind_list = flat_top_span_indices.tolist() 454 | all_ind_list = list(range(0, span_marginal.size(0))) 455 | neg_sample_indices = np.random.choice( 456 | list(set(all_ind_list) - set(top_ind_list)), 457 | int(self.neg_sample_rate * num_spans_to_keep), False 458 | ) 459 | neg_sample_indices = top_span_indices.new_tensor(sorted(neg_sample_indices)) 460 | pass # End else 461 | 462 | 463 | if not self.training: 464 | with torch.no_grad(): 465 | best_trees = CKY.viterbi(span_score_matrix_by_sentence.detach(), sentence_lengths) 466 | 467 | best_tree_spans = [(x[:,:2]+offset).cuda() for x, offset in zip(best_trees, sentence_offsets.view(-1).cpu())] 468 | 469 | best_tree_tags = torch.cat([x[:,-1] for x in best_trees], dim=-1).cuda() 470 | best_tree_spans = torch.cat(best_tree_spans, dim=0).cuda() 471 | best_tree_span_mask = (best_tree_tags == 0).unsqueeze(-1) 472 | if best_tree_span_mask.sum() > 0: 473 | # top spans per sentence 474 | helper_matrix = span_mask.new_zeros(span_max_item+1, span_max_item+1) 475 | top_spans = torch.masked_select(best_tree_spans, best_tree_span_mask).view(-1, 2) 476 | helper_matrix[top_spans[:,0],top_spans[:,1]] |= torch.tensor(True) 477 | top_span_mask = helper_matrix[spans[:,0],spans[:,1]] 478 | top_span_indices = torch.nonzero(top_span_mask, as_tuple=True)[0] 479 | 480 | if take_top_spans_per_sentence: 481 | sentwise_top_span_indices, sentwise_top_span_marginal, sentwise_top_spans = [], [], [] 482 | prev_sent_id, prev_span_id = 0, 0 483 | 484 | for span_id, sent_id in enumerate(flat_span_sent_ids.tolist()): 485 | if sent_id != prev_sent_id: 486 | current_sentence_indices = torch.nonzero(top_span_mask[prev_span_id:span_id], as_tuple=True)[0] # unshifted 487 | sentwise_top_span_indices.append(current_sentence_indices + prev_span_id) 488 | sentwise_top_span_marginal.append(span_marginal[prev_span_id:span_id][current_sentence_indices]) 489 | sentwise_top_spans.append(spans[prev_span_id:span_id][current_sentence_indices]) 490 | 491 | prev_sent_id, prev_span_id = sent_id, span_id 492 | 493 | current_sentence_indices = torch.nonzero(top_span_mask[prev_span_id:], as_tuple=True)[0] # unshifted 494 | sentwise_top_span_indices.append(current_sentence_indices + prev_span_id) 495 | sentwise_top_span_marginal.append(span_marginal[prev_span_id:][current_sentence_indices]) 496 | sentwise_top_spans.append(spans[prev_span_id:][current_sentence_indices]) 497 | 498 | num_top_spans = [x.size(0) for x in sentwise_top_span_indices] 499 | max_num_top_span = max(num_top_spans) 500 | 501 | top_spans = torch.stack( 502 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), 2))], dim=0) for x in sentwise_top_spans], dim=0 503 | ) 504 | top_span_masks = torch.stack( 505 | [torch.cat([spans.new_ones((x, )), spans.new_zeros((max_num_top_span-x, ))], dim=0) for x in num_top_spans], dim=0 506 | ) 507 | top_span_indices = torch.stack( 508 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in sentwise_top_span_indices], dim=0 509 | ) 510 | sentwise_top_span_marginal = torch.stack( 511 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in sentwise_top_span_marginal], dim=0 512 | ) 513 | else: 514 | logger.info("expected %d but %d in CKY parse, not using CKY parse" % (num_spans_to_keep, int(best_tree_span_mask.sum()))) 515 | pass 516 | 517 | 518 | if self.training and neg_sample_indices.size(0) > 0: 519 | if take_top_spans_per_sentence: 520 | not_mention_loss = -(1 - span_marginal[neg_sample_indices]).log() 521 | loss = not_mention_loss.mean() * (self.neg_sample_rate * num_spans_to_keep) 522 | else: 523 | non_mention_flag = (span_labels <= 0) 524 | # -log(1 - P(m)) 525 | not_mention_loss = -(1 - span_marginal[neg_sample_indices]).log() * non_mention_flag[neg_sample_indices] 526 | loss = not_mention_loss.mean() * (self.neg_sample_rate * num_spans_to_keep) 527 | else: 528 | loss = 0. 529 | 530 | if not take_top_spans_per_sentence: 531 | top_spans = spans[top_span_indices] 532 | return span_marginal, top_span_indices, top_spans, loss, None 533 | else: 534 | 535 | return sentwise_top_span_marginal, top_span_indices, top_spans, loss, None, top_span_masks 536 | 537 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | from transformers import AdamW 7 | from torch.optim import Adam 8 | from tensorize import CorefDataProcessor 9 | import util 10 | import time 11 | from os.path import join 12 | from metrics import CorefEvaluator 13 | from datetime import datetime 14 | from torch.optim.lr_scheduler import LambdaLR 15 | 16 | import json 17 | 18 | import model as Model 19 | 20 | import conll 21 | import sys 22 | 23 | torch.autograd.set_detect_anomaly(False) 24 | USE_AMP = True 25 | 26 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 27 | datefmt='%m/%d/%Y %H:%M:%S', 28 | level=logging.INFO) 29 | logger = logging.getLogger() 30 | 31 | class Runner: 32 | def __init__(self, config_name, gpu_id=0, seed=None): 33 | self.name = config_name 34 | self.name_suffix = datetime.now().strftime('%b%d_%H-%M-%S') 35 | self.gpu_id = gpu_id 36 | self.seed = seed 37 | 38 | # Set up config 39 | self.config = util.initialize_config(config_name) 40 | 41 | # Set up logger 42 | log_path = join(self.config['log_dir'], 'log_' + self.name_suffix + '.txt') 43 | logger.addHandler(logging.FileHandler(log_path, 'a')) 44 | logger.info('Log file path: %s' % log_path) 45 | 46 | # Set up seed 47 | if seed: 48 | util.set_seed(seed) 49 | 50 | # Set up device 51 | self.device = torch.device('cpu' if gpu_id is None else f'cuda:{gpu_id}') 52 | self.scaler = torch.cuda.amp.GradScaler(init_scale=256.0) 53 | 54 | # Set up data 55 | self.data = CorefDataProcessor(self.config) 56 | 57 | def initialize_model(self, saved_suffix=None): 58 | model = Model.CorefModel(self.config, self.device) 59 | if saved_suffix: 60 | self.load_model_checkpoint(model, saved_suffix) 61 | return model 62 | 63 | def train(self, model): 64 | conf = self.config 65 | logger.info(conf) 66 | epochs, grad_accum = conf['num_epochs'], conf['gradient_accumulation_steps'] 67 | 68 | model.to(self.device) 69 | logger.info('Model parameters:') 70 | for name, param in model.named_parameters(): 71 | logger.info('%s: %s' % (name, tuple(param.shape))) 72 | 73 | # Set up tensorboard 74 | tb_path = join(conf['tb_dir'], self.name + '_' + self.name_suffix) 75 | tb_writer = SummaryWriter(tb_path, flush_secs=30) 76 | logger.info('Tensorboard summary path: %s' % tb_path) 77 | 78 | # Set up data 79 | examples_train, examples_dev, examples_test = self.data.get_tensor_examples() 80 | stored_info = self.data.get_stored_info() 81 | 82 | # Set up optimizer and scheduler 83 | total_update_steps = len(examples_train) * epochs // grad_accum 84 | optimizers = self.get_optimizer(model) 85 | schedulers = self.get_scheduler(optimizers, total_update_steps) 86 | 87 | # Get model parameters for grad clipping 88 | bert_param, task_param = model.get_params() 89 | 90 | # Start training 91 | logger.info('*******************Training*******************') 92 | logger.info('Num samples: %d' % len(examples_train)) 93 | logger.info('Num epochs: %d' % epochs) 94 | logger.info('Gradient accumulation steps: %d' % grad_accum) 95 | logger.info('Total update steps: %d' % total_update_steps) 96 | 97 | loss_during_accum = [] # To compute effective loss at each update 98 | loss_during_report = 0.0 # Effective loss during logging step 99 | loss_history = [] # Full history of effective loss; length equals total update steps 100 | max_f1, max_f1_test = 0, 0 101 | start_time = time.time() 102 | model.zero_grad() 103 | for epo in range(epochs): 104 | print("EPOCH", epo) 105 | random.shuffle(examples_train) # Shuffle training set 106 | for doc_key, example in examples_train: 107 | # Forward pass 108 | model.train() 109 | example_gpu = [d.to(self.device) if d is not None else None for d in example] 110 | 111 | with torch.cuda.amp.autocast(enabled=USE_AMP): 112 | _, loss = model(*example_gpu) 113 | 114 | # Backward; accumulate gradients and clip by grad norm 115 | if grad_accum > 1: 116 | loss /= grad_accum 117 | 118 | if USE_AMP: 119 | scaled_loss = self.scaler.scale(loss) 120 | scaled_loss.backward() 121 | else: 122 | loss.backward() 123 | 124 | loss_during_accum.append(loss.item()) 125 | 126 | # Update 127 | if len(loss_during_accum) % grad_accum == 0: 128 | if USE_AMP: 129 | self.scaler.unscale_(optimizers[0]) 130 | self.scaler.unscale_(optimizers[1]) 131 | 132 | if conf['max_grad_norm']: 133 | norm_bert = torch.nn.utils.clip_grad_norm_(bert_param, conf['max_grad_norm'], error_if_nonfinite=False) 134 | norm_task = torch.nn.utils.clip_grad_norm_(task_param, conf['max_grad_norm'], error_if_nonfinite=False) 135 | 136 | for optimizer in optimizers: 137 | if USE_AMP: 138 | self.scaler.step(optimizer) 139 | else: 140 | optimizer.step() 141 | 142 | if USE_AMP: 143 | self.scaler.update() 144 | 145 | model.zero_grad() 146 | for scheduler in schedulers: 147 | scheduler.step() 148 | 149 | # Compute effective loss 150 | effective_loss = np.sum(loss_during_accum).item() 151 | loss_during_accum = [] 152 | loss_during_report += effective_loss 153 | loss_history.append(effective_loss) 154 | 155 | # Report 156 | if len(loss_history) % conf['report_frequency'] == 0: 157 | # Show avg loss during last report interval 158 | avg_loss = loss_during_report / conf['report_frequency'] 159 | loss_during_report = 0.0 160 | end_time = time.time() 161 | logger.info('Step %d: avg loss %.2f; steps/sec %.2f' % 162 | (len(loss_history), avg_loss, conf['report_frequency'] / (end_time - start_time))) 163 | start_time = end_time 164 | 165 | tb_writer.add_scalar('Training_Loss', avg_loss, len(loss_history)) 166 | tb_writer.add_scalar('Learning_Rate_Bert', schedulers[0].get_last_lr()[0], len(loss_history)) 167 | tb_writer.add_scalar('Learning_Rate_Task', schedulers[1].get_last_lr()[-1], len(loss_history)) 168 | 169 | # Evaluate 170 | if len(loss_history) > 0 and len(loss_history) % conf['eval_frequency'] == 0: 171 | # Testing Dev 172 | logger.info('Dev') 173 | f1, _ = self.evaluate(model, examples_dev, stored_info, len(loss_history), official=False, conll_path=self.config['conll_eval_path'], tb_writer=tb_writer) 174 | # Testing Test 175 | logger.info('Test') 176 | f1_test = 0. 177 | if f1 > max_f1 or f1_test > max_f1_test: 178 | max_f1 = max(max_f1, f1) 179 | max_f1_test = 0. # max(max_f1_test, f1_test) 180 | self.save_model_checkpoint(model, len(loss_history)) 181 | 182 | logger.info('Eval max f1: %.2f' % max_f1) 183 | logger.info('Test max f1: %.2f' % max_f1_test) 184 | start_time = time.time() 185 | 186 | logger.info('**********Finished training**********') 187 | logger.info('Actual update steps: %d' % len(loss_history)) 188 | 189 | # Wrap up 190 | tb_writer.close() 191 | return loss_history 192 | 193 | def evaluate(self, model, tensor_examples, stored_info, step, official=False, conll_path=None, tb_writer=None, predict=False): 194 | logger.info('Step %d: evaluating on %d samples...' % (step, len(tensor_examples))) 195 | 196 | model.to(self.device) 197 | evaluator = CorefEvaluator() 198 | doc_to_prediction = {} 199 | 200 | model.eval() 201 | 202 | for i, (doc_key, tensor_example) in enumerate(tensor_examples): 203 | current_json = {} 204 | 205 | gold_clusters = stored_info['gold'][doc_key] 206 | tensor_example = tensor_example[:7] # Strip out gold 207 | example_gpu = [d.to(self.device) for d in tensor_example] 208 | 209 | with torch.no_grad(): 210 | returned_tuple = model(*example_gpu) 211 | 212 | if len(returned_tuple) == 10: 213 | _, _, _, span_starts, span_ends, antecedent_idx, antecedent_scores, score_j_i, input_ids, head_cond_score = returned_tuple 214 | 215 | current_json["score_j_i"] = score_j_i.tolist() 216 | current_json["input_ids"] = input_ids.tolist() 217 | current_json["head_cond_score"] = head_cond_score.tolist() 218 | 219 | elif len(returned_tuple) == 7: 220 | _, _, _, span_starts, span_ends, antecedent_idx, antecedent_scores = returned_tuple 221 | 222 | span_starts, span_ends = span_starts.tolist(), span_ends.tolist() 223 | antecedent_idx, antecedent_scores = antecedent_idx, antecedent_scores 224 | 225 | predicted_clusters = model.update_evaluator(span_starts, span_ends, antecedent_idx, antecedent_scores, gold_clusters, evaluator) 226 | doc_to_prediction[doc_key] = predicted_clusters 227 | 228 | 229 | p, r, f, m_recall, (blanc_p, blanc_r, blanc_f) = evaluator.get_prf() 230 | all_metrics = evaluator.get_all() 231 | 232 | metrics = { 233 | 'Eval_Avg_Precision': p * 100, 'Eval_Avg_Recall': r * 100, 'Eval_Avg_F1': f * 100, 234 | "Eval_Men_Recall": m_recall*100, 235 | "Eval_Blanc_Precision": blanc_p * 100, "Eval_Blanc_Recall": blanc_r * 100, "Eval_Blanc_F1": blanc_f * 100 236 | } 237 | 238 | for k,v in all_metrics.items(): 239 | logger.info('%s: %.4f'%(k, v)) 240 | 241 | for name, score in metrics.items(): 242 | logger.info('%s: %.2f' % (name, score)) 243 | if tb_writer: 244 | tb_writer.add_scalar(name, score, step) 245 | 246 | if official: 247 | conll_results = conll.evaluate_conll(conll_path, doc_to_prediction, stored_info['subtoken_maps']) 248 | official_f1 = sum(results["f"] for results in conll_results.values()) / len(conll_results) 249 | logger.info('Official avg F1: %.4f' % official_f1) 250 | 251 | return f * 100, metrics 252 | 253 | def predict(self, model, tensor_examples): 254 | logger.info('Predicting %d samples...' % len(tensor_examples)) 255 | model.to(self.device) 256 | model.eval() 257 | predicted_spans, predicted_antecedents, predicted_clusters = [], [], [] 258 | 259 | for i, (doc_key, tensor_example) in enumerate(tensor_examples): 260 | tensor_example = tensor_example[:7] 261 | example_gpu = [d.to(self.device) for d in tensor_example] 262 | with torch.no_grad(): 263 | _, _, _, span_starts, span_ends, antecedent_idx, antecedent_scores = model(*example_gpu) 264 | span_starts, span_ends = span_starts.tolist(), span_ends.tolist() 265 | antecedent_idx, antecedent_scores = antecedent_idx.tolist(), antecedent_scores.tolist() 266 | clusters, mention_to_cluster_id, antecedents = model.get_predicted_clusters(span_starts, span_ends, antecedent_idx, antecedent_scores) 267 | 268 | spans = [(span_start, span_end) for span_start, span_end in zip(span_starts, span_ends)] 269 | predicted_spans.append(spans) 270 | predicted_antecedents.append(antecedents) 271 | predicted_clusters.append(clusters) 272 | 273 | return predicted_clusters, predicted_spans, predicted_antecedents 274 | 275 | def get_optimizer(self, model): 276 | no_decay = ['bias', 'LayerNorm.weight'] 277 | bert_param, task_param = model.get_params(named=True) 278 | grouped_bert_param = [ 279 | { 280 | 'params': [p for n, p in bert_param if not any(nd in n for nd in no_decay)], 281 | 'lr': self.config['bert_learning_rate'], 282 | 'weight_decay': self.config['adam_weight_decay'] 283 | }, { 284 | 'params': [p for n, p in bert_param if any(nd in n for nd in no_decay)], 285 | 'lr': self.config['bert_learning_rate'], 286 | 'weight_decay': 0.0 287 | } 288 | ] 289 | optimizers = [ 290 | AdamW(grouped_bert_param, lr=self.config['bert_learning_rate'], eps=self.config['adam_eps']), 291 | Adam(model.get_params()[1], lr=self.config['task_learning_rate'], eps=self.config['adam_eps'], weight_decay=0) 292 | ] 293 | return optimizers 294 | 295 | def get_scheduler(self, optimizers, total_update_steps): 296 | # Only warm up bert lr 297 | warmup_steps = int(total_update_steps * self.config['warmup_ratio']) 298 | 299 | def lr_lambda_bert(current_step): 300 | if current_step < warmup_steps: 301 | return float(current_step) / float(max(1, warmup_steps)) 302 | return max( 303 | 0.0, float(total_update_steps - current_step) / float(max(1, total_update_steps - warmup_steps)) 304 | ) 305 | 306 | def lr_lambda_task(current_step): 307 | return max(0.0, float(total_update_steps - current_step) / float(max(1, total_update_steps))) 308 | 309 | schedulers = [ 310 | LambdaLR(optimizers[0], lr_lambda_bert), 311 | LambdaLR(optimizers[1], lr_lambda_task) 312 | ] 313 | return schedulers 314 | 315 | def save_model_checkpoint(self, model, step): 316 | path_ckpt = join(self.config['log_dir'], f'model_{self.name_suffix}_{step}.bin') 317 | torch.save(model.state_dict(), path_ckpt) 318 | logger.info('Saved model to %s' % path_ckpt) 319 | 320 | def load_model_checkpoint(self, model, suffix): 321 | path_ckpt = join(self.config['log_dir'], f'model_{suffix}.bin') 322 | model.load_state_dict(torch.load(path_ckpt, map_location=torch.device('cpu')), strict=False) 323 | logger.info('Loaded model from %s' % path_ckpt) 324 | 325 | 326 | if __name__ == '__main__': 327 | config_name, gpu_id = sys.argv[1], int(sys.argv[2]) 328 | saved_suffix = sys.argv[3] if len(sys.argv) >= 4 else None 329 | runner = Runner(config_name, gpu_id) 330 | model = runner.initialize_model(saved_suffix) 331 | 332 | runner.train(model) 333 | -------------------------------------------------------------------------------- /sss.yml: -------------------------------------------------------------------------------- 1 | name: sss 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - huggingface_hub 8 | - pip 9 | - scipy 10 | - pyhocon 11 | - python=3.8 12 | - pytorch=1.11.0 13 | - tensorboard 14 | - tqdm 15 | - transformers=4.19.3 16 | - numpy 17 | - scikit-learn -------------------------------------------------------------------------------- /tensorize.py: -------------------------------------------------------------------------------- 1 | import util 2 | import numpy as np 3 | import random 4 | from transformers import AutoTokenizer 5 | import os 6 | from os.path import join 7 | import json 8 | import pickle 9 | import logging 10 | import torch 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class CorefDataProcessor: 15 | def __init__(self, config, language='english'): 16 | self.config = config 17 | self.language = language 18 | 19 | self.max_seg_len = config['max_segment_len'] 20 | self.max_training_seg = config['max_training_sentences'] 21 | self.data_dir = config['data_dir'] 22 | 23 | # Get tensorized samples 24 | cache_path = self.get_cache_path() 25 | if os.path.exists(cache_path): 26 | # Load cached tensors if exists 27 | with open(cache_path, 'rb') as f: 28 | self.tensor_samples, self.stored_info = pickle.load(f) 29 | logger.info('Loaded tensorized examples from cache') 30 | else: 31 | # Generate tensorized samples 32 | if self.config["dataset"] == "ontonotes": 33 | self.tensor_samples = {} 34 | tensorizer = Tensorizer(self.config) 35 | paths = { 36 | 'trn': join(self.data_dir, f'train.{language}.{self.max_seg_len}.jsonlines'), 37 | 'dev': join(self.data_dir, f'dev.{language}.{self.max_seg_len}.jsonlines'), 38 | 'tst': join(self.data_dir, f'test.{language}.{self.max_seg_len}.jsonlines') 39 | } 40 | for split, path in paths.items(): 41 | logger.info('Tensorizing examples from %s; results will be cached)' % path) 42 | is_training = (split == 'trn') 43 | with open(path, 'r') as f: 44 | samples = [json.loads(line) for line in f.readlines()] 45 | tensor_samples = [tensorizer.tensorize_example(sample, is_training) for sample in samples] 46 | print(len(tensor_samples[0])) 47 | self.tensor_samples[split] = [(doc_key, self.convert_to_torch_tensor(*tensor)) for doc_key, tensor in tensor_samples] 48 | self.stored_info = tensorizer.stored_info 49 | # Cache tensorized samples 50 | with open(cache_path, 'wb') as f: 51 | pickle.dump((self.tensor_samples, self.stored_info), f) 52 | 53 | 54 | @classmethod 55 | def convert_to_torch_tensor(cls, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, 56 | is_training, gold_starts, gold_ends, gold_mention_cluster_map, 57 | coreferable_starts, coreferable_ends, 58 | constituent_starts, constituent_ends, constituent_type): 59 | 60 | input_ids = torch.tensor(input_ids, dtype=torch.long) 61 | input_mask = torch.tensor(input_mask, dtype=torch.long) 62 | speaker_ids = torch.tensor(speaker_ids, dtype=torch.long) 63 | sentence_len = torch.tensor(sentence_len, dtype=torch.long) 64 | genre = torch.tensor(genre, dtype=torch.long) 65 | sentence_map = torch.tensor(sentence_map, dtype=torch.long) 66 | is_training = torch.tensor(is_training, dtype=torch.bool) 67 | gold_starts = torch.tensor(gold_starts, dtype=torch.long) 68 | gold_ends = torch.tensor(gold_ends, dtype=torch.long) 69 | gold_mention_cluster_map = torch.tensor(gold_mention_cluster_map, dtype=torch.long) 70 | coreferable_starts = torch.tensor(coreferable_starts, dtype=torch.long) if coreferable_starts is not None else None 71 | coreferable_ends = torch.tensor(coreferable_ends, dtype=torch.long) if coreferable_ends is not None else None 72 | 73 | constituent_starts = torch.tensor(constituent_starts, dtype=torch.long) if constituent_starts is not None else None 74 | constituent_ends = torch.tensor(constituent_ends, dtype=torch.long) if constituent_ends is not None else None 75 | constituent_type = None 76 | 77 | return input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, \ 78 | is_training, gold_starts, gold_ends, gold_mention_cluster_map, \ 79 | coreferable_starts, coreferable_ends, \ 80 | constituent_starts, constituent_ends, constituent_type 81 | 82 | def get_tensor_examples(self): 83 | # For each split, return list of tensorized samples to allow variable length input (batch size = 1) 84 | return self.tensor_samples['trn'], self.tensor_samples['dev'], self.tensor_samples['tst'] 85 | 86 | def get_stored_info(self): 87 | return self.stored_info 88 | 89 | def get_cache_path(self): 90 | if self.config["dataset"] == "ontonotes": 91 | cache_path = join(self.data_dir, f'cached.tensors.{self.language}.{self.max_seg_len}.{self.max_training_seg}.bin') 92 | 93 | return cache_path 94 | 95 | 96 | class Tensorizer: 97 | def __init__(self, config): 98 | self.config = config 99 | self.tokenizer = AutoTokenizer.from_pretrained(config['bert_tokenizer_name']) 100 | 101 | # Will be used in evaluation 102 | self.stored_info = {} 103 | self.stored_info['tokens'] = {} # {doc_key: ...} 104 | self.stored_info['subtoken_maps'] = {} # {doc_key: ...}; mapping back to tokens 105 | self.stored_info['gold'] = {} # {doc_key: ...} 106 | self.stored_info['genre_dict'] = {genre: idx for idx, genre in enumerate(config['genres'])} 107 | self.stored_info['constituents'] = {} 108 | 109 | def _tensorize_spans(self, spans): 110 | if len(spans) > 0: 111 | starts, ends = zip(*spans) 112 | else: 113 | starts, ends = [], [] 114 | return np.array(starts), np.array(ends) 115 | 116 | def _tensorize_span_w_labels(self, spans, label_dict): 117 | if len(spans) > 0: 118 | starts, ends, labels = zip(*spans) 119 | else: 120 | starts, ends, labels = [], [], [] 121 | return np.array(starts), np.array(ends), np.array([label_dict[label] for label in labels]) 122 | 123 | def _get_speaker_dict(self, speakers): 124 | speaker_dict = {'UNK': 0, '[SPL]': 1} 125 | for speaker in speakers: 126 | if len(speaker_dict) > self.config['max_num_speakers']: 127 | pass # 'break' to limit # speakers 128 | if speaker not in speaker_dict: 129 | speaker_dict[speaker] = len(speaker_dict) 130 | return speaker_dict 131 | 132 | def tensorize_example(self, example, is_training): 133 | # Mentions and clusters 134 | clusters = example['clusters'] 135 | gold_mentions = sorted(tuple(mention) for mention in util.flatten(clusters)) 136 | 137 | gold_coreferables = sorted(tuple(mention) for mention in example["coreferables"]) if "coreferables" in example else None 138 | gold_constituents = list(tuple(mention) for mention in example["constituents"]) if "constituents" in example else None 139 | gold_constituent_type = list(example["constituent_type"]) if "constituent_type" in example else None 140 | 141 | gold_mention_map = {mention: idx for idx, mention in enumerate(gold_mentions)} 142 | gold_mention_cluster_map = np.zeros(len(gold_mentions)) # 0: no cluster 143 | for cluster_id, cluster in enumerate(clusters): 144 | for mention in cluster: 145 | gold_mention_cluster_map[gold_mention_map[tuple(mention)]] = cluster_id + 1 146 | 147 | # Speakers 148 | speakers = example['speakers'] 149 | speaker_dict = self._get_speaker_dict(util.flatten(speakers)) 150 | 151 | # Sentences/segments 152 | sentences = example['sentences'] # Segments 153 | sentence_map = example['sentence_map'] 154 | num_words = sum([len(s) for s in sentences]) 155 | max_sentence_len = self.config['max_segment_len'] 156 | sentence_len = np.array([len(s) for s in sentences]) 157 | 158 | # Bert input 159 | input_ids, input_mask, speaker_ids = [], [], [] 160 | for idx, (sent_tokens, sent_speakers) in enumerate(zip(sentences, speakers)): 161 | sent_input_ids = self.tokenizer.convert_tokens_to_ids(sent_tokens) 162 | sent_input_mask = [1] * len(sent_input_ids) 163 | sent_speaker_ids = [speaker_dict[speaker] for speaker in sent_speakers] 164 | while len(sent_input_ids) < max_sentence_len: 165 | sent_input_ids.append(0) 166 | sent_input_mask.append(0) 167 | sent_speaker_ids.append(0) 168 | input_ids.append(sent_input_ids) 169 | input_mask.append(sent_input_mask) 170 | speaker_ids.append(sent_speaker_ids) 171 | input_ids = np.array(input_ids) 172 | input_mask = np.array(input_mask) 173 | speaker_ids = np.array(speaker_ids) 174 | assert num_words == np.sum(input_mask), (num_words, np.sum(input_mask)) 175 | 176 | # Keep info to store 177 | doc_key = example['doc_key'] 178 | self.stored_info['subtoken_maps'][doc_key] = example.get('subtoken_map', None) 179 | self.stored_info['gold'][doc_key] = example['clusters'] 180 | # self.stored_info['constituents'][doc_key] = example['constituents'] 181 | # self.stored_info['tokens'][doc_key] = example['tokens'] 182 | 183 | # Construct example 184 | genre = self.stored_info['genre_dict'].get(doc_key[:2], 0) 185 | gold_starts, gold_ends = self._tensorize_spans(gold_mentions) 186 | coreferable_starts, coreferable_ends = self._tensorize_spans(gold_coreferables) if gold_coreferables is not None else (None, None) 187 | constituent_starts, constituent_ends = self._tensorize_spans(gold_constituents) if gold_constituents is not None else (None, None) 188 | constituent_type = np.array(gold_constituent_type) if gold_constituent_type is not None else None 189 | 190 | example_tensor = (input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, is_training, 191 | gold_starts, gold_ends, gold_mention_cluster_map, coreferable_starts, coreferable_ends, 192 | constituent_starts, constituent_ends, constituent_type) 193 | if is_training and len(sentences) > self.config['max_training_sentences']: 194 | return doc_key, self.truncate_example(*example_tensor, max_sentences=self.config['max_training_sentences']) 195 | else: 196 | return doc_key, example_tensor 197 | 198 | def truncate_example(self, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, is_training, 199 | gold_starts, gold_ends, gold_mention_cluster_map, coreferable_starts, coreferable_ends, 200 | constituent_starts, constituent_ends, constituent_type, 201 | max_sentences, sentence_offset=None): 202 | num_sentences = input_ids.shape[0] 203 | assert num_sentences > max_sentences 204 | 205 | sent_offset = sentence_offset 206 | if sent_offset is None: 207 | sent_offset = random.randint(0, num_sentences - max_sentences) 208 | word_offset = sentence_len[:sent_offset].sum() 209 | num_words = sentence_len[sent_offset: sent_offset + max_sentences].sum() 210 | 211 | input_ids = input_ids[sent_offset: sent_offset + max_sentences, :] 212 | input_mask = input_mask[sent_offset: sent_offset + max_sentences, :] 213 | speaker_ids = speaker_ids[sent_offset: sent_offset + max_sentences, :] 214 | sentence_len = sentence_len[sent_offset: sent_offset + max_sentences] 215 | 216 | sentence_map = sentence_map[word_offset: word_offset + num_words] 217 | 218 | gold_spans = (gold_starts < word_offset + num_words) & (gold_ends >= word_offset) 219 | gold_starts = gold_starts[gold_spans] - word_offset 220 | gold_ends = gold_ends[gold_spans] - word_offset 221 | gold_mention_cluster_map = gold_mention_cluster_map[gold_spans] 222 | 223 | coreferable_flags = (coreferable_starts < word_offset + num_words) & (coreferable_ends >= word_offset) if coreferable_starts is not None else None 224 | coreferable_starts = coreferable_starts[coreferable_flags] - word_offset if coreferable_starts is not None else None 225 | coreferable_ends = coreferable_ends[coreferable_flags] - word_offset if coreferable_starts is not None else None 226 | 227 | constituent_flags = (constituent_starts < word_offset + num_words) & (constituent_ends >= word_offset) if constituent_starts is not None else None 228 | constituent_starts = constituent_starts[constituent_flags] - word_offset if constituent_starts is not None else None 229 | constituent_ends = constituent_ends[constituent_flags] - word_offset if constituent_starts is not None else None 230 | constituent_type = constituent_type[constituent_flags] if constituent_type is not None else None 231 | 232 | return input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, \ 233 | is_training, gold_starts, gold_ends, gold_mention_cluster_map, coreferable_starts, coreferable_ends, \ 234 | constituent_starts, constituent_ends, constituent_type 235 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import numpy as np 4 | import pyhocon 5 | import logging 6 | import torch 7 | import random 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def flatten(l): 13 | return [item for sublist in l for item in sublist] 14 | 15 | 16 | def initialize_config(config_name, task="coref"): 17 | logger.info("Running {} experiment: {}".format(task, config_name)) 18 | 19 | if task == "coref": 20 | config = pyhocon.ConfigFactory.parse_file("experiments.conf")[config_name] 21 | elif task == "srl": 22 | config = pyhocon.ConfigFactory.parse_file("experiments_srl.conf")[config_name] 23 | 24 | config['log_dir'] = join(config["log_root"], config_name) 25 | makedirs(config['log_dir'], exist_ok=True) 26 | 27 | config['tb_dir'] = join(config['log_root'], 'tensorboard') 28 | makedirs(config['tb_dir'], exist_ok=True) 29 | 30 | logger.info(pyhocon.HOCONConverter.convert(config, "hocon")) 31 | return config 32 | 33 | 34 | def set_seed(seed, set_gpu=True): 35 | random.seed(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | if set_gpu and torch.cuda.is_available(): 39 | # Necessary for reproducibility; lower performance 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | torch.cuda.manual_seed_all(seed) 43 | logger.info('Random seed is set to %d' % seed) 44 | 45 | 46 | def bucket_distance(offsets): 47 | """ offsets: [num spans1, num spans2] """ 48 | # 10 semi-logscale bin: 0, 1, 2, 3, 4, (5-7)->5, (8-15)->6, (16-31)->7, (32-63)->8, (64+)->9 49 | logspace_distance = torch.log2(offsets.to(torch.float)).to(torch.long) + 3 50 | identity_mask = (offsets <= 4).to(torch.long) 51 | combined_distance = identity_mask * offsets + (1 - identity_mask) * logspace_distance 52 | combined_distance = torch.clamp(combined_distance, 0, 9) 53 | return combined_distance 54 | 55 | 56 | def batch_select(tensor, idx, device=torch.device('cpu')): 57 | """ Do selection per row (first axis). """ 58 | assert tensor.shape[0] == idx.shape[0] # Same size of first dim 59 | dim0_size, dim1_size = tensor.shape[0], tensor.shape[1] 60 | 61 | tensor = torch.reshape(tensor, [dim0_size * dim1_size, -1]) 62 | idx_offset = torch.unsqueeze(torch.arange(0, dim0_size, device=device) * dim1_size, 1) 63 | new_idx = idx + idx_offset 64 | selected = tensor[new_idx] 65 | 66 | if tensor.shape[-1] == 1: # If selected element is scalar, restore original dim 67 | selected = torch.squeeze(selected, -1) 68 | 69 | return selected 70 | 71 | 72 | def batch_add(tensor, idx, val, device=torch.device('cpu')): 73 | """ Do addition per row (first axis). """ 74 | assert tensor.shape[0] == idx.shape[0] # Same size of first dim 75 | dim0_size, dim1_size = tensor.shape[0], tensor.shape[1] 76 | 77 | tensor = torch.reshape(tensor, [dim0_size * dim1_size, -1]) 78 | idx_offset = torch.unsqueeze(torch.arange(0, dim0_size, device=device) * dim1_size, 1) 79 | new_idx = idx + idx_offset 80 | 81 | val = val.reshape(val.size(0) * val.size(1), -1) 82 | 83 | res = tensor.index_add(0, new_idx.view(-1), val).reshape([dim0_size, dim1_size, -1]) 84 | 85 | if tensor.shape[-1] == 1: # If selected element is scalar, restore original dim 86 | res = res.squeeze(-1) 87 | 88 | return res 89 | 90 | 91 | def sample_subset(w, k, t=0.1): 92 | ''' 93 | Args: 94 | w (Tensor): Float Tensor of weights for each element. In gumbel mode 95 | these are interpreted as log probabilities 96 | k (int): number of elements in the subset sample 97 | t (float): temperature of the softmax 98 | ''' 99 | wg = gumbel_perturb(w) 100 | return continuous_topk(wg, k, t) 101 | 102 | 103 | def logsumexp(tensor: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor: 104 | max_score, _ = tensor.max(dim, keepdim=keepdim) 105 | if keepdim: 106 | stable_vec = tensor - max_score 107 | else: 108 | stable_vec = tensor - max_score.unsqueeze(dim) 109 | 110 | return max_score + stable_vec.logsumexp(dim, keepdim=keepdim) # (stable_vec.exp().sum(dim, keepdim=keepdim)).log() 111 | 112 | 113 | def clip_to_01(x: torch.Tensor): 114 | eps = torch.finfo(x.dtype).eps 115 | tiny = torch.finfo(x.dtype).tiny 116 | 117 | x_val = x.detach() 118 | 119 | cond_greater = (x_val > (1.0 - eps)) 120 | diff_greater = (x_val - 1.0 + eps) 121 | 122 | cond_less = (x < tiny) 123 | diff_less = (tiny - x_val) 124 | 125 | x -= diff_greater * cond_greater 126 | x += diff_less * cond_less 127 | 128 | return x 129 | 130 | 131 | def log1mexp(x): 132 | x -= torch.finfo(x.dtype).eps 133 | return torch.where(x > -0.693, (-torch.expm1(x)).log(), torch.log1p(-(x.exp()))) 134 | 135 | 136 | def gumbel_perturb(w): 137 | uniform_01 = torch.distributions.uniform.Uniform(1e-6, 1.0) 138 | # sample some gumbels 139 | u = uniform_01.sample(w.size()).cuda() 140 | g = -torch.log(-torch.log(u)) 141 | w = w + g 142 | return w 143 | 144 | 145 | def continuous_topk(w, k, t): 146 | khot_list = [] 147 | onehot_approx = torch.zeros_like(w) 148 | 149 | for i in range(k): 150 | khot_mask = torch.clamp(1.0 - onehot_approx, min=1e-6) 151 | w += torch.log(khot_mask) 152 | onehot_approx = F.softmax(w / t, dim=-1) 153 | khot_list.append(onehot_approx) 154 | 155 | return torch.stack(khot_list, dim=0) 156 | 157 | 158 | def stripe(x, n, w, offset=(0, 0), horizontal=1): 159 | r""" 160 | Returns a diagonal stripe of the tensor. 161 | Args: 162 | x (~torch.Tensor): the input tensor with 2 or more dims. 163 | n (int): the length of the stripe. 164 | w (int): the width of the stripe. 165 | offset (tuple): the offset of the first two dims. 166 | dim (int): 1 if returns a horizontal stripe; 0 otherwise. 167 | Returns: 168 | a diagonal stripe of the tensor. 169 | Examples: 170 | >>> x = torch.arange(25).view(5, 5) 171 | >>> x 172 | tensor([[ 0, 1, 2, 3, 4], 173 | [ 5, 6, 7, 8, 9], 174 | [10, 11, 12, 13, 14], 175 | [15, 16, 17, 18, 19], 176 | [20, 21, 22, 23, 24]]) 177 | >>> stripe(x, 2, 3) 178 | tensor([[0, 1, 2], 179 | [6, 7, 8]]) 180 | >>> stripe(x, 2, 3, (1, 1)) 181 | tensor([[ 6, 7, 8], 182 | [12, 13, 14]]) 183 | >>> stripe(x, 2, 3, (1, 1), 0) 184 | tensor([[ 6, 11, 16], 185 | [12, 17, 22]]) 186 | """ 187 | x, seq_len = x.contiguous(), x.size(1) 188 | stride, numel = list(x.stride()), x[0, 0].numel() 189 | stride[0] = (seq_len + 1) * numel 190 | stride[1] = (1 if horizontal == 1 else seq_len) * numel 191 | 192 | return x.as_strided( 193 | size=(n, w, *x.shape[2:]), 194 | stride=stride, 195 | storage_offset=(offset[0]*seq_len+offset[1])*numel 196 | ) 197 | 198 | 199 | def masked_topk_non_overlap( 200 | span_scores, 201 | span_mask, 202 | num_spans_to_keep, 203 | spans, 204 | non_crossing=True 205 | ): 206 | sorted_scores, sorted_indices = torch.sort(span_scores + span_mask.log(), descending=True) 207 | sorted_indices = sorted_indices.tolist() 208 | spans = spans.tolist() 209 | 210 | if not non_crossing: 211 | selected_candidate_idx = sorted(sorted_indices[:num_spans_to_keep], key=lambda idx: (spans[idx][0], spans[idx][1])) 212 | selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long) 213 | return selected_candidate_idx 214 | 215 | selected_candidate_idx = [] 216 | start_to_max_end, end_to_min_start = {}, {} 217 | for candidate_idx in sorted_indices: 218 | if len(selected_candidate_idx) >= num_spans_to_keep: 219 | break 220 | # Perform overlapping check 221 | span_start_idx = spans[candidate_idx][0] 222 | span_end_idx = spans[candidate_idx][1] 223 | cross_overlap = False 224 | for token_idx in range(span_start_idx, span_end_idx + 1): 225 | max_end = start_to_max_end.get(token_idx, -1) 226 | if token_idx > span_start_idx and max_end > span_end_idx: 227 | cross_overlap = True 228 | break 229 | min_start = end_to_min_start.get(token_idx, -1) 230 | if token_idx < span_end_idx and 0 <= min_start < span_start_idx: 231 | cross_overlap = True 232 | break 233 | if not cross_overlap: 234 | # Pass check; select idx and update dict stats 235 | selected_candidate_idx.append(candidate_idx) 236 | max_end = start_to_max_end.get(span_start_idx, -1) 237 | if span_end_idx > max_end: 238 | start_to_max_end[span_start_idx] = span_end_idx 239 | min_start = end_to_min_start.get(span_end_idx, -1) 240 | if min_start == -1 or span_start_idx < min_start: 241 | end_to_min_start[span_end_idx] = span_start_idx 242 | # Sort selected candidates by span idx 243 | selected_candidate_idx = sorted(selected_candidate_idx, key=lambda idx: (spans[idx][0], spans[idx][1])) 244 | selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long) 245 | 246 | return selected_candidate_idx 247 | --------------------------------------------------------------------------------