├── LICENSE
├── README.md
├── bert_modelling.py
├── blanc.py
├── conll.py
├── evaluate.py
├── experiments.conf
├── fig
├── illustration.png
└── structured_span_selector_cky_chart.pdf
├── greedy_mp.py
├── metrics.py
├── minimize.py
├── model.py
├── outside_mp.py
├── run.py
├── sss.yml
├── tensorize.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [Tianyu Liu] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## A Structured Span Selector
2 |
3 | https://arxiv.org/pdf/2205.03977.pdf
4 |
5 | This repository contains the open-sourced official implementation of our **structured span selector** paper:
6 |
7 | [A Structured Span Selector](https://arxiv.org/abs/2205.03977) (NAACL 2022).
8 | _Tianyu Liu, Yuchen Eleanor Jiang, Ryan Cotterell, and Mrinmaya Sachan_
9 |
10 |
11 | ## Overall idea
12 |
13 | For all **span selection** tasks (e.g. coreference resolution, semantic role labelling, question answering), we learn the latent **context-free grammar** of the spans of interest. The search space of spans $O(n^2)$ is reduced to the space of nonterminals $O(n)$.
14 |
15 |
16 |
17 |
18 | ## Installation
19 |
20 | First of all:
21 | ```bash
22 | git clone https://github.com/lyutyuh/structured-span-selector.git
23 | cd structured-span-selector
24 | ```
25 |
26 | 1. Create a virtual environment with Conda
27 | ```bash
28 | conda env create -f sss.yml
29 | ```
30 |
31 | 2. Activate the new environment
32 | ```bash
33 | conda activate sss
34 | ```
35 |
36 | 3. **Install genbmm with [inside-outside algorithm extension](https://github.com/lyutyuh/genbmm)**
37 | ```bash
38 | pip install git+https://github.com/lyutyuh/genbmm
39 | ```
40 |
41 |
42 | ## Obtaining the CoNLL-2012 data
43 |
44 | Please follow https://github.com/mandarjoshi90/coref and especially https://github.com/mandarjoshi90/coref/blob/master/setup_training.sh to obtain the {train, dev, test}.english.v4_gold_conll. There are 2802, 343, 348 documents in the {train, dev, test} datasets respectively.
45 |
46 | The MD5 values are:
47 | ```bash
48 | md5sum dev.english.v4_gold_conll
49 | >>> bde418ea4bbec119b3a43b43933ec2ae
50 | md5sum test.english.v4_gold_conll
51 | >>> 6e64b649a039b4320ad32780db3abfa1
52 | md5sum train.english.v4_gold_conll
53 | >>> 9f92a664298dc78600fd50813246aa77
54 | ```
55 |
56 | Then, run
57 | ```bash
58 | python minimize.py ./data_dir/ ./data_dir/ false
59 | ```
60 | and get the jsonlines files.
61 |
62 | ## Training
63 |
64 | ```bash
65 | python run.py spanbert_large 0
66 | ```
67 |
68 | ## Evaluating
69 |
70 | ```bash
71 | python evaluate.py spanbert_large 0
72 | ```
73 |
74 |
75 |
76 | ## Citing
77 |
78 | If you find this repo helpful, please cite the following version of the paper:
79 | ```tex
80 | @inproceedings{liu-etal-2022-structured,
81 | title = "A Structured Span Selector",
82 | author = "Liu, Tianyu and
83 | Jiang, Yuchen and
84 | Cotterell, Ryan and
85 | Sachan, Mrinmaya",
86 | booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies",
87 | month = jul,
88 | year = "2022",
89 | address = "Seattle, United States",
90 | publisher = "Association for Computational Linguistics",
91 | url = "https://aclanthology.org/2022.naacl-main.189",
92 | pages = "2629--2641",
93 | }
94 | ```
95 |
--------------------------------------------------------------------------------
/blanc.py:
--------------------------------------------------------------------------------
1 | import math
2 | import typing as ty
3 |
4 | import numpy as np
5 |
6 | def trace(cluster: ty.Set, partition: ty.Iterable[ty.Set]) -> ty.Iterable[ty.Set]:
7 | r"""
8 | Return the partition of `#cluster` induced by `#partition`, that is
9 | ```math
10 | \{C∩A|A∈P\} ∪ \{\{x\}|x∈C∖∪P\}
11 | ```
12 | Where `$C$` is `#cluster` and `$P$` is `#partition`.
13 |
14 | This assume that the elements of `#partition` are indeed pairwise disjoint.
15 | """
16 | remaining = set(cluster)
17 | for a in partition:
18 | common = remaining.intersection(a)
19 | if common:
20 | remaining.difference_update(common)
21 | yield common
22 | for x in sorted(remaining):
23 | yield set((x,))
24 |
25 |
26 | class RemapClusteringsReturn(ty.NamedTuple):
27 | clusterings: ty.Sequence[ty.Sequence[ty.Sequence[int]]]
28 | elts_map: ty.Dict[ty.Hashable, int]
29 |
30 |
31 | def remap_clusterings(
32 | clusterings: ty.Sequence[ty.Sequence[ty.Set[ty.Hashable]]],
33 | ) -> RemapClusteringsReturn:
34 | """Remap clusterings of arbitrary elements to clusterings of integers."""
35 | elts = set(e for clusters in clusterings for c in clusters for e in c)
36 | elts_map = {e: i for i, e in enumerate(elts)}
37 | res = []
38 | for clusters in clusterings:
39 | remapped_clusters = []
40 | for c in clusters:
41 | remapped_c = [elts_map[e] for e in c]
42 | remapped_clusters.append(remapped_c)
43 | res.append(remapped_clusters)
44 | return RemapClusteringsReturn(res, elts_map)
45 |
46 |
47 |
48 | # COMBAK: Check the numeric stability
49 | def blanc(
50 | key: ty.Sequence[ty.Set], response: ty.Sequence[ty.Set], fast=True,
51 | ) -> ty.Tuple[float, float, float]:
52 | r"""
53 | Return the BLANC `$(R, P, F)$` scores for a `#response` clustering given a `#key` clustering.
54 |
55 | ## Notes
56 |
57 | - Mention identifiers have to be comparable
58 | - To ensure the compliance with the reference implementation, the edge cases results are
59 | those from Recasens and Hovy (2011) rather than from the more recent Luo et al. (2014) when
60 | those two disagree. This has an effect for the N-6 testcase, where according to Luo et al.
61 | (2014), BLANC should be `$\frac{0+F_n}{2}$` since `$C_k=∅$` and `$C_r≠∅$`, but according to
62 | Recasens and Hovy (2011), BLANC should be `$F_n$`.
63 | """
64 | if fast:
65 | C_score, N_score = fast_detailed_blanc(key, response)
66 | else:
67 | C_score, N_score = detailed_blanc(key, response)
68 | if C_score is None:
69 | assert N_score is not None # nosec:B101
70 | return N_score
71 | if N_score is None:
72 | assert C_score is not None # nosec:B101
73 | return C_score
74 | return C_score, N_score
75 |
76 |
77 | def links_from_clusters(
78 | clusters: ty.Iterable[ty.Set],
79 | ) -> ty.Tuple[
80 | ty.Set[ty.Tuple[ty.Hashable, ty.Hashable]],
81 | ty.Set[ty.Tuple[ty.Hashable, ty.Hashable]],
82 | ]:
83 | r"""
84 | Return a `(coreference_links, non-coreference_links)` tuple corresponding to a clustering.
85 |
86 | The links are given as sorted couples for uniqueness
87 | """
88 | clusters_lst = [list(c) for c in clusters]
89 | C = set()
90 | N = set()
91 | for i, c in enumerate(clusters_lst[:-1]):
92 | for j, e in enumerate(c[:-1]):
93 | # Since the links are symmetric, we only add the links between `e` and
94 | # the following mentions
95 | for f in c[j + 1 :]:
96 | C.add((e, f) if e <= f else (f, e))
97 | for other in clusters_lst[i + 1 :]:
98 | for e in c:
99 | for f in other:
100 | N.add((e, f) if e <= f else (f, e))
101 | # We missed the coreference links for the last cluster, add them here
102 | last_cluster = clusters_lst[-1]
103 | for j, e in enumerate(last_cluster):
104 | for f in last_cluster[j + 1 :]:
105 | C.add((e, f) if e <= f else (f, e))
106 | return C, N
107 |
108 |
109 | def detailed_blanc(
110 | key: ty.Sequence[ty.Set], response: ty.Sequence[ty.Set]
111 | ) -> ty.Tuple[
112 | ty.Union[ty.Tuple[float, float, float], None],
113 | ty.Union[ty.Tuple[float, float, float], None],
114 | ]:
115 | """Return BLANC `$(R, P, F)$` scores for coreference and non-coreference respectively."""
116 |
117 | # Edge case : a single mention in both `key` and `response` clusters
118 | # in that case, `C_k`, `C_r`, `N_k` and `N_r` are all empty, so we need a separate examination
119 | # of the mentions to know if we are very good or very bad.
120 | if len(key) == len(response) == 1 and len(key[0]) == len(response[0]) == 1:
121 | if key[0] == response[0]:
122 | return ((1.0, 1.0, 1.0), (1.0, 1.0, 1.0))
123 | else:
124 | return ((0.0, 0.0, 0.0), (0.0, 0.0, 0.0))
125 |
126 | C_k, N_k = links_from_clusters(key)
127 | C_r, N_r = links_from_clusters(response)
128 |
129 | tp_c = len(C_k.intersection(C_r))
130 | tp_n = len(N_k.intersection(N_r))
131 | c_k, n_k = len(C_k), len(N_k)
132 | c_r, n_r = len(C_r), len(N_r)
133 |
134 | if not c_k and not c_r:
135 | R_c, P_c, F_c = (1.0, 1.0, 1.0)
136 | elif not c_k or not c_r:
137 | R_c, P_c, F_c = (0.0, 0.0, 0.0)
138 | else:
139 | R_c, P_c = tp_c / c_k, tp_c / c_r
140 | F_c = 2 * tp_c / (c_k + c_r)
141 |
142 | if not n_k and not n_r:
143 | R_n, P_n, F_n = (1.0, 1.0, 1.0)
144 | elif not n_k or not n_r:
145 | R_n, P_n, F_n = (0.0, 0.0, 0.0)
146 | else:
147 | R_n, P_n = tp_n / n_k, tp_n / n_r
148 | F_n = 2 * tp_n / (n_k + n_r)
149 |
150 | # Edge cases
151 | if not c_k:
152 | return (None, (R_n, P_n, F_n))
153 | if not n_k:
154 | return ((R_c, P_c, F_c), None)
155 |
156 | return ((R_c, P_c, F_c), (R_n, P_n, F_n))
157 |
158 |
159 | class AdjacencyReturn(ty.NamedTuple):
160 | """Represents a clustering of integers as an adjacency matrix and a presence mask"""
161 |
162 | adjacency: np.ndarray
163 | presence: np.ndarray
164 |
165 |
166 | def adjacency(clusters: ty.List[ty.List[int]], num_elts: int) -> AdjacencyReturn:
167 | adjacency = np.zeros((num_elts, num_elts), dtype=np.bool)
168 | presence = np.zeros(num_elts, dtype=np.bool)
169 | # **Note** The nested loop makes the complexity of this `$∑|c|²$` but we are only doing memory
170 | # access, which is really fast, so this is not really an issue. In comparison, doing it by
171 | # computing the Gram matrix one-hot elt-cluster attribution matrix was making `fast_blanc` 3×
172 | # slower than the naïve version.
173 | for c in clusters:
174 | # Note: don't be clever and use numpy array indicing here, see
175 | #
176 | # for why it would be slower. If you want to get C loops here, cythonize it instead (nut
177 | # it's probably not worth it)
178 | for e in c:
179 | presence[e] = True
180 | for f in c:
181 | if f != e:
182 | adjacency[e, f] = True
183 | return AdjacencyReturn(adjacency, presence)
184 |
185 |
186 | def fast_detailed_blanc(
187 | key: ty.Sequence[ty.Set], response: ty.Sequence[ty.Set]
188 | ) -> ty.Tuple[
189 | ty.Union[ty.Tuple[float, float, float], None],
190 | ty.Union[ty.Tuple[float, float, float], None],
191 | ]:
192 | """Return BLANC `$(R, P, F)$` scores for coreference and non-coreference respectively."""
193 |
194 | # Edge case : a single mention in both `key` and `response` clusters
195 | # in that case, `C_k`, `C_r`, `N_k` and `N_r` are all empty, so we need a separate examination
196 | # of the mentions to know if we are very good or very bad.
197 | if len(key) == len(response) == 1 and len(key[0]) == len(response[0]) == 1:
198 | if key[0] == response[0]:
199 | return ((1.0, 1.0, 1.0), (1.0, 1.0, 1.0))
200 | else:
201 | return ((0.0, 0.0, 0.0), (0.0, 0.0, 0.0))
202 |
203 | (key, response), mentions_map = remap_clusterings([key, response])
204 | num_mentions = len(mentions_map)
205 |
206 | key_coref_links, key_presence = adjacency(key, num_mentions)
207 | response_coref_links, response_presence = adjacency(response, num_mentions)
208 |
209 | tp_c = np.logical_and(key_coref_links, response_coref_links).sum() // 2
210 | c_k = key_coref_links.sum() // 2
211 | c_r = response_coref_links.sum() // 2
212 |
213 | # Headache ahead
214 | common_links = np.logical_and(
215 | np.outer(key_presence, key_presence),
216 | np.outer(response_presence, response_presence),
217 | )
218 | # There is no link between a mention and itself
219 | np.fill_diagonal(common_links, False)
220 | tp_n = (
221 | np.logical_and(
222 | common_links,
223 | np.logical_not(np.logical_or(key_coref_links, response_coref_links)),
224 | ).sum()
225 | / 2
226 | )
227 | num_key_mentions = key_presence.sum()
228 | n_k = (num_key_mentions * (num_key_mentions - 1)) // 2 - c_k
229 | num_response_mentions = response_presence.sum()
230 | n_r = (num_response_mentions * (num_response_mentions - 1)) // 2 - c_r
231 |
232 | if not c_k and not c_r:
233 | R_c, P_c, F_c = (1.0, 1.0, 1.0)
234 | elif not c_k or not c_r:
235 | R_c, P_c, F_c = (0.0, 0.0, 0.0)
236 | else:
237 | R_c, P_c = tp_c / c_k, tp_c / c_r
238 | F_c = 2 * tp_c / (c_k + c_r)
239 |
240 | if not n_k and not n_r:
241 | R_n, P_n, F_n = (1.0, 1.0, 1.0)
242 | elif not n_k or not n_r:
243 | R_n, P_n, F_n = (0.0, 0.0, 0.0)
244 | else:
245 | R_n, P_n = tp_n / n_k, tp_n / n_r
246 | F_n = 2 * tp_n / (n_k + n_r)
247 |
248 | # # Edge cases
249 | # if not c_k:
250 | # return (None, (R_n, P_n, F_n))
251 | # if not n_k:
252 | # return ((R_c, P_c, F_c), None)
253 |
254 | return ((tp_c, c_k,c_r), (tp_n, n_k, n_r))
255 | # return ((R_c, P_c, F_c), (R_n, P_n, F_n))
256 |
257 | def tuple_to_metric(c_tuple, n_tuple):
258 | (tp_c, c_k,c_r), (tp_n, n_k, n_r) = c_tuple, n_tuple
259 |
260 | if not c_k and not c_r:
261 | R_c, P_c, F_c = (1.0, 1.0, 1.0)
262 | elif not c_k or not c_r:
263 | R_c, P_c, F_c = (0.0, 0.0, 0.0)
264 | else:
265 | R_c, P_c = tp_c / c_k, tp_c / c_r
266 | F_c = 2 * tp_c / (c_k + c_r)
267 |
268 | if not n_k and not n_r:
269 | R_n, P_n, F_n = (1.0, 1.0, 1.0)
270 | elif not n_k or not n_r:
271 | R_n, P_n, F_n = (0.0, 0.0, 0.0)
272 | else:
273 | R_n, P_n = tp_n / n_k, tp_n / n_r
274 | F_n = 2 * tp_n / (n_k + n_r)
275 | return ((P_c, R_c, F_c), (P_n, R_n, F_n))
--------------------------------------------------------------------------------
/conll.py:
--------------------------------------------------------------------------------
1 | import re
2 | import tempfile
3 | import subprocess
4 | import operator
5 | import collections
6 | import logging
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)") # First line at each document
11 | COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL)
12 |
13 |
14 | def get_doc_key(doc_id, part):
15 | return "{}_{}".format(doc_id, int(part))
16 |
17 |
18 | def output_conll(input_file, output_file, predictions, subtoken_map):
19 | prediction_map = {}
20 | for doc_key, clusters in predictions.items():
21 | start_map = collections.defaultdict(list)
22 | end_map = collections.defaultdict(list)
23 | word_map = collections.defaultdict(list)
24 | for cluster_id, mentions in enumerate(clusters):
25 | for start, end in mentions:
26 | start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end]
27 | if start == end:
28 | word_map[start].append(cluster_id)
29 | else:
30 | start_map[start].append((cluster_id, end))
31 | end_map[end].append((cluster_id, start))
32 | for k,v in start_map.items():
33 | start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)]
34 | for k,v in end_map.items():
35 | end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)]
36 | prediction_map[doc_key] = (start_map, end_map, word_map)
37 |
38 | word_index = 0
39 | for line in input_file.readlines():
40 | row = line.split()
41 | if len(row) == 0:
42 | output_file.write("\n")
43 | elif row[0].startswith("#"):
44 | begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)
45 | if begin_match:
46 | doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))
47 | start_map, end_map, word_map = prediction_map[doc_key]
48 | word_index = 0
49 | output_file.write(line)
50 | output_file.write("\n")
51 | else:
52 | assert get_doc_key(row[0], row[1]) == doc_key
53 | coref_list = []
54 | if word_index in end_map:
55 | for cluster_id in end_map[word_index]:
56 | coref_list.append("{})".format(cluster_id))
57 | if word_index in word_map:
58 | for cluster_id in word_map[word_index]:
59 | coref_list.append("({})".format(cluster_id))
60 | if word_index in start_map:
61 | for cluster_id in start_map[word_index]:
62 | coref_list.append("({}".format(cluster_id))
63 |
64 | if len(coref_list) == 0:
65 | row[-1] = "-"
66 | else:
67 | row[-1] = "|".join(coref_list)
68 |
69 | output_file.write(" ".join(row))
70 | output_file.write("\n")
71 | word_index += 1
72 |
73 |
74 | def official_conll_eval(gold_path, predicted_path, metric, official_stdout=True):
75 | cmd = ["conll-2012/scorer/v8.01/scorer.pl", metric, gold_path, predicted_path, "none"]
76 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
77 | stdout, stderr = process.communicate()
78 | process.wait()
79 |
80 | stdout = stdout.decode("utf-8")
81 | if stderr is not None:
82 | logger.error(stderr)
83 |
84 | if official_stdout:
85 | logger.info("Official result for {}".format(metric))
86 | logger.info(stdout)
87 |
88 | coref_results_match = re.match(COREF_RESULTS_REGEX, stdout)
89 | recall = float(coref_results_match.group(1))
90 | precision = float(coref_results_match.group(2))
91 | f1 = float(coref_results_match.group(3))
92 | return {"r": recall, "p": precision, "f": f1}
93 |
94 |
95 | def evaluate_conll(gold_path, predictions, subtoken_maps, official_stdout=True):
96 | with tempfile.NamedTemporaryFile(delete=True, mode="w") as prediction_file:
97 | with open(gold_path, "r") as gold_file:
98 | output_conll(gold_file, prediction_file, predictions, subtoken_maps)
99 | # logger.info("Predicted conll file: {}".format(prediction_file.name))
100 | results = {m: official_conll_eval(gold_file.name, prediction_file.name, m, official_stdout) for m in ("muc", "bcub", "ceafe") }
101 | return results
102 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | from run import Runner
2 | import sys
3 | import torch
4 |
5 | def evaluate(config_name, gpu_id, saved_suffix):
6 | runner = Runner(config_name, gpu_id)
7 | model = runner.initialize_model(saved_suffix)
8 |
9 | examples_train, examples_dev, examples_test = runner.data.get_tensor_examples()
10 | stored_info = runner.data.get_stored_info()
11 |
12 | runner.evaluate(
13 | model, examples_dev, stored_info, 0, official=False, conll_path=runner.config['conll_eval_path'], predict=True
14 | ) # Eval dev
15 | # print('=================================')
16 | # runner.evaluate(model, examples_test, stored_info, 0, official=False, conll_path=runner.config['conll_test_path']) # Eval test
17 |
18 |
19 | if __name__ == '__main__':
20 | config_name, saved_suffix, gpu_id = sys.argv[1], sys.argv[2], int(sys.argv[3])
21 | evaluate(config_name, gpu_id, saved_suffix)
22 | print(
23 | torch.cuda.max_memory_allocated(device=gpu_id)
24 | )
25 |
--------------------------------------------------------------------------------
/experiments.conf:
--------------------------------------------------------------------------------
1 | best {
2 | data_dir = ./data_dir/
3 |
4 | # Computation limits.
5 | max_top_antecedents = 50
6 | max_training_sentences = 5
7 | top_span_ratio = 0.4
8 | max_num_extracted_spans = 3900
9 | max_num_speakers = 20
10 | max_segment_len = 256
11 |
12 | dataset = "ontonotes"
13 |
14 | mention_sigmoid = false
15 |
16 | # Learning
17 | bert_learning_rate = 1e-5
18 | task_learning_rate = 2e-4
19 |
20 | adam_eps = 1e-8
21 | adam_weight_decay = 1e-2
22 |
23 | warmup_ratio = 0.1
24 | max_grad_norm = 1 # Set 0 to disable clipping
25 | gradient_accumulation_steps = 1
26 |
27 | # Model hyperparameters.
28 | coref_depth = 1 # when 1: no higher order (except for cluster_merging)
29 | coarse_to_fine = true
30 | fine_grained = true
31 | dropout_rate = 0.3
32 | ffnn_size = 1000
33 | ffnn_depth = 1
34 |
35 | num_epochs = 24
36 | feature_emb_size = 20
37 | max_span_width = 30
38 | use_metadata = true
39 | use_features = true
40 | use_segment_distance = true
41 | model_heads = true
42 | use_width_prior = true # For mention score
43 | use_distance_prior = true # For mention-ranking score
44 |
45 |
46 | # Other.
47 | conll_eval_path = ${best.data_dir}/dev.english.v4_gold_conll # gold_conll file for dev
48 | conll_test_path = ${best.data_dir}/test.english.v4_gold_conll # gold_conll file for test
49 | genres = ["bc", "bn", "mz", "nw", "pt", "tc", "wb"]
50 | eval_frequency = 1000
51 | report_frequency = 100
52 | log_root = ${best.data_dir}
53 |
54 | mention_proposer = outside
55 | }
56 |
57 | spanbert_base = ${best}{
58 | num_docs = 2802
59 | bert_learning_rate = 2e-05
60 | task_learning_rate = 0.0001
61 | coref_depth = 1
62 | max_segment_len = 384
63 | ffnn_size = 3000
64 | cluster_ffnn_size = 1000
65 | max_training_sentences = 3
66 | neg_sample_rate=0.2
67 |
68 | bert_tokenizer_name = bert-base-cased
69 | bert_pretrained_name_or_path = SpanBERT/spanbert-base-cased
70 |
71 | }
72 |
73 | spanbert_base_greedy = ${best}{
74 | num_docs = 2802
75 | bert_learning_rate = 2e-05
76 | task_learning_rate = 0.0001
77 | coref_depth = 1
78 | max_segment_len = 384
79 | ffnn_size = 3000
80 | cluster_ffnn_size = 1000
81 | max_training_sentences = 3
82 |
83 | bert_tokenizer_name = bert-base-cased
84 | bert_pretrained_name_or_path = SpanBERT/spanbert-base-cased
85 |
86 | mention_proposer = greedy
87 | }
88 |
89 |
90 | spanbert_large = ${best}{
91 | num_docs = 2802
92 | bert_learning_rate = 1e-05
93 | task_learning_rate = 0.0003
94 | max_segment_len = 512
95 | ffnn_size = 3000
96 | cluster_ffnn_size = 3000
97 | max_training_sentences = 3
98 |
99 | neg_sample_rate=0.2
100 |
101 | bert_tokenizer_name = bert-base-cased
102 | bert_pretrained_name_or_path = SpanBERT/spanbert-large-cased
103 |
104 | }
105 |
106 | spanbert_large_greedy = ${best}{
107 | num_docs = 2802
108 | bert_learning_rate = 1e-05
109 | task_learning_rate = 0.0003
110 | max_segment_len = 512
111 | ffnn_size = 3000
112 | cluster_ffnn_size = 3000
113 | max_training_sentences = 3
114 |
115 | bert_tokenizer_name = bert-base-cased
116 | bert_pretrained_name_or_path = SpanBERT/spanbert-large-cased
117 |
118 | mention_proposer = greedy
119 | }
120 |
--------------------------------------------------------------------------------
/fig/illustration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyutyuh/structured-span-selector/eac7771312622cd98535b4c93bad3ee000957e59/fig/illustration.png
--------------------------------------------------------------------------------
/fig/structured_span_selector_cky_chart.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyutyuh/structured-span-selector/eac7771312622cd98535b4c93bad3ee000957e59/fig/structured_span_selector_cky_chart.pdf
--------------------------------------------------------------------------------
/greedy_mp.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from typing import Any, Dict, List, Tuple
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.nn.parallel import DataParallel
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | class GreedyMentionProposer(torch.nn.Module):
13 |
14 | def __init__(
15 | self,
16 | **kwargs
17 | ) -> None:
18 | super().__init__(**kwargs)
19 |
20 |
21 | def forward(
22 | self,
23 | spans: torch.IntTensor,
24 | span_mention_scores: torch.FloatTensor,
25 | span_mask: torch.FloatTensor,
26 | token_num: torch.IntTensor,
27 | num_spans_to_keep: int,
28 | take_top_spans_per_sentence = False,
29 | flat_span_sent_ids = None,
30 | ratio = 0.4,
31 | ):
32 | if not take_top_spans_per_sentence:
33 | top_span_indices = masked_topk_non_overlap(
34 | span_mention_scores,
35 | span_mask,
36 | num_spans_to_keep,
37 | spans
38 | )
39 | top_spans = spans[top_span_indices]
40 | top_span_scores = span_mention_scores[top_span_indices]
41 | return top_span_scores, top_span_indices, top_spans, 0., None
42 | else:
43 | top_span_indices, top_span_scores, top_spans = [], [], []
44 | prev_sent_id, prev_span_id = 0, 0
45 | for span_id, sent_id in enumerate(flat_span_sent_ids.tolist()):
46 | if sent_id != prev_sent_id:
47 | sent_span_indices = masked_topk_non_overlap(
48 | span_mention_scores[prev_span_id:span_id],
49 | span_mask[prev_span_id:span_id],
50 | int(ratio * (token_num[prev_sent_id])), # [CLS], [SEP]
51 | spans[prev_span_id:span_id],
52 | non_crossing=True,
53 | ) + prev_span_id
54 |
55 | top_span_indices.append(sent_span_indices)
56 | top_span_scores.append(span_mention_scores[sent_span_indices])
57 | top_spans.append(spans[sent_span_indices])
58 |
59 | prev_sent_id, prev_span_id = sent_id, span_id
60 | # last sentence
61 | sent_span_indices = masked_topk_non_overlap(
62 | span_mention_scores[prev_span_id:],
63 | span_mask[prev_span_id:],
64 | int(ratio * (token_num[-1])),
65 | spans[prev_span_id:],
66 | non_crossing=True,
67 | ) + prev_span_id
68 |
69 | top_span_indices.append(sent_span_indices)
70 | top_span_scores.append(span_mention_scores[sent_span_indices])
71 | top_spans.append(spans[sent_span_indices])
72 |
73 | num_top_spans = [x.size(0) for x in top_span_indices]
74 | max_num_top_span = max(num_top_spans)
75 |
76 | top_spans = torch.stack(
77 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), 2))], dim=0) for x in top_spans], dim=0
78 | )
79 | top_span_masks = torch.stack(
80 | [torch.cat([x.new_ones((x.size(0), )), x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in top_span_indices], dim=0
81 | )
82 | top_span_indices = torch.stack(
83 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in top_span_indices], dim=0
84 | )
85 |
86 | return top_span_scores, top_span_indices, top_spans, 0., None, top_span_masks
87 |
88 |
89 | def masked_topk_non_overlap(
90 | span_scores,
91 | span_mask,
92 | num_spans_to_keep,
93 | spans,
94 | non_crossing=True
95 | ):
96 |
97 | sorted_scores, sorted_indices = torch.sort(span_scores + span_mask.log(), descending=True)
98 | sorted_indices = sorted_indices.tolist()
99 | spans = spans.tolist()
100 |
101 | if not non_crossing:
102 | selected_candidate_idx = sorted(sorted_indices[:num_spans_to_keep], key=lambda idx: (spans[idx][0], spans[idx][1]))
103 | selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long)
104 | return selected_candidate_idx
105 |
106 | selected_candidate_idx = []
107 | start_to_max_end, end_to_min_start = {}, {}
108 | for candidate_idx in sorted_indices:
109 | if len(selected_candidate_idx) >= num_spans_to_keep:
110 | break
111 | # Perform overlapping check
112 | span_start_idx = spans[candidate_idx][0]
113 | span_end_idx = spans[candidate_idx][1]
114 | cross_overlap = False
115 | for token_idx in range(span_start_idx, span_end_idx + 1):
116 | max_end = start_to_max_end.get(token_idx, -1)
117 | if token_idx > span_start_idx and max_end > span_end_idx:
118 | cross_overlap = True
119 | break
120 | min_start = end_to_min_start.get(token_idx, -1)
121 | if token_idx < span_end_idx and 0 <= min_start < span_start_idx:
122 | cross_overlap = True
123 | break
124 | if not cross_overlap:
125 | # Pass check; select idx and update dict stats
126 | selected_candidate_idx.append(candidate_idx)
127 | max_end = start_to_max_end.get(span_start_idx, -1)
128 | if span_end_idx > max_end:
129 | start_to_max_end[span_start_idx] = span_end_idx
130 | min_start = end_to_min_start.get(span_end_idx, -1)
131 | if min_start == -1 or span_start_idx < min_start:
132 | end_to_min_start[span_end_idx] = span_start_idx
133 | # Sort selected candidates by span idx
134 | selected_candidate_idx = sorted(selected_candidate_idx, key=lambda idx: (spans[idx][0], spans[idx][1]))
135 | selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long)
136 |
137 | return selected_candidate_idx
138 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | from collections import Counter, defaultdict
7 |
8 | # from scipy.optimize import linear_sum_assignment as linear_assignment
9 | # from sklearn.utils.linear_assignment_ import linear_assignment
10 | from scipy.optimize import linear_sum_assignment
11 |
12 | from blanc import blanc, tuple_to_metric
13 |
14 |
15 | def f1(p_num, p_den, r_num, r_den, beta=1):
16 | p = 0 if p_den == 0 else p_num / float(p_den)
17 | r = 0 if r_den == 0 else r_num / float(r_den)
18 | return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
19 |
20 |
21 | class CorefEvaluator(object):
22 | def __init__(self):
23 | self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]
24 | self.all_gm = 1e-6
25 | self.recalled_gm = 1e-6
26 |
27 | self.all_gm_by_width = defaultdict((int))
28 | self.recalled_gm_by_width = defaultdict((int))
29 |
30 | self.all_gm_by_depth = defaultdict((int))
31 | self.recalled_gm_by_depth = defaultdict((int))
32 |
33 | self.c_tuple = [0,0,0]
34 | self.n_tuple = [0,0,0]
35 |
36 | def update(
37 | self, predicted, gold, mention_to_predicted, mention_to_gold,
38 | metainfo_gms, recalled_gms # all_gm, recalled_gm
39 | ):
40 | for e in self.evaluators:
41 | e.update(predicted, gold, mention_to_predicted, mention_to_gold)
42 | self.all_gm += len(metainfo_gms)
43 | self.recalled_gm += len(recalled_gms & set(metainfo_gms.keys()))
44 |
45 | for x, v in metainfo_gms.items():
46 | self.all_gm_by_width[v['width']] += 1
47 | self.all_gm_by_depth[v['depth']] += 1
48 | if x in recalled_gms:
49 | self.recalled_gm_by_width[v['width']] += 1
50 | self.recalled_gm_by_depth[v['depth']] += 1
51 |
52 |
53 |
54 | c_tuple, n_tuple = blanc(gold, predicted)
55 | for i in range(3):
56 | self.c_tuple[i] += c_tuple[i]
57 | for i in range(3):
58 | self.n_tuple[i] += n_tuple[i]
59 |
60 | def get_all(self):
61 | all_res = {}
62 | name_dict = {0: "muc", 1: "b_cubed", 2: "ceafe"}
63 | for i, e in enumerate(self.evaluators):
64 | all_res[name_dict[i]+"_f1"] = e.get_f1()
65 | all_res[name_dict[i]+"_p"] = e.get_precision()
66 | all_res[name_dict[i]+"_r"] = e.get_recall()
67 |
68 | return all_res
69 |
70 | def get_f1(self):
71 | for e in self.evaluators:
72 | print("f:", e.get_f1())
73 |
74 | return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators)
75 |
76 | def get_recall(self):
77 | for e in self.evaluators:
78 | print("r:", e.get_recall())
79 | return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators)
80 |
81 | def get_precision(self):
82 | for e in self.evaluators:
83 | print("p:", e.get_precision())
84 | return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators)
85 |
86 | def get_prf(self):
87 | blanc_scores = tuple_to_metric(self.c_tuple, self.n_tuple)
88 | blanc_p, blanc_r, blanc_f = tuple(0.5*(a+b) for (a,b) in zip(*blanc_scores))
89 | print("all_gm", self.all_gm)
90 |
91 | print(self.all_gm_by_width)
92 | print(self.recalled_gm_by_width)
93 |
94 | print(self.all_gm_by_depth)
95 | print(self.recalled_gm_by_depth)
96 |
97 | return self.get_precision(), self.get_recall(), self.get_f1(), self.recalled_gm / self.all_gm, (blanc_p, blanc_r, blanc_f)
98 |
99 |
100 | class Evaluator(object):
101 | def __init__(self, metric, beta=1):
102 | self.p_num = 0
103 | self.p_den = 0
104 | self.r_num = 0
105 | self.r_den = 0
106 | self.metric = metric
107 | self.beta = beta
108 |
109 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
110 | if self.metric == ceafe:
111 | pn, pd, rn, rd = self.metric(predicted, gold)
112 | else:
113 | pn, pd = self.metric(predicted, mention_to_gold)
114 | rn, rd = self.metric(gold, mention_to_predicted)
115 | self.p_num += pn
116 | self.p_den += pd
117 | self.r_num += rn
118 | self.r_den += rd
119 |
120 | def get_f1(self):
121 | return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
122 |
123 | def get_recall(self):
124 | return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
125 |
126 | def get_precision(self):
127 | return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
128 |
129 | def get_prf(self):
130 | return self.get_precision(), self.get_recall(), self.get_f1()
131 |
132 | def get_counts(self):
133 | return self.p_num, self.p_den, self.r_num, self.r_den
134 |
135 |
136 | def evaluate_documents(documents, metric, beta=1):
137 | evaluator = Evaluator(metric, beta=beta)
138 | for document in documents:
139 | evaluator.update(document)
140 | return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1()
141 |
142 |
143 | def b_cubed(clusters, mention_to_gold):
144 | num, dem = 0, 0
145 |
146 | for c in clusters:
147 | # if len(c) == 1:
148 | # continue
149 |
150 | gold_counts = Counter()
151 | correct = 0
152 | for m in c:
153 | if m in mention_to_gold:
154 | gold_counts[tuple(mention_to_gold[m])] += 1
155 | for c2, count in gold_counts.items():
156 | # if len(c2) != 1:
157 | correct += count * count
158 |
159 | num += correct / float(len(c))
160 | dem += len(c)
161 |
162 | return num, dem
163 |
164 |
165 | def muc(clusters, mention_to_gold):
166 | tp, p = 0, 0
167 | for c in clusters:
168 | p += len(c) - 1
169 | tp += len(c)
170 | linked = set()
171 | for m in c:
172 | if m in mention_to_gold:
173 | linked.add(mention_to_gold[m])
174 | else:
175 | tp -= 1
176 | tp -= len(linked)
177 | return tp, p
178 |
179 |
180 | def phi4(c1, c2):
181 | return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
182 |
183 |
184 | def ceafe(clusters, gold_clusters):
185 | clusters = [c for c in clusters] # if len(c) != 1]
186 | scores = np.zeros((len(gold_clusters), len(clusters)))
187 | for i in range(len(gold_clusters)):
188 | for j in range(len(clusters)):
189 | scores[i, j] = phi4(gold_clusters[i], clusters[j])
190 | matching = linear_sum_assignment(-scores)
191 | matching = np.transpose(np.asarray(matching))
192 | similarity = sum(scores[matching[:,0], matching[:,1]])
193 | return similarity, len(clusters), similarity, len(gold_clusters)
194 |
195 |
196 | def lea(clusters, mention_to_gold):
197 | num, dem = 0, 0
198 |
199 | for c in clusters:
200 | if len(c) == 1:
201 | continue
202 |
203 | common_links = 0
204 | all_links = len(c) * (len(c) - 1) / 2.0
205 | for i, m in enumerate(c):
206 | if m in mention_to_gold:
207 | for m2 in c[i + 1:]:
208 | if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]:
209 | common_links += 1
210 |
211 | num += len(c) * common_links / float(all_links)
212 | dem += len(c)
213 |
214 | return num, dem
215 |
--------------------------------------------------------------------------------
/minimize.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import re
6 | import os
7 | import sys
8 | import json
9 | import tempfile
10 | import subprocess
11 | import collections
12 | from nltk import Tree
13 |
14 | import conll
15 | from transformers import AutoTokenizer
16 |
17 | import copy
18 | import nltk
19 |
20 | SPEAKER_START = '[unused19]'
21 | SPEAKER_END = '[unused73]'
22 | GENRE_DICT = {"bc":"[unused20]", "bn":"[unused21]", "mz":"[unused22]", "nw":"[unused23]", "pt":"[unused24]", "tc":"[unused25]", "wb":"[unused26]"}
23 |
24 | def recur_list_tree(parse_tree, offset=0):
25 | span_to_tag = {}
26 | for node in parse_tree:
27 | if type(node) == type(parse_tree):
28 | span_to_tag.update(recur_list_tree(node, offset))
29 | n_leaves = len(node.leaves())
30 | start, end = offset, offset+n_leaves-1
31 | span_to_tag[(start,end)] = node.label()
32 | offset += n_leaves
33 | else:
34 | offset += 1
35 | pass
36 | return span_to_tag
37 |
38 | def get_chunks_from_tree(parse_tree,chunk_types,offset=0,depth=0):
39 | chunks = {}
40 | root = parse_tree
41 | for node in parse_tree:
42 | if type(node) == type(parse_tree):
43 | n_leaves = len(node.leaves())
44 | start, end = offset, offset+n_leaves-1
45 | if node.label() not in chunk_types:
46 | chunks.update(get_chunks_from_tree(node,chunk_types, offset, depth))
47 | else:
48 | if depth == 0:
49 | chunks[(start,end)] = node.label()
50 | else:
51 | next_level = get_chunks_from_tree(node,chunk_types, offset, depth-1)
52 | if len(next_level) > 0:
53 | chunks.update(next_level)
54 | else:
55 | chunks[(start,end)] = node.label()
56 | offset += n_leaves
57 | else: # leaf node
58 | offset += 1
59 | pass
60 | return chunks
61 |
62 | def keep_coreferable(root):
63 | coreferable_types = ["NP", "VB", "PRP", "NML","CD","NNP"]
64 | processed_types = ["MNP", "V", "PRP", "NML", "NNP"]
65 | result_tree = copy.deepcopy(root)
66 |
67 | coreferable_spans = {}
68 |
69 | def recur(node, parent, parent_span, span, index_in_siblings):
70 | child_left_ind = span[0]
71 | res = []
72 | if True:
73 | for (i, child) in enumerate(node):
74 | if type(child) == nltk.tree.Tree:
75 | child_span = (child_left_ind, child_left_ind+len(child.leaves()))
76 | child_left_ind += len(child.leaves())
77 | processed_node = recur(child, node, span, child_span, i)
78 |
79 | res += processed_node
80 | else:
81 | if node.label() not in ["CD", "NML","NNP","PRP"]:
82 | res += [node]
83 | else:
84 | res += [child]
85 | if node.label() == "NP":
86 | if True: # keeping all NPs parent.label() not in ["NP",]:
87 | node.set_label("MNP")
88 | else:
89 | if node[-1].label() in {"POS"}:
90 | node.set_label("MNP")
91 | pass
92 | elif any([x.label() in ("VP", "CC","HYPH") for x in parent]):
93 | node.set_label("MNP")
94 | pass
95 | elif sum([(x.label() in ("NP", "MNP")) for x in parent]) > 1:
96 | node.set_label("MNP")
97 | pass
98 |
99 | elif node.label().startswith("PRP"):
100 | node.set_label("PRP")
101 | elif node.label() == "CD":
102 | # exclude a lot of CDs
103 | if index_in_siblings > 0 and parent[index_in_siblings-1].label() == "DT":
104 | pass
105 | elif index_in_siblings < len(parent)-1 and parent[index_in_siblings+1].label() in ("NN", "NNS", ):
106 | pass
107 | else:
108 | node.set_label("-CD")
109 | elif node.label() in {"NN", "NNP"}:
110 | # exclude a lot of NNPs
111 | if index_in_siblings < len(parent)-1 and parent[index_in_siblings+1].label() == "POS":
112 | # node.set_label("-NNP")
113 | pass
114 | elif any([(x.label() in ("CC", "HYPH")) for x in parent]):
115 | node.set_label("NNP")
116 | pass
117 | elif index_in_siblings > 0 and parent[index_in_siblings-1].label() == "DT":
118 | pass
119 | elif index_in_siblings < len(parent)-1 and parent[index_in_siblings+1].label() in ("NN", "NNS", ):
120 | pass
121 | elif index_in_siblings == 0:
122 | pass
123 | else:
124 | node.set_label("-NNP")
125 | pass
126 | elif node.label().startswith("VB"):
127 | node.set_label("-V")
128 | pass
129 | if node.label() in processed_types:
130 | coreferable_spans[(span[0], span[1]-1)] = node.label()
131 | res = [nltk.Tree(node.label(), res)]
132 | else:
133 | res = res
134 | return res
135 | result_tree = recur(result_tree, None, None, (0, len(result_tree.leaves())), 0)
136 | result_tree = nltk.Tree("TOP", result_tree)
137 | return result_tree, coreferable_spans
138 |
139 | class DocumentState(object):
140 | def __init__(self, key):
141 | self.doc_key = key
142 | self.sentence_end = []
143 | self.token_end = []
144 | self.tokens = []
145 | self.subtokens = []
146 | self.info = []
147 | self.segments = []
148 | self.subtoken_map = []
149 | self.segment_subtoken_map = []
150 | self.sentence_map = []
151 | self.pronouns = []
152 | self.clusters = collections.defaultdict(list)
153 | self.coref_stacks = collections.defaultdict(list)
154 | self.speakers = []
155 | self.segment_info = []
156 |
157 | self.chunk_tags = [] # corresponding to subtokens
158 | self.constituents = []
159 | self.coreferables = []
160 | self.pos_tags = []
161 |
162 |
163 | def finalize(self):
164 | # finalized: segments, segment_subtoken_map
165 | # populate speakers from info
166 | subtoken_idx = 0
167 | for segment in self.segment_info:
168 | speakers = []
169 | for i, tok_info in enumerate(segment):
170 | if tok_info is None and (i == 0 or i == len(segment) - 1):
171 | speakers.append('[SPL]')
172 | elif tok_info is None:
173 | speakers.append(speakers[-1])
174 | else:
175 | speakers.append(tok_info[9])
176 | if tok_info[4] == 'PRP':
177 | self.pronouns.append(subtoken_idx)
178 | subtoken_idx += 1
179 | self.speakers += [speakers]
180 | # populate sentence map
181 |
182 | # populate clusters
183 | first_subtoken_index = -1
184 | for seg_idx, segment in enumerate(self.segment_info):
185 | speakers = []
186 | for i, tok_info in enumerate(segment):
187 | first_subtoken_index += 1
188 | coref = tok_info[-2] if tok_info is not None else '-'
189 | if coref != "-":
190 | last_subtoken_index = first_subtoken_index + tok_info[-1] - 1
191 | for part in coref.split("|"):
192 | if part[0] == "(":
193 | if part[-1] == ")":
194 | cluster_id = int(part[1:-1])
195 | self.clusters[cluster_id].append((first_subtoken_index, last_subtoken_index))
196 | else:
197 | cluster_id = int(part[1:])
198 | self.coref_stacks[cluster_id].append(first_subtoken_index)
199 | else:
200 | cluster_id = int(part[:-1])
201 | start = self.coref_stacks[cluster_id].pop()
202 | self.clusters[cluster_id].append((start, last_subtoken_index))
203 | # merge clusters
204 | merged_clusters = []
205 | for c1 in self.clusters.values():
206 | existing = None
207 | for m in c1:
208 | for c2 in merged_clusters:
209 | if m in c2:
210 | existing = c2
211 | break
212 | if existing is not None:
213 | break
214 | if existing is not None:
215 | print("Merging clusters (shouldn't happen very often.)")
216 | existing.update(c1)
217 | else:
218 | merged_clusters.append(set(c1))
219 | merged_clusters = [list(c) for c in merged_clusters]
220 |
221 | flattened_sentences = flatten(self.segments)
222 | all_mentions = flatten(merged_clusters)
223 | sentence_map = get_sentence_map(self.segments, self.sentence_end)
224 | subtoken_map = flatten(self.segment_subtoken_map)
225 |
226 |
227 | for cluster in merged_clusters:
228 | for mention in cluster:
229 | if subtoken_map[mention[0]] == subtoken_map[mention[1]]:
230 | if self.pos_tags[subtoken_map[mention[0]]].startswith("V"):
231 | self.coreferables.append(mention)
232 |
233 |
234 | assert len(all_mentions) == len(set(all_mentions))
235 | chunk_tags = flatten(self.chunk_tags)
236 | num_words = len(flattened_sentences)
237 | assert num_words == len(flatten(self.speakers))
238 | # assert num_words == len(chunk_tags), (num_words, len(chunk_tags))
239 | assert num_words == len(subtoken_map), (num_words, len(subtoken_map))
240 | assert num_words == len(sentence_map), (num_words, len(sentence_map))
241 | def mapper(x):
242 | if x == "NP":
243 | return 1
244 | else:
245 | return 2
246 | return {
247 | "doc_key": self.doc_key,
248 | "sentences": self.segments,
249 | "speakers": self.speakers,
250 | "constituents": [x[0] for x in self.constituents], #
251 | "constituent_type": [x[1] for x in self.constituents], #
252 | # "coreferables": self.coreferables,
253 | "ner": [],
254 | "clusters": merged_clusters,
255 | 'sentence_map':sentence_map,
256 | "subtoken_map": subtoken_map,
257 | 'pronouns': self.pronouns,
258 | "chunk_tags": self.chunk_tags
259 | }
260 |
261 |
262 | def normalize_word(word, language):
263 | if language == "arabic":
264 | word = word[:word.find("#")]
265 | if word == "/." or word == "/?":
266 | return word[1:]
267 | else:
268 | return word
269 |
270 | # first try to satisfy constraints1, and if not possible, constraints2.
271 | def split_into_segments(document_state, max_segment_len, constraints1, constraints2):
272 | current = 0
273 | previous_token = 0
274 | final_chunk_tags = []
275 | index_mapping_dict = {}
276 | cur_seg_ind = 1
277 | while current < len(document_state.subtokens):
278 | # -3 for 3 additional special tokens
279 | end = min(current + max_segment_len - 1 - 3, len(document_state.subtokens) - 1)
280 | while end >= current and not constraints1[end]:
281 | end -= 1
282 | if end < current:
283 | end = min(current + max_segment_len - 1 - 3, len(document_state.subtokens) - 1)
284 | while end >= current and not constraints2[end]:
285 | end -= 1
286 | if end < current:
287 | raise Exception("Can't find valid segment")
288 |
289 | for i in range(current, end+1):
290 | index_mapping_dict[i] = i + 3*cur_seg_ind - 1
291 |
292 |
293 | genre = document_state.doc_key[:2]
294 | genre_text = GENRE_DICT[genre]
295 | document_state.tokens.append(genre_text)
296 |
297 | document_state.segments.append(['[CLS]', genre_text] + document_state.subtokens[current:end + 1] + ['[SEP]'])
298 |
299 | subtoken_map = document_state.subtoken_map[current : end + 1]
300 | document_state.segment_subtoken_map.append([previous_token, previous_token] + subtoken_map + [subtoken_map[-1]])
301 | info = document_state.info[current : end + 1]
302 | document_state.segment_info.append([None, None] + info + [None])
303 | current = end + 1
304 | cur_seg_ind += 1
305 | previous_token = subtoken_map[-1]
306 |
307 | document_state.chunk_tags = final_chunk_tags
308 | return index_mapping_dict
309 |
310 | def get_sentence_map(segments, sentence_end):
311 | current = 0
312 | sent_map = []
313 | sent_end_idx = 0
314 | assert len(sentence_end) == sum([len(s)-3 for s in segments])
315 | for segment in segments:
316 | sent_map.append(current)
317 | sent_map.append(current)
318 | for i in range(len(segment) - 3):
319 | sent_map.append(current)
320 | current += int(sentence_end[sent_end_idx])
321 | sent_end_idx += 1
322 | sent_map.append(current)
323 | return sent_map
324 |
325 | def get_document(document_lines, tokenizer, language, segment_len):
326 | document_state = DocumentState(document_lines[0])
327 | word_idx = -1
328 | parse_pieces = []
329 | cur_sent_offset = 0
330 | cur_sent_len = 0
331 |
332 | current_speaker = None
333 |
334 | for line in document_lines[1]:
335 | row = line.split()
336 | sentence_end = len(row) == 0
337 | if not sentence_end:
338 | assert len(row) >= 12
339 |
340 | if current_speaker is None or current_speaker != row[9]:
341 | added_speaker_head = True
342 | # insert speaker
343 | word_idx += 1
344 | current_speaker = row[9]
345 | speaker_text = tokenizer.tokenize(current_speaker)
346 | parse_pieces.append(f"(PSEUDO {' '.join([SPEAKER_START] + speaker_text + [SPEAKER_END])}")
347 | document_state.tokens.append(current_speaker)
348 | document_state.pos_tags.append("SPEAKER")
349 | for sidx, subtoken in enumerate([SPEAKER_START] + speaker_text + [SPEAKER_END]):
350 | cur_sent_len += 1
351 | document_state.subtokens.append(subtoken)
352 | info = None
353 | document_state.info.append(info)
354 | document_state.sentence_end.append(False)
355 | document_state.subtoken_map.append(word_idx)
356 |
357 | word_idx += 1
358 | word = normalize_word(row[3], language)
359 |
360 | parse_piece = row[5]
361 | pos_tag = row[4]
362 | if pos_tag == "(":
363 | pos_tag = "-LRB-"
364 | if pos_tag == ")":
365 | pos_tag = "-RRB-"
366 |
367 | (left_brackets, right_hand_side) = parse_piece.split("*")
368 | right_brackets = right_hand_side.count(")") * ")"
369 |
370 | subtokens = tokenizer.tokenize(word)
371 | document_state.tokens.append(word)
372 | document_state.pos_tags.append(pos_tag)
373 |
374 | document_state.token_end += ([False] * (len(subtokens) - 1)) + [True]
375 | for sidx, subtoken in enumerate(subtokens):
376 | cur_sent_len += 1
377 | document_state.subtokens.append(subtoken)
378 | info = None if sidx != 0 else (row + [len(subtokens)])
379 | document_state.info.append(info)
380 | document_state.sentence_end.append(False)
381 | document_state.subtoken_map.append(word_idx)
382 | new_core = " ".join(subtokens)
383 |
384 | parse_piece = f"{left_brackets} {new_core} {right_brackets}"
385 | # parse_piece = f"{left_brackets} {new_core} {right_brackets}"
386 | parse_pieces.append(parse_piece)
387 | else:
388 | if added_speaker_head:
389 | parse_pieces.append(")")
390 | added_speaker_head = False
391 |
392 | parse_tree = Tree.fromstring("".join(parse_pieces))
393 | chunk_dict = get_chunks_from_tree(parse_tree, ["NP"])
394 | constituent_dict = recur_list_tree(parse_tree)
395 |
396 | coreferable_spans_dict = keep_coreferable(parse_tree)[1]
397 |
398 | document_state.coreferables += [[x[0]+cur_sent_offset,x[1]+cur_sent_offset] for x,y in coreferable_spans_dict.items()]
399 | document_state.constituents += [[[x[0]+cur_sent_offset,x[1]+cur_sent_offset],y] for x,y in constituent_dict.items()]
400 | sent_chunk_tags = ["O" for i in range(cur_sent_len)]
401 | if pos_tag==".":
402 | sent_chunk_tags[-1] = "."
403 | for (chunk_l, chunk_r), chunk_tag in chunk_dict.items():
404 | chunk_len = chunk_r - chunk_l + 1
405 | if chunk_len == 1:
406 | sent_chunk_tags[chunk_r] = "U-" + chunk_tag
407 | else:
408 | sent_chunk_tags[chunk_l] = "B-" + chunk_tag
409 | sent_chunk_tags[chunk_r] = "L-" + chunk_tag
410 | for i in range(chunk_l + 1, chunk_r):
411 | sent_chunk_tags[i] = "I-" + chunk_tag
412 | document_state.sentence_end[-1] = True
413 | cur_sent_offset += cur_sent_len
414 | cur_sent_len, parse_pieces = 0, []
415 | document_state.chunk_tags += sent_chunk_tags
416 |
417 | constraints1 = document_state.sentence_end if language != 'arabic' else document_state.token_end
418 | index_mapping_dict = split_into_segments(document_state, segment_len, constraints1, document_state.token_end)
419 |
420 | for x in document_state.constituents:
421 | x[0][0] = index_mapping_dict[x[0][0]]
422 | x[0][1] = index_mapping_dict[x[0][1]]
423 | for x in document_state.coreferables:
424 | x[0] = index_mapping_dict[x[0]]
425 | x[1] = index_mapping_dict[x[1]]
426 |
427 | stats["max_sent_len_{}".format(language)] = max(max([len(s) for s in document_state.segments]), stats["max_sent_len_{}".format(language)])
428 | document = document_state.finalize()
429 | return document
430 |
431 | def skip(doc_key):
432 | # if doc_key in ['nw/xinhua/00/chtb_0078_0', 'wb/eng/00/eng_0004_1']: #, 'nw/xinhua/01/chtb_0194_0', 'nw/xinhua/01/chtb_0157_0']:
433 | # return True
434 | return False
435 |
436 | def minimize_partition(name, language, extension, labels, stats, tokenizer, seg_len, input_dir, output_dir):
437 | input_path = "{}/{}.{}.{}".format(input_dir, name, language, extension)
438 | output_path = "{}/{}.{}.{}.jsonlines".format(output_dir, name, language, seg_len)
439 | count = 0
440 | print("Minimizing {}".format(input_path))
441 | documents = []
442 | with open(input_path, "r") as input_file:
443 | for line in input_file.readlines():
444 | begin_document_match = re.match(conll.BEGIN_DOCUMENT_REGEX, line)
445 | if begin_document_match:
446 | doc_key = conll.get_doc_key(begin_document_match.group(1), begin_document_match.group(2))
447 | documents.append((doc_key, []))
448 | elif line.startswith("#end document"):
449 | continue
450 | else:
451 | documents[-1][1].append(line)
452 | num_coreferables, num_words = 0, 0
453 | with open(output_path, "w") as output_file:
454 | for document_lines in documents:
455 | if skip(document_lines[0]):
456 | continue
457 | document = get_document(document_lines, tokenizer, language, seg_len)
458 |
459 | # num_coreferables += len(document["coreferables"])
460 | num_words += len(flatten(document["sentences"]))
461 |
462 | output_file.write(json.dumps(document))
463 | output_file.write("\n")
464 | count += 1
465 | print("Wrote {} documents to {}, with {} words".format(count, output_path, num_words))
466 |
467 | def minimize_language(language, labels, stats, seg_len, input_dir, output_dir, do_lower_case):
468 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
469 | # tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-large-4096")
470 |
471 | minimize_partition("dev", language, "v4_gold_conll", labels, stats, tokenizer, seg_len, input_dir, output_dir)
472 | minimize_partition("train", language, "v4_gold_conll", labels, stats, tokenizer, seg_len, input_dir, output_dir)
473 | minimize_partition("test", language, "v4_gold_conll", labels, stats, tokenizer, seg_len, input_dir, output_dir)
474 |
475 |
476 | def flatten(l):
477 | return [item for sublist in l for item in sublist]
478 |
479 | # Usage: python minimize.py ./data_dir/ ./data_dir/ontonotes_speaker_encoding/ false
480 | if __name__ == "__main__":
481 | input_dir = sys.argv[1]
482 | output_dir = sys.argv[2]
483 | do_lower_case = (sys.argv[3].lower() == 'true')
484 |
485 | print("do_lower_case", do_lower_case)
486 | labels = collections.defaultdict(set)
487 | stats = collections.defaultdict(int)
488 |
489 | if not os.path.isdir(output_dir):
490 | os.mkdir(output_dir)
491 | for seg_len in [384, 512]:
492 | minimize_language("english", labels, stats, seg_len, input_dir, output_dir, do_lower_case)
493 | for k, v in labels.items():
494 | print("{} = [{}]".format(k, ", ".join("\"{}\"".format(label) for label in v)))
495 | for k, v in stats.items():
496 | print("{} = {}".format(k, v))
497 |
498 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 | from bert_modelling import BertModel
7 | from transformers import BertTokenizer
8 |
9 | import logging
10 | import numpy as np
11 | from collections import Iterable, defaultdict
12 |
13 | from outside_mp import CFGMentionProposer
14 | from greedy_mp import GreedyMentionProposer
15 |
16 | from util import logsumexp, log1mexp, batch_select, bucket_distance
17 |
18 |
19 | logging.basicConfig(
20 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
21 | datefmt='%m/%d/%Y %H:%M:%S',
22 | level=logging.INFO
23 | )
24 | logger = logging.getLogger()
25 |
26 |
27 | class CorefModel(torch.nn.Module):
28 | tz = BertTokenizer.from_pretrained("bert-base-cased")
29 | def __init__(self, config, device, num_genres=None):
30 | super().__init__()
31 | self.config = config
32 | self.device = device
33 |
34 | self.cls_id = self.tz.convert_tokens_to_ids("[CLS]")
35 | self.sep_id = self.tz.convert_tokens_to_ids("[SEP]")
36 |
37 | self.num_genres = num_genres if num_genres else len(config['genres'])
38 | self.max_seg_len = config['max_segment_len']
39 | self.max_span_width = config['max_span_width']
40 |
41 | # Model
42 | self.dropout = nn.Dropout(p=config['dropout_rate'])
43 | self.bert = BertModel.from_pretrained(config['bert_pretrained_name_or_path'])
44 |
45 | self.bert_emb_size = self.bert.config.hidden_size
46 | self.span_emb_size = self.bert_emb_size * 3
47 | if config['use_features']:
48 | self.span_emb_size += config['feature_emb_size']
49 |
50 | self.pair_emb_size = self.span_emb_size * 3
51 | if config['use_metadata']:
52 | self.pair_emb_size += 2 * config['feature_emb_size']
53 | if config['use_features']:
54 | self.pair_emb_size += config['feature_emb_size']
55 | if config['use_segment_distance']:
56 | self.pair_emb_size += config['feature_emb_size']
57 |
58 | if config['mention_proposer'].lower() == "outside":
59 | self.mention_proposer = CFGMentionProposer(max_span_width=self.max_span_width, neg_sample_rate=config['neg_sample_rate'])
60 | elif config['mention_proposer'].lower() == "greedy":
61 | self.mention_proposer = GreedyMentionProposer()
62 |
63 | self.emb_span_width = self.make_embedding(self.max_span_width) if config['use_features'] else None
64 | self.emb_span_width_prior = self.make_embedding(self.max_span_width) if config['use_width_prior'] else None
65 | self.emb_antecedent_distance_prior = self.make_embedding(10) if config['use_distance_prior'] else None
66 | self.emb_genre = self.make_embedding(self.num_genres)
67 | self.emb_same_speaker = self.make_embedding(2) if config['use_metadata'] else None
68 | self.emb_segment_distance = self.make_embedding(config['max_training_sentences']) if config['use_segment_distance'] else None
69 | self.emb_top_antecedent_distance = self.make_embedding(10)
70 |
71 | self.mention_token_attn = self.make_ffnn(self.bert_emb_size, 0, output_size=1) if config['model_heads'] else None
72 | if type(self.mention_proposer) == CFGMentionProposer:
73 | self.span_emb_score_ffnn = self.make_ffnn(self.span_emb_size, [config['ffnn_size']] * config['ffnn_depth'], output_size=2)
74 | elif type(self.mention_proposer) == GreedyMentionProposer:
75 | self.span_emb_score_ffnn = self.make_ffnn(self.span_emb_size, [config['ffnn_size']] * config['ffnn_depth'], output_size=1)
76 |
77 | self.span_width_score_ffnn = self.make_ffnn(config['feature_emb_size'], [config['ffnn_size']] * config['ffnn_depth'], output_size=1) if config['use_width_prior'] else None
78 | self.coarse_bilinear = self.make_ffnn(self.span_emb_size, 0, output_size=self.span_emb_size)
79 | self.antecedent_distance_score_ffnn = self.make_ffnn(config['feature_emb_size'], 0, output_size=1) if config['use_distance_prior'] else None
80 | self.coref_score_ffnn = self.make_ffnn(self.pair_emb_size, [config['ffnn_size']] * config['ffnn_depth'], output_size=1) if config['fine_grained'] else None
81 |
82 | self.all_words = 0
83 | self.all_pred_men = 0
84 | self.debug = False
85 |
86 |
87 | def make_embedding(self, dict_size, std=0.02):
88 | emb = nn.Embedding(dict_size, self.config['feature_emb_size'])
89 | init.normal_(emb.weight, std=std)
90 | return emb
91 |
92 | def make_linear(self, in_features, out_features, bias=True, std=0.02):
93 | linear = nn.Linear(in_features, out_features, bias)
94 | init.normal_(linear.weight, std=std)
95 | if bias:
96 | init.zeros_(linear.bias)
97 | return linear
98 |
99 | def make_ffnn(self, feat_size, hidden_size, output_size):
100 | if hidden_size is None or hidden_size == 0 or hidden_size == [] or hidden_size == [0]:
101 | return self.make_linear(feat_size, output_size)
102 |
103 | if not isinstance(hidden_size, Iterable):
104 | hidden_size = [hidden_size]
105 | ffnn = [self.make_linear(feat_size, hidden_size[0]), nn.ReLU(), self.dropout]
106 | for i in range(1, len(hidden_size)):
107 | ffnn += [self.make_linear(hidden_size[i-1], hidden_size[i]), nn.ReLU(), self.dropout]
108 | ffnn.append(self.make_linear(hidden_size[-1], output_size))
109 | return nn.Sequential(*ffnn)
110 |
111 | def get_params(self, named=False):
112 | bert_based_param, task_param = [], []
113 | for name, param in self.named_parameters():
114 | if name.startswith('bert') or name.startswith("mention_transformer"):
115 | to_add = (name, param) if named else param
116 | bert_based_param.append(to_add)
117 | else:
118 | to_add = (name, param) if named else param
119 | task_param.append(to_add)
120 | return bert_based_param, task_param
121 |
122 | def forward(self, *input):
123 | mention_doc = self.get_mention_doc(*input)
124 | return self.get_predictions_and_loss(mention_doc, *input)
125 |
126 | def get_flat_span_location_indices(self, spans, sentence_map):
127 | sentence_map = sentence_map.tolist()
128 | spans_list = spans.tolist()
129 | flat_span_location_indices = []
130 |
131 | prev_sent_id = sentence_map[0]
132 | sentence_lengths, cur_sent_len = [], 0
133 | for i in sentence_map:
134 | if prev_sent_id == i:
135 | cur_sent_len += 1
136 | else:
137 | sentence_lengths.append(cur_sent_len)
138 | cur_sent_len = 1
139 | prev_sent_id = i
140 |
141 | sentence_lengths.append(cur_sent_len)
142 | max_sentence_len = max(sentence_lengths)
143 |
144 | sentence_offsets = np.cumsum([0] + sentence_lengths)[:-1]
145 | for (start, end) in spans_list:
146 | sent_id = sentence_map[start] - sentence_map[0]
147 | offset = sentence_offsets[sent_id]
148 |
149 | flat_id = sent_id * (max_sentence_len**2) + (start-offset)*max_sentence_len + (end-offset)
150 | flat_span_location_indices.append(flat_id)
151 |
152 | sentence_lengths = spans.new_tensor(sentence_lengths)
153 | flat_span_location_indices = spans.new_tensor(flat_span_location_indices)
154 | return flat_span_location_indices, sentence_lengths
155 |
156 | def get_mention_doc(self, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map,
157 | is_training, gold_starts=None, gold_ends=None, gold_mention_cluster_map=None,
158 | coreferable_starts=None, coreferable_ends=None,
159 | constituent_starts=None, constituent_ends=None, constituent_type=None):
160 |
161 | mention_doc = self.bert(input_ids, attention_mask=input_mask) # [num seg, num max tokens, emb size]
162 | mention_doc = mention_doc["last_hidden_state"]
163 | input_mask = input_mask.bool()
164 | mention_doc = mention_doc[input_mask]
165 | return mention_doc
166 |
167 |
168 | def get_predictions_and_loss(
169 | self, mention_doc, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map,
170 | is_training, gold_starts=None, gold_ends=None, gold_mention_cluster_map=None,
171 | coreferable_starts=None, coreferable_ends=None,
172 | constituent_starts=None, constituent_ends=None, constituent_type=None
173 | ):
174 | """ Model and input are already on the device """
175 | device = self.device
176 | conf = self.config
177 |
178 | do_loss = False
179 | if gold_mention_cluster_map is not None:
180 | assert gold_starts is not None
181 | assert gold_ends is not None
182 | do_loss = True
183 |
184 | input_mask = input_mask.bool()
185 | speaker_ids = speaker_ids[input_mask]
186 | num_words = mention_doc.shape[0]
187 |
188 | self.all_words += num_words
189 |
190 | # Get candidate span
191 | sentence_indices = sentence_map # [num tokens]
192 | candidate_starts = torch.unsqueeze(torch.arange(0, num_words, device=device), 1).repeat(1, self.max_span_width)
193 | candidate_ends = candidate_starts + torch.arange(0, self.max_span_width, device=device)
194 | candidate_start_sent_idx = sentence_indices[candidate_starts]
195 | candidate_end_sent_idx = sentence_indices[torch.min(candidate_ends, torch.tensor(num_words - 1, device=device))]
196 | candidate_mask = (candidate_ends < num_words) & (candidate_start_sent_idx == candidate_end_sent_idx)
197 | candidate_mask &= (input_ids[input_mask][candidate_starts] != self.cls_id)
198 | candidate_mask &= (input_ids[input_mask][torch.clamp(candidate_ends, max=num_words-1)] != self.sep_id)
199 |
200 | candidate_starts, candidate_ends = candidate_starts[candidate_mask], candidate_ends[candidate_mask] # [num valid candidates]
201 | num_candidates = candidate_starts.shape[0]
202 |
203 | candidate_labels = None
204 | non_dummy_indicator = None
205 | # Get candidate labels
206 | if do_loss:
207 | same_start = (torch.unsqueeze(gold_starts, 1) == torch.unsqueeze(candidate_starts, 0))
208 | same_end = (torch.unsqueeze(gold_ends, 1) == torch.unsqueeze(candidate_ends, 0))
209 | same_span = (same_start & same_end).long()
210 | candidate_labels = torch.matmul(gold_mention_cluster_map.unsqueeze(0).type_as(mention_doc), same_span.type_as(mention_doc))
211 | candidate_labels = candidate_labels.long().squeeze() # [num candidates]; non-gold span has label 0
212 |
213 |
214 | # Get span embedding
215 | span_start_emb, span_end_emb = mention_doc[candidate_starts], mention_doc[candidate_ends]
216 | # span_start_emb_1, span_end_emb_1 = mention_doc[candidate_starts], mention_doc[candidate_ends+1]
217 | # candidate_emb_list = [span_start_emb, span_end_emb]
218 | candidate_emb_list = [span_start_emb, span_end_emb]
219 | if conf['use_features']:
220 | candidate_width_idx = candidate_ends - candidate_starts
221 | candidate_width_emb = self.emb_span_width(candidate_width_idx)
222 | candidate_width_emb = self.dropout(candidate_width_emb)
223 | candidate_emb_list.append(candidate_width_emb)
224 | # Use attended head or avg token
225 | candidate_tokens = torch.unsqueeze(torch.arange(0, num_words, device=device), 0).repeat(num_candidates, 1)
226 | candidate_tokens_mask = (candidate_tokens >= torch.unsqueeze(candidate_starts, 1)) & (candidate_tokens <= torch.unsqueeze(candidate_ends, 1))
227 | if conf['model_heads']:
228 | token_attn = self.mention_token_attn(mention_doc).squeeze()
229 | else:
230 | token_attn = torch.ones(num_words, dtype=mention_doc.dtype, device=device) # Use avg if no attention
231 | candidate_tokens_attn_raw = candidate_tokens_mask.log() + token_attn.unsqueeze(0)
232 | candidate_tokens_attn = F.softmax(candidate_tokens_attn_raw, dim=1)
233 |
234 | head_attn_emb = torch.matmul(candidate_tokens_attn, mention_doc)
235 | candidate_emb_list.append(head_attn_emb)
236 | candidate_span_emb = torch.cat(candidate_emb_list, dim=-1) # [num candidates, new emb size]
237 |
238 | # Get span scores
239 | candidate_mention_scores_and_parsing = self.span_emb_score_ffnn(candidate_span_emb)
240 |
241 | if type(self.mention_proposer) == CFGMentionProposer:
242 | candidate_mention_scores, candidate_mention_parsing_scores = candidate_mention_scores_and_parsing.split(1, dim=-1)
243 | candidate_mention_scores = candidate_mention_scores.squeeze(1)
244 | elif type(self.mention_proposer) == GreedyMentionProposer:
245 | candidate_mention_scores = candidate_mention_scores_and_parsing.squeeze(-1)
246 | candidate_mention_parsing_scores = candidate_mention_scores
247 |
248 | if conf['use_width_prior']:
249 | width_score = self.span_width_score_ffnn(self.emb_span_width_prior.weight).squeeze(1)
250 | candidate_mention_scores = candidate_mention_scores + width_score[candidate_width_idx]
251 |
252 |
253 | spans = torch.stack([candidate_starts, candidate_ends], dim=-1)
254 | flat_span_location_indices, sentence_lengths = self.get_flat_span_location_indices(
255 | spans, sentence_map
256 | )
257 | num_top_spans = int(min(conf['max_num_extracted_spans'], conf['top_span_ratio'] * num_words))
258 | non_dummy_indicator = (candidate_labels > 0) if candidate_labels is not None else None
259 |
260 | if type(self.mention_proposer) == CFGMentionProposer:
261 | top_span_p_mention, selected_idx, top_spans, mp_loss, _ = self.mention_proposer(
262 | spans,
263 | candidate_mention_parsing_scores,
264 | candidate_mask[candidate_mask],
265 | non_dummy_indicator if non_dummy_indicator is not None else None,
266 | sentence_lengths,
267 | num_top_spans,
268 | flat_span_location_indices,
269 | )
270 | top_span_log_p_mention = top_span_p_mention.log()
271 | top_span_log_p_mention = top_span_log_p_mention[selected_idx]
272 |
273 | elif type(self.mention_proposer) == GreedyMentionProposer:
274 | _, selected_idx, top_spans, mp_loss, _ = self.mention_proposer(
275 | spans,
276 | candidate_mention_parsing_scores,
277 | candidate_mask[candidate_mask],
278 | sentence_lengths,
279 | num_top_spans,
280 | )
281 |
282 | num_top_spans = selected_idx.size(0)
283 |
284 | top_span_starts, top_span_ends = candidate_starts[selected_idx], candidate_ends[selected_idx]
285 | top_span_emb = candidate_span_emb[selected_idx]
286 | top_span_cluster_ids = candidate_labels[selected_idx] if do_loss else None
287 | top_span_mention_scores = candidate_mention_scores[selected_idx]
288 |
289 | # Coarse pruning on each mention's antecedents
290 | max_top_antecedents = min(num_top_spans, conf['max_top_antecedents'])
291 | top_span_range = torch.arange(0, num_top_spans, device=device)
292 | antecedent_offsets = torch.unsqueeze(top_span_range, 1) - torch.unsqueeze(top_span_range, 0)
293 | antecedent_mask = (antecedent_offsets >= 1)
294 | pairwise_mention_score_sum = torch.unsqueeze(top_span_mention_scores, 1) + torch.unsqueeze(top_span_mention_scores, 0)
295 | source_span_emb = self.dropout(self.coarse_bilinear(top_span_emb))
296 | target_span_emb = self.dropout(torch.transpose(top_span_emb, 0, 1))
297 | pairwise_coref_scores = torch.matmul(source_span_emb, target_span_emb)
298 |
299 | pairwise_fast_scores = pairwise_mention_score_sum + pairwise_coref_scores
300 | pairwise_fast_scores += antecedent_mask.type_as(mention_doc).log()
301 | if conf['use_distance_prior']:
302 | distance_score = torch.squeeze(self.antecedent_distance_score_ffnn(self.dropout(self.emb_antecedent_distance_prior.weight)), 1)
303 | bucketed_distance = bucket_distance(antecedent_offsets)
304 | antecedent_distance_score = distance_score[bucketed_distance]
305 | pairwise_fast_scores += antecedent_distance_score
306 | # Slow mention ranking
307 | if conf['fine_grained']:
308 | top_pairwise_fast_scores, top_antecedent_idx = torch.topk(pairwise_fast_scores, k=max_top_antecedents)
309 | top_antecedent_mask = batch_select(antecedent_mask, top_antecedent_idx, device) # [num top spans, max top antecedents]
310 | top_antecedent_offsets = batch_select(antecedent_offsets, top_antecedent_idx, device)
311 |
312 | same_speaker_emb, genre_emb, seg_distance_emb, top_antecedent_distance_emb = None, None, None, None
313 | if conf['use_metadata']:
314 | top_span_speaker_ids = speaker_ids[top_span_starts]
315 | top_antecedent_speaker_id = top_span_speaker_ids[top_antecedent_idx]
316 | same_speaker = torch.unsqueeze(top_span_speaker_ids, 1) == top_antecedent_speaker_id
317 | same_speaker_emb = self.emb_same_speaker(same_speaker.long())
318 | genre_emb = self.emb_genre(genre)
319 | genre_emb = torch.unsqueeze(torch.unsqueeze(genre_emb, 0), 0).repeat(num_top_spans, max_top_antecedents, 1)
320 | if conf['use_segment_distance']:
321 | num_segs, seg_len = input_ids.shape[0], input_ids.shape[1]
322 | token_seg_ids = torch.arange(0, num_segs, device=device).unsqueeze(1).repeat(1, seg_len)
323 | token_seg_ids = token_seg_ids[input_mask]
324 | top_span_seg_ids = token_seg_ids[top_span_starts]
325 | top_antecedent_seg_ids = token_seg_ids[top_span_starts[top_antecedent_idx]]
326 | top_antecedent_seg_distance = torch.unsqueeze(top_span_seg_ids, 1) - top_antecedent_seg_ids
327 | top_antecedent_seg_distance = torch.clamp(top_antecedent_seg_distance, 0, self.config['max_training_sentences'] - 1)
328 | seg_distance_emb = self.emb_segment_distance(top_antecedent_seg_distance)
329 | if conf['use_features']: # Antecedent distance
330 | top_antecedent_distance = bucket_distance(top_antecedent_offsets)
331 | top_antecedent_distance_emb = self.emb_top_antecedent_distance(top_antecedent_distance)
332 |
333 | top_antecedent_emb = top_span_emb[top_antecedent_idx] # [num top spans, max top antecedents, emb size]
334 | feature_list = []
335 | if conf['use_metadata']: # speaker, genre
336 | feature_list.append(same_speaker_emb)
337 | feature_list.append(genre_emb)
338 | if conf['use_segment_distance']:
339 | feature_list.append(seg_distance_emb)
340 | if conf['use_features']: # Antecedent distance
341 | feature_list.append(top_antecedent_distance_emb)
342 | feature_emb = torch.cat(feature_list, dim=2)
343 | feature_emb = self.dropout(feature_emb)
344 | target_emb = torch.unsqueeze(top_span_emb, 1).repeat(1, max_top_antecedents, 1)
345 | # target_parent_emb = torch.unsqueeze(top_span_parent_emb, 1).repeat(1, max_top_antecedents, 1)
346 | similarity_emb = target_emb * top_antecedent_emb
347 | pair_emb = torch.cat([target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2)
348 | top_pairwise_slow_scores = self.coref_score_ffnn(pair_emb).squeeze(2)
349 | # print(pair_emb.size(), mention_doc.size(), pair_emb.size(0) / mention_doc.size()[0])
350 | top_pairwise_scores = top_pairwise_slow_scores + top_pairwise_fast_scores
351 | else:
352 | top_pairwise_fast_scores, top_antecedent_idx = torch.topk(pairwise_fast_scores, k=pairwise_fast_scores.size(0))
353 | top_antecedent_mask = batch_select(antecedent_mask, top_antecedent_idx, device) # [num top spans, max top antecedents]
354 | top_antecedent_offsets = batch_select(antecedent_offsets, top_antecedent_idx, device)
355 |
356 | top_pairwise_scores = top_pairwise_fast_scores # [num top spans, max top antecedents]
357 |
358 | top_antecedent_scores = torch.cat([torch.zeros(num_top_spans, 1, device=device), top_pairwise_scores], dim=1)
359 |
360 |
361 | if not do_loss:
362 | if type(self.mention_proposer) == CFGMentionProposer or self.config["mention_sigmoid"]:
363 | top_antecedent_log_p_mention = top_span_log_p_mention[top_antecedent_idx]
364 | log_norm = logsumexp(top_antecedent_scores, dim=1)
365 | # Shape: (num_spans_to_keep, max_antecedents+1)
366 | log_p_im = top_antecedent_scores - log_norm.unsqueeze(-1) + top_span_log_p_mention.unsqueeze(-1)
367 | # Shape: (num_spans_to_keep)
368 | log_p_em = torch.logaddexp(
369 | log1mexp(top_span_log_p_mention),
370 | top_span_log_p_mention - log_norm + torch.finfo(log_norm.dtype).eps
371 | )
372 | # log probability for inference
373 | log_probs = torch.cat([log_p_em.unsqueeze(-1), log_p_im[:,1:]], dim=-1)
374 |
375 | return candidate_starts, candidate_ends, candidate_mention_parsing_scores, top_span_starts, top_span_ends, top_antecedent_idx, log_probs
376 | elif type(self.mention_proposer) == GreedyMentionProposer:
377 | return candidate_starts, candidate_ends, candidate_mention_parsing_scores, top_span_starts, top_span_ends, top_antecedent_idx, top_antecedent_scores
378 |
379 | log_norm = logsumexp(top_antecedent_scores, dim=1)
380 | if type(self.mention_proposer) == CFGMentionProposer or self.config["mention_sigmoid"]:
381 | top_antecedent_log_p_mention = top_span_log_p_mention[top_antecedent_idx]
382 | # Shape: (num_spans_to_keep, max_antecedents+1)
383 | log_p_im = top_antecedent_scores - log_norm.unsqueeze(-1) + top_span_log_p_mention.unsqueeze(-1)
384 | # Shape: (num_spans_to_keep)
385 | log_p_em = torch.logaddexp(
386 | log1mexp(top_span_log_p_mention) + torch.finfo(log_norm.dtype).eps,
387 | top_span_log_p_mention - log_norm + torch.finfo(log_norm.dtype).eps
388 | )
389 | # log probability for inference
390 | log_probs = torch.cat([log_p_em.unsqueeze(-1), log_p_im[:,1:]], dim=-1)
391 |
392 | # Get gold labels
393 | top_antecedent_cluster_ids = top_span_cluster_ids[top_antecedent_idx]
394 | top_antecedent_cluster_ids += (top_antecedent_mask.long() - 1) * 100000 # Mask id on invalid antecedents
395 | same_gold_cluster_indicator = (top_antecedent_cluster_ids == torch.unsqueeze(top_span_cluster_ids, 1))
396 | non_dummy_indicator = non_dummy_indicator[selected_idx] # (top_span_cluster_ids > 0).squeeze()
397 | # non_dummy_indicator is the coreferable flags
398 | pairwise_labels = same_gold_cluster_indicator & torch.unsqueeze(top_span_cluster_ids > 0, 1)
399 | dummy_antecedent_labels = torch.logical_not(pairwise_labels.any(dim=1, keepdims=True))
400 |
401 | top_antecedent_gold_labels = torch.cat([dummy_antecedent_labels, pairwise_labels], dim=1)
402 |
403 | # Get loss
404 | if type(self.mention_proposer) == CFGMentionProposer or self.config["mention_sigmoid"]:
405 | coref_loss = -logsumexp(log_p_im + top_antecedent_gold_labels.log(), dim=-1) # for mentions
406 | loss = mp_loss + (coref_loss * non_dummy_indicator).sum() + (-log_p_em * torch.logical_not(non_dummy_indicator)).sum()
407 | return [candidate_starts, candidate_ends, candidate_mention_parsing_scores, top_span_starts, top_span_ends, top_antecedent_idx, log_probs], loss
408 | elif type(self.mention_proposer) == GreedyMentionProposer:
409 | log_marginalized_antecedent_scores = logsumexp(top_antecedent_scores + top_antecedent_gold_labels.log(), dim=1)
410 | loss = (log_norm - log_marginalized_antecedent_scores).sum()
411 |
412 | return [candidate_starts, candidate_ends, candidate_mention_parsing_scores, top_span_starts, top_span_ends, top_antecedent_idx, top_antecedent_scores], loss
413 |
414 | def _extract_top_spans(self, candidate_idx_sorted, candidate_starts, candidate_ends, num_top_spans):
415 | """ Keep top non-cross-overlapping candidates ordered by scores; compute on CPU because of loop """
416 | selected_candidate_idx = []
417 | start_to_max_end, end_to_min_start = {}, {}
418 | for candidate_idx in candidate_idx_sorted:
419 | if len(selected_candidate_idx) >= num_top_spans:
420 | break
421 | # Perform overlapping check
422 | span_start_idx = candidate_starts[candidate_idx]
423 | span_end_idx = candidate_ends[candidate_idx]
424 | cross_overlap = False
425 | for token_idx in range(span_start_idx, span_end_idx + 1):
426 | max_end = start_to_max_end.get(token_idx, -1)
427 | if token_idx > span_start_idx and max_end > span_end_idx:
428 | cross_overlap = True
429 | break
430 | min_start = end_to_min_start.get(token_idx, -1)
431 | if token_idx < span_end_idx and 0 <= min_start < span_start_idx:
432 | cross_overlap = True
433 | break
434 | if not cross_overlap:
435 | # Pass check; select idx and update dict stats
436 | selected_candidate_idx.append(candidate_idx)
437 | max_end = start_to_max_end.get(span_start_idx, -1)
438 | if span_end_idx > max_end:
439 | start_to_max_end[span_start_idx] = span_end_idx
440 | min_start = end_to_min_start.get(span_end_idx, -1)
441 | if min_start == -1 or span_start_idx < min_start:
442 | end_to_min_start[span_end_idx] = span_start_idx
443 | # Sort selected candidates by span idx
444 | selected_candidate_idx = sorted(selected_candidate_idx, key=lambda idx: (candidate_starts[idx], candidate_ends[idx]))
445 | if len(selected_candidate_idx) < num_top_spans: # Padding
446 | selected_candidate_idx += ([selected_candidate_idx[0]] * (num_top_spans - len(selected_candidate_idx)))
447 | return selected_candidate_idx
448 |
449 | def get_predicted_antecedents(self, antecedent_idx, antecedent_scores):
450 | """ CPU list input """
451 | predicted_antecedents = []
452 | for i, idx in enumerate((antecedent_scores.argmax(dim=1) - 1).tolist()):
453 | if idx < 0:
454 | predicted_antecedents.append(-1)
455 | elif idx >= len(antecedent_idx[0]):
456 | predicted_antecedents.append(-2)
457 | else:
458 | predicted_antecedents.append(antecedent_idx[i][idx])
459 | return predicted_antecedents
460 |
461 | def get_predicted_clusters(self, span_starts, span_ends, antecedent_idx, antecedent_scores):
462 | """ CPU list input """
463 | # Get predicted antecedents
464 | predicted_antecedents = self.get_predicted_antecedents(antecedent_idx, antecedent_scores)
465 |
466 | # Get predicted clusters
467 | mention_to_cluster_id = {}
468 | predicted_clusters = []
469 | for i, predicted_idx in enumerate(predicted_antecedents):
470 | if predicted_idx == -1:
471 | continue
472 | elif predicted_idx == -2:
473 | cluster_id = len(predicted_clusters)
474 | predicted_clusters.append([(int(span_starts[i]), int(span_ends[i]))])
475 | mention_to_cluster_id[(int(span_starts[i]), int(span_ends[i]))] = cluster_id
476 | continue
477 | assert i > predicted_idx, f'span idx: {i}; antecedent idx: {predicted_idx}'
478 | # Check antecedent's cluster
479 | antecedent = (int(span_starts[predicted_idx]), int(span_ends[predicted_idx]))
480 | antecedent_cluster_id = mention_to_cluster_id.get(antecedent, -1)
481 |
482 | if antecedent_cluster_id == -1:
483 | antecedent_cluster_id = len(predicted_clusters)
484 | predicted_clusters.append([antecedent])
485 | mention_to_cluster_id[antecedent] = antecedent_cluster_id
486 |
487 | # Add mention to cluster
488 | mention = (int(span_starts[i]), int(span_ends[i]))
489 | predicted_clusters[antecedent_cluster_id].append(mention)
490 | mention_to_cluster_id[mention] = antecedent_cluster_id
491 |
492 | predicted_clusters = [tuple(c) for c in predicted_clusters]
493 | return predicted_clusters, mention_to_cluster_id, predicted_antecedents
494 |
495 | def update_evaluator(self, span_starts, span_ends, antecedent_idx, antecedent_scores, gold_clusters, evaluator):
496 | predicted_clusters, mention_to_cluster_id, _ = self.get_predicted_clusters(span_starts, span_ends, antecedent_idx, antecedent_scores)
497 | mention_to_predicted = {m: predicted_clusters[cluster_idx] for m, cluster_idx in mention_to_cluster_id.items()}
498 | gold_clusters = [tuple(tuple(m) for m in cluster) for cluster in gold_clusters]
499 | mention_to_gold = {m: cluster for cluster in gold_clusters for m in cluster}
500 |
501 | # gold mentions
502 | gms = set([x for cluster in gold_clusters for x in cluster])
503 | # getting meta informations, e.g. nested depth, width
504 | metainfo_gms = defaultdict(lambda: defaultdict(int))
505 | for x in gms:
506 | metainfo_gms[x]["width"] = x[1] - x[0]
507 | for y in gms:
508 | if y[0] <= x[0] and y[1] >= x[1]:
509 | metainfo_gms[x]["depth"] += 1
510 |
511 |
512 | recalled_gms = set([(int(x), int(y)) for x,y in zip(span_starts, span_ends)])
513 | self.all_pred_men += len(recalled_gms)
514 |
515 | evaluator.update(
516 | predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold,
517 | metainfo_gms, recalled_gms
518 | )
519 | return predicted_clusters
520 |
521 |
--------------------------------------------------------------------------------
/outside_mp.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from typing import Any, Dict, List, Tuple
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.nn.parallel import DataParallel
10 |
11 | from util import logsumexp, clip_to_01, stripe, masked_topk_non_overlap
12 |
13 | from genbmm import logbmm, logbmminside_rule
14 |
15 | logger = logging.getLogger(__name__)
16 | LARGENUMBER = 1e4
17 |
18 | class CKY(torch.nn.Module):
19 | def __init__(
20 | self,
21 | max_span_width=30,
22 | ):
23 | super().__init__()
24 | self.max_span_width = max_span_width
25 |
26 | def forward(
27 | self,
28 | span_mention_score_matrix: torch.FloatTensor,
29 | sequence_lengths: torch.IntTensor,
30 | ) -> Tuple[torch.FloatTensor]:
31 |
32 | with torch.autograd.enable_grad():
33 | # Enable autograd during inference
34 | # For outside value computation
35 | return self.coolio(span_mention_score_matrix, sequence_lengths)
36 |
37 | def coolio(
38 | self,
39 | span_mention_score_matrix: torch.FloatTensor,
40 | sequence_lengths: torch.IntTensor,
41 | ) -> Tuple[torch.FloatTensor]:
42 | """
43 | Parameters:
44 | span_mention_score_matrix: shape (batch_size, sent_len, sent_len, score_dim)
45 | Score of each span being a span of interest. There are batch_size number
46 | of sentences in this document. And the maximum length of sentence is
47 | sent_len.
48 | sequence_lengths: shape (batch_size, )
49 | The actual length of each sentence.
50 | """
51 | # faster inside-outside
52 | # requiring grad for outside algorithm
53 | # https://www.cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf
54 | span_mention_score_matrix.requires_grad_(True)
55 |
56 | batch_size, _, _, score_dim = span_mention_score_matrix.size()
57 | seq_len = sequence_lengths.max()
58 | # Shape: (batch_size, )
59 | sequence_lengths = sequence_lengths.view(-1)
60 |
61 | rules = span_mention_score_matrix
62 | # distributive law: log(exp(s+r) + exp(s)) == s + log(exp(r) + 1)
63 | log1p_exp_rules = torch.log1p(rules.squeeze(-1).exp())
64 |
65 | zero_rules = (rules.new_ones(seq_len, seq_len).tril(diagonal=-1))*(-LARGENUMBER)
66 | zero_rules = zero_rules.unsqueeze(0).unsqueeze(-1).repeat(batch_size,1,1,1)
67 |
68 | inside_s = torch.cat([rules.clone(), zero_rules], dim=3)
69 | inside_s = inside_s.logsumexp(dim=3)
70 |
71 | for width in range(0, seq_len-1):
72 | # Usage: https://github.com/lyutyuh/genbmm
73 | inside_s = logbmminside_rule(inside_s, log1p_exp_rules, width+1)
74 |
75 | series_batchsize = torch.arange(0, batch_size, dtype=torch.long)
76 | Z = inside_s[series_batchsize, 0, sequence_lengths-1] # (batch_size, )
77 |
78 | marginal = torch.autograd.grad(
79 | Z.sum(),
80 | span_mention_score_matrix,
81 | create_graph=True,
82 | only_inputs=True,
83 | allow_unused=False,
84 | )
85 | marginal = marginal[0].squeeze()
86 | return (Z.view(-1), marginal) # Shape: (batch_size, seq_len, seq_len, )
87 |
88 | def io(
89 | self,
90 | span_mention_score_matrix: torch.FloatTensor,
91 | sequence_lengths: torch.IntTensor,
92 | ) -> Tuple[torch.FloatTensor]:
93 | """
94 | Parameters:
95 | span_mention_score_matrix: shape (batch_size, sent_len, sent_len, score_dim)
96 | Score of each span being a span of interest. There are batch_size number
97 | of sentences in this document. And the maximum length of sentence is
98 | sent_len.
99 | sequence_lengths: shape (batch_size, )
100 | The actual length of each sentence.
101 | """
102 | # inside-outside
103 | span_mention_score_matrix.requires_grad_(True)
104 |
105 | batch_size, _, _, score_dim = span_mention_score_matrix.size()
106 | seq_len = sequence_lengths.max()
107 | # Shape: (batch_size, )
108 | sequence_lengths = sequence_lengths.view(-1)
109 |
110 | # Shape: (seq_len, seq_len, score_dim, batch_size)
111 | span_mention_score_matrix = span_mention_score_matrix.permute(1, 2, 3, 0)
112 |
113 | # There should be another matrix of non-mention span scores, which is full of 0s
114 | # Shape: (seq_len, seq_len, score_dim + 1, batch_size), 2 for mention / non-mention
115 | inside_s = span_mention_score_matrix.new_zeros(seq_len, seq_len, score_dim + 1, batch_size)
116 |
117 | for width in range(0, seq_len):
118 | n = seq_len - width
119 | if width == 0:
120 | inside_s[:,:,:score_dim,:].diagonal(width).copy_(
121 | span_mention_score_matrix.diagonal(width)
122 | )
123 | continue
124 | # [n, width, score_dim + 1, batch_size]
125 | split_1 = stripe(inside_s, n, width)
126 | split_2 = stripe(inside_s, n, width, (1, width), 0)
127 |
128 | # [n, width, batch_size]
129 | inside_s_span = logsumexp(split_1, 2) + logsumexp(split_2, 2)
130 | # [1, batch_size, n]
131 | inside_s_span = logsumexp(inside_s_span, 1, keepdim=True).permute(1, 2, 0)
132 |
133 | if width < self.max_span_width:
134 | inside_s.diagonal(width).copy_(
135 | torch.cat(
136 | [inside_s_span + span_mention_score_matrix.diagonal(width), # mention
137 | inside_s_span], # non-mention
138 | dim=0
139 | )
140 | )
141 | else:
142 | inside_s.diagonal(width).copy_(
143 | torch.cat(
144 | [torch.full_like(span_mention_score_matrix.diagonal(width), -LARGENUMBER), # mention
145 | inside_s_span], # non-mention
146 | dim=0
147 | )
148 | )
149 |
150 | inside_s = inside_s.permute(0,1,3,2) # (seq_len, seq_len, batch_size, 2), 2 for mention / non-mention
151 | series_batchsize = torch.arange(0, batch_size, dtype=torch.long)
152 |
153 | Z = logsumexp(inside_s[0, sequence_lengths-1, series_batchsize], dim=-1) # (batch_size,)
154 |
155 | marginal = torch.autograd.grad(
156 | Z.sum(),
157 | span_mention_score_matrix,
158 | create_graph=True,
159 | only_inputs=True,
160 | allow_unused=False,
161 | )
162 | marginal = marginal[0].squeeze()
163 | return (Z.view(-1), marginal.permute(2,0,1)) # Shape: (batch_size, seq_len, seq_len, )
164 |
165 | @staticmethod
166 | def viterbi(
167 | span_mention_score_matrix: torch.FloatTensor,
168 | sequence_lengths: torch.IntTensor,
169 | ) -> Tuple[torch.FloatTensor]:
170 |
171 | if len(span_mention_score_matrix.size()) == 4:
172 | span_mention_score_matrix, _ = span_mention_score_matrix.max(-1)
173 | # Shape: (seq_len, seq_len, batch_size)
174 | span_mention_score_matrix = span_mention_score_matrix.permute(1, 2, 0)
175 |
176 |
177 | # Shape: (batch_size, )
178 | sequence_lengths = sequence_lengths.view(-1)
179 | # There should be another matrix of non-mention span scores, which is full of 0s
180 | seq_len, _, batch_size = span_mention_score_matrix.size()
181 |
182 | s = span_mention_score_matrix.new_zeros(seq_len, seq_len, 2, batch_size)
183 | p = sequence_lengths.new_zeros(seq_len, seq_len, 2, batch_size) # backtrack
184 |
185 | for width in range(0, seq_len):
186 | n = seq_len - width
187 | span_score = span_mention_score_matrix.diagonal(width)
188 | if width == 0:
189 | s.diagonal(0)[0, :].copy_(span_score)
190 | continue
191 | # [n, width, 2, 1, batch_size]
192 | split1 = stripe(s, n, width, ).unsqueeze(3)
193 | # [n, width, 1, 2, batch_size]
194 | split2 = stripe(s, n, width, (1, width), 0).unsqueeze(2)
195 |
196 | # [n, width, 2, 2, batch_size]
197 | s_span = split1 + split2
198 | # [batch_size, n, 2, width, 2, 2]
199 | s_span = s_span.permute(4, 0, 1, 2, 3).unsqueeze(2).repeat(1,1,2,1,1,1)
200 |
201 | s_span[:,:,0,:,:,:] += span_score.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
202 | # [batch_size, n, 2]
203 | s_span, p_span = s_span.view(batch_size, n, 2, -1).max(-1) # best split
204 | s.diagonal(width).copy_(
205 | s_span.permute(2, 0, 1)
206 | )
207 | starts = p.new_tensor(range(n)).unsqueeze(0).unsqueeze(0)
208 | p.diagonal(width).copy_(
209 | p_span.permute(2, 0, 1) + starts * 4
210 | )
211 |
212 | def backtrack(mat, p, i, j, c):
213 | if j == i:
214 | return p.new_tensor([(i, j, c)], dtype=torch.long)
215 | int_pijc = int(p[i][j][c])
216 | split = int_pijc // 4
217 | ltree = backtrack(mat, p, i, split, (int_pijc // 2) % 2)
218 | rtree = backtrack(mat, p, split+1, j, int_pijc % 2)
219 | return torch.cat([p.new_tensor([(i,j,c)], dtype=torch.long), ltree, rtree], dim=0)
220 |
221 | # (batch_size, seq_len, seq_len, 2)
222 | p = p.permute(3, 0, 1, 2)
223 |
224 | span_mention_score_matrix_cpu = span_mention_score_matrix.cpu()
225 | p_cpu = p.cpu()
226 |
227 | trees = [backtrack(s[:,:,:,i], p_cpu[i], 0, int(sequence_lengths[i]-1),
228 | (int) (s[0,int(sequence_lengths[i]-1),0,i] < s[0,int(sequence_lengths[i]-1),1,i]))
229 | for i in range(batch_size)]
230 |
231 | return trees
232 |
233 | def outside(
234 | self,
235 | inside_s: torch.FloatTensor,
236 | span_mention_score_matrix: torch.FloatTensor,
237 | sequence_lengths: torch.IntTensor,
238 | ):
239 | '''
240 | inside_s: Shape: (seq_len, seq_len, 2, batch_size), 2 for mention / non-mention
241 | span_mention_score_matrix: Shape: (seq_len, seq_len, batch_size)
242 |
243 | Return: outside_s : Shape: (seq_len, seq_len, 2, batch_size), 2 for mention / non-mention
244 | '''
245 | seq_len = sequence_lengths.max()
246 |
247 | _, _, batch_size = span_mention_score_matrix.size()
248 | series_batchsize = torch.arange(0, batch_size, dtype=torch.long)
249 |
250 | # Shape: (seq_len, seq_len, batch_size)
251 | # outside_s = span_mention_score_matrix.new_zeros(seq_len, seq_len, batch_size)
252 | outside_s = span_mention_score_matrix.new_full((seq_len, seq_len, batch_size), fill_value=-LARGENUMBER)
253 |
254 | mask_top = span_mention_score_matrix.new_zeros(seq_len, seq_len, batch_size).bool()
255 | mask_top[0, sequence_lengths-1, series_batchsize] = 1
256 |
257 | for width in range(seq_len-1, -1, -1):
258 | n = seq_len - width
259 | if width == seq_len-1:
260 | continue
261 | outside_s[mask_top] = 0
262 | # [n, n, 1, 2, batch_size]
263 | split_1 = inside_s[:n-1,:n-1].unsqueeze(2) # using the upper triangular [:n, :n] of the inside matrix
264 | # [n, n, 2, 1, batch_size]
265 | split_2 = outside_s[:n-1, width+1:seq_len].unsqueeze(2).unsqueeze(3).repeat(1,1,2,1,1) # using a submatrix of the outside matrix [:n, width:seq_len]
266 | # [n, n, 1, batch_size]
267 | span_score_submatrix = span_mention_score_matrix[:n-1, width+1:seq_len].unsqueeze(2)
268 | # [n, n, 1, 2, batch_size]
269 | split_3 = inside_s[width+1:seq_len, width+1:seq_len].unsqueeze(2) # using the upper triangular [width:seq_len, width:seq_len] of the inside matrix
270 |
271 | # [n, n, 1, 1, 1]
272 | upp_triu_mask = torch.triu(span_score_submatrix.new_ones(n-1,n-1), diagonal=0).view(n-1,n-1,1,1,1)
273 |
274 | # [n, n, 2, 2, batch_size]
275 | # B -> CA, B, C, A \in {0,1}
276 | outside_s_span_1 = (split_1 + split_2)
277 | outside_s_span_1[:,:,0,:,:] += span_score_submatrix
278 |
279 | outside_s_span_1 += (upp_triu_mask*LARGENUMBER - LARGENUMBER) # upp_triu_mask.log() #
280 | # [batch_size, n, n, 2, 2]
281 | outside_s_span_1 = outside_s_span_1.permute(4, 1, 0, 2, 3)
282 | # [batch_size, n]
283 | outside_s_span_1 = logsumexp(outside_s_span_1.reshape(batch_size, n-1, -1), dim=-1)
284 | # outside_s_span_1.logsumexp((1,3,4)) # sum vertical, as right child
285 |
286 | # [n, n, 2, 2, batch_size]
287 | outside_s_span_2 = (split_3 + split_2)
288 | outside_s_span_2[:,:,0,:,:] += span_score_submatrix
289 |
290 | outside_s_span_2 += (upp_triu_mask*LARGENUMBER - LARGENUMBER) # upp_triu_mask.log() #
291 | # [batch_size, n, n, 2, 2]
292 | outside_s_span_2 = outside_s_span_2.permute(4, 0, 1, 2, 3)
293 | # [batch_size, n]
294 | outside_s_span_2 = logsumexp(outside_s_span_2.view(batch_size, n-1, -1), dim=-1) # sum horizontal, as left child
295 |
296 | # shift and sum
297 | outside_s_span_1 = torch.cat([outside_s_span_1.new_tensor([float(-LARGENUMBER)]*batch_size).unsqueeze(-1), outside_s_span_1], dim=-1)
298 | outside_s_span_2 = torch.cat([outside_s_span_2, outside_s_span_2.new_tensor([float(-LARGENUMBER)]*batch_size).unsqueeze(-1)], dim=-1)
299 |
300 | # [batch_size, n, 2]
301 | outside_s_span = torch.stack([outside_s_span_1, outside_s_span_2], dim=-1)
302 |
303 | # [batch_size, n]
304 | outside_s_span = logsumexp(outside_s_span, dim=-1)
305 |
306 | outside_s.diagonal(width).copy_(outside_s_span)
307 |
308 | return outside_s
309 |
310 |
311 | def get_sentence_matrix(
312 | sentence_num,
313 | max_sentence_length,
314 | unidimensional_values,
315 | span_location_indices,
316 | padding_value=0.
317 | ):
318 | total_units = sentence_num * max_sentence_length * max_sentence_length
319 | flat_matrix_by_sentence = unidimensional_values.new_full(
320 | (total_units, unidimensional_values.size(-1)), padding_value
321 | ).index_copy(0, span_location_indices, unidimensional_values.view(-1, unidimensional_values.size(-1)))
322 |
323 | return flat_matrix_by_sentence.view(sentence_num, max_sentence_length, max_sentence_length, unidimensional_values.size(-1))
324 |
325 |
326 | class CFGMentionProposer(torch.nn.Module):
327 | def __init__(
328 | self,
329 | max_span_width=30,
330 | neg_sample_rate=0.2,
331 | **kwargs
332 | ) -> None:
333 | super().__init__(**kwargs)
334 | self.neg_sample_rate = float(neg_sample_rate)
335 | self.cky_module = CKY(max_span_width)
336 |
337 | def forward(
338 | self, # type: ignore
339 | spans: torch.IntTensor,
340 | span_mention_scores: torch.FloatTensor,
341 | span_mask: torch.FloatTensor,
342 | span_labels: torch.IntTensor,
343 | sentence_lengths: torch.IntTensor,
344 | num_spans_to_keep: int,
345 | flat_span_location_indices: torch.IntTensor,
346 | take_top_spans_per_sentence = False,
347 | flat_span_sent_ids = None,
348 | ratio = 0.
349 | ):
350 | # Shape: (batch_size, document_length, embedding_size)
351 | num_spans = spans.size(1)
352 | span_max_item = spans.max()
353 |
354 | sentence_offsets = torch.cumsum(sentence_lengths.squeeze(), 0)
355 | sentence_offsets = torch.cat(
356 | [sentence_offsets.new_zeros(1, 1), sentence_offsets.view(1, -1)],
357 | dim=-1
358 | )
359 | span_mention_scores = span_mention_scores + (span_mask.unsqueeze(-1) * LARGENUMBER - LARGENUMBER)
360 | max_sentence_length = sentence_lengths.max()
361 | sentence_num = sentence_lengths.size(0)
362 |
363 | # We directly calculate indices of span scores in the matrices during data preparation.
364 | # The indices is 1-d (except the batch dimension) to facilitate index_copy_
365 | # We copy the span scores into (batch_size, sentence_num, max_sentence_length, max_sentence_length)
366 | # shaped score matrices
367 |
368 | # We will do sentence-level CKY over span scores
369 | # span_mention_scores shape: (batch_size, num_spans, 2)
370 | # the first column scores are for parsing, the second column for linking
371 |
372 | span_score_matrix_by_sentence = get_sentence_matrix(
373 | sentence_num, max_sentence_length, span_mention_scores,
374 | flat_span_location_indices, padding_value=-LARGENUMBER
375 | )
376 | valid_span_flag_matrix_by_sentence = get_sentence_matrix(
377 | sentence_num, max_sentence_length, torch.ones_like(span_mask).unsqueeze(-1),
378 | flat_span_location_indices, padding_value=0
379 | ).squeeze(-1)
380 |
381 |
382 | Z, marginal = self.cky_module(
383 | span_score_matrix_by_sentence, sentence_lengths
384 | )
385 | span_marginal = torch.masked_select(marginal, valid_span_flag_matrix_by_sentence)
386 |
387 | if not take_top_spans_per_sentence:
388 | top_span_indices = masked_topk_non_overlap(
389 | span_marginal,
390 | span_mask,
391 | num_spans_to_keep,
392 | spans
393 | )
394 | span_marginal = clip_to_01(span_marginal)
395 |
396 | top_ind_list = top_span_indices.tolist()
397 | all_ind_list = list(range(0, span_marginal.size(0)))
398 | neg_sample_indices = np.random.choice(
399 | list(set(all_ind_list) - set(top_ind_list)),
400 | int(self.neg_sample_rate * num_spans_to_keep),
401 | False
402 | )
403 | neg_sample_indices = top_span_indices.new_tensor(sorted(neg_sample_indices))
404 | else:
405 | top_span_indices, sentwise_top_span_marginal, top_spans = [], [], []
406 | prev_sent_id, prev_span_id = 0, 0
407 | for span_id, sent_id in enumerate(flat_span_sent_ids.tolist()):
408 | if sent_id != prev_sent_id:
409 | sent_span_indices = masked_topk_non_overlap(
410 | span_marginal[prev_span_id:span_id],
411 | span_mask[prev_span_id:span_id],
412 | int(ratio * sentence_lengths[prev_sent_id]),
413 | spans[prev_span_id:span_id],
414 | non_crossing=False,
415 | ) + prev_span_id
416 |
417 | top_span_indices.append(sent_span_indices)
418 | sentwise_top_span_marginal.append(span_marginal[sent_span_indices])
419 | top_spans.append(spans[sent_span_indices])
420 |
421 | prev_sent_id, prev_span_id = sent_id, span_id
422 | # last sentence
423 | sent_span_indices = masked_topk_non_overlap(
424 | span_marginal[prev_span_id:],
425 | span_mask[prev_span_id:],
426 | int(ratio * sentence_lengths[-1]),
427 | spans[prev_span_id:],
428 | non_crossing=False,
429 | ) + prev_span_id
430 |
431 | top_span_indices.append(sent_span_indices)
432 | sentwise_top_span_marginal.append(span_marginal[sent_span_indices])
433 | top_spans.append(spans[sent_span_indices])
434 |
435 | num_top_spans = [x.size(0) for x in top_span_indices]
436 | max_num_top_span = max(num_top_spans)
437 |
438 | sentwise_top_span_marginal = torch.stack(
439 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in sentwise_top_span_marginal], dim=0
440 | )
441 | top_spans = torch.stack(
442 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), 2))], dim=0) for x in top_spans], dim=0
443 | )
444 | top_span_masks = torch.stack(
445 | [torch.cat([spans.new_ones((x, )), spans.new_zeros((max_num_top_span-x, ))], dim=0) for x in num_top_spans], dim=0
446 | )
447 |
448 | flat_top_span_indices = torch.cat(top_span_indices, dim=0)
449 | top_span_indices = torch.stack(
450 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in top_span_indices], dim=0
451 | )
452 |
453 | top_ind_list = flat_top_span_indices.tolist()
454 | all_ind_list = list(range(0, span_marginal.size(0)))
455 | neg_sample_indices = np.random.choice(
456 | list(set(all_ind_list) - set(top_ind_list)),
457 | int(self.neg_sample_rate * num_spans_to_keep), False
458 | )
459 | neg_sample_indices = top_span_indices.new_tensor(sorted(neg_sample_indices))
460 | pass # End else
461 |
462 |
463 | if not self.training:
464 | with torch.no_grad():
465 | best_trees = CKY.viterbi(span_score_matrix_by_sentence.detach(), sentence_lengths)
466 |
467 | best_tree_spans = [(x[:,:2]+offset).cuda() for x, offset in zip(best_trees, sentence_offsets.view(-1).cpu())]
468 |
469 | best_tree_tags = torch.cat([x[:,-1] for x in best_trees], dim=-1).cuda()
470 | best_tree_spans = torch.cat(best_tree_spans, dim=0).cuda()
471 | best_tree_span_mask = (best_tree_tags == 0).unsqueeze(-1)
472 | if best_tree_span_mask.sum() > 0:
473 | # top spans per sentence
474 | helper_matrix = span_mask.new_zeros(span_max_item+1, span_max_item+1)
475 | top_spans = torch.masked_select(best_tree_spans, best_tree_span_mask).view(-1, 2)
476 | helper_matrix[top_spans[:,0],top_spans[:,1]] |= torch.tensor(True)
477 | top_span_mask = helper_matrix[spans[:,0],spans[:,1]]
478 | top_span_indices = torch.nonzero(top_span_mask, as_tuple=True)[0]
479 |
480 | if take_top_spans_per_sentence:
481 | sentwise_top_span_indices, sentwise_top_span_marginal, sentwise_top_spans = [], [], []
482 | prev_sent_id, prev_span_id = 0, 0
483 |
484 | for span_id, sent_id in enumerate(flat_span_sent_ids.tolist()):
485 | if sent_id != prev_sent_id:
486 | current_sentence_indices = torch.nonzero(top_span_mask[prev_span_id:span_id], as_tuple=True)[0] # unshifted
487 | sentwise_top_span_indices.append(current_sentence_indices + prev_span_id)
488 | sentwise_top_span_marginal.append(span_marginal[prev_span_id:span_id][current_sentence_indices])
489 | sentwise_top_spans.append(spans[prev_span_id:span_id][current_sentence_indices])
490 |
491 | prev_sent_id, prev_span_id = sent_id, span_id
492 |
493 | current_sentence_indices = torch.nonzero(top_span_mask[prev_span_id:], as_tuple=True)[0] # unshifted
494 | sentwise_top_span_indices.append(current_sentence_indices + prev_span_id)
495 | sentwise_top_span_marginal.append(span_marginal[prev_span_id:][current_sentence_indices])
496 | sentwise_top_spans.append(spans[prev_span_id:][current_sentence_indices])
497 |
498 | num_top_spans = [x.size(0) for x in sentwise_top_span_indices]
499 | max_num_top_span = max(num_top_spans)
500 |
501 | top_spans = torch.stack(
502 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), 2))], dim=0) for x in sentwise_top_spans], dim=0
503 | )
504 | top_span_masks = torch.stack(
505 | [torch.cat([spans.new_ones((x, )), spans.new_zeros((max_num_top_span-x, ))], dim=0) for x in num_top_spans], dim=0
506 | )
507 | top_span_indices = torch.stack(
508 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in sentwise_top_span_indices], dim=0
509 | )
510 | sentwise_top_span_marginal = torch.stack(
511 | [torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in sentwise_top_span_marginal], dim=0
512 | )
513 | else:
514 | logger.info("expected %d but %d in CKY parse, not using CKY parse" % (num_spans_to_keep, int(best_tree_span_mask.sum())))
515 | pass
516 |
517 |
518 | if self.training and neg_sample_indices.size(0) > 0:
519 | if take_top_spans_per_sentence:
520 | not_mention_loss = -(1 - span_marginal[neg_sample_indices]).log()
521 | loss = not_mention_loss.mean() * (self.neg_sample_rate * num_spans_to_keep)
522 | else:
523 | non_mention_flag = (span_labels <= 0)
524 | # -log(1 - P(m))
525 | not_mention_loss = -(1 - span_marginal[neg_sample_indices]).log() * non_mention_flag[neg_sample_indices]
526 | loss = not_mention_loss.mean() * (self.neg_sample_rate * num_spans_to_keep)
527 | else:
528 | loss = 0.
529 |
530 | if not take_top_spans_per_sentence:
531 | top_spans = spans[top_span_indices]
532 | return span_marginal, top_span_indices, top_spans, loss, None
533 | else:
534 |
535 | return sentwise_top_span_marginal, top_span_indices, top_spans, loss, None, top_span_masks
536 |
537 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import numpy as np
4 | import torch
5 | from torch.utils.tensorboard import SummaryWriter
6 | from transformers import AdamW
7 | from torch.optim import Adam
8 | from tensorize import CorefDataProcessor
9 | import util
10 | import time
11 | from os.path import join
12 | from metrics import CorefEvaluator
13 | from datetime import datetime
14 | from torch.optim.lr_scheduler import LambdaLR
15 |
16 | import json
17 |
18 | import model as Model
19 |
20 | import conll
21 | import sys
22 |
23 | torch.autograd.set_detect_anomaly(False)
24 | USE_AMP = True
25 |
26 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
27 | datefmt='%m/%d/%Y %H:%M:%S',
28 | level=logging.INFO)
29 | logger = logging.getLogger()
30 |
31 | class Runner:
32 | def __init__(self, config_name, gpu_id=0, seed=None):
33 | self.name = config_name
34 | self.name_suffix = datetime.now().strftime('%b%d_%H-%M-%S')
35 | self.gpu_id = gpu_id
36 | self.seed = seed
37 |
38 | # Set up config
39 | self.config = util.initialize_config(config_name)
40 |
41 | # Set up logger
42 | log_path = join(self.config['log_dir'], 'log_' + self.name_suffix + '.txt')
43 | logger.addHandler(logging.FileHandler(log_path, 'a'))
44 | logger.info('Log file path: %s' % log_path)
45 |
46 | # Set up seed
47 | if seed:
48 | util.set_seed(seed)
49 |
50 | # Set up device
51 | self.device = torch.device('cpu' if gpu_id is None else f'cuda:{gpu_id}')
52 | self.scaler = torch.cuda.amp.GradScaler(init_scale=256.0)
53 |
54 | # Set up data
55 | self.data = CorefDataProcessor(self.config)
56 |
57 | def initialize_model(self, saved_suffix=None):
58 | model = Model.CorefModel(self.config, self.device)
59 | if saved_suffix:
60 | self.load_model_checkpoint(model, saved_suffix)
61 | return model
62 |
63 | def train(self, model):
64 | conf = self.config
65 | logger.info(conf)
66 | epochs, grad_accum = conf['num_epochs'], conf['gradient_accumulation_steps']
67 |
68 | model.to(self.device)
69 | logger.info('Model parameters:')
70 | for name, param in model.named_parameters():
71 | logger.info('%s: %s' % (name, tuple(param.shape)))
72 |
73 | # Set up tensorboard
74 | tb_path = join(conf['tb_dir'], self.name + '_' + self.name_suffix)
75 | tb_writer = SummaryWriter(tb_path, flush_secs=30)
76 | logger.info('Tensorboard summary path: %s' % tb_path)
77 |
78 | # Set up data
79 | examples_train, examples_dev, examples_test = self.data.get_tensor_examples()
80 | stored_info = self.data.get_stored_info()
81 |
82 | # Set up optimizer and scheduler
83 | total_update_steps = len(examples_train) * epochs // grad_accum
84 | optimizers = self.get_optimizer(model)
85 | schedulers = self.get_scheduler(optimizers, total_update_steps)
86 |
87 | # Get model parameters for grad clipping
88 | bert_param, task_param = model.get_params()
89 |
90 | # Start training
91 | logger.info('*******************Training*******************')
92 | logger.info('Num samples: %d' % len(examples_train))
93 | logger.info('Num epochs: %d' % epochs)
94 | logger.info('Gradient accumulation steps: %d' % grad_accum)
95 | logger.info('Total update steps: %d' % total_update_steps)
96 |
97 | loss_during_accum = [] # To compute effective loss at each update
98 | loss_during_report = 0.0 # Effective loss during logging step
99 | loss_history = [] # Full history of effective loss; length equals total update steps
100 | max_f1, max_f1_test = 0, 0
101 | start_time = time.time()
102 | model.zero_grad()
103 | for epo in range(epochs):
104 | print("EPOCH", epo)
105 | random.shuffle(examples_train) # Shuffle training set
106 | for doc_key, example in examples_train:
107 | # Forward pass
108 | model.train()
109 | example_gpu = [d.to(self.device) if d is not None else None for d in example]
110 |
111 | with torch.cuda.amp.autocast(enabled=USE_AMP):
112 | _, loss = model(*example_gpu)
113 |
114 | # Backward; accumulate gradients and clip by grad norm
115 | if grad_accum > 1:
116 | loss /= grad_accum
117 |
118 | if USE_AMP:
119 | scaled_loss = self.scaler.scale(loss)
120 | scaled_loss.backward()
121 | else:
122 | loss.backward()
123 |
124 | loss_during_accum.append(loss.item())
125 |
126 | # Update
127 | if len(loss_during_accum) % grad_accum == 0:
128 | if USE_AMP:
129 | self.scaler.unscale_(optimizers[0])
130 | self.scaler.unscale_(optimizers[1])
131 |
132 | if conf['max_grad_norm']:
133 | norm_bert = torch.nn.utils.clip_grad_norm_(bert_param, conf['max_grad_norm'], error_if_nonfinite=False)
134 | norm_task = torch.nn.utils.clip_grad_norm_(task_param, conf['max_grad_norm'], error_if_nonfinite=False)
135 |
136 | for optimizer in optimizers:
137 | if USE_AMP:
138 | self.scaler.step(optimizer)
139 | else:
140 | optimizer.step()
141 |
142 | if USE_AMP:
143 | self.scaler.update()
144 |
145 | model.zero_grad()
146 | for scheduler in schedulers:
147 | scheduler.step()
148 |
149 | # Compute effective loss
150 | effective_loss = np.sum(loss_during_accum).item()
151 | loss_during_accum = []
152 | loss_during_report += effective_loss
153 | loss_history.append(effective_loss)
154 |
155 | # Report
156 | if len(loss_history) % conf['report_frequency'] == 0:
157 | # Show avg loss during last report interval
158 | avg_loss = loss_during_report / conf['report_frequency']
159 | loss_during_report = 0.0
160 | end_time = time.time()
161 | logger.info('Step %d: avg loss %.2f; steps/sec %.2f' %
162 | (len(loss_history), avg_loss, conf['report_frequency'] / (end_time - start_time)))
163 | start_time = end_time
164 |
165 | tb_writer.add_scalar('Training_Loss', avg_loss, len(loss_history))
166 | tb_writer.add_scalar('Learning_Rate_Bert', schedulers[0].get_last_lr()[0], len(loss_history))
167 | tb_writer.add_scalar('Learning_Rate_Task', schedulers[1].get_last_lr()[-1], len(loss_history))
168 |
169 | # Evaluate
170 | if len(loss_history) > 0 and len(loss_history) % conf['eval_frequency'] == 0:
171 | # Testing Dev
172 | logger.info('Dev')
173 | f1, _ = self.evaluate(model, examples_dev, stored_info, len(loss_history), official=False, conll_path=self.config['conll_eval_path'], tb_writer=tb_writer)
174 | # Testing Test
175 | logger.info('Test')
176 | f1_test = 0.
177 | if f1 > max_f1 or f1_test > max_f1_test:
178 | max_f1 = max(max_f1, f1)
179 | max_f1_test = 0. # max(max_f1_test, f1_test)
180 | self.save_model_checkpoint(model, len(loss_history))
181 |
182 | logger.info('Eval max f1: %.2f' % max_f1)
183 | logger.info('Test max f1: %.2f' % max_f1_test)
184 | start_time = time.time()
185 |
186 | logger.info('**********Finished training**********')
187 | logger.info('Actual update steps: %d' % len(loss_history))
188 |
189 | # Wrap up
190 | tb_writer.close()
191 | return loss_history
192 |
193 | def evaluate(self, model, tensor_examples, stored_info, step, official=False, conll_path=None, tb_writer=None, predict=False):
194 | logger.info('Step %d: evaluating on %d samples...' % (step, len(tensor_examples)))
195 |
196 | model.to(self.device)
197 | evaluator = CorefEvaluator()
198 | doc_to_prediction = {}
199 |
200 | model.eval()
201 |
202 | for i, (doc_key, tensor_example) in enumerate(tensor_examples):
203 | current_json = {}
204 |
205 | gold_clusters = stored_info['gold'][doc_key]
206 | tensor_example = tensor_example[:7] # Strip out gold
207 | example_gpu = [d.to(self.device) for d in tensor_example]
208 |
209 | with torch.no_grad():
210 | returned_tuple = model(*example_gpu)
211 |
212 | if len(returned_tuple) == 10:
213 | _, _, _, span_starts, span_ends, antecedent_idx, antecedent_scores, score_j_i, input_ids, head_cond_score = returned_tuple
214 |
215 | current_json["score_j_i"] = score_j_i.tolist()
216 | current_json["input_ids"] = input_ids.tolist()
217 | current_json["head_cond_score"] = head_cond_score.tolist()
218 |
219 | elif len(returned_tuple) == 7:
220 | _, _, _, span_starts, span_ends, antecedent_idx, antecedent_scores = returned_tuple
221 |
222 | span_starts, span_ends = span_starts.tolist(), span_ends.tolist()
223 | antecedent_idx, antecedent_scores = antecedent_idx, antecedent_scores
224 |
225 | predicted_clusters = model.update_evaluator(span_starts, span_ends, antecedent_idx, antecedent_scores, gold_clusters, evaluator)
226 | doc_to_prediction[doc_key] = predicted_clusters
227 |
228 |
229 | p, r, f, m_recall, (blanc_p, blanc_r, blanc_f) = evaluator.get_prf()
230 | all_metrics = evaluator.get_all()
231 |
232 | metrics = {
233 | 'Eval_Avg_Precision': p * 100, 'Eval_Avg_Recall': r * 100, 'Eval_Avg_F1': f * 100,
234 | "Eval_Men_Recall": m_recall*100,
235 | "Eval_Blanc_Precision": blanc_p * 100, "Eval_Blanc_Recall": blanc_r * 100, "Eval_Blanc_F1": blanc_f * 100
236 | }
237 |
238 | for k,v in all_metrics.items():
239 | logger.info('%s: %.4f'%(k, v))
240 |
241 | for name, score in metrics.items():
242 | logger.info('%s: %.2f' % (name, score))
243 | if tb_writer:
244 | tb_writer.add_scalar(name, score, step)
245 |
246 | if official:
247 | conll_results = conll.evaluate_conll(conll_path, doc_to_prediction, stored_info['subtoken_maps'])
248 | official_f1 = sum(results["f"] for results in conll_results.values()) / len(conll_results)
249 | logger.info('Official avg F1: %.4f' % official_f1)
250 |
251 | return f * 100, metrics
252 |
253 | def predict(self, model, tensor_examples):
254 | logger.info('Predicting %d samples...' % len(tensor_examples))
255 | model.to(self.device)
256 | model.eval()
257 | predicted_spans, predicted_antecedents, predicted_clusters = [], [], []
258 |
259 | for i, (doc_key, tensor_example) in enumerate(tensor_examples):
260 | tensor_example = tensor_example[:7]
261 | example_gpu = [d.to(self.device) for d in tensor_example]
262 | with torch.no_grad():
263 | _, _, _, span_starts, span_ends, antecedent_idx, antecedent_scores = model(*example_gpu)
264 | span_starts, span_ends = span_starts.tolist(), span_ends.tolist()
265 | antecedent_idx, antecedent_scores = antecedent_idx.tolist(), antecedent_scores.tolist()
266 | clusters, mention_to_cluster_id, antecedents = model.get_predicted_clusters(span_starts, span_ends, antecedent_idx, antecedent_scores)
267 |
268 | spans = [(span_start, span_end) for span_start, span_end in zip(span_starts, span_ends)]
269 | predicted_spans.append(spans)
270 | predicted_antecedents.append(antecedents)
271 | predicted_clusters.append(clusters)
272 |
273 | return predicted_clusters, predicted_spans, predicted_antecedents
274 |
275 | def get_optimizer(self, model):
276 | no_decay = ['bias', 'LayerNorm.weight']
277 | bert_param, task_param = model.get_params(named=True)
278 | grouped_bert_param = [
279 | {
280 | 'params': [p for n, p in bert_param if not any(nd in n for nd in no_decay)],
281 | 'lr': self.config['bert_learning_rate'],
282 | 'weight_decay': self.config['adam_weight_decay']
283 | }, {
284 | 'params': [p for n, p in bert_param if any(nd in n for nd in no_decay)],
285 | 'lr': self.config['bert_learning_rate'],
286 | 'weight_decay': 0.0
287 | }
288 | ]
289 | optimizers = [
290 | AdamW(grouped_bert_param, lr=self.config['bert_learning_rate'], eps=self.config['adam_eps']),
291 | Adam(model.get_params()[1], lr=self.config['task_learning_rate'], eps=self.config['adam_eps'], weight_decay=0)
292 | ]
293 | return optimizers
294 |
295 | def get_scheduler(self, optimizers, total_update_steps):
296 | # Only warm up bert lr
297 | warmup_steps = int(total_update_steps * self.config['warmup_ratio'])
298 |
299 | def lr_lambda_bert(current_step):
300 | if current_step < warmup_steps:
301 | return float(current_step) / float(max(1, warmup_steps))
302 | return max(
303 | 0.0, float(total_update_steps - current_step) / float(max(1, total_update_steps - warmup_steps))
304 | )
305 |
306 | def lr_lambda_task(current_step):
307 | return max(0.0, float(total_update_steps - current_step) / float(max(1, total_update_steps)))
308 |
309 | schedulers = [
310 | LambdaLR(optimizers[0], lr_lambda_bert),
311 | LambdaLR(optimizers[1], lr_lambda_task)
312 | ]
313 | return schedulers
314 |
315 | def save_model_checkpoint(self, model, step):
316 | path_ckpt = join(self.config['log_dir'], f'model_{self.name_suffix}_{step}.bin')
317 | torch.save(model.state_dict(), path_ckpt)
318 | logger.info('Saved model to %s' % path_ckpt)
319 |
320 | def load_model_checkpoint(self, model, suffix):
321 | path_ckpt = join(self.config['log_dir'], f'model_{suffix}.bin')
322 | model.load_state_dict(torch.load(path_ckpt, map_location=torch.device('cpu')), strict=False)
323 | logger.info('Loaded model from %s' % path_ckpt)
324 |
325 |
326 | if __name__ == '__main__':
327 | config_name, gpu_id = sys.argv[1], int(sys.argv[2])
328 | saved_suffix = sys.argv[3] if len(sys.argv) >= 4 else None
329 | runner = Runner(config_name, gpu_id)
330 | model = runner.initialize_model(saved_suffix)
331 |
332 | runner.train(model)
333 |
--------------------------------------------------------------------------------
/sss.yml:
--------------------------------------------------------------------------------
1 | name: sss
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - huggingface_hub
8 | - pip
9 | - scipy
10 | - pyhocon
11 | - python=3.8
12 | - pytorch=1.11.0
13 | - tensorboard
14 | - tqdm
15 | - transformers=4.19.3
16 | - numpy
17 | - scikit-learn
--------------------------------------------------------------------------------
/tensorize.py:
--------------------------------------------------------------------------------
1 | import util
2 | import numpy as np
3 | import random
4 | from transformers import AutoTokenizer
5 | import os
6 | from os.path import join
7 | import json
8 | import pickle
9 | import logging
10 | import torch
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | class CorefDataProcessor:
15 | def __init__(self, config, language='english'):
16 | self.config = config
17 | self.language = language
18 |
19 | self.max_seg_len = config['max_segment_len']
20 | self.max_training_seg = config['max_training_sentences']
21 | self.data_dir = config['data_dir']
22 |
23 | # Get tensorized samples
24 | cache_path = self.get_cache_path()
25 | if os.path.exists(cache_path):
26 | # Load cached tensors if exists
27 | with open(cache_path, 'rb') as f:
28 | self.tensor_samples, self.stored_info = pickle.load(f)
29 | logger.info('Loaded tensorized examples from cache')
30 | else:
31 | # Generate tensorized samples
32 | if self.config["dataset"] == "ontonotes":
33 | self.tensor_samples = {}
34 | tensorizer = Tensorizer(self.config)
35 | paths = {
36 | 'trn': join(self.data_dir, f'train.{language}.{self.max_seg_len}.jsonlines'),
37 | 'dev': join(self.data_dir, f'dev.{language}.{self.max_seg_len}.jsonlines'),
38 | 'tst': join(self.data_dir, f'test.{language}.{self.max_seg_len}.jsonlines')
39 | }
40 | for split, path in paths.items():
41 | logger.info('Tensorizing examples from %s; results will be cached)' % path)
42 | is_training = (split == 'trn')
43 | with open(path, 'r') as f:
44 | samples = [json.loads(line) for line in f.readlines()]
45 | tensor_samples = [tensorizer.tensorize_example(sample, is_training) for sample in samples]
46 | print(len(tensor_samples[0]))
47 | self.tensor_samples[split] = [(doc_key, self.convert_to_torch_tensor(*tensor)) for doc_key, tensor in tensor_samples]
48 | self.stored_info = tensorizer.stored_info
49 | # Cache tensorized samples
50 | with open(cache_path, 'wb') as f:
51 | pickle.dump((self.tensor_samples, self.stored_info), f)
52 |
53 |
54 | @classmethod
55 | def convert_to_torch_tensor(cls, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map,
56 | is_training, gold_starts, gold_ends, gold_mention_cluster_map,
57 | coreferable_starts, coreferable_ends,
58 | constituent_starts, constituent_ends, constituent_type):
59 |
60 | input_ids = torch.tensor(input_ids, dtype=torch.long)
61 | input_mask = torch.tensor(input_mask, dtype=torch.long)
62 | speaker_ids = torch.tensor(speaker_ids, dtype=torch.long)
63 | sentence_len = torch.tensor(sentence_len, dtype=torch.long)
64 | genre = torch.tensor(genre, dtype=torch.long)
65 | sentence_map = torch.tensor(sentence_map, dtype=torch.long)
66 | is_training = torch.tensor(is_training, dtype=torch.bool)
67 | gold_starts = torch.tensor(gold_starts, dtype=torch.long)
68 | gold_ends = torch.tensor(gold_ends, dtype=torch.long)
69 | gold_mention_cluster_map = torch.tensor(gold_mention_cluster_map, dtype=torch.long)
70 | coreferable_starts = torch.tensor(coreferable_starts, dtype=torch.long) if coreferable_starts is not None else None
71 | coreferable_ends = torch.tensor(coreferable_ends, dtype=torch.long) if coreferable_ends is not None else None
72 |
73 | constituent_starts = torch.tensor(constituent_starts, dtype=torch.long) if constituent_starts is not None else None
74 | constituent_ends = torch.tensor(constituent_ends, dtype=torch.long) if constituent_ends is not None else None
75 | constituent_type = None
76 |
77 | return input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, \
78 | is_training, gold_starts, gold_ends, gold_mention_cluster_map, \
79 | coreferable_starts, coreferable_ends, \
80 | constituent_starts, constituent_ends, constituent_type
81 |
82 | def get_tensor_examples(self):
83 | # For each split, return list of tensorized samples to allow variable length input (batch size = 1)
84 | return self.tensor_samples['trn'], self.tensor_samples['dev'], self.tensor_samples['tst']
85 |
86 | def get_stored_info(self):
87 | return self.stored_info
88 |
89 | def get_cache_path(self):
90 | if self.config["dataset"] == "ontonotes":
91 | cache_path = join(self.data_dir, f'cached.tensors.{self.language}.{self.max_seg_len}.{self.max_training_seg}.bin')
92 |
93 | return cache_path
94 |
95 |
96 | class Tensorizer:
97 | def __init__(self, config):
98 | self.config = config
99 | self.tokenizer = AutoTokenizer.from_pretrained(config['bert_tokenizer_name'])
100 |
101 | # Will be used in evaluation
102 | self.stored_info = {}
103 | self.stored_info['tokens'] = {} # {doc_key: ...}
104 | self.stored_info['subtoken_maps'] = {} # {doc_key: ...}; mapping back to tokens
105 | self.stored_info['gold'] = {} # {doc_key: ...}
106 | self.stored_info['genre_dict'] = {genre: idx for idx, genre in enumerate(config['genres'])}
107 | self.stored_info['constituents'] = {}
108 |
109 | def _tensorize_spans(self, spans):
110 | if len(spans) > 0:
111 | starts, ends = zip(*spans)
112 | else:
113 | starts, ends = [], []
114 | return np.array(starts), np.array(ends)
115 |
116 | def _tensorize_span_w_labels(self, spans, label_dict):
117 | if len(spans) > 0:
118 | starts, ends, labels = zip(*spans)
119 | else:
120 | starts, ends, labels = [], [], []
121 | return np.array(starts), np.array(ends), np.array([label_dict[label] for label in labels])
122 |
123 | def _get_speaker_dict(self, speakers):
124 | speaker_dict = {'UNK': 0, '[SPL]': 1}
125 | for speaker in speakers:
126 | if len(speaker_dict) > self.config['max_num_speakers']:
127 | pass # 'break' to limit # speakers
128 | if speaker not in speaker_dict:
129 | speaker_dict[speaker] = len(speaker_dict)
130 | return speaker_dict
131 |
132 | def tensorize_example(self, example, is_training):
133 | # Mentions and clusters
134 | clusters = example['clusters']
135 | gold_mentions = sorted(tuple(mention) for mention in util.flatten(clusters))
136 |
137 | gold_coreferables = sorted(tuple(mention) for mention in example["coreferables"]) if "coreferables" in example else None
138 | gold_constituents = list(tuple(mention) for mention in example["constituents"]) if "constituents" in example else None
139 | gold_constituent_type = list(example["constituent_type"]) if "constituent_type" in example else None
140 |
141 | gold_mention_map = {mention: idx for idx, mention in enumerate(gold_mentions)}
142 | gold_mention_cluster_map = np.zeros(len(gold_mentions)) # 0: no cluster
143 | for cluster_id, cluster in enumerate(clusters):
144 | for mention in cluster:
145 | gold_mention_cluster_map[gold_mention_map[tuple(mention)]] = cluster_id + 1
146 |
147 | # Speakers
148 | speakers = example['speakers']
149 | speaker_dict = self._get_speaker_dict(util.flatten(speakers))
150 |
151 | # Sentences/segments
152 | sentences = example['sentences'] # Segments
153 | sentence_map = example['sentence_map']
154 | num_words = sum([len(s) for s in sentences])
155 | max_sentence_len = self.config['max_segment_len']
156 | sentence_len = np.array([len(s) for s in sentences])
157 |
158 | # Bert input
159 | input_ids, input_mask, speaker_ids = [], [], []
160 | for idx, (sent_tokens, sent_speakers) in enumerate(zip(sentences, speakers)):
161 | sent_input_ids = self.tokenizer.convert_tokens_to_ids(sent_tokens)
162 | sent_input_mask = [1] * len(sent_input_ids)
163 | sent_speaker_ids = [speaker_dict[speaker] for speaker in sent_speakers]
164 | while len(sent_input_ids) < max_sentence_len:
165 | sent_input_ids.append(0)
166 | sent_input_mask.append(0)
167 | sent_speaker_ids.append(0)
168 | input_ids.append(sent_input_ids)
169 | input_mask.append(sent_input_mask)
170 | speaker_ids.append(sent_speaker_ids)
171 | input_ids = np.array(input_ids)
172 | input_mask = np.array(input_mask)
173 | speaker_ids = np.array(speaker_ids)
174 | assert num_words == np.sum(input_mask), (num_words, np.sum(input_mask))
175 |
176 | # Keep info to store
177 | doc_key = example['doc_key']
178 | self.stored_info['subtoken_maps'][doc_key] = example.get('subtoken_map', None)
179 | self.stored_info['gold'][doc_key] = example['clusters']
180 | # self.stored_info['constituents'][doc_key] = example['constituents']
181 | # self.stored_info['tokens'][doc_key] = example['tokens']
182 |
183 | # Construct example
184 | genre = self.stored_info['genre_dict'].get(doc_key[:2], 0)
185 | gold_starts, gold_ends = self._tensorize_spans(gold_mentions)
186 | coreferable_starts, coreferable_ends = self._tensorize_spans(gold_coreferables) if gold_coreferables is not None else (None, None)
187 | constituent_starts, constituent_ends = self._tensorize_spans(gold_constituents) if gold_constituents is not None else (None, None)
188 | constituent_type = np.array(gold_constituent_type) if gold_constituent_type is not None else None
189 |
190 | example_tensor = (input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, is_training,
191 | gold_starts, gold_ends, gold_mention_cluster_map, coreferable_starts, coreferable_ends,
192 | constituent_starts, constituent_ends, constituent_type)
193 | if is_training and len(sentences) > self.config['max_training_sentences']:
194 | return doc_key, self.truncate_example(*example_tensor, max_sentences=self.config['max_training_sentences'])
195 | else:
196 | return doc_key, example_tensor
197 |
198 | def truncate_example(self, input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, is_training,
199 | gold_starts, gold_ends, gold_mention_cluster_map, coreferable_starts, coreferable_ends,
200 | constituent_starts, constituent_ends, constituent_type,
201 | max_sentences, sentence_offset=None):
202 | num_sentences = input_ids.shape[0]
203 | assert num_sentences > max_sentences
204 |
205 | sent_offset = sentence_offset
206 | if sent_offset is None:
207 | sent_offset = random.randint(0, num_sentences - max_sentences)
208 | word_offset = sentence_len[:sent_offset].sum()
209 | num_words = sentence_len[sent_offset: sent_offset + max_sentences].sum()
210 |
211 | input_ids = input_ids[sent_offset: sent_offset + max_sentences, :]
212 | input_mask = input_mask[sent_offset: sent_offset + max_sentences, :]
213 | speaker_ids = speaker_ids[sent_offset: sent_offset + max_sentences, :]
214 | sentence_len = sentence_len[sent_offset: sent_offset + max_sentences]
215 |
216 | sentence_map = sentence_map[word_offset: word_offset + num_words]
217 |
218 | gold_spans = (gold_starts < word_offset + num_words) & (gold_ends >= word_offset)
219 | gold_starts = gold_starts[gold_spans] - word_offset
220 | gold_ends = gold_ends[gold_spans] - word_offset
221 | gold_mention_cluster_map = gold_mention_cluster_map[gold_spans]
222 |
223 | coreferable_flags = (coreferable_starts < word_offset + num_words) & (coreferable_ends >= word_offset) if coreferable_starts is not None else None
224 | coreferable_starts = coreferable_starts[coreferable_flags] - word_offset if coreferable_starts is not None else None
225 | coreferable_ends = coreferable_ends[coreferable_flags] - word_offset if coreferable_starts is not None else None
226 |
227 | constituent_flags = (constituent_starts < word_offset + num_words) & (constituent_ends >= word_offset) if constituent_starts is not None else None
228 | constituent_starts = constituent_starts[constituent_flags] - word_offset if constituent_starts is not None else None
229 | constituent_ends = constituent_ends[constituent_flags] - word_offset if constituent_starts is not None else None
230 | constituent_type = constituent_type[constituent_flags] if constituent_type is not None else None
231 |
232 | return input_ids, input_mask, speaker_ids, sentence_len, genre, sentence_map, \
233 | is_training, gold_starts, gold_ends, gold_mention_cluster_map, coreferable_starts, coreferable_ends, \
234 | constituent_starts, constituent_ends, constituent_type
235 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | from os import makedirs
2 | from os.path import join
3 | import numpy as np
4 | import pyhocon
5 | import logging
6 | import torch
7 | import random
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def flatten(l):
13 | return [item for sublist in l for item in sublist]
14 |
15 |
16 | def initialize_config(config_name, task="coref"):
17 | logger.info("Running {} experiment: {}".format(task, config_name))
18 |
19 | if task == "coref":
20 | config = pyhocon.ConfigFactory.parse_file("experiments.conf")[config_name]
21 | elif task == "srl":
22 | config = pyhocon.ConfigFactory.parse_file("experiments_srl.conf")[config_name]
23 |
24 | config['log_dir'] = join(config["log_root"], config_name)
25 | makedirs(config['log_dir'], exist_ok=True)
26 |
27 | config['tb_dir'] = join(config['log_root'], 'tensorboard')
28 | makedirs(config['tb_dir'], exist_ok=True)
29 |
30 | logger.info(pyhocon.HOCONConverter.convert(config, "hocon"))
31 | return config
32 |
33 |
34 | def set_seed(seed, set_gpu=True):
35 | random.seed(seed)
36 | np.random.seed(seed)
37 | torch.manual_seed(seed)
38 | if set_gpu and torch.cuda.is_available():
39 | # Necessary for reproducibility; lower performance
40 | torch.backends.cudnn.deterministic = True
41 | torch.backends.cudnn.benchmark = False
42 | torch.cuda.manual_seed_all(seed)
43 | logger.info('Random seed is set to %d' % seed)
44 |
45 |
46 | def bucket_distance(offsets):
47 | """ offsets: [num spans1, num spans2] """
48 | # 10 semi-logscale bin: 0, 1, 2, 3, 4, (5-7)->5, (8-15)->6, (16-31)->7, (32-63)->8, (64+)->9
49 | logspace_distance = torch.log2(offsets.to(torch.float)).to(torch.long) + 3
50 | identity_mask = (offsets <= 4).to(torch.long)
51 | combined_distance = identity_mask * offsets + (1 - identity_mask) * logspace_distance
52 | combined_distance = torch.clamp(combined_distance, 0, 9)
53 | return combined_distance
54 |
55 |
56 | def batch_select(tensor, idx, device=torch.device('cpu')):
57 | """ Do selection per row (first axis). """
58 | assert tensor.shape[0] == idx.shape[0] # Same size of first dim
59 | dim0_size, dim1_size = tensor.shape[0], tensor.shape[1]
60 |
61 | tensor = torch.reshape(tensor, [dim0_size * dim1_size, -1])
62 | idx_offset = torch.unsqueeze(torch.arange(0, dim0_size, device=device) * dim1_size, 1)
63 | new_idx = idx + idx_offset
64 | selected = tensor[new_idx]
65 |
66 | if tensor.shape[-1] == 1: # If selected element is scalar, restore original dim
67 | selected = torch.squeeze(selected, -1)
68 |
69 | return selected
70 |
71 |
72 | def batch_add(tensor, idx, val, device=torch.device('cpu')):
73 | """ Do addition per row (first axis). """
74 | assert tensor.shape[0] == idx.shape[0] # Same size of first dim
75 | dim0_size, dim1_size = tensor.shape[0], tensor.shape[1]
76 |
77 | tensor = torch.reshape(tensor, [dim0_size * dim1_size, -1])
78 | idx_offset = torch.unsqueeze(torch.arange(0, dim0_size, device=device) * dim1_size, 1)
79 | new_idx = idx + idx_offset
80 |
81 | val = val.reshape(val.size(0) * val.size(1), -1)
82 |
83 | res = tensor.index_add(0, new_idx.view(-1), val).reshape([dim0_size, dim1_size, -1])
84 |
85 | if tensor.shape[-1] == 1: # If selected element is scalar, restore original dim
86 | res = res.squeeze(-1)
87 |
88 | return res
89 |
90 |
91 | def sample_subset(w, k, t=0.1):
92 | '''
93 | Args:
94 | w (Tensor): Float Tensor of weights for each element. In gumbel mode
95 | these are interpreted as log probabilities
96 | k (int): number of elements in the subset sample
97 | t (float): temperature of the softmax
98 | '''
99 | wg = gumbel_perturb(w)
100 | return continuous_topk(wg, k, t)
101 |
102 |
103 | def logsumexp(tensor: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
104 | max_score, _ = tensor.max(dim, keepdim=keepdim)
105 | if keepdim:
106 | stable_vec = tensor - max_score
107 | else:
108 | stable_vec = tensor - max_score.unsqueeze(dim)
109 |
110 | return max_score + stable_vec.logsumexp(dim, keepdim=keepdim) # (stable_vec.exp().sum(dim, keepdim=keepdim)).log()
111 |
112 |
113 | def clip_to_01(x: torch.Tensor):
114 | eps = torch.finfo(x.dtype).eps
115 | tiny = torch.finfo(x.dtype).tiny
116 |
117 | x_val = x.detach()
118 |
119 | cond_greater = (x_val > (1.0 - eps))
120 | diff_greater = (x_val - 1.0 + eps)
121 |
122 | cond_less = (x < tiny)
123 | diff_less = (tiny - x_val)
124 |
125 | x -= diff_greater * cond_greater
126 | x += diff_less * cond_less
127 |
128 | return x
129 |
130 |
131 | def log1mexp(x):
132 | x -= torch.finfo(x.dtype).eps
133 | return torch.where(x > -0.693, (-torch.expm1(x)).log(), torch.log1p(-(x.exp())))
134 |
135 |
136 | def gumbel_perturb(w):
137 | uniform_01 = torch.distributions.uniform.Uniform(1e-6, 1.0)
138 | # sample some gumbels
139 | u = uniform_01.sample(w.size()).cuda()
140 | g = -torch.log(-torch.log(u))
141 | w = w + g
142 | return w
143 |
144 |
145 | def continuous_topk(w, k, t):
146 | khot_list = []
147 | onehot_approx = torch.zeros_like(w)
148 |
149 | for i in range(k):
150 | khot_mask = torch.clamp(1.0 - onehot_approx, min=1e-6)
151 | w += torch.log(khot_mask)
152 | onehot_approx = F.softmax(w / t, dim=-1)
153 | khot_list.append(onehot_approx)
154 |
155 | return torch.stack(khot_list, dim=0)
156 |
157 |
158 | def stripe(x, n, w, offset=(0, 0), horizontal=1):
159 | r"""
160 | Returns a diagonal stripe of the tensor.
161 | Args:
162 | x (~torch.Tensor): the input tensor with 2 or more dims.
163 | n (int): the length of the stripe.
164 | w (int): the width of the stripe.
165 | offset (tuple): the offset of the first two dims.
166 | dim (int): 1 if returns a horizontal stripe; 0 otherwise.
167 | Returns:
168 | a diagonal stripe of the tensor.
169 | Examples:
170 | >>> x = torch.arange(25).view(5, 5)
171 | >>> x
172 | tensor([[ 0, 1, 2, 3, 4],
173 | [ 5, 6, 7, 8, 9],
174 | [10, 11, 12, 13, 14],
175 | [15, 16, 17, 18, 19],
176 | [20, 21, 22, 23, 24]])
177 | >>> stripe(x, 2, 3)
178 | tensor([[0, 1, 2],
179 | [6, 7, 8]])
180 | >>> stripe(x, 2, 3, (1, 1))
181 | tensor([[ 6, 7, 8],
182 | [12, 13, 14]])
183 | >>> stripe(x, 2, 3, (1, 1), 0)
184 | tensor([[ 6, 11, 16],
185 | [12, 17, 22]])
186 | """
187 | x, seq_len = x.contiguous(), x.size(1)
188 | stride, numel = list(x.stride()), x[0, 0].numel()
189 | stride[0] = (seq_len + 1) * numel
190 | stride[1] = (1 if horizontal == 1 else seq_len) * numel
191 |
192 | return x.as_strided(
193 | size=(n, w, *x.shape[2:]),
194 | stride=stride,
195 | storage_offset=(offset[0]*seq_len+offset[1])*numel
196 | )
197 |
198 |
199 | def masked_topk_non_overlap(
200 | span_scores,
201 | span_mask,
202 | num_spans_to_keep,
203 | spans,
204 | non_crossing=True
205 | ):
206 | sorted_scores, sorted_indices = torch.sort(span_scores + span_mask.log(), descending=True)
207 | sorted_indices = sorted_indices.tolist()
208 | spans = spans.tolist()
209 |
210 | if not non_crossing:
211 | selected_candidate_idx = sorted(sorted_indices[:num_spans_to_keep], key=lambda idx: (spans[idx][0], spans[idx][1]))
212 | selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long)
213 | return selected_candidate_idx
214 |
215 | selected_candidate_idx = []
216 | start_to_max_end, end_to_min_start = {}, {}
217 | for candidate_idx in sorted_indices:
218 | if len(selected_candidate_idx) >= num_spans_to_keep:
219 | break
220 | # Perform overlapping check
221 | span_start_idx = spans[candidate_idx][0]
222 | span_end_idx = spans[candidate_idx][1]
223 | cross_overlap = False
224 | for token_idx in range(span_start_idx, span_end_idx + 1):
225 | max_end = start_to_max_end.get(token_idx, -1)
226 | if token_idx > span_start_idx and max_end > span_end_idx:
227 | cross_overlap = True
228 | break
229 | min_start = end_to_min_start.get(token_idx, -1)
230 | if token_idx < span_end_idx and 0 <= min_start < span_start_idx:
231 | cross_overlap = True
232 | break
233 | if not cross_overlap:
234 | # Pass check; select idx and update dict stats
235 | selected_candidate_idx.append(candidate_idx)
236 | max_end = start_to_max_end.get(span_start_idx, -1)
237 | if span_end_idx > max_end:
238 | start_to_max_end[span_start_idx] = span_end_idx
239 | min_start = end_to_min_start.get(span_end_idx, -1)
240 | if min_start == -1 or span_start_idx < min_start:
241 | end_to_min_start[span_end_idx] = span_start_idx
242 | # Sort selected candidates by span idx
243 | selected_candidate_idx = sorted(selected_candidate_idx, key=lambda idx: (spans[idx][0], spans[idx][1]))
244 | selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long)
245 |
246 | return selected_candidate_idx
247 |
--------------------------------------------------------------------------------