├── LICENSE
├── README.md
├── bin
├── blinkify.py
├── eval_bleu.py
├── inspect_.py
├── patch_legacy_checkpoint.py
├── predict_amrs.py
├── predict_amrs_from_plaintext.py
├── predict_sentences.py
└── train.py
├── configs
└── config.yaml
├── data
└── vocab
│ ├── additions.txt
│ ├── predicates.txt
│ └── recategorizations.txt
├── docs
├── appendix.pdf
├── camera-ready.pdf
└── preprint.pdf
├── requirements.txt
├── sample.txt
├── setup.py
└── spring_amr
├── IO.py
├── __init__.py
├── dataset.py
├── entities.py
├── evaluation.py
├── linearization.py
├── modeling_bart.py
├── optim.py
├── penman.py
├── postprocessing.py
├── tokenization_bart.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | =======================================================================
2 |
3 | Attribution-NonCommercial-ShareAlike 4.0 International
4 |
5 | =======================================================================
6 |
7 | Creative Commons Corporation ("Creative Commons") is not a law firm and
8 | does not provide legal services or legal advice. Distribution of
9 | Creative Commons public licenses does not create a lawyer-client or
10 | other relationship. Creative Commons makes its licenses and related
11 | information available on an "as-is" basis. Creative Commons gives no
12 | warranties regarding its licenses, any material licensed under their
13 | terms and conditions, or any related information. Creative Commons
14 | disclaims all liability for damages resulting from their use to the
15 | fullest extent possible.
16 |
17 | Using Creative Commons Public Licenses
18 |
19 | Creative Commons public licenses provide a standard set of terms and
20 | conditions that creators and other rights holders may use to share
21 | original works of authorship and other material subject to copyright
22 | and certain other rights specified in the public license below. The
23 | following considerations are for informational purposes only, are not
24 | exhaustive, and do not form part of our licenses.
25 |
26 | Considerations for licensors: Our public licenses are
27 | intended for use by those authorized to give the public
28 | permission to use material in ways otherwise restricted by
29 | copyright and certain other rights. Our licenses are
30 | irrevocable. Licensors should read and understand the terms
31 | and conditions of the license they choose before applying it.
32 | Licensors should also secure all rights necessary before
33 | applying our licenses so that the public can reuse the
34 | material as expected. Licensors should clearly mark any
35 | material not subject to the license. This includes other CC-
36 | licensed material, or material used under an exception or
37 | limitation to copyright. More considerations for licensors:
38 | wiki.creativecommons.org/Considerations_for_licensors
39 |
40 | Considerations for the public: By using one of our public
41 | licenses, a licensor grants the public permission to use the
42 | licensed material under specified terms and conditions. If
43 | the licensor's permission is not necessary for any reason--for
44 | example, because of any applicable exception or limitation to
45 | copyright--then that use is not regulated by the license. Our
46 | licenses grant only permissions under copyright and certain
47 | other rights that a licensor has authority to grant. Use of
48 | the licensed material may still be restricted for other
49 | reasons, including because others have copyright or other
50 | rights in the material. A licensor may make special requests,
51 | such as asking that all changes be marked or described.
52 | Although not required by our licenses, you are encouraged to
53 | respect those requests where reasonable. More considerations
54 | for the public:
55 | wiki.creativecommons.org/Considerations_for_licensees
56 |
57 | =======================================================================
58 |
59 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
60 | Public License
61 |
62 | By exercising the Licensed Rights (defined below), You accept and agree
63 | to be bound by the terms and conditions of this Creative Commons
64 | Attribution-NonCommercial-ShareAlike 4.0 International Public License
65 | ("Public License"). To the extent this Public License may be
66 | interpreted as a contract, You are granted the Licensed Rights in
67 | consideration of Your acceptance of these terms and conditions, and the
68 | Licensor grants You such rights in consideration of benefits the
69 | Licensor receives from making the Licensed Material available under
70 | these terms and conditions.
71 |
72 |
73 | Section 1 -- Definitions.
74 |
75 | a. Adapted Material means material subject to Copyright and Similar
76 | Rights that is derived from or based upon the Licensed Material
77 | and in which the Licensed Material is translated, altered,
78 | arranged, transformed, or otherwise modified in a manner requiring
79 | permission under the Copyright and Similar Rights held by the
80 | Licensor. For purposes of this Public License, where the Licensed
81 | Material is a musical work, performance, or sound recording,
82 | Adapted Material is always produced where the Licensed Material is
83 | synched in timed relation with a moving image.
84 |
85 | b. Adapter's License means the license You apply to Your Copyright
86 | and Similar Rights in Your contributions to Adapted Material in
87 | accordance with the terms and conditions of this Public License.
88 |
89 | c. BY-NC-SA Compatible License means a license listed at
90 | creativecommons.org/compatiblelicenses, approved by Creative
91 | Commons as essentially the equivalent of this Public License.
92 |
93 | d. Copyright and Similar Rights means copyright and/or similar rights
94 | closely related to copyright including, without limitation,
95 | performance, broadcast, sound recording, and Sui Generis Database
96 | Rights, without regard to how the rights are labeled or
97 | categorized. For purposes of this Public License, the rights
98 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
99 | Rights.
100 |
101 | e. Effective Technological Measures means those measures that, in the
102 | absence of proper authority, may not be circumvented under laws
103 | fulfilling obligations under Article 11 of the WIPO Copyright
104 | Treaty adopted on December 20, 1996, and/or similar international
105 | agreements.
106 |
107 | f. Exceptions and Limitations means fair use, fair dealing, and/or
108 | any other exception or limitation to Copyright and Similar Rights
109 | that applies to Your use of the Licensed Material.
110 |
111 | g. License Elements means the license attributes listed in the name
112 | of a Creative Commons Public License. The License Elements of this
113 | Public License are Attribution, NonCommercial, and ShareAlike.
114 |
115 | h. Licensed Material means the artistic or literary work, database,
116 | or other material to which the Licensor applied this Public
117 | License.
118 |
119 | i. Licensed Rights means the rights granted to You subject to the
120 | terms and conditions of this Public License, which are limited to
121 | all Copyright and Similar Rights that apply to Your use of the
122 | Licensed Material and that the Licensor has authority to license.
123 |
124 | j. Licensor means the individual(s) or entity(ies) granting rights
125 | under this Public License.
126 |
127 | k. NonCommercial means not primarily intended for or directed towards
128 | commercial advantage or monetary compensation. For purposes of
129 | this Public License, the exchange of the Licensed Material for
130 | other material subject to Copyright and Similar Rights by digital
131 | file-sharing or similar means is NonCommercial provided there is
132 | no payment of monetary compensation in connection with the
133 | exchange.
134 |
135 | l. Share means to provide material to the public by any means or
136 | process that requires permission under the Licensed Rights, such
137 | as reproduction, public display, public performance, distribution,
138 | dissemination, communication, or importation, and to make material
139 | available to the public including in ways that members of the
140 | public may access the material from a place and at a time
141 | individually chosen by them.
142 |
143 | m. Sui Generis Database Rights means rights other than copyright
144 | resulting from Directive 96/9/EC of the European Parliament and of
145 | the Council of 11 March 1996 on the legal protection of databases,
146 | as amended and/or succeeded, as well as other essentially
147 | equivalent rights anywhere in the world.
148 |
149 | n. You means the individual or entity exercising the Licensed Rights
150 | under this Public License. Your has a corresponding meaning.
151 |
152 |
153 | Section 2 -- Scope.
154 |
155 | a. License grant.
156 |
157 | 1. Subject to the terms and conditions of this Public License,
158 | the Licensor hereby grants You a worldwide, royalty-free,
159 | non-sublicensable, non-exclusive, irrevocable license to
160 | exercise the Licensed Rights in the Licensed Material to:
161 |
162 | a. reproduce and Share the Licensed Material, in whole or
163 | in part, for NonCommercial purposes only; and
164 |
165 | b. produce, reproduce, and Share Adapted Material for
166 | NonCommercial purposes only.
167 |
168 | 2. Exceptions and Limitations. For the avoidance of doubt, where
169 | Exceptions and Limitations apply to Your use, this Public
170 | License does not apply, and You do not need to comply with
171 | its terms and conditions.
172 |
173 | 3. Term. The term of this Public License is specified in Section
174 | 6(a).
175 |
176 | 4. Media and formats; technical modifications allowed. The
177 | Licensor authorizes You to exercise the Licensed Rights in
178 | all media and formats whether now known or hereafter created,
179 | and to make technical modifications necessary to do so. The
180 | Licensor waives and/or agrees not to assert any right or
181 | authority to forbid You from making technical modifications
182 | necessary to exercise the Licensed Rights, including
183 | technical modifications necessary to circumvent Effective
184 | Technological Measures. For purposes of this Public License,
185 | simply making modifications authorized by this Section 2(a)
186 | (4) never produces Adapted Material.
187 |
188 | 5. Downstream recipients.
189 |
190 | a. Offer from the Licensor -- Licensed Material. Every
191 | recipient of the Licensed Material automatically
192 | receives an offer from the Licensor to exercise the
193 | Licensed Rights under the terms and conditions of this
194 | Public License.
195 |
196 | b. Additional offer from the Licensor -- Adapted Material.
197 | Every recipient of Adapted Material from You
198 | automatically receives an offer from the Licensor to
199 | exercise the Licensed Rights in the Adapted Material
200 | under the conditions of the Adapter's License You apply.
201 |
202 | c. No downstream restrictions. You may not offer or impose
203 | any additional or different terms or conditions on, or
204 | apply any Effective Technological Measures to, the
205 | Licensed Material if doing so restricts exercise of the
206 | Licensed Rights by any recipient of the Licensed
207 | Material.
208 |
209 | 6. No endorsement. Nothing in this Public License constitutes or
210 | may be construed as permission to assert or imply that You
211 | are, or that Your use of the Licensed Material is, connected
212 | with, or sponsored, endorsed, or granted official status by,
213 | the Licensor or others designated to receive attribution as
214 | provided in Section 3(a)(1)(A)(i).
215 |
216 | b. Other rights.
217 |
218 | 1. Moral rights, such as the right of integrity, are not
219 | licensed under this Public License, nor are publicity,
220 | privacy, and/or other similar personality rights; however, to
221 | the extent possible, the Licensor waives and/or agrees not to
222 | assert any such rights held by the Licensor to the limited
223 | extent necessary to allow You to exercise the Licensed
224 | Rights, but not otherwise.
225 |
226 | 2. Patent and trademark rights are not licensed under this
227 | Public License.
228 |
229 | 3. To the extent possible, the Licensor waives any right to
230 | collect royalties from You for the exercise of the Licensed
231 | Rights, whether directly or through a collecting society
232 | under any voluntary or waivable statutory or compulsory
233 | licensing scheme. In all other cases the Licensor expressly
234 | reserves any right to collect such royalties, including when
235 | the Licensed Material is used other than for NonCommercial
236 | purposes.
237 |
238 |
239 | Section 3 -- License Conditions.
240 |
241 | Your exercise of the Licensed Rights is expressly made subject to the
242 | following conditions.
243 |
244 | a. Attribution.
245 |
246 | 1. If You Share the Licensed Material (including in modified
247 | form), You must:
248 |
249 | a. retain the following if it is supplied by the Licensor
250 | with the Licensed Material:
251 |
252 | i. identification of the creator(s) of the Licensed
253 | Material and any others designated to receive
254 | attribution, in any reasonable manner requested by
255 | the Licensor (including by pseudonym if
256 | designated);
257 |
258 | ii. a copyright notice;
259 |
260 | iii. a notice that refers to this Public License;
261 |
262 | iv. a notice that refers to the disclaimer of
263 | warranties;
264 |
265 | v. a URI or hyperlink to the Licensed Material to the
266 | extent reasonably practicable;
267 |
268 | b. indicate if You modified the Licensed Material and
269 | retain an indication of any previous modifications; and
270 |
271 | c. indicate the Licensed Material is licensed under this
272 | Public License, and include the text of, or the URI or
273 | hyperlink to, this Public License.
274 |
275 | 2. You may satisfy the conditions in Section 3(a)(1) in any
276 | reasonable manner based on the medium, means, and context in
277 | which You Share the Licensed Material. For example, it may be
278 | reasonable to satisfy the conditions by providing a URI or
279 | hyperlink to a resource that includes the required
280 | information.
281 | 3. If requested by the Licensor, You must remove any of the
282 | information required by Section 3(a)(1)(A) to the extent
283 | reasonably practicable.
284 |
285 | b. ShareAlike.
286 |
287 | In addition to the conditions in Section 3(a), if You Share
288 | Adapted Material You produce, the following conditions also apply.
289 |
290 | 1. The Adapter's License You apply must be a Creative Commons
291 | license with the same License Elements, this version or
292 | later, or a BY-NC-SA Compatible License.
293 |
294 | 2. You must include the text of, or the URI or hyperlink to, the
295 | Adapter's License You apply. You may satisfy this condition
296 | in any reasonable manner based on the medium, means, and
297 | context in which You Share Adapted Material.
298 |
299 | 3. You may not offer or impose any additional or different terms
300 | or conditions on, or apply any Effective Technological
301 | Measures to, Adapted Material that restrict exercise of the
302 | rights granted under the Adapter's License You apply.
303 |
304 |
305 | Section 4 -- Sui Generis Database Rights.
306 |
307 | Where the Licensed Rights include Sui Generis Database Rights that
308 | apply to Your use of the Licensed Material:
309 |
310 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
311 | to extract, reuse, reproduce, and Share all or a substantial
312 | portion of the contents of the database for NonCommercial purposes
313 | only;
314 |
315 | b. if You include all or a substantial portion of the database
316 | contents in a database in which You have Sui Generis Database
317 | Rights, then the database in which You have Sui Generis Database
318 | Rights (but not its individual contents) is Adapted Material,
319 | including for purposes of Section 3(b); and
320 |
321 | c. You must comply with the conditions in Section 3(a) if You Share
322 | all or a substantial portion of the contents of the database.
323 |
324 | For the avoidance of doubt, this Section 4 supplements and does not
325 | replace Your obligations under this Public License where the Licensed
326 | Rights include other Copyright and Similar Rights.
327 |
328 |
329 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
330 |
331 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
332 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
333 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
334 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
335 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
336 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
337 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
338 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
339 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
340 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
341 |
342 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
343 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
344 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
345 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
346 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
347 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
348 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
349 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
350 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
351 |
352 | c. The disclaimer of warranties and limitation of liability provided
353 | above shall be interpreted in a manner that, to the extent
354 | possible, most closely approximates an absolute disclaimer and
355 | waiver of all liability.
356 |
357 |
358 | Section 6 -- Term and Termination.
359 |
360 | a. This Public License applies for the term of the Copyright and
361 | Similar Rights licensed here. However, if You fail to comply with
362 | this Public License, then Your rights under this Public License
363 | terminate automatically.
364 |
365 | b. Where Your right to use the Licensed Material has terminated under
366 | Section 6(a), it reinstates:
367 |
368 | 1. automatically as of the date the violation is cured, provided
369 | it is cured within 30 days of Your discovery of the
370 | violation; or
371 |
372 | 2. upon express reinstatement by the Licensor.
373 |
374 | For the avoidance of doubt, this Section 6(b) does not affect any
375 | right the Licensor may have to seek remedies for Your violations
376 | of this Public License.
377 |
378 | c. For the avoidance of doubt, the Licensor may also offer the
379 | Licensed Material under separate terms or conditions or stop
380 | distributing the Licensed Material at any time; however, doing so
381 | will not terminate this Public License.
382 |
383 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
384 | License.
385 |
386 |
387 | Section 7 -- Other Terms and Conditions.
388 |
389 | a. The Licensor shall not be bound by any additional or different
390 | terms or conditions communicated by You unless expressly agreed.
391 |
392 | b. Any arrangements, understandings, or agreements regarding the
393 | Licensed Material not stated herein are separate from and
394 | independent of the terms and conditions of this Public License.
395 |
396 |
397 | Section 8 -- Interpretation.
398 |
399 | a. For the avoidance of doubt, this Public License does not, and
400 | shall not be interpreted to, reduce, limit, restrict, or impose
401 | conditions on any use of the Licensed Material that could lawfully
402 | be made without permission under this Public License.
403 |
404 | b. To the extent possible, if any provision of this Public License is
405 | deemed unenforceable, it shall be automatically reformed to the
406 | minimum extent necessary to make it enforceable. If the provision
407 | cannot be reformed, it shall be severed from this Public License
408 | without affecting the enforceability of the remaining terms and
409 | conditions.
410 |
411 | c. No term or condition of this Public License will be waived and no
412 | failure to comply consented to unless expressly agreed to by the
413 | Licensor.
414 |
415 | d. Nothing in this Public License constitutes or may be interpreted
416 | as a limitation upon, or waiver of, any privileges and immunities
417 | that apply to the Licensor or You, including from the legal
418 | processes of any jurisdiction or authority.
419 |
420 | =======================================================================
421 |
422 | Creative Commons is not a party to its public
423 | licenses. Notwithstanding, Creative Commons may elect to apply one of
424 | its public licenses to material it publishes and in those instances
425 | will be considered the “Licensor.” The text of the Creative Commons
426 | public licenses is dedicated to the public domain under the CC0 Public
427 | Domain Dedication. Except for the limited purpose of indicating that
428 | material is shared under a Creative Commons public license or as
429 | otherwise permitted by the Creative Commons policies published at
430 | creativecommons.org/policies, Creative Commons does not authorize the
431 | use of the trademark "Creative Commons" or any other trademark or logo
432 | of Creative Commons without its prior written consent including,
433 | without limitation, in connection with any unauthorized modifications
434 | to any of its public licenses or any other arrangements,
435 | understandings, or agreements concerning use of licensed material. For
436 | the avoidance of doubt, this paragraph does not form part of the
437 | public licenses.
438 |
439 | Creative Commons may be contacted at creativecommons.org.
440 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SPRING
2 |
3 |
4 | [](https://paperswithcode.com/sota/amr-parsing-on-ldc2017t10?p=one-spring-to-rule-them-both-symmetric-amr)
5 |
6 | [](https://paperswithcode.com/sota/amr-parsing-on-ldc2020t02?p=one-spring-to-rule-them-both-symmetric-amr)
7 |
8 | [](https://paperswithcode.com/sota/amr-to-text-generation-on-ldc2017t10?p=one-spring-to-rule-them-both-symmetric-amr)
9 |
10 | [](https://paperswithcode.com/sota/amr-to-text-generation-on-ldc2020t02?p=one-spring-to-rule-them-both-symmetric-amr)
11 |
12 | This is the repo for [SPRING (*Symmetric ParsIng aNd Generation*)](https://ojs.aaai.org/index.php/AAAI/article/view/17489), a novel approach to semantic parsing and generation, presented at AAAI 2021.
13 |
14 | With SPRING you can perform both state-of-the-art Text-to-AMR parsing and AMR-to-Text generation without many cumbersome external components.
15 | If you use the code, please reference this work in your paper:
16 |
17 | ```
18 | @inproceedings{bevilacqua-etal-2021-one,
19 | title = {One {SPRING} to Rule Them Both: {S}ymmetric {AMR} Semantic Parsing and Generation without a Complex Pipeline},
20 | author = {Bevilacqua, Michele and Blloshmi, Rexhina and Navigli, Roberto},
21 | booktitle = {Proceedings of AAAI},
22 | year = {2021}
23 | }
24 | ```
25 |
26 | ## Pretrained Checkpoints
27 |
28 | Here we release our best SPRING models which are based on the DFS linearization.
29 |
30 | ### Text-to-AMR Parsing
31 | - Model trained in the AMR 2.0 training set: AMR2.parsing-1.0.tar.bz2
32 |
33 | - Model trained in the AMR 3.0 training set: [AMR3.parsing-1.0.tar.bz2](http://nlp.uniroma1.it/AMR/AMR3.parsing-1.0.tar.bz2)
34 |
35 | ### AMR-to-Text Generation
36 | - Model trained in the AMR 2.0 training set: [AMR2.generation-1.0.tar.bz2](http://nlp.uniroma1.it/AMR/AMR2.generation-1.0.tar.bz2)
37 |
38 | - Model trained in the AMR 3.0 training set: [AMR3.generation-1.0.tar.bz2](http://nlp.uniroma1.it/AMR/AMR3.generation-1.0.tar.bz2)
39 |
40 |
41 | If you need the checkpoints of other experiments in the paper, please send us an email.
42 |
43 | ## Installation
44 | ```shell script
45 | cd spring
46 | pip install -r requirements.txt
47 | pip install -e .
48 | ```
49 |
50 | The code only works with `transformers` < 3.0 because of a disrupting change in positional embeddings.
51 | The code works fine with `torch` 1.5. We recommend the usage of a new `conda` env.
52 |
53 | ## Train
54 | Modify `config.yaml` in `configs`. Instructions in comments within the file. Also see the [appendix](docs/appendix.pdf).
55 |
56 | ### Text-to-AMR
57 | ```shell script
58 | python bin/train.py --config configs/config.yaml --direction amr
59 | ```
60 | Results in `runs/`
61 |
62 | ### AMR-to-Text
63 | ```shell script
64 | python bin/train.py --config configs/config.yaml --direction text
65 | ```
66 | Results in `runs/`
67 |
68 | ## Evaluate
69 | ### Text-to-AMR
70 | ```shell script
71 | python bin/predict_amrs.py \
72 | --datasets /data/amrs/split/test/*.txt \
73 | --gold-path data/tmp/amr2.0/gold.amr.txt \
74 | --pred-path data/tmp/amr2.0/pred.amr.txt \
75 | --checkpoint runs/.pt \
76 | --beam-size 5 \
77 | --batch-size 500 \
78 | --device cuda \
79 | --penman-linearization --use-pointer-tokens
80 | ```
81 | `gold.amr.txt` and `pred.amr.txt` will contain, respectively, the concatenated gold and the predictions.
82 |
83 | To reproduce our paper's results, you will also need need to run the [BLINK](https://github.com/facebookresearch/BLINK)
84 | entity linking system on the prediction file (`data/tmp/amr2.0/pred.amr.txt` in the previous code snippet).
85 | To do so, you will need to install BLINK, and download their models:
86 | ```shell script
87 | git clone https://github.com/facebookresearch/BLINK.git
88 | cd BLINK
89 | pip install -r requirements.txt
90 | sh download_blink_models.sh
91 | cd models
92 | wget http://dl.fbaipublicfiles.com/BLINK//faiss_flat_index.pkl
93 | cd ../..
94 | ```
95 | Then, you will be able to launch the `blinkify.py` script:
96 | ```shell
97 | python bin/blinkify.py \
98 | --datasets data/tmp/amr2.0/pred.amr.txt \
99 | --out data/tmp/amr2.0/pred.amr.blinkified.txt \
100 | --device cuda \
101 | --blink-models-dir BLINK/models
102 | ```
103 | To have comparable Smatch scores you will also need to use the scripts available at https://github.com/mdtux89/amr-evaluation, which provide
104 | results that are around ~0.3 Smatch points lower than those returned by `bin/predict_amrs.py`.
105 |
106 | ### AMR-to-Text
107 | ```shell script
108 | python bin/predict_sentences.py \
109 | --datasets /data/amrs/split/test/*.txt \
110 | --gold-path data/tmp/amr2.0/gold.text.txt \
111 | --pred-path data/tmp/amr2.0/pred.text.txt \
112 | --checkpoint runs/.pt \
113 | --beam-size 5 \
114 | --batch-size 500 \
115 | --device cuda \
116 | --penman-linearization --use-pointer-tokens
117 | ```
118 | `gold.text.txt` and `pred.text.txt` will contain, respectively, the concatenated gold and the predictions.
119 | For BLEU, chrF++, and Meteor in order to be comparable you will need to tokenize both gold and predictions using [JAMR tokenizer](https://github.com/redpony/cdec/blob/master/corpus/tokenize-anything.sh).
120 | To compute BLEU and chrF++, please use `bin/eval_bleu.py`. For METEOR, use https://www.cs.cmu.edu/~alavie/METEOR/ .
121 | For BLEURT don't use tokenization and run the eval with `https://github.com/google-research/bleurt`. Also see the [appendix](docs/appendix.pdf).
122 |
123 | ## Linearizations
124 | The previously shown commands assume the use of the DFS-based linearization. To use BFS or PENMAN decomment the relevant lines in `configs/config.yaml` (for training). As for the evaluation scripts, substitute the `--penman-linearization --use-pointer-tokens` line with `--use-pointer-tokens` for BFS or with `--penman-linearization` for PENMAN.
125 |
126 | ## License
127 | This project is released under the CC-BY-NC-SA 4.0 license (see `LICENSE`). If you use SPRING, please put a link to this repo.
128 |
129 | ## Acknowledgements
130 | The authors gratefully acknowledge the support of the [ERC Consolidator Grant MOUSSE](http://mousse-project.org) No. 726487 and the [ELEXIS project](https://elex.is/) No. 731015 under the European Union’s Horizon 2020 research and innovation programme.
131 |
132 | This work was supported in part by the MIUR under the grant "Dipartimenti di eccellenza 2018-2022" of the Department of Computer Science of the Sapienza University of Rome.
133 |
--------------------------------------------------------------------------------
/bin/blinkify.py:
--------------------------------------------------------------------------------
1 | import blink.main_dense as main_dense
2 | from logging import getLogger
3 | from penman import Triple, Graph
4 | from spring_amr.evaluation import write_predictions
5 | from spring_amr.tokenization_bart import AMRBartTokenizer
6 | import json
7 | from pathlib import Path
8 | from spring_amr.IO import read_raw_amr_data
9 | from spring_amr.entities import read_entities
10 |
11 | if __name__ == '__main__':
12 |
13 | import argparse
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--datasets', nargs='+', required=True)
17 | parser.add_argument('--blink-models-dir', type=str, required=True)
18 | parser.add_argument('--out', type=str, required=True)
19 | parser.add_argument('--device', type=str, default='cuda',
20 | help="Device. 'cpu', 'cuda', 'cuda:'.")
21 | parser.add_argument('--all', action='store_true')
22 | parser.add_argument('--fast', action='store_true')
23 | args = parser.parse_args()
24 |
25 | graphs = read_raw_amr_data(args.datasets)
26 | sentences = [g.metadata['snt'] for g in graphs]
27 | for_blink = []
28 | sample_id = 0
29 |
30 | for sent, (i, with_wikis, name_to_entity, name_to_ops) in zip(sentences, read_entities(sentences, graphs, just_tagged=not args.all)):
31 | for name, parent in name_to_entity.items():
32 | nt, wiki = with_wikis[parent]
33 | ops_triples = name_to_ops[name]
34 | ops_triples = sorted(ops_triples, key=lambda t: t[1])
35 | ops_triples = [t[2].strip('"') for t in ops_triples]
36 | string = ' '.join(ops_triples)
37 | found = string.lower() in sent.lower()
38 | if found:
39 | left = sent.lower().find(string.lower())
40 | right = left + len(string)
41 |
42 | sample = {
43 | "id": sample_id,
44 | "label": "unknown",
45 | "label_id": -1,
46 | "context_left": sent[:left].strip().lower(),
47 | "mention": string.lower(),
48 | "context_right": sent[right:].strip().lower(),
49 | "graph_n": i,
50 | "triple_n": nt,
51 | }
52 | sample_id += 1
53 | for_blink.append(sample)
54 |
55 | main_dense.logger = logger = getLogger('BLINK')
56 | models_path = args.blink_models_dir # the path where you stored the BLINK models
57 |
58 | config = {
59 | "test_entities": None,
60 | "test_mentions": None,
61 | "interactive": False,
62 | "biencoder_model": models_path+"biencoder_wiki_large.bin",
63 | "biencoder_config": models_path+"biencoder_wiki_large.json",
64 | "entity_catalogue": models_path+"entity.jsonl",
65 | "entity_encoding": models_path+"all_entities_large.t7",
66 | "crossencoder_model": models_path+"crossencoder_wiki_large.bin",
67 | "crossencoder_config": models_path+"crossencoder_wiki_large.json",
68 | "top_k": 10,
69 | "show_url": False,
70 | "fast": args.fast, # set this to be true if speed is a concern
71 | "output_path": models_path+"logs/", # logging directory
72 | "faiss_index": None,#"flat",
73 | "index_path": models_path+"faiss_flat_index.pkl",
74 | }
75 |
76 | args_blink = argparse.Namespace(**config)
77 | models = main_dense.load_models(args_blink, logger=logger)
78 | _, _, _, _, _, predictions, scores, = main_dense.run(args_blink, logger, *models, test_data=for_blink, device=args.device)
79 |
80 | for s, pp in zip(for_blink, predictions):
81 | pp = [p for p in pp if not p.startswith('List of')]
82 | p = f'"{pp[0]}"' if pp else '-'
83 | p = p.replace(' ', '_')
84 | graph_n = s['graph_n']
85 | triple_n = s['triple_n']
86 | triples = [g for g in graphs[graph_n].triples]
87 | n, rel, w = triples[triple_n]
88 | triples[triple_n] = Triple(n, rel, p)
89 | g = Graph(triples)
90 | g.metadata = graphs[graph_n].metadata
91 | graphs[graph_n] = g
92 |
93 |
94 | write_predictions(args.out, AMRBartTokenizer, graphs)
95 |
--------------------------------------------------------------------------------
/bin/eval_bleu.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from typing import Iterable, Optional
4 | import sacrebleu
5 | import re
6 |
7 |
8 | def argument_parser():
9 |
10 | parser = argparse.ArgumentParser(description='Preprocess AMR data')
11 | # Multiple input parameters
12 | parser.add_argument(
13 | "--in-tokens",
14 | help="input tokens",
15 | required=True,
16 | type=str
17 | )
18 | parser.add_argument(
19 | "--in-reference-tokens",
20 | help="refrence tokens to compute metric",
21 | type=str
22 | )
23 | args = parser.parse_args()
24 |
25 | return args
26 |
27 |
28 | def tokenize_sentence(text, debug=False):
29 | text = re.sub(r"('ll|n't|'m|'s|'d|'re)", r" \1", text)
30 | text = re.sub(r"(\s+)", r" ", text)
31 | return text
32 |
33 |
34 | def raw_corpus_bleu(hypothesis: Iterable[str], reference: Iterable[str],
35 | offset: Optional[float] = 0.01) -> float:
36 | bleu = sacrebleu.corpus_bleu(hypothesis, reference, smooth_value=offset,
37 | force=True, use_effective_order=False,
38 | lowercase=True)
39 | return bleu.score
40 |
41 |
42 | def raw_corpus_chrf(hypotheses: Iterable[str],
43 | references: Iterable[str]) -> float:
44 | return sacrebleu.corpus_chrf(hypotheses, references,
45 | order=sacrebleu.CHRF_ORDER,
46 | beta=sacrebleu.CHRF_BETA,
47 | remove_whitespace=True)
48 |
49 | def read_tokens(in_tokens_file):
50 | with open(in_tokens_file) as fid:
51 | lines = fid.readlines()
52 | return lines
53 |
54 |
55 | if __name__ == '__main__':
56 |
57 | # Argument handlig
58 | args = argument_parser()
59 |
60 | # read files
61 | ref = read_tokens(args.in_reference_tokens)
62 | hyp = read_tokens(args.in_tokens)
63 |
64 | # Lower evaluation
65 | for i in range(len(ref)):
66 | ref[i] = ref[i].lower()
67 |
68 | # Lower case output
69 | for i in range(len(hyp)):
70 | if '' in hyp[i]:
71 | hyp[i] = hyp[i].split('')[-1]
72 | hyp[i] = tokenize_sentence(hyp[i].lower())
73 |
74 | # results
75 |
76 | bleu = raw_corpus_bleu(hyp, [ref])
77 | print('BLEU {:.2f}'.format(bleu))
78 | chrFpp = raw_corpus_chrf(hyp, ref).score * 100
79 | print('chrF++ {:.2f}'.format(chrFpp))
--------------------------------------------------------------------------------
/bin/inspect_.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import penman
3 | from spring_amr.utils import instantiate_model_and_tokenizer
4 |
5 | if __name__ == '__main__':
6 |
7 | from argparse import ArgumentParser
8 | parser = ArgumentParser()
9 | parser.add_argument('--checkpoint', type=str, required=True)
10 | parser.add_argument('--beam-size', type=int, default=1)
11 | parser.add_argument('--device', type=str, default='cpu')
12 | parser.add_argument('--penman-linearization', action='store_true',
13 | help="Predict using PENMAN linearization instead of ours.")
14 | parser.add_argument('--use-pointer-tokens', action='store_true')
15 | parser.add_argument('--restore-name-ops', action='store_true')
16 | args = parser.parse_args()
17 |
18 | device = torch.device(args.device)
19 | model, tokenizer = instantiate_model_and_tokenizer(
20 | name='facebook/bart-large',
21 | checkpoint=args.checkpoint,
22 | dropout=0., attention_dropout=0.,
23 | penman_linearization=args.penman_linearization,
24 | use_pointer_tokens=args.use_pointer_tokens,
25 | )
26 | model.eval().to(device)
27 |
28 | while True:
29 | sentence = [input('Sentence to parse:\n')]
30 | x, extra = tokenizer.batch_encode_sentences(sentence, device)
31 | with torch.no_grad():
32 | out = model.generate(**x, max_length=1024, decoder_start_token_id=0, num_beams=args.beam_size)
33 | out = out[0].tolist()
34 | graph, status, (lin, backr) = tokenizer.decode_amr(out, restore_name_ops=args.restore_name_ops)
35 | print('-' * 5)
36 | print('Status:', status)
37 | print('-' * 5)
38 | print('Graph:')
39 | print(penman.encode(graph))
40 | print('-' * 5)
41 | print('Linearization:')
42 | print(lin)
43 | print('\n')
44 |
--------------------------------------------------------------------------------
/bin/patch_legacy_checkpoint.py:
--------------------------------------------------------------------------------
1 | if __name__ == '__main__':
2 |
3 | from argparse import ArgumentParser
4 | import torch
5 |
6 | parser = ArgumentParser()
7 | parser.add_argument('legacy_checkpoint')
8 | parser.add_argument('patched_checkpoint')
9 | parser.parse_args()
10 |
11 | args = parser.parse_args()
12 |
13 | to_remove = []
14 |
15 | fixed = False
16 | w = torch.load(args.legacy_checkpoint, map_location='cpu')
17 | for name in w['model']:
18 | if 'backreferences' in name:
19 | fixed = True
20 | to_remove.append(name)
21 | print('Deleting parameters:', name)
22 |
23 | if not fixed:
24 | print('The checkpoint was fine as it was!')
25 | else:
26 | for name in to_remove:
27 | del w['model'][name]
28 | torch.save(w, args.patched_checkpoint)
29 |
--------------------------------------------------------------------------------
/bin/predict_amrs.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import penman
4 | import torch
5 |
6 | from spring_amr import ROOT
7 | from spring_amr.evaluation import predict_amrs, compute_smatch
8 | from spring_amr.penman import encode
9 | from spring_amr.utils import instantiate_loader, instantiate_model_and_tokenizer
10 |
11 | if __name__ == '__main__':
12 |
13 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
14 |
15 | parser = ArgumentParser(
16 | description="Script to predict AMR graphs given sentences. LDC format as input.",
17 | formatter_class=ArgumentDefaultsHelpFormatter,
18 | )
19 | parser.add_argument('--datasets', type=str, required=True, nargs='+',
20 | help="Required. One or more glob patterns to use to load amr files.")
21 | parser.add_argument('--checkpoint', type=str, required=True,
22 | help="Required. Checkpoint to restore.")
23 | parser.add_argument('--model', type=str, default='facebook/bart-large',
24 | help="Model config to use to load the model class.")
25 | parser.add_argument('--beam-size', type=int, default=1,
26 | help="Beam size.")
27 | parser.add_argument('--batch-size', type=int, default=1000,
28 | help="Batch size (as number of linearized graph tokens per batch).")
29 | parser.add_argument('--device', type=str, default='cuda',
30 | help="Device. 'cpu', 'cuda', 'cuda:'.")
31 | parser.add_argument('--pred-path', type=Path, default=ROOT / 'data/tmp/inf-pred.txt',
32 | help="Where to write predictions.")
33 | parser.add_argument('--gold-path', type=Path, default=ROOT / 'data/tmp/inf-gold.txt',
34 | help="Where to write the gold file.")
35 | parser.add_argument('--use-recategorization', action='store_true',
36 | help="Predict using Zhang recategorization on top of our linearization (requires recategorized sentences in input).")
37 | parser.add_argument('--penman-linearization', action='store_true',
38 | help="Predict using PENMAN linearization instead of ours.")
39 | parser.add_argument('--use-pointer-tokens', action='store_true')
40 | parser.add_argument('--raw-graph', action='store_true')
41 | parser.add_argument('--restore-name-ops', action='store_true')
42 | parser.add_argument('--return-all', action='store_true')
43 |
44 | args = parser.parse_args()
45 |
46 | device = torch.device(args.device)
47 | model, tokenizer = instantiate_model_and_tokenizer(
48 | args.model,
49 | dropout=0.,
50 | attention_dropout=0.,
51 | penman_linearization=args.penman_linearization,
52 | use_pointer_tokens=args.use_pointer_tokens,
53 | raw_graph=args.raw_graph,
54 | )
55 | model.amr_mode = True
56 | model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model'])
57 | model.to(device)
58 |
59 | gold_path = args.gold_path
60 | pred_path = args.pred_path
61 | loader = instantiate_loader(
62 | args.datasets,
63 | tokenizer,
64 | batch_size=args.batch_size,
65 | evaluation=True, out=gold_path,
66 | use_recategorization=args.use_recategorization,
67 | )
68 | loader.device = device
69 |
70 | graphs = predict_amrs(
71 | loader,
72 | model,
73 | tokenizer,
74 | beam_size=args.beam_size,
75 | restore_name_ops=args.restore_name_ops,
76 | return_all=args.return_all,
77 | )
78 | if args.return_all:
79 | graphs = [g for gg in graphs for g in gg]
80 |
81 | pieces = [encode(g) for g in graphs]
82 | pred_path.write_text('\n\n'.join(pieces))
83 |
84 | if not args.return_all:
85 | score = compute_smatch(gold_path, pred_path)
86 | print(f'Smatch: {score:.3f}')
87 |
--------------------------------------------------------------------------------
/bin/predict_amrs_from_plaintext.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import penman
4 | import torch
5 | from tqdm import tqdm
6 |
7 | from spring_amr.penman import encode
8 | from spring_amr.utils import instantiate_model_and_tokenizer
9 |
10 | def read_file_in_batches(path, batch_size=1000, max_length=100):
11 |
12 | data = []
13 | idx = 0
14 | for line in Path(path).read_text().strip().splitlines():
15 | line = line.strip()
16 | if not line:
17 | continue
18 | n = len(line.split())
19 | if n > max_length:
20 | continue
21 | data.append((idx, line, n))
22 | idx += 1
23 |
24 | def _iterator(data):
25 |
26 | data = sorted(data, key=lambda x: x[2], reverse=True)
27 |
28 | maxn = 0
29 | batch = []
30 |
31 | for sample in data:
32 | idx, line, n = sample
33 | if n > batch_size:
34 | if batch:
35 | yield batch
36 | maxn = 0
37 | batch = []
38 | yield [sample]
39 | else:
40 | curr_batch_size = maxn * len(batch)
41 | cand_batch_size = max(maxn, n) * (len(batch) + 1)
42 |
43 | if 0 < curr_batch_size <= batch_size and cand_batch_size > batch_size:
44 | yield batch
45 | maxn = 0
46 | batch = []
47 | maxn = max(maxn, n)
48 | batch.append(sample)
49 |
50 | if batch:
51 | yield batch
52 |
53 | return _iterator(data), len(data)
54 |
55 | if __name__ == '__main__':
56 |
57 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
58 |
59 | parser = ArgumentParser(
60 | description="Script to predict AMR graphs given sentences. LDC format as input.",
61 | formatter_class=ArgumentDefaultsHelpFormatter,
62 | )
63 | parser.add_argument('--texts', type=str, required=True, nargs='+',
64 | help="Required. One or more files containing \\n-separated sentences.")
65 | parser.add_argument('--checkpoint', type=str, required=True,
66 | help="Required. Checkpoint to restore.")
67 | parser.add_argument('--model', type=str, default='facebook/bart-large',
68 | help="Model config to use to load the model class.")
69 | parser.add_argument('--beam-size', type=int, default=1,
70 | help="Beam size.")
71 | parser.add_argument('--batch-size', type=int, default=1000,
72 | help="Batch size (as number of linearized graph tokens per batch).")
73 | parser.add_argument('--penman-linearization', action='store_true',
74 | help="Predict using PENMAN linearization instead of ours.")
75 | parser.add_argument('--use-pointer-tokens', action='store_true')
76 | parser.add_argument('--restore-name-ops', action='store_true')
77 | parser.add_argument('--device', type=str, default='cuda',
78 | help="Device. 'cpu', 'cuda', 'cuda:'.")
79 | parser.add_argument('--only-ok', action='store_true')
80 | args = parser.parse_args()
81 |
82 | device = torch.device(args.device)
83 | model, tokenizer = instantiate_model_and_tokenizer(
84 | args.model,
85 | dropout=0.,
86 | attention_dropout=0,
87 | penman_linearization=args.penman_linearization,
88 | use_pointer_tokens=args.use_pointer_tokens,
89 | )
90 | model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model'])
91 | model.to(device)
92 | model.eval()
93 |
94 | for path in tqdm(args.texts, desc='Files:'):
95 |
96 | iterator, nsent = read_file_in_batches(path, args.batch_size)
97 |
98 | with tqdm(desc=path, total=nsent) as bar:
99 | for batch in iterator:
100 | if not batch:
101 | continue
102 | ids, sentences, _ = zip(*batch)
103 | x, _ = tokenizer.batch_encode_sentences(sentences, device=device)
104 | with torch.no_grad():
105 | model.amr_mode = True
106 | out = model.generate(**x, max_length=512, decoder_start_token_id=0, num_beams=args.beam_size)
107 |
108 | bgraphs = []
109 | for idx, sent, tokk in zip(ids, sentences, out):
110 | graph, status, (lin, backr) = tokenizer.decode_amr(tokk.tolist(), restore_name_ops=args.restore_name_ops)
111 | if args.only_ok and ('OK' not in str(status)):
112 | continue
113 | graph.metadata['status'] = str(status)
114 | graph.metadata['source'] = path
115 | graph.metadata['nsent'] = str(idx)
116 | graph.metadata['snt'] = sent
117 | bgraphs.append((idx, graph))
118 |
119 | for i, g in bgraphs:
120 | print(encode(g))
121 | print()
122 |
123 | # if bgraphs and args.reverse:
124 | # bgraphs = [x[1] for x in bgraphs]
125 | # x, _ = tokenizer.batch_encode_graphs(bgraphs, device)
126 | # x = torch.cat([x['decoder_input_ids'], x['lm_labels'][:, -1:]], 1)
127 | # att = torch.ones_like(x)
128 | # att[att == tokenizer.pad_token_id] = 0
129 | # x = {
130 | # 'input_ids': x,
131 | # #'attention_mask': att,
132 | # }
133 | # with torch.no_grad():
134 | # model.amr_mode = False
135 | # out = model.generate(**x, max_length=1024, decoder_start_token_id=0, num_beams=args.beam_size)
136 | #
137 | # for graph, tokk in zip(bgraphs, out):
138 | # tokk = [t for t in tokk.tolist() if t > 2]
139 | # graph.metadata['snt-pred'] = tokenizer.decode(tokk).strip()
140 | bar.update(len(sentences))
141 |
142 | exit(0)
143 |
144 | ids, graphs = zip(*sorted(results, key=lambda x:x[0]))
145 |
146 | for g in graphs:
147 | print(encode(g))
148 | print()
149 |
--------------------------------------------------------------------------------
/bin/predict_sentences.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import penman
4 | import torch
5 |
6 | from spring_amr import ROOT
7 | from spring_amr.evaluation import predict_amrs, compute_smatch, predict_sentences, compute_bleu
8 | from spring_amr.penman import encode
9 | from spring_amr.utils import instantiate_loader, instantiate_model_and_tokenizer
10 |
11 | if __name__ == '__main__':
12 |
13 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
14 |
15 | parser = ArgumentParser(
16 | description="Script to predict AMR graphs given sentences. LDC format as input.",
17 | formatter_class=ArgumentDefaultsHelpFormatter,
18 | )
19 | parser.add_argument('--datasets', type=str, required=True, nargs='+',
20 | help="Required. One or more glob patterns to use to load amr files.")
21 | parser.add_argument('--checkpoint', type=str, required=True,
22 | help="Required. Checkpoint to restore.")
23 | parser.add_argument('--model', type=str, default='facebook/bart-large',
24 | help="Model config to use to load the model class.")
25 | parser.add_argument('--beam-size', type=int, default=1,
26 | help="Beam size.")
27 | parser.add_argument('--batch-size', type=int, default=1000,
28 | help="Batch size (as number of linearized graph tokens per batch).")
29 | parser.add_argument('--device', type=str, default='cuda',
30 | help="Device. 'cpu', 'cuda', 'cuda:'.")
31 | parser.add_argument('--pred-path', type=Path, default=ROOT / 'data/tmp/inf-pred-sentences.txt',
32 | help="Where to write predictions.")
33 | parser.add_argument('--gold-path', type=Path, default=ROOT / 'data/tmp/inf-gold-sentences.txt',
34 | help="Where to write the gold file.")
35 | parser.add_argument('--add-to-graph-file', action='store_true')
36 | parser.add_argument('--use-reverse-decoder', action='store_true')
37 | parser.add_argument('--deinvert', action='store_true')
38 | parser.add_argument('--penman-linearization', action='store_true',
39 | help="Predict using PENMAN linearization instead of ours.")
40 | parser.add_argument('--collapse-name-ops', action='store_true')
41 | parser.add_argument('--use-pointer-tokens', action='store_true')
42 | parser.add_argument('--raw-graph', action='store_true')
43 | parser.add_argument('--return-all', action='store_true')
44 | args = parser.parse_args()
45 |
46 | device = torch.device(args.device)
47 | model, tokenizer = instantiate_model_and_tokenizer(
48 | args.model,
49 | dropout=0.,
50 | attention_dropout=0.,
51 | penman_linearization=args.penman_linearization,
52 | use_pointer_tokens=args.use_pointer_tokens,
53 | collapse_name_ops=args.collapse_name_ops,
54 | init_reverse=args.use_reverse_decoder,
55 | raw_graph=args.raw_graph,
56 | )
57 | model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model'])
58 | model.to(device)
59 | model.rev.amr_mode = False
60 |
61 | loader = instantiate_loader(
62 | args.datasets,
63 | tokenizer,
64 | batch_size=args.batch_size,
65 | evaluation=True, out='/tmp/a.txt',
66 | dereify=args.deinvert)
67 | loader.device = device
68 |
69 | pred_sentences = predict_sentences(loader, model.rev, tokenizer, beam_size=args.beam_size, return_all=args.return_all)
70 | if args.add_to_graph_file:
71 | graphs = loader.dataset.graphs
72 | for ss, g in zip(pred_sentences, graphs):
73 | if args.return_all:
74 | g.metadata['snt-pred'] = '\t\t'.join(ss)
75 | else:
76 | g.metadata['snt-pred'] = ss
77 | args.pred_path.write_text('\n\n'.join([encode(g) for g in graphs]))
78 | else:
79 | if args.return_all:
80 | pred_sentences = [s for ss in pred_sentences for s in ss]
81 | args.gold_path.write_text('\n'.join(loader.dataset.sentences))
82 | args.pred_path.write_text('\n'.join(pred_sentences))
83 | if not args.return_all:
84 | score = compute_bleu(loader.dataset.sentences, pred_sentences)
85 | print(f'BLEU: {score.score:.2f}')
86 |
--------------------------------------------------------------------------------
/bin/train.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | try:
5 | from torch.cuda.amp import autocast
6 | autocast_available = True
7 | except ImportError:
8 | class autocast:
9 | def __init__(self, enabled=True): pass
10 | def __enter__(self): return self
11 | def __exit__(self, exc_type, exc_value, exc_traceback): pass
12 | autocast_available = False
13 |
14 | from torch.cuda.amp.grad_scaler import GradScaler
15 | import transformers
16 |
17 | from spring_amr import ROOT
18 | from spring_amr.dataset import reverse_direction
19 | from spring_amr.optim import RAdam
20 | from spring_amr.evaluation import write_predictions, compute_smatch, predict_amrs, predict_sentences, compute_bleu
21 | from spring_amr.utils import instantiate_model_and_tokenizer, instantiate_loader
22 |
23 | from ignite.engine import Engine, Events
24 | from ignite.metrics import RunningAverage
25 | from ignite.handlers import ModelCheckpoint, global_step_from_engine
26 |
27 | def do_train(checkpoint=None, direction='amr', split_both_decoder=False, fp16=False):
28 |
29 | assert direction in ('amr', 'text', 'both')
30 |
31 | model, tokenizer = instantiate_model_and_tokenizer(
32 | config['model'],
33 | checkpoint=checkpoint,
34 | additional_tokens_smart_init=config['smart_init'],
35 | dropout=config['dropout'],
36 | attention_dropout=config['attention_dropout'],
37 | from_pretrained=config['warm_start'],
38 | init_reverse=split_both_decoder,
39 | penman_linearization=config['penman_linearization'],
40 | collapse_name_ops=config['collapse_name_ops'],
41 | use_pointer_tokens=config['use_pointer_tokens'],
42 | raw_graph=config.get('raw_graph', False)
43 | )
44 |
45 | print(model)
46 | print(model.config)
47 |
48 | if checkpoint is not None:
49 | print(f'Checkpoint restored ({checkpoint})!')
50 |
51 | if direction == 'both' and split_both_decoder:
52 | params_dir_enc = list(model.model.encoder.parameters())
53 | params_dir_enc_check = {id(p) for p in params_dir_enc}
54 | params_dir_dec = set()
55 | params_dir_dec |= {p for p in model.model.decoder.parameters() if id(p) not in params_dir_enc_check}
56 | params_dir_dec |= {p for p in model.rev.model.decoder.parameters() if id(p) not in params_dir_enc_check}
57 | params_dir_dec = list(params_dir_dec)
58 | optimizer = RAdam(
59 | [{'params': params_dir_enc, 'lr': config['learning_rate']},
60 | {'params': params_dir_dec, 'lr': config['learning_rate'] * 2},],
61 | weight_decay=config['weight_decay'])
62 | else:
63 | optimizer = RAdam(
64 | model.parameters(),
65 | lr=config['learning_rate'],
66 | weight_decay=config['weight_decay'])
67 | if checkpoint is not None:
68 | optimizer.load_state_dict(torch.load(checkpoint)['optimizer'])
69 |
70 | if config['scheduler'] == 'cosine':
71 | scheduler = transformers.get_cosine_schedule_with_warmup(
72 | optimizer,
73 | num_warmup_steps=config['warmup_steps'],
74 | num_training_steps=config['training_steps'])
75 | elif config['scheduler'] == 'constant':
76 | scheduler = transformers.get_constant_schedule_with_warmup(
77 | optimizer,
78 | num_warmup_steps=config['warmup_steps'])
79 | else:
80 | raise ValueError
81 |
82 | scaler = GradScaler(enabled=fp16)
83 |
84 | train_loader = instantiate_loader(
85 | config['train'],
86 | tokenizer,
87 | batch_size=config['batch_size'],
88 | evaluation=False,
89 | use_recategorization=config['use_recategorization'],
90 | remove_longer_than=config['remove_longer_than'],
91 | remove_wiki=config['remove_wiki'],
92 | dereify=config['dereify'],
93 | )
94 |
95 | dev_gold_path = ROOT / 'data/tmp/dev-gold.txt'
96 | dev_pred_path = ROOT / 'data/tmp/dev-pred.txt'
97 | dev_loader = instantiate_loader(
98 | config['dev'],
99 | tokenizer,
100 | batch_size=config['batch_size'],
101 | evaluation=True, out=dev_gold_path,
102 | use_recategorization=config['use_recategorization'],
103 | remove_wiki=config['remove_wiki'],
104 | dereify=config['dereify'],
105 | )
106 |
107 | if direction == 'amr':
108 |
109 | def train_step(engine, batch):
110 | model.train()
111 | x, y, extra = batch
112 | model.amr_mode = True
113 | with autocast(enabled=fp16):
114 | loss, *_ = model(**x, **y)
115 | scaler.scale((loss / config['accum_steps'])).backward()
116 | return loss.item()
117 |
118 | @torch.no_grad()
119 | def eval_step(engine, batch):
120 | model.eval()
121 | x, y, extra = batch
122 | model.amr_mode = True
123 | loss, *_ = model(**x, **y)
124 | return loss.item()
125 |
126 | elif direction == 'text':
127 |
128 | def train_step(engine, batch):
129 | model.train()
130 | x, y, extra = batch
131 | x, y = reverse_direction(x, y)
132 | model.rev.amr_mode = False
133 | with autocast(enabled=fp16):
134 | loss, *_ = model.rev(**x, **y)
135 | scaler.scale((loss / config['accum_steps'])).backward()
136 | return loss.item()
137 |
138 | @torch.no_grad()
139 | def eval_step(engine, batch):
140 | model.eval()
141 | x, y, extra = batch
142 | x, y = reverse_direction(x, y)
143 | model.rev.amr_mode = False
144 | loss, *_ = model(**x, **y)
145 | return loss.item()
146 |
147 | elif direction == 'both':
148 |
149 | def train_step(engine, batch):
150 | model.train()
151 | x, y, extra = batch
152 | model.amr_mode = True
153 | with autocast(enabled=fp16):
154 | loss1, *_ = model(**x, **y)
155 | scaler.scale((loss1 / config['accum_steps'] * 0.5)).backward()
156 | loss1 = loss1.item()
157 | x, y = reverse_direction(x, y)
158 | model.rev.amr_mode = False
159 | with autocast(enabled=fp16):
160 | loss2, *_ = model.rev(**x, **y)
161 | scaler.scale((loss2 / config['accum_steps'] * 0.5)).backward()
162 | return loss1, loss2.item()
163 |
164 | @torch.no_grad()
165 | def eval_step(engine, batch):
166 | model.eval()
167 | x, y, extra = batch
168 | model.amr_mode = True
169 | loss1, *_ = model(**x, **y)
170 | x, y = reverse_direction(x, y)
171 | model.rev.amr_mode = False
172 | loss2, *_ = model.rev(**x, **y)
173 | return loss1.item(), loss2.item()
174 |
175 | else:
176 | raise ValueError
177 |
178 | trainer = Engine(train_step)
179 | evaluator = Engine(eval_step)
180 |
181 | @trainer.on(Events.STARTED)
182 | def update(engine):
183 | print('training started!')
184 |
185 | @trainer.on(Events.EPOCH_COMPLETED)
186 | @trainer.on(Events.ITERATION_COMPLETED(every=config['accum_steps']))
187 | def update(engine):
188 | scaler.unscale_(optimizer)
189 | torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_norm'])
190 | scaler.step(optimizer)
191 | scaler.update()
192 | optimizer.zero_grad()
193 | scheduler.step()
194 |
195 | @trainer.on(Events.EPOCH_COMPLETED)
196 | def log_trn_loss(engine):
197 | log_msg = f"training epoch: {engine.state.epoch}"
198 | if direction in ('amr', 'both'):
199 | log_msg += f" | loss_amr: {engine.state.metrics['trn_amr_loss']:.3f}"
200 | if direction in ('text', 'both'):
201 | log_msg += f" | loss_text: {engine.state.metrics['trn_text_loss']:.3f}"
202 | print(log_msg)
203 |
204 | @trainer.on(Events.EPOCH_COMPLETED)
205 | def run_dev_eval(engine):
206 | dev_loader.batch_size = config['batch_size']
207 | dev_loader.device = next(model.parameters()).device
208 | evaluator.run(dev_loader)
209 |
210 | if not config['best_loss']:
211 | if direction in ('amr', 'both'):
212 | @evaluator.on(Events.EPOCH_COMPLETED)
213 | def smatch_eval(engine):
214 | device = next(model.parameters()).device
215 | dev_loader.device = device
216 | graphs = predict_amrs(dev_loader, model, tokenizer, restore_name_ops=config['collapse_name_ops'])
217 | write_predictions(dev_pred_path, tokenizer, graphs)
218 | try:
219 | smatch = compute_smatch(dev_gold_path, dev_pred_path)
220 | except:
221 | smatch = 0.
222 | engine.state.metrics['dev_smatch'] = smatch
223 |
224 | if direction in ('text', 'both'):
225 | @evaluator.on(Events.EPOCH_COMPLETED)
226 | def smatch_eval(engine):
227 | device = next(model.parameters()).device
228 | dev_loader.device = device
229 | pred_sentences = predict_sentences(dev_loader, model.rev, tokenizer, beam_size=config['beam_size'])
230 | bleu = compute_bleu(dev_loader.dataset.sentences, pred_sentences)
231 | engine.state.metrics['dev_bleu'] = bleu.score
232 |
233 | @evaluator.on(Events.EPOCH_COMPLETED)
234 | def log_dev_loss(engine):
235 | log_msg = f"dev epoch: {trainer.state.epoch}"
236 | if direction in ('amr', 'both'):
237 | log_msg += f" | loss_amr: {engine.state.metrics['dev_amr_loss']:.3f}"
238 | if not config['best_loss']:
239 | log_msg += f" | smatch: {engine.state.metrics['dev_smatch']:.3f}"
240 | if direction in ('text', 'both'):
241 | log_msg += f" | loss_text: {engine.state.metrics['dev_text_loss']:.3f}"
242 | if not config['best_loss']:
243 | log_msg += f" | bleu: {engine.state.metrics['dev_bleu']:.3f}"
244 | print(log_msg)
245 |
246 | if direction == 'amr':
247 | RunningAverage(output_transform=lambda out: out).attach(trainer, 'trn_amr_loss')
248 | RunningAverage(output_transform=lambda out: out).attach(evaluator, 'dev_amr_loss')
249 | elif direction == 'text':
250 | RunningAverage(output_transform=lambda out: out).attach(trainer, 'trn_text_loss')
251 | RunningAverage(output_transform=lambda out: out).attach(evaluator, 'dev_text_loss')
252 | elif direction == 'both':
253 | RunningAverage(output_transform=lambda out: out[0]).attach(trainer, 'trn_amr_loss')
254 | RunningAverage(output_transform=lambda out: out[1]).attach(trainer, 'trn_text_loss')
255 | RunningAverage(output_transform=lambda out: out[0]).attach(evaluator, 'dev_amr_loss')
256 | RunningAverage(output_transform=lambda out: out[1]).attach(evaluator, 'dev_text_loss')
257 |
258 |
259 | if config['log_wandb']:
260 | from ignite.contrib.handlers.wandb_logger import WandBLogger
261 | wandb_logger = WandBLogger(init=False)
262 |
263 | if direction == 'amr':
264 | wandb_logger.attach_output_handler(
265 | trainer,
266 | event_name=Events.ITERATION_COMPLETED,
267 | tag="iterations/trn_amr_loss",
268 | output_transform=lambda loss: loss
269 | )
270 | elif direction == 'text':
271 | wandb_logger.attach_output_handler(
272 | trainer,
273 | event_name=Events.ITERATION_COMPLETED,
274 | tag="iterations/trn_text_loss",
275 | output_transform=lambda loss: loss
276 | )
277 | if direction == 'both':
278 | wandb_logger.attach_output_handler(
279 | trainer,
280 | event_name=Events.ITERATION_COMPLETED,
281 | tag="iterations/trn_amr_loss",
282 | output_transform=lambda loss: loss[0]
283 | )
284 | wandb_logger.attach_output_handler(
285 | trainer,
286 | event_name=Events.ITERATION_COMPLETED,
287 | tag="iterations/trn_text_loss",
288 | output_transform=lambda loss: loss[1]
289 | )
290 |
291 | if direction == 'amr':
292 | metric_names_trn = ['trn_amr_loss']
293 | metric_names_dev = ['dev_amr_loss']
294 | if not config['best_loss']:
295 | metric_names_dev.append('dev_smatch')
296 | elif direction == 'text':
297 | metric_names_trn = ['trn_text_loss']
298 | metric_names_dev = ['dev_text_loss']
299 | if not config['best_loss']:
300 | metric_names_dev.append('dev_bleu')
301 | elif direction == 'both':
302 | metric_names_trn = ['trn_amr_loss', 'trn_text_loss']
303 | metric_names_dev = ['dev_amr_loss', 'dev_smatch']
304 | if not config['best_loss']:
305 | metric_names_dev.extend(['dev_text_loss', 'dev_bleu'])
306 |
307 | wandb_logger.attach_output_handler(
308 | trainer,
309 | event_name=Events.EPOCH_COMPLETED,
310 | tag="epochs",
311 | metric_names=metric_names_trn,
312 | global_step_transform=lambda *_: trainer.state.iteration,
313 | )
314 |
315 | wandb_logger.attach_output_handler(
316 | evaluator,
317 | event_name=Events.EPOCH_COMPLETED,
318 | tag="epochs",
319 | metric_names=metric_names_dev,
320 | global_step_transform=lambda *_: trainer.state.iteration,
321 | )
322 |
323 | @trainer.on(Events.ITERATION_COMPLETED)
324 | def wandb_log_lr(engine):
325 | wandb.log({'lr': scheduler.get_last_lr()[0]}, step=engine.state.iteration)
326 |
327 | if config['save_checkpoints']:
328 |
329 | if direction in ('amr', 'both'):
330 | if config['best_loss']:
331 | prefix = 'best-loss-amr'
332 | score_function = lambda x: 1 / evaluator.state.metrics['dev_amr_loss']
333 | else:
334 | prefix = 'best-smatch'
335 | score_function = lambda x: evaluator.state.metrics['dev_smatch']
336 | else:
337 | if config['best_loss']:
338 | prefix = 'best-loss-text'
339 | score_function = lambda x: 1 / evaluator.state.metrics['dev_amr_loss']
340 | else:
341 | prefix = 'best-bleu'
342 | score_function = lambda x: evaluator.state.metrics['dev_bleu']
343 |
344 | to_save = {'model': model, 'optimizer': optimizer}
345 | if config['log_wandb']:
346 | where_checkpoints = str(wandb_logger.run.dir)
347 | else:
348 | root = ROOT/'runs'
349 | try:
350 | root.mkdir()
351 | except:
352 | pass
353 | where_checkpoints = root/str(len(list(root.iterdir())))
354 | try:
355 | where_checkpoints.mkdir()
356 | except:
357 | pass
358 | where_checkpoints = str(where_checkpoints)
359 |
360 | print(where_checkpoints)
361 | handler = ModelCheckpoint(
362 | where_checkpoints,
363 | prefix,
364 | n_saved=1,
365 | create_dir=True,
366 | score_function=score_function,
367 | global_step_transform=global_step_from_engine(trainer),
368 | )
369 | evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save)
370 |
371 | model.cuda()
372 | device = next(model.parameters()).device
373 | train_loader.device = device
374 | trainer.run(train_loader, max_epochs=config['max_epochs'])
375 |
376 | if __name__ == '__main__':
377 |
378 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
379 | import yaml
380 |
381 | import wandb
382 |
383 | parser = ArgumentParser(
384 | description="Trainer script",
385 | formatter_class=ArgumentDefaultsHelpFormatter,
386 | )
387 | parser.add_argument('--direction', type=str, default='amr', choices=['amr', 'text', 'both'],
388 | help='Train a uni- (amr, text) or bidirectional (both).')
389 | parser.add_argument('--split-both-decoder', action='store_true')
390 | parser.add_argument('--config', type=Path, default=ROOT/'configs/sweeped.yaml',
391 | help='Use the following config for hparams.')
392 | parser.add_argument('--checkpoint', type=str,
393 | help='Warm-start from a previous fine-tuned checkpoint.')
394 | parser.add_argument('--fp16', action='store_true')
395 | args, unknown = parser.parse_known_args()
396 |
397 | if args.fp16 and autocast_available:
398 | raise ValueError('You\'ll need a newer PyTorch version to enable fp16 training.')
399 |
400 | with args.config.open() as y:
401 | config = yaml.load(y, Loader=yaml.FullLoader)
402 |
403 | if config['log_wandb']:
404 | wandb.init(
405 | entity="SOME-RUNS",
406 | project="SOME-PROJECT",
407 | config=config,
408 | dir=str(ROOT / 'runs/'))
409 | config = wandb.config
410 |
411 | print(config)
412 |
413 | if args.checkpoint:
414 | checkpoint = args.checkpoint
415 | else:
416 | checkpoint = None
417 |
418 | do_train(
419 | checkpoint=checkpoint,
420 | direction=args.direction,
421 | split_both_decoder=args.split_both_decoder,
422 | fp16=args.fp16,
423 | )
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | name: baseline+smart_init
2 | model: facebook/bart-large
3 |
4 | # <--------------
5 | # Linearizations
6 | # Comment DFS and uncomment the relevant block if you want to use a different linearization scheme
7 |
8 | # DFS
9 | penman_linearization: True
10 | use_pointer_tokens: True
11 | raw_graph: False
12 |
13 | # BFS
14 | # penman_linearization: False
15 | # use_pointer_tokens: True
16 | # raw_graph: False
17 |
18 | # PENMAN
19 | # penman_linearization: True
20 | # use_pointer_tokens: False
21 | # raw_graph: False
22 |
23 | # BART baseline
24 | # penman_linearization: True
25 | # use_pointer_tokens: False
26 | # raw_graph: True
27 |
28 | remove_wiki: False
29 | dereify: False
30 | collapse_name_ops: False
31 |
32 | # Hparams
33 | batch_size: 500
34 | beam_size: 1
35 | dropout: 0.25
36 | attention_dropout: 0.0
37 | smart_init: True
38 | accum_steps: 10
39 | warmup_steps: 1
40 | training_steps: 250000
41 | weight_decay: 0.004
42 | grad_norm: 2.5
43 | scheduler: constant
44 | learning_rate: 0.00005
45 | max_epochs: 30
46 | save_checkpoints: True
47 | log_wandb: False
48 | warm_start: True
49 | use_recategorization: False
50 | best_loss: False
51 | remove_longer_than: 1024
52 |
53 | # <------------------
54 | # Data: replace DATA below with the root of your AMR 2/3 release folder
55 | train: DATA/data/amrs/split/training/*.txt
56 | dev: DATA/data/amrs/split/dev/*.txt
57 | test: DATA/data/amrs/split/test/*.txt
58 |
--------------------------------------------------------------------------------
/data/vocab/additions.txt:
--------------------------------------------------------------------------------
1 | date-entity
2 | government-organization
3 | temporal-quantity
4 | amr-unknown
5 | multi-sentence
6 | political-party
7 | :compared-to
8 | monetary-quantity
9 | ordinal-entity
10 | religious-group
11 | percentage-entity
12 | world-region
13 | :consist
14 | url-entity
15 | political-movement
16 | et-cetera
17 | at-least
18 | mass-quantity
19 | have-org-role-91
20 | have-rel-role-91
21 | include-91
22 | have-concession-91
23 | have-condition-91
24 | be-located-at-91
25 | rate-entity-91
26 | instead-of-91
27 | hyperlink-91
28 | request-confirmation-91
29 | have-purpose-91
30 | be-temporally-at-91
31 | regardless-91
32 | have-polarity-91
33 | byline-91
34 | have-manner-91
35 | have-part-91
36 | have-quant-91
37 | publication-91
38 | be-from-91
39 | have-mod-91
40 | have-frequency-91
41 | score-on-scale-91
42 | have-li-91
43 | be-compared-to-91
44 | be-destined-for-91
45 | course-91
46 | have-subevent-91
47 | street-address-91
48 | have-extent-91
49 | statistical-test-91
50 | have-instrument-91
51 | have-name-91
52 | be-polite-91
53 | -00
54 | -01
55 | -02
56 | -03
57 | -04
58 | -05
59 | -06
60 | -07
61 | -08
62 | -09
63 | -10
64 | -11
65 | -12
66 | -13
67 | -14
68 | -15
69 | -16
70 | -17
71 | -18
72 | -19
73 | -20
74 | -21
75 | -22
76 | -23
77 | -24
78 | -25
79 | -26
80 | -27
81 | -28
82 | -29
83 | -20
84 | -31
85 | -32
86 | -33
87 | -34
88 | -35
89 | -36
90 | -37
91 | -38
92 | -39
93 | -40
94 | -41
95 | -42
96 | -43
97 | -44
98 | -45
99 | -46
100 | -47
101 | -48
102 | -49
103 | -50
104 | -51
105 | -52
106 | -53
107 | -54
108 | -55
109 | -56
110 | -57
111 | -58
112 | -59
113 | -60
114 | -61
115 | -62
116 | -63
117 | -64
118 | -65
119 | -66
120 | -67
121 | -68
122 | -69
123 | -70
124 | -71
125 | -72
126 | -73
127 | -74
128 | -75
129 | -76
130 | -77
131 | -78
132 | -79
133 | -80
134 | -81
135 | -82
136 | -83
137 | -84
138 | -85
139 | -86
140 | -87
141 | -88
142 | -89
143 | -90
144 | -91
145 | -92
146 | -93
147 | -94
148 | -95
149 | -96
150 | -97
151 | -98
152 | -of
153 | :op1
154 | :op2
155 | :op3
156 | :op4
157 | :op5
158 | :ARG0
159 | :ARG1
160 | :ARG2
161 | :ARG3
162 | :ARG4
163 | :ARG5
164 | :ARG6
165 | :ARG7
166 | :ARG8
167 | :ARG9
168 | :ARG10
169 | :ARG11
170 | :ARG12
171 | :ARG13
172 | :ARG14
173 | :ARG15
174 | :ARG16
175 | :ARG17
176 | :ARG18
177 | :ARG19
178 | :ARG20
179 | :accompanier
180 | :age
181 | :beneficiary
182 | :calendar
183 | :cause
184 | :century
185 | :concession
186 | :condition
187 | :conj-as-if
188 | :consist-of
189 | :cost
190 | :day
191 | :dayperiod
192 | :decade
193 | :degree
194 | :destination
195 | :direction
196 | :domain
197 | :duration
198 | :employed-by
199 | :era
200 | :example
201 | :extent
202 | :frequency
203 | :instrument
204 | :li
205 | :location
206 | :manner
207 | :meaning
208 | :medium
209 | :mod
210 | :mode
211 | :month
212 | :name
213 | :ord
214 | :part
215 | :path
216 | :polarity
217 | :polite
218 | :poss
219 | :purpose
220 | :quant
221 | :quarter
222 | :range
223 | :relation
224 | :role
225 | :scale
226 | :season
227 | :source
228 | :subevent
229 | :subset
230 | :superset
231 | :time
232 | :timezone
233 | :topic
234 | :unit
235 | :value
236 | :weekday
237 | :wiki
238 | :year
239 | :year2
240 | :snt0
241 | :snt1
242 | :snt2
243 | :snt3
244 | :snt4
245 | :snt5
246 |
--------------------------------------------------------------------------------
/data/vocab/recategorizations.txt:
--------------------------------------------------------------------------------
1 | PERSON
2 | COUNTRY
3 | QUANTITY
4 | ORGANIZATION
5 | DATE_ATTRS
6 | NATIONALITY
7 | LOCATION
8 | ENTITY
9 | CITY
10 | MISC
11 | ORDINAL_ENTITY
12 | IDEOLOGY
13 | RELIGION
14 | STATE_OR_PROVINCE
15 | URL
16 | CAUSE_OF_DEATH
17 | O
18 | TITLE
19 | DATE
20 | NUMBER
21 | HANDLE
22 | SCORE_ENTITY
23 | DURATION
24 | ORDINAL
25 | MONEY
26 | SET
27 | CRIMINAL_CHARGE
28 | _1
29 | _2
30 | _3
31 | _4
32 | _2
33 | _5
34 | _6
35 | _7
36 | _8
37 | _9
38 | _10
39 | _11
40 | _12
41 | _13
42 | _14
43 | _15
--------------------------------------------------------------------------------
/docs/appendix.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/spring/39079940d028ba0dde4c1af60432be49f67d76f8/docs/appendix.pdf
--------------------------------------------------------------------------------
/docs/camera-ready.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/spring/39079940d028ba0dde4c1af60432be49f67d76f8/docs/camera-ready.pdf
--------------------------------------------------------------------------------
/docs/preprint.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/spring/39079940d028ba0dde4c1af60432be49f67d76f8/docs/preprint.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cached_property
2 | networkx
3 | penman>=1.1.0
4 | pytorch-ignite
5 | regex
6 | sacrebleu
7 | smatch
8 | transformers==2.11.0
9 | wandb
10 | PyYAML>=5.1
--------------------------------------------------------------------------------
/sample.txt:
--------------------------------------------------------------------------------
1 | # ::status ParsedStatus.OK
2 | # ::source sample.txt
3 | # ::nsent 6
4 | # ::snt In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains.
5 | # ::snt-pred Scientists were shocked to discover a herd of unicorns living in a remote valley inaccessible in the Andes Mountains.
6 | (z0 / discover-01
7 | :ARG0 (z1 / scientist)
8 | :ARG1 (z2 / herd
9 | :consist-of (z3 / unicorn)
10 | :ARG0-of (z4 / live-01
11 | :location (z5 / valley
12 | :mod (z6 / remote)
13 | :ARG1-of (z7 / explore-01
14 | :polarity -
15 | :time (z8 / previous))
16 | :location (z9 / mountain
17 | :wiki "Andes"
18 | :name (z10 / name
19 | :op1 "Andes"
20 | :op2 "Mountains")))))
21 | :ARG0-of (z11 / shock-01))
22 |
23 | # ::status ParsedStatus.OK
24 | # ::source sample.txt
25 | # ::nsent 5
26 | # ::snt Emily loves mint chocolate cake, but she requires that it be paired with mini chocolate chips, so I threw some of those in between the layers.
27 | # ::snt-pred Emily loves chocolate cake, but it requires to be paired with mini chocolate chips, so I threw some of them in between the layers.
28 | (z0 / love-01
29 | :ARG0 (z1 / person
30 | :wiki -
31 | :name (z2 / name
32 | :op1 "Emily"))
33 | :ARG1 (z3 / cake
34 | :consist-of (z4 / chocolate
35 | :mod (z5 / mint)))
36 | :concession-of (z6 / require-01
37 | :ARG0 z1
38 | :ARG1 (z7 / pair-01
39 | :ARG1 z3
40 | :ARG2 (z8 / chip
41 | :consist-of (z9 / chocolate
42 | :mod (z10 / mini)))))
43 | :ARG0-of (z11 / cause-01
44 | :ARG1 (z12 / throw-01
45 | :ARG0 (z13 / i)
46 | :ARG1 (z14 / some
47 | :ARG1-of (z15 / include-91
48 | :ARG2 z3))
49 | :ARG2 (z16 / between
50 | :op1 (z17 / layer)))))
51 |
52 | # ::status ParsedStatus.OK
53 | # ::source sample.txt
54 | # ::nsent 7
55 | # ::snt Prehistoric man sketched an incredible array of prehistoric beasts on the rough limestone walls of a cave in modern day France 36,000 years ago.
56 | # ::snt-pred 36,000 years ago, prehistoric men drew an incredible array of prehistoric beasts on a rough limestone wall of a cave in modern-day France.
57 | (z0 / draw-01
58 | :ARG0 (z1 / man
59 | :mod (z2 / prehistoric))
60 | :ARG1 (z3 / array
61 | :mod (z4 / incredible)
62 | :consist-of (z5 / beast
63 | :mod (z6 / prehistoric)))
64 | :location (z7 / wall
65 | :consist-of (z8 / limestone)
66 | :ARG1-of (z9 / rough-04)
67 | :part-of (z10 / cave
68 | :location (z11 / country
69 | :wiki "France"
70 | :name (z12 / name
71 | :op1 "France")
72 | :time (z13 / day
73 | :ARG1-of (z14 / modern-02)))))
74 | :time (z15 / before
75 | :op1 (z16 / now)
76 | :quant (z17 / temporal-quantity
77 | :quant 36000
78 | :unit (z18 / year))))
79 |
80 | # ::status ParsedStatus.OK
81 | # ::source sample.txt
82 | # ::nsent 3
83 | # ::snt Corporal Michael P. Goeldin was an unskilled laborer from Ireland when he enlisted in Company A in November 1860.
84 | # ::snt-pred When Michael P. Goeldin enlisted in Company A in November, 1860, he was an Irish labourer with no skills.
85 | (z0 / person
86 | :ARG0-of (z1 / labor-01
87 | :manner (z2 / skill
88 | :polarity -))
89 | :domain (z3 / person
90 | :wiki -
91 | :name (z4 / name
92 | :op1 "Michael"
93 | :op2 "P."
94 | :op3 "Goeldin")
95 | :ARG0-of (z5 / have-org-role-91
96 | :ARG2 (z6 / corporal)))
97 | :mod (z7 / country
98 | :wiki "Ireland"
99 | :name (z8 / name
100 | :op1 "Ireland"))
101 | :time (z9 / enlist-01
102 | :ARG1 z3
103 | :ARG2 (z10 / military
104 | :wiki -
105 | :name (z11 / name
106 | :op1 "Company"
107 | :op2 "A"))
108 | :time (z12 / date-entity
109 | :year 1860
110 | :month 11)))
111 |
112 | # ::status ParsedStatus.OK
113 | # ::source sample.txt
114 | # ::nsent 0
115 | # ::snt This pairing was the first outfit I thought of when I bought the shoes.
116 | # ::snt-pred This pair is the first outfit I thought of when I bought shoes.
117 | (z0 / outfit
118 | :ord (z1 / ordinal-entity
119 | :value 1)
120 | :ARG1-of (z2 / think-01
121 | :ARG0 (z3 / i)
122 | :time (z4 / buy-01
123 | :ARG0 z3
124 | :ARG1 (z5 / shoe)))
125 | :domain (z6 / pair-01
126 | :mod (z7 / this)))
127 |
128 | # ::status ParsedStatus.OK
129 | # ::source sample.txt
130 | # ::nsent 2
131 | # ::snt The pink ghost’s AI is designed to ”feel” opposite of the red ghost’s behavior.
132 | # ::snt-pred The artificial system of the pink ghosts was designed to feel the opposite of the way the red ghosts behaved.
133 | (z0 / design-01
134 | :ARG1 (z1 / system
135 | :mod (z2 / artificial)
136 | :poss (z3 / ghost
137 | :ARG1-of (z4 / pink-04)))
138 | :ARG3 (z5 / feel-01
139 | :ARG0 z1
140 | :ARG1 (z6 / opposite-01
141 | :ARG2 (z7 / behave-01
142 | :ARG0 (z8 / ghost
143 | :mod (z9 / red))))))
144 |
145 | # ::status ParsedStatus.OK
146 | # ::source sample.txt
147 | # ::nsent 4
148 | # ::snt Xresources can be an absolute pain (they were for me).
149 | # ::snt-pred The x-resoures could absolutely cause pain to me.
150 | (z0 / possible-01
151 | :ARG1 (z1 / pain-01
152 | :ARG0 (z2 / resource
153 | :mod (z3 / xresources))
154 | :mod (z4 / absolute)
155 | :ARG1-of (z5 / cause-01
156 | :ARG0 (z6 / they
157 | :beneficiary (z7 / i)))))
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='spring_amr',
5 | version='1.0',
6 | packages=['spring_amr'],
7 | url='https://github.com/SapienzaNLP/spring',
8 | license='CC BY-NC-SA 4.0',
9 | author='Michele Bevilacqua, Rexhina Blloshmi and Roberto Navigli',
10 | author_email='{bevilacqua,blloshmi,navigli}@di.uniroma1.it',
11 | description='Parse sentences into AMR graphs and generate sentences from AMR graphs without breaking a sweat!'
12 | )
13 |
--------------------------------------------------------------------------------
/spring_amr/IO.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from typing import List, Union, Iterable
3 | from pathlib import Path
4 | from spring_amr.penman import load as pm_load
5 |
6 | def read_raw_amr_data(
7 | paths: List[Union[str, Path]],
8 | use_recategorization=False,
9 | dereify=True,
10 | remove_wiki=False,
11 | ):
12 | assert paths
13 |
14 | if not isinstance(paths, Iterable):
15 | paths = [paths]
16 |
17 | graphs = []
18 | for path_ in paths:
19 | for path in glob.glob(str(path_)):
20 | path = Path(path)
21 | graphs.extend(pm_load(path, dereify=dereify, remove_wiki=remove_wiki))
22 |
23 | assert graphs
24 |
25 | if use_recategorization:
26 | for g in graphs:
27 | metadata = g.metadata
28 | metadata['snt_orig'] = metadata['snt']
29 | tokens = eval(metadata['tokens'])
30 | metadata['snt'] = ' '.join([t for t in tokens if not ((t.startswith('-L') or t.startswith('-R')) and t.endswith('-'))])
31 |
32 | return graphs
--------------------------------------------------------------------------------
/spring_amr/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1"
2 |
3 | from pathlib import Path
4 |
5 | ROOT = Path(__file__).parent.parent
6 |
--------------------------------------------------------------------------------
/spring_amr/dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import torch
4 | from cached_property import cached_property
5 | from torch.utils.data import Dataset
6 | from spring_amr.IO import read_raw_amr_data
7 |
8 | def reverse_direction(x, y, pad_token_id=1):
9 | input_ids = torch.cat([y['decoder_input_ids'], y['lm_labels'][:, -1:]], 1)
10 | attention_mask = torch.ones_like(input_ids)
11 | attention_mask[input_ids == pad_token_id] = 0
12 | decoder_input_ids = x['input_ids'][:,:-1]
13 | lm_labels = x['input_ids'][:,1:]
14 | x = {'input_ids': input_ids, 'attention_mask': attention_mask}
15 | y = {'decoder_input_ids': decoder_input_ids, 'lm_labels': lm_labels}
16 | return x, y
17 |
18 | class AMRDataset(Dataset):
19 |
20 | def __init__(
21 | self,
22 | paths,
23 | tokenizer,
24 | device=torch.device('cpu'),
25 | use_recategorization=False,
26 | remove_longer_than=None,
27 | remove_wiki=False,
28 | dereify=True,
29 | ):
30 | self.paths = paths
31 | self.tokenizer = tokenizer
32 | self.device = device
33 | graphs = read_raw_amr_data(paths, use_recategorization, remove_wiki=remove_wiki, dereify=dereify)
34 | self.graphs = []
35 | self.sentences = []
36 | self.linearized = []
37 | self.linearized_extra = []
38 | self.remove_longer_than = remove_longer_than
39 | for g in graphs:
40 | l, e = self.tokenizer.linearize(g)
41 |
42 | try:
43 | self.tokenizer.batch_encode_sentences([g.metadata['snt']])
44 | except:
45 | logging.warning('Invalid sentence!')
46 | continue
47 |
48 | if remove_longer_than and len(l) > remove_longer_than:
49 | continue
50 | if len(l) > 1024:
51 | logging.warning('Sequence longer than 1024 included. BART does not support it!')
52 |
53 | self.sentences.append(g.metadata['snt'])
54 | self.graphs.append(g)
55 | self.linearized.append(l)
56 | self.linearized_extra.append(e)
57 |
58 | def __len__(self):
59 | return len(self.sentences)
60 |
61 | def __getitem__(self, idx):
62 | sample = {}
63 | sample['id'] = idx
64 | sample['sentences'] = self.sentences[idx]
65 | if self.linearized is not None:
66 | sample['linearized_graphs_ids'] = self.linearized[idx]
67 | sample.update(self.linearized_extra[idx])
68 | return sample
69 |
70 | def size(self, sample):
71 | return len(sample['linearized_graphs_ids'])
72 |
73 | def collate_fn(self, samples, device=torch.device('cpu')):
74 | x = [s['sentences'] for s in samples]
75 | x, extra = self.tokenizer.batch_encode_sentences(x, device=device)
76 | if 'linearized_graphs_ids' in samples[0]:
77 | y = [s['linearized_graphs_ids'] for s in samples]
78 | y, extra_y = self.tokenizer.batch_encode_graphs_from_linearized(y, samples, device=device)
79 | extra.update(extra_y)
80 | else:
81 | y = None
82 | extra['ids'] = [s['id'] for s in samples]
83 | return x, y, extra
84 |
85 | class AMRDatasetTokenBatcherAndLoader:
86 |
87 | def __init__(self, dataset, batch_size=800 ,device=torch.device('cpu'), shuffle=False, sort=False):
88 | assert not (shuffle and sort)
89 | self.batch_size = batch_size
90 | self.tokenizer = dataset.tokenizer
91 | self.dataset = dataset
92 | self.device = device
93 | self.shuffle = shuffle
94 | self.sort = sort
95 |
96 | def __iter__(self):
97 | it = self.sampler()
98 | it = ([[self.dataset[s] for s in b] for b in it])
99 | it = (self.dataset.collate_fn(b, device=self.device) for b in it)
100 | return it
101 |
102 | @cached_property
103 | def sort_ids(self):
104 | lengths = [len(s.split()) for s in self.dataset.sentences]
105 | ids, _ = zip(*sorted(enumerate(lengths), reverse=True))
106 | ids = list(ids)
107 | return ids
108 |
109 | def sampler(self):
110 | ids = list(range(len(self.dataset)))[::-1]
111 |
112 | if self.shuffle:
113 | random.shuffle(ids)
114 | if self.sort:
115 | ids = self.sort_ids.copy()
116 |
117 | batch_longest = 0
118 | batch_nexamps = 0
119 | batch_ntokens = 0
120 | batch_ids = []
121 |
122 | def discharge():
123 | nonlocal batch_longest
124 | nonlocal batch_nexamps
125 | nonlocal batch_ntokens
126 | ret = batch_ids.copy()
127 | batch_longest *= 0
128 | batch_nexamps *= 0
129 | batch_ntokens *= 0
130 | batch_ids[:] = []
131 | return ret
132 |
133 | while ids:
134 | idx = ids.pop()
135 | size = self.dataset.size(self.dataset[idx])
136 | cand_batch_ntokens = max(size, batch_longest) * (batch_nexamps + 1)
137 | if cand_batch_ntokens > self.batch_size and batch_ids:
138 | yield discharge()
139 | batch_longest = max(batch_longest, size)
140 | batch_nexamps += 1
141 | batch_ntokens = batch_longest * batch_nexamps
142 | batch_ids.append(idx)
143 |
144 | if len(batch_ids) == 1 and batch_ntokens > self.batch_size:
145 | yield discharge()
146 |
147 | if batch_ids:
148 | yield discharge()
149 |
--------------------------------------------------------------------------------
/spring_amr/entities.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | def read_entities(sentences, graphs, just_tagged=True):
4 |
5 | for i, (s, g) in enumerate(zip(sentences, graphs)):
6 |
7 | with_wikis = {}
8 | name_to_entity = {}
9 | name_to_ops = defaultdict(list)
10 |
11 | for nt, t in enumerate(g.triples):
12 | n1, rel, n2 = t
13 |
14 | if n2 == '-' and just_tagged:
15 | continue
16 |
17 | if rel == ':wiki':
18 | with_wikis[n1] = (nt, n2)
19 |
20 | for t in g.triples:
21 | n1, rel, n2 = t
22 | if (n1 in with_wikis) and (rel == ':name'):
23 | name_to_entity[n2] = n1
24 |
25 | for nt, t in enumerate(g.triples):
26 | n1, rel, n2 = t
27 | if (n1 in name_to_entity) and rel.startswith(':op'):
28 | name_to_ops[n1].append(t)
29 |
30 | yield (i, with_wikis, name_to_entity, name_to_ops)
--------------------------------------------------------------------------------
/spring_amr/evaluation.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from pathlib import Path
3 |
4 | import penman
5 | from sacrebleu import corpus_bleu
6 | import torch
7 | from tqdm import tqdm
8 | import smatch
9 |
10 | from spring_amr.dataset import reverse_direction
11 |
12 | def predict_amrs(
13 | loader, model, tokenizer, beam_size=1, tokens=None, restore_name_ops=False, return_all=False):
14 |
15 | shuffle_orig = loader.shuffle
16 | sort_orig = loader.sort
17 |
18 | loader.shuffle = False
19 | loader.sort = True
20 |
21 | total = len(loader.dataset)
22 | model.eval()
23 | model.amr_mode = True
24 |
25 | if tokens is None:
26 | ids = []
27 | tokens = []
28 | with tqdm(total=total) as bar:
29 | for x, y, extra in loader:
30 | ii = extra['ids']
31 | ids.extend(ii)
32 | with torch.no_grad():
33 | out = model.generate(
34 | **x,
35 | max_length=1024,
36 | decoder_start_token_id=0,
37 | num_beams=beam_size,
38 | num_return_sequences=beam_size)
39 | nseq = len(ii)
40 | for i1 in range(0, out.size(0), beam_size):
41 | tokens_same_source = []
42 | tokens.append(tokens_same_source)
43 | for i2 in range(i1, i1+beam_size):
44 | tokk = out[i2].tolist()
45 | tokens_same_source.append(tokk)
46 | bar.update(nseq)
47 | # reorder
48 | tokens = [tokens[i] for i in ids]
49 | tokens = [t for tt in tokens for t in tt]
50 |
51 | graphs = []
52 | for i1 in range(0, len(tokens), beam_size):
53 | graphs_same_source = []
54 | graphs.append(graphs_same_source)
55 | for i2 in range(i1, i1+beam_size):
56 | tokk = tokens[i2]
57 | graph, status, (lin, backr) = tokenizer.decode_amr(tokk, restore_name_ops=restore_name_ops)
58 | graph.status = status
59 | graph.nodes = lin
60 | graph.backreferences = backr
61 | graph.tokens = tokk
62 | graphs_same_source.append(graph)
63 | graphs_same_source[:] = tuple(zip(*sorted(enumerate(graphs_same_source), key=lambda x: (x[1].status.value, x[0]))))[1]
64 |
65 | for gps, gg in zip(graphs, loader.dataset.graphs):
66 | for gp in gps:
67 | metadata = gg.metadata.copy()
68 | metadata['annotator'] = 'bart-amr'
69 | metadata['date'] = str(datetime.datetime.now())
70 | if 'save-date' in metadata:
71 | del metadata['save-date']
72 | gp.metadata = metadata
73 |
74 | loader.shuffle = shuffle_orig
75 | loader.sort = sort_orig
76 |
77 | if not return_all:
78 | graphs = [gg[0] for gg in graphs]
79 |
80 | return graphs
81 |
82 | def predict_sentences(loader, model, tokenizer, beam_size=1, tokens=None, return_all=False):
83 |
84 | shuffle_orig = loader.shuffle
85 | sort_orig = loader.sort
86 |
87 | loader.shuffle = False
88 | loader.sort = True
89 |
90 | total = len(loader.dataset)
91 | model.eval()
92 | model.amr_mode = False
93 |
94 | if tokens is None:
95 | ids = []
96 | tokens = []
97 | with tqdm(total=total) as bar:
98 | for x, y, extra in loader:
99 | ids.extend(extra['ids'])
100 | x, y = reverse_direction(x, y)
101 | x['input_ids'] = x['input_ids'][:, :1024]
102 | x['attention_mask'] = x['attention_mask'][:, :1024]
103 | with torch.no_grad():
104 | out = model.generate(
105 | **x,
106 | max_length=350,
107 | decoder_start_token_id=0,
108 | num_beams=beam_size,
109 | num_return_sequences=beam_size)
110 | for i1 in range(0, len(out), beam_size):
111 | tokens_same_source = []
112 | tokens.append(tokens_same_source)
113 | for i2 in range(i1, i1+beam_size):
114 | tokk = out[i2]
115 | tokk = [t for t in tokk.tolist() if t > 2]
116 | tokens_same_source.append(tokk)
117 | bar.update(out.size(0) // beam_size)
118 | #reorder
119 | tokens = [tokens[i] for i in ids]
120 |
121 | sentences = []
122 | for tokens_same_source in tokens:
123 | if return_all:
124 | sentences.append([tokenizer.decode(tokk).strip() for tokk in tokens_same_source])
125 | else:
126 | sentences.append(tokenizer.decode(tokens_same_source[0]).strip())
127 |
128 | loader.shuffle = shuffle_orig
129 | loader.sort = sort_orig
130 |
131 | return sentences
132 |
133 | def write_predictions(predictions_path, tokenizer, graphs):
134 | pieces = [penman.encode(g) for g in graphs]
135 | Path(predictions_path).write_text('\n\n'.join(pieces).replace(tokenizer.INIT, ''))
136 | return predictions_path
137 |
138 | def compute_smatch(test_path, predictions_path):
139 | with Path(predictions_path).open() as p, Path(test_path).open() as g:
140 | score = next(smatch.score_amr_pairs(p, g))
141 | return score[2]
142 |
143 | def compute_bleu(gold_sentences, pred_sentences):
144 | return corpus_bleu(pred_sentences, [gold_sentences])
145 |
--------------------------------------------------------------------------------
/spring_amr/linearization.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import itertools
3 | from collections import deque, defaultdict
4 | import re
5 | from typing import List, Optional, Dict, Any, Set, TypeVar
6 |
7 | from cached_property import cached_property
8 | from dataclasses import dataclass
9 | import networkx as nx
10 | import penman
11 |
12 | @dataclass
13 | class SemanticGraph:
14 |
15 | nodes_var: List[str]
16 | """
17 | List of linearized nodes, with special tokens.
18 | """
19 | edges: Optional[List[str]]
20 | """
21 | List of linearized edges, with special tokens.
22 | """
23 | backreferences: List[int]
24 | """
25 | List of backpointers to handle rentrancies and cycles.
26 | """
27 | var2instance: Dict[str, str]
28 | """
29 | Dict from var ids to 'lemmatized' readable strings qualifying the node (collapsing the :instance edge for AMR).
30 | """
31 | extra: Dict[str, Any]
32 | """
33 | Holds extra stuff that might be useful, e.g. alignments, NER, EL.
34 | """
35 |
36 | @cached_property
37 | def variables(self) -> Set[str]:
38 | """Set of variables in this semantic graph"""
39 | variables = {v for v in self.nodes_var if not v.startswith('<')}
40 | return variables
41 |
42 | @property
43 | def resolved_nodes_var(self) -> List[str]:
44 | return [self.nodes_var[b] for b in self.backreferences]
45 |
46 | @cached_property
47 | def nodes(self) -> List[str]:
48 | """Linearized nodes with varids replaced by instances"""
49 | return [self.var2instance.get(node, node) for node in self.nodes_var]
50 |
51 | @property
52 | def resolved_nodes(self) -> List[str]:
53 | return [self.nodes[b] for b in self.backreferences]
54 |
55 | def src_occurrence(self, var: str) -> int:
56 | pass
57 |
58 |
59 | class BaseLinearizer(metaclass=abc.ABCMeta):
60 |
61 | @abc.abstractmethod
62 | def linearize(self, *args, **kwargs) -> SemanticGraph:
63 | pass
64 |
65 | class AMRTokens:
66 |
67 | START, END = '<', '>'
68 | _TEMPL = START + '{}' + END
69 |
70 | BOS_N = _TEMPL.format('s')
71 | EOS_N = _TEMPL.format('/s')
72 | START_N = _TEMPL.format('start')
73 | STOP_N = _TEMPL.format('stop')
74 | PNTR_N = _TEMPL.format('pointer')
75 |
76 | LIT_START = _TEMPL.format( 'lit')
77 | LIT_END = _TEMPL.format('/lit')
78 |
79 | BACKR_SRC_N = _TEMPL.format('backr:src:XXX')
80 | BACKR_TRG_N = _TEMPL.format('backr:trg:XXX')
81 |
82 | BOS_E = _TEMPL.format('s')
83 | EOS_E = _TEMPL.format('/s')
84 | START_E = _TEMPL.format('start')
85 | STOP_E = _TEMPL.format('stop')
86 |
87 | _FIXED_SPECIAL_TOKENS_N = {
88 | BOS_N, EOS_N, START_N, STOP_N}
89 | _FIXED_SPECIAL_TOKENS_E = {
90 | BOS_E, EOS_E, START_E, STOP_E}
91 | _FIXED_SPECIAL_TOKENS = _FIXED_SPECIAL_TOKENS_N | _FIXED_SPECIAL_TOKENS_E
92 |
93 | # match and read backreferences
94 | _re_BACKR_SRC_N = re.compile(BACKR_SRC_N.replace('XXX', r'([0-9]+)'))
95 | _re_BACKR_TRG_N = re.compile(BACKR_TRG_N.replace('XXX', r'([0-9]+)'))
96 |
97 | @classmethod
98 | def is_node(cls, string: str) -> bool:
99 | if isinstance(string, str) and string.startswith(':'):
100 | return False
101 | elif string in cls._FIXED_SPECIAL_TOKENS_E:
102 | return False
103 | return True
104 |
105 | @classmethod
106 | def read_backr(cls, string: str) -> Optional:
107 | m_src = cls._re_BACKR_SRC_N.search(string)
108 | if m_src is not None:
109 | return m_src
110 | m_trg = cls._re_BACKR_TRG_N.search(string)
111 | if m_trg is not None:
112 | return m_trg
113 | return None
114 |
115 |
116 | T = TypeVar('T')
117 |
118 |
119 | def index_default(
120 | item: T, list_: List[T],
121 | start: Optional[int] = None,
122 | stop: Optional[int] = None,
123 | default: Optional[int] = None
124 | ):
125 | if start is None:
126 | start = 0
127 | if stop is None:
128 | stop = len(list_)
129 | return next((i for i, x in enumerate(list_[start:stop], start=start) if x == item), default)
130 |
131 | class AMRLinearizer(BaseLinearizer):
132 |
133 | def __init__(
134 | self,
135 | use_pointer_tokens: bool = True,
136 | collapse_name_ops: bool = False,
137 | ):
138 | self.collapse_name_ops = collapse_name_ops
139 | self.interleave_edges = False
140 | self.use_pointer_tokens = use_pointer_tokens
141 |
142 | def _collapse_name_ops(self, amr):
143 | # identify name triples
144 | name_vars = {}
145 | for i, (v1, rel, v2) in enumerate(amr.triples):
146 | if rel == ':instance' and v2 == 'name':
147 | name_vars[v1] = 1
148 |
149 | # check if they have ops
150 | name_vars_to_ops = defaultdict(list)
151 | for i, (v1, rel, v2) in enumerate(amr.triples):
152 | if v1 in name_vars and rel.startswith(':op'):
153 | name_vars_to_ops[v1].append((i, rel, v2.strip('"')))
154 |
155 | triples = amr.triples.copy()
156 | for nv, ops in name_vars_to_ops.items():
157 | ops = sorted(ops, key=lambda x: int(x[1][3:]))
158 | idx, _, lits = zip(*ops)
159 | for i in idx:
160 | triples[i] = None
161 | lit = '"' + '_'.join(lits) + '"'
162 | triples[min(idx)] = penman.Triple(nv, ':op1', lit)
163 |
164 | triples = [t for t in triples if t is not None]
165 | amr_ = penman.Graph(triples)
166 | amr_.metadata = amr.metadata
167 | return amr_
168 |
169 |
170 | def linearize(self, amr: penman.Graph) -> SemanticGraph:
171 | if self.collapse_name_ops:
172 | amr = self._collapse_name_ops(amr)
173 | linearized = self._linearize(amr)
174 | linearized = self._interleave(linearized)
175 | if self.use_pointer_tokens:
176 | linearized = self._add_pointer_tokens(linearized)
177 | return linearized
178 |
179 | def _linearize(self, amr: penman.Graph) -> SemanticGraph:
180 | variables = set(amr.variables())
181 | variables = {'var:' + v for v in variables}
182 | var2instance = {}
183 |
184 | graph = nx.MultiDiGraph()
185 |
186 | triples2order = {k: i for i, k in enumerate(amr.triples)}
187 |
188 | for triple in amr.triples:
189 | var, rel, instance = triple
190 | order = triples2order[triple]
191 | if rel != ':instance':
192 | continue
193 | for expansion_candidate in itertools.chain(range(order - 1, -1), range(order + 1, len(amr.triples))):
194 | if var == amr.triples[expansion_candidate][2]:
195 | expansion = expansion_candidate
196 | break
197 | else:
198 | expansion = 0
199 | var = 'var:' + var
200 | var2instance[var] = instance
201 | graph.add_node(var, instance=instance, order=order, expansion=expansion)
202 |
203 | for triple in amr.edges():
204 | var1, rel, var2 = triple
205 | order = triples2order[triple]
206 | if rel == ':instance':
207 | continue
208 | var1 = 'var:' + var1
209 | var2 = 'var:' + var2
210 | graph.add_edge(var1, var2, rel=rel, order=order)
211 |
212 | for triple in amr.attributes():
213 | var, rel, attr = triple
214 | order = triples2order[triple]
215 | if rel == ':instance':
216 | continue
217 | var = 'var:' + var
218 | graph.add_edge(var, attr, rel=rel, order=order)
219 |
220 | # nodes that are not reachable from the root (e.g. because of reification)
221 | # will be present in the not_explored queue
222 | # undirected_graph = graph.to_undirected()
223 | # print(amr.variables())
224 | not_explored = deque(sorted(variables, key=lambda x: nx.get_node_attributes(graph, 'order')[x]))
225 | # (
226 | # len(nx.shortest_path(undirected_graph, 'var:' + amr.top, x)),
227 | # -graph.out_degree(x),
228 | # )
229 |
230 | first_index = {}
231 | explored = set()
232 | added_to_queue = set()
233 | nodes_visit = [AMRTokens.BOS_N]
234 | edges_visit = [AMRTokens.BOS_E]
235 | backreferences = [0]
236 | queue = deque()
237 | queue.append('var:' + amr.top)
238 |
239 | while queue or not_explored:
240 |
241 | if queue:
242 | node1 = queue.popleft()
243 | else:
244 | node1 = not_explored.popleft()
245 | if node1 in added_to_queue:
246 | continue
247 | if not list(graph.successors(node1)):
248 | continue
249 |
250 | if node1 in variables:
251 | if node1 in explored:
252 | continue
253 | if node1 in first_index:
254 | nodes_visit.append(AMRTokens.BACKR_TRG_N)
255 | backreferences.append(first_index[node1])
256 | else:
257 | backreferences.append(len(nodes_visit))
258 | first_index[node1] = len(nodes_visit)
259 | nodes_visit.append(node1)
260 | edges_visit.append(AMRTokens.START_E)
261 |
262 | successors = []
263 | for node2 in graph.successors(node1):
264 | for edge_data in graph.get_edge_data(node1, node2).values():
265 | rel = edge_data['rel']
266 | order = edge_data['order']
267 | successors.append((order, rel, node2))
268 | successors = sorted(successors)
269 |
270 | for order, rel, node2 in successors:
271 | edges_visit.append(rel)
272 |
273 | # node2 is a variable
274 | if node2 in variables:
275 | # ... which was mentioned before
276 | if node2 in first_index:
277 | nodes_visit.append(AMRTokens.BACKR_TRG_N)
278 | backreferences.append(first_index[node2])
279 |
280 | # .. which is mentioned for the first time
281 | else:
282 | backreferences.append(len(nodes_visit))
283 | first_index[node2] = len(nodes_visit)
284 | nodes_visit.append(node2)
285 |
286 | # 1) not already in Q
287 | # 2) has children
288 | # 3) the edge right before its expansion has been encountered
289 | if (node2 not in added_to_queue) and list(graph.successors(node2)) and (nx.get_node_attributes(graph, 'expansion')[node2] <= order):
290 | queue.append(node2)
291 | added_to_queue.add(node2)
292 |
293 | # node2 is a constant
294 | else:
295 | backreferences.append(len(nodes_visit))
296 | nodes_visit.append(node2)
297 |
298 | backreferences.append(len(nodes_visit))
299 | nodes_visit.append(AMRTokens.STOP_N)
300 | edges_visit.append(AMRTokens.STOP_E)
301 | explored.add(node1)
302 |
303 | else:
304 | backreferences.append(len(nodes_visit))
305 | nodes_visit.append(node1)
306 | explored.add(node1)
307 |
308 | backreferences.append(len(nodes_visit))
309 | nodes_visit.append(AMRTokens.EOS_N)
310 | edges_visit.append(AMRTokens.EOS_E)
311 | assert len(nodes_visit) == len(edges_visit) == len(backreferences)
312 | return SemanticGraph(
313 | nodes_visit,
314 | edges_visit,
315 | backreferences,
316 | var2instance,
317 | extra={'graph': graph, 'amr': amr}
318 | )
319 |
320 | def _interleave(self, graph: SemanticGraph) -> SemanticGraph:
321 |
322 | new_backreferences_map = []
323 | new_nodes = []
324 | new_edges = None
325 | new_backreferences = []
326 |
327 | # to isolate sublist to the stop token
328 | start_i = 1
329 | end_i = index_default(AMRTokens.STOP_N, graph.nodes_var, start_i, -1, -1)
330 |
331 | def add_node(node, backr = None):
332 | old_n_node = len(new_backreferences_map)
333 | new_n_node = len(new_nodes)
334 |
335 | if backr is None:
336 | backr = old_n_node
337 |
338 | new_backreferences_map.append(new_n_node)
339 | new_nodes.append(node)
340 | if old_n_node == backr:
341 | new_backreferences.append(new_n_node)
342 | else:
343 | new_backreferences.append(new_backreferences_map[backr])
344 |
345 | def add_edge(edge):
346 | new_nodes.append(edge)
347 | new_backreferences.append(len(new_backreferences))
348 |
349 | add_node(AMRTokens.BOS_N)
350 |
351 | while end_i > -1:
352 |
353 | # src node
354 | add_node(graph.nodes_var[start_i], graph.backreferences[start_i])
355 |
356 | # edges and trg nodes, interleaved
357 | nodes = graph.nodes_var[start_i+1:end_i]
358 | edges = graph.edges[start_i+1:end_i]
359 | backr = graph.backreferences[start_i+1:end_i]
360 | for n, e, b in zip(nodes, edges, backr):
361 | add_edge(e)
362 | add_node(n, b)
363 |
364 | # stop
365 | add_node(graph.nodes_var[end_i], graph.backreferences[end_i])
366 |
367 | start_i = end_i + 1
368 | end_i = index_default(AMRTokens.STOP_N, graph.nodes_var, start_i, -1, -1)
369 |
370 | add_node(AMRTokens.EOS_N)
371 |
372 | new_graph = SemanticGraph(
373 | new_nodes,
374 | None,
375 | new_backreferences,
376 | graph.var2instance,
377 | extra=graph.extra,
378 | )
379 | return new_graph
380 |
381 | def _add_pointer_tokens(self, graph: SemanticGraph) -> SemanticGraph:
382 | new_nodes = []
383 | var2pointer = {}
384 | for node, backr in zip(graph.nodes_var, graph.backreferences):
385 |
386 | if node == AMRTokens.BACKR_TRG_N:
387 | node = graph.nodes_var[backr]
388 | pointer = var2pointer[node]
389 | new_nodes.append(pointer)
390 | elif node in graph.var2instance:
391 | pointer = var2pointer.setdefault(node, f"")
392 | new_nodes.append(pointer)
393 | new_nodes.append(node)
394 | else:
395 | new_nodes.append(node)
396 |
397 | new_backreferences = list(range(len(new_nodes)))
398 | new_graph = SemanticGraph(
399 | new_nodes,
400 | None,
401 | new_backreferences,
402 | graph.var2instance,
403 | extra=graph.extra,
404 | )
405 | return new_graph
--------------------------------------------------------------------------------
/spring_amr/optim.py:
--------------------------------------------------------------------------------
1 | # taken from
2 |
3 | import math
4 | import torch
5 | from torch.optim.optimizer import Optimizer, required
6 |
7 |
8 | class RAdam(Optimizer):
9 |
10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
11 | if not 0.0 <= lr:
12 | raise ValueError("Invalid learning rate: {}".format(lr))
13 | if not 0.0 <= eps:
14 | raise ValueError("Invalid epsilon value: {}".format(eps))
15 | if not 0.0 <= betas[0] < 1.0:
16 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
17 | if not 0.0 <= betas[1] < 1.0:
18 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
19 |
20 | self.degenerated_to_sgd = degenerated_to_sgd
21 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
22 | for param in params:
23 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
24 | param['buffer'] = [[None, None, None] for _ in range(10)]
25 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
26 | buffer=[[None, None, None] for _ in range(10)])
27 | super(RAdam, self).__init__(params, defaults)
28 |
29 | def __setstate__(self, state):
30 | super(RAdam, self).__setstate__(state)
31 |
32 | def step(self, closure=None):
33 |
34 | loss = None
35 | if closure is not None:
36 | loss = closure()
37 |
38 | for group in self.param_groups:
39 |
40 | for p in group['params']:
41 | if p.grad is None:
42 | continue
43 | grad = p.grad.data.float()
44 | if grad.is_sparse:
45 | raise RuntimeError('RAdam does not support sparse gradients')
46 |
47 | p_data_fp32 = p.data.float()
48 |
49 | state = self.state[p]
50 |
51 | if len(state) == 0:
52 | state['step'] = 0
53 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
54 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
55 | else:
56 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
57 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
58 |
59 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
60 | beta1, beta2 = group['betas']
61 |
62 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
63 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
64 |
65 | state['step'] += 1
66 | buffered = group['buffer'][int(state['step'] % 10)]
67 | if state['step'] == buffered[0]:
68 | N_sma, step_size = buffered[1], buffered[2]
69 | else:
70 | buffered[0] = state['step']
71 | beta2_t = beta2 ** state['step']
72 | N_sma_max = 2 / (1 - beta2) - 1
73 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
74 | buffered[1] = N_sma
75 |
76 | # more conservative since it's an approximated value
77 | if N_sma >= 5:
78 | step_size = math.sqrt(
79 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
80 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
81 | elif self.degenerated_to_sgd:
82 | step_size = 1.0 / (1 - beta1 ** state['step'])
83 | else:
84 | step_size = -1
85 | buffered[2] = step_size
86 |
87 | # more conservative since it's an approximated value
88 | if N_sma >= 5:
89 | if group['weight_decay'] != 0:
90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
91 | denom = exp_avg_sq.sqrt().add_(group['eps'])
92 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
93 | p.data.copy_(p_data_fp32)
94 | elif step_size > 0:
95 | if group['weight_decay'] != 0:
96 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
97 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
98 | p.data.copy_(p_data_fp32)
99 |
100 | return loss
--------------------------------------------------------------------------------
/spring_amr/penman.py:
--------------------------------------------------------------------------------
1 | from penman import load as load_, Graph, Triple
2 | from penman import loads as loads_
3 | from penman import encode as encode_
4 | from penman.model import Model
5 | from penman.models.noop import NoOpModel
6 | from penman.models import amr
7 |
8 | op_model = Model()
9 | noop_model = NoOpModel()
10 | amr_model = amr.model
11 | DEFAULT = op_model
12 |
13 | def _get_model(dereify):
14 | if dereify is None:
15 | return DEFAULT
16 |
17 |
18 | elif dereify:
19 | return op_model
20 |
21 | else:
22 | return noop_model
23 |
24 | def _remove_wiki(graph):
25 | metadata = graph.metadata
26 | triples = []
27 | for t in graph.triples:
28 | v1, rel, v2 = t
29 | if rel == ':wiki':
30 | t = Triple(v1, rel, '+')
31 | triples.append(t)
32 | graph = Graph(triples)
33 | graph.metadata = metadata
34 | return graph
35 |
36 | def load(source, dereify=None, remove_wiki=False):
37 | model = _get_model(dereify)
38 | out = load_(source=source, model=model)
39 | if remove_wiki:
40 | for i in range(len(out)):
41 | out[i] = _remove_wiki(out[i])
42 | return out
43 |
44 | def loads(string, dereify=None, remove_wiki=False):
45 | model = _get_model(dereify)
46 | out = loads_(string=string, model=model)
47 | if remove_wiki:
48 | for i in range(len(out)):
49 | out[i] = _remove_wiki(out[i])
50 | return out
51 |
52 | def encode(g, top=None, indent=-1, compact=False):
53 | model = amr_model
54 | return encode_(g=g, top=top, indent=indent, compact=compact, model=model)
--------------------------------------------------------------------------------
/spring_amr/postprocessing.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, Counter
2 | import enum
3 | import re
4 |
5 | import networkx as nx
6 | import penman
7 |
8 | from spring_amr.penman import encode
9 |
10 | from spring_amr.linearization import AMRTokens
11 |
12 | BACKOFF = penman.Graph([
13 | penman.Triple('d2', ':instance', 'dog'),
14 | penman.Triple('b1', ':instance', 'bark-01'),
15 | penman.Triple('b1', ':ARG0', 'd2'),])
16 |
17 | def token_processing(tok):
18 | if tok is None:
19 | return None
20 | elif tok.isdigit():
21 | try:
22 | return eval(tok)
23 | except:
24 | return tok
25 | elif tok.startswith('"') and (not tok.endswith('"')):
26 | return tok + '"'
27 | elif tok.endswith('"') and (not tok.startswith('"')):
28 | return '"' + tok
29 | else:
30 | return tok
31 |
32 | def decode_into_node_and_backreferences(subtoken_ids, tokenizer):
33 | rex_arg = re.compile(f"^{tokenizer.INIT}(op|snt|conj|prep)")
34 | rex_spc = re.compile(r"<(s|/s|lit|/lit|stop|unk|pad|mask)>")
35 |
36 | # get strings
37 | subtokens = [tokenizer.decoder.get(t) for t in subtoken_ids]
38 | # fix backreferences
39 | subtoken_backreferences = [max(t - len(tokenizer.encoder), -1) for t in subtoken_ids]
40 | # strip padding
41 | subtokens, subtoken_backreferences = zip(
42 | *[(s, b) for s, b in zip(subtokens, subtoken_backreferences) if s != (tokenizer.INIT + '')])
43 |
44 | # subword collapse
45 | tokens = []
46 | backreferences = []
47 | subword_to_token_map = {}
48 | current_token_i = 0
49 | for subw_i, (subw_backr, subtok) in enumerate(zip(subtoken_backreferences, subtokens)):
50 | subword_to_token_map[subw_i] = current_token_i
51 |
52 | # if empty you cannot do anything but add a new word
53 | if not tokens:
54 | tokens.append(subtok.lstrip(tokenizer.INIT))
55 | backreferences.append(-1)
56 | current_token_i += 1
57 |
58 | # backref can't be splitted
59 | elif subw_backr > -1:
60 | tokens.append(None)
61 | backreferences.append(subword_to_token_map[subw_backr])
62 | current_token_i += 1
63 |
64 | # after a special token release
65 | elif isinstance(tokens[-1], str) and rex_spc.match(tokens[-1]):
66 | tokens.append(subtok.lstrip(tokenizer.INIT))
67 | backreferences.append(-1)
68 | current_token_i += 1
69 |
70 | # after a subtoken ':' (which should be followed by the rest of the edge) ignore tokenizer.INIT
71 | # TODO: this is an ugly patch due to the fact that BART tokenizer splits after ':'
72 | elif (tokens[-1] == ':') and rex_arg.match(subtok):
73 | tokens[-1] = tokens[-1] + subtok[1:]
74 |
75 | # leading tokenizer.INIT
76 | elif subtok.startswith(tokenizer.INIT):
77 | tokens.append(subtok.lstrip(tokenizer.INIT))
78 | backreferences.append(-1)
79 | current_token_i += 1
80 |
81 | # very ugly patch for some cases in which tokenizer.INIT is not in the following token to the edge
82 | elif isinstance(tokens[-1], str) and tokens[-1].startswith(':') and tokens[-1][-1].isdigit() and (subtok != '-of'):
83 | tokens.append(subtok.lstrip(tokenizer.INIT))
84 | backreferences.append(-1)
85 | current_token_i += 1
86 |
87 | # in any other case attach to the previous
88 | else:
89 | tokens[-1] = tokens[-1] + subtok
90 |
91 | # strip INIT and fix byte-level
92 | tokens = [tokenizer.convert_tokens_to_string(list(t)).lstrip() if isinstance(t, str) else t for t in tokens]
93 | # tokens = [t.replace(tokenizer.INIT, '') if isinstance(t, str) else t for t in tokens]
94 |
95 | # unks are substituted with thing
96 | tokens = [t if t != '' else 'thing' for t in tokens]
97 |
98 | old_tokens = tokens
99 | old_backreferences = backreferences
100 |
101 | # Barack Obama -> "Barack Obama"
102 | tokens = []
103 | backreferences = []
104 | token_to_token_map = {}
105 | start_search = 0
106 | removed = 0
107 | while True:
108 | try:
109 |
110 | lit_start = old_tokens.index('', start_search)
111 | token_addition = old_tokens[start_search:lit_start]
112 | for i, t in enumerate(token_addition, start=start_search):
113 | token_to_token_map[i] = i - removed
114 | tokens += token_addition
115 |
116 | backreferences_addition = [token_to_token_map[b] if b > -1 else -1 for b in
117 | old_backreferences[start_search:lit_start]]
118 | backreferences += backreferences_addition
119 |
120 | lit_end = min(lit_start + 2, len(old_tokens) - 1)
121 |
122 | while lit_end < len(old_tokens):
123 | old_tok = old_tokens[lit_end]
124 |
125 | if isinstance(old_tok, str) and (
126 | (old_tok.startswith(':') and len(old_tok) > 3) or (old_tok == '')):
127 | res_tok = old_tokens[lit_start + 1:lit_end]
128 | for i in range(lit_start, lit_end):
129 | token_to_token_map[i] = len(tokens)
130 |
131 | # Remove possible wrong None
132 | res = old_tokens[lit_start+1:lit_end]
133 | res = [str(r) for r in res if r is not None]
134 | res = '"' + '_'.join(res) + '"'
135 |
136 | removed += len(res_tok)
137 | start_search = lit_end
138 | tokens += [res, old_tok]
139 | backreferences += [-1, -1]
140 | break
141 |
142 | elif old_tok == '':
143 | res_tok = old_tokens[lit_start + 1:lit_end]
144 | for i in range(lit_start, lit_end + 1):
145 | token_to_token_map[i] = len(tokens)
146 |
147 | # Remove possible wrong None
148 | res = old_tokens[lit_start+1:lit_end]
149 | res = [str(r) for r in res if r is not None]
150 | res = '"' + '_'.join(res) + '"'
151 |
152 | removed += len(res_tok) + 1
153 | start_search = lit_end + 1
154 | tokens.append(res)
155 | backreferences.append(-1)
156 | break
157 |
158 | else:
159 | lit_end += 1
160 | start_search = lit_end
161 |
162 | except ValueError:
163 | token_addition = old_tokens[start_search:]
164 | for i, t in enumerate(token_addition, start=start_search):
165 | token_to_token_map[i] = i - removed
166 | backreferences_addition = [token_to_token_map[b] if b > -1 else b for b in
167 | old_backreferences[start_search:]]
168 | tokens += token_addition
169 | backreferences += backreferences_addition
170 | break
171 |
172 | tokens = [token_processing(t) for t in tokens]
173 |
174 | shift = 1
175 | if tokens[1] == '':
176 | shift = 2
177 |
178 | tokens = tokens[shift:]
179 | backreferences = [b if b == -1 else b - shift for b in backreferences[shift:]]
180 |
181 | if tokens[-1] == '':
182 | tokens.pop()
183 | backreferences.pop()
184 |
185 | return tokens, backreferences
186 |
187 |
188 | def index_of(element, iterable, default=None, start=None, end=None):
189 | if not callable(element):
190 | def check(x):
191 | return element == x
192 | else:
193 | check = element
194 | if start is None:
195 | start = 0
196 | if end is None:
197 | end = len(iterable)
198 | item = start
199 | while item < end:
200 | if check(iterable[item]):
201 | return item
202 | item += 1
203 | return default
204 |
205 |
206 | def separate_edges_nodes(edges_nodes_slice, *other):
207 | is_arg = lambda x: isinstance(x, str) and x.startswith(':')
208 | start = 0
209 | edges = []
210 | nodes = []
211 | l = len(edges_nodes_slice)
212 | while start < l:
213 | edge_index = index_of(
214 | is_arg,
215 | edges_nodes_slice,
216 | start=start)
217 | if edge_index is None or edge_index == (l - 1):
218 | break
219 | if is_arg(edges_nodes_slice[edge_index + 1]):
220 | start = edge_index + 1
221 | continue
222 | edges.append(edge_index)
223 | nodes.append(edge_index + 1)
224 | start = edge_index + 2
225 | ret = []
226 | for oth in other:
227 | edges_oth = [oth[i] for i in edges]
228 | nodes_oth = [oth[i] for i in nodes]
229 | ret.append((edges_oth, nodes_oth))
230 | return ret
231 |
232 | def _split_name_ops(graph):
233 | # identify name triples
234 | name_vars = {}
235 | for i, (v1, rel, v2) in enumerate(graph.triples):
236 | if rel == ':instance' and v2 == 'name':
237 | name_vars[v1] = 1
238 |
239 | # check if they have ops
240 | name_vars_to_ops = defaultdict(list)
241 | for i, (v1, rel, v2) in enumerate(graph.triples):
242 | if v1 in name_vars and rel.startswith(':op'):
243 | name_vars_to_ops[v1].append((i, rel, v2.strip('"')))
244 |
245 | triples = graph.triples.copy()
246 | for nv, ops in name_vars_to_ops.items():
247 | ops = sorted(ops, key=lambda x: int(x[1][3:]))
248 | idx, _, lits = zip(*ops)
249 | for i in idx:
250 | triples[i] = None
251 |
252 | lits = ['"' + l + '"' for lit in lits for l in lit.split('_')]
253 |
254 | tt = []
255 | for i, l in enumerate(lits, start=1):
256 | rel = ':op' + str(i)
257 | tt.append(penman.Triple(nv, rel, l))
258 |
259 | triples[min(idx)] = tt
260 |
261 | triples = [t if isinstance(t, list) else [t] for t in triples if t is not None]
262 | triples = [t for tt in triples for t in tt]
263 |
264 | graph_ = penman.Graph(triples)
265 | graph_.metadata = graph.metadata
266 | return graph_
267 |
268 | def _reconstruct_graph_from_nodes(nodes, backreferences):
269 | triples = []
270 | triples_added = set()
271 |
272 | variable2index = {}
273 | index2variable = {}
274 | start_index = 0
275 |
276 | cnt = defaultdict(Counter)
277 |
278 | while start_index < len(nodes):
279 | stop_index = index_of('', nodes, default=len(nodes) + 1, start=start_index)
280 | old_start_index = start_index
281 | start_index = stop_index + 1
282 |
283 | src_node, src_backr = nodes[old_start_index], backreferences[old_start_index]
284 |
285 | if src_node == '':
286 | continue
287 |
288 | trg_nodes_edges = nodes[old_start_index:stop_index]
289 | trg_nodes_edges_backr = backreferences[old_start_index:stop_index]
290 | trg_nodes_edges_indices = list(range(old_start_index, stop_index))
291 |
292 | if isinstance(src_node, str):
293 | if src_node in ('', '', ''):
294 | continue
295 | elif ('/' in src_node) or (':' in src_node) or ('(' in src_node) or (')' in src_node):
296 | src_node = 'thing'
297 |
298 | if src_node is not None:
299 | src_node = str(src_node)
300 | src_var = src_node[0].lower()
301 | if not src_var not in 'abcdefghijklmnopqrstuvwxyz':
302 | src_var = 'x'
303 | #src_var = f'{src_var}_{len(variable2index)}'
304 | src_var = f'{src_var}{len(variable2index)}'
305 | src_var_i = old_start_index
306 | variable2index[src_var] = src_var_i
307 | index2variable[src_var_i] = src_var
308 | triple = penman.Triple(src_var, ':instance', src_node)
309 | if triple not in triples_added:
310 | triples.append(triple)
311 | triples_added.add(triple)
312 | else:
313 | if src_backr in index2variable:
314 | src_var = index2variable[src_backr]
315 | # more resilient logic here
316 | (trg_edges, trg_nodes), (_, trg_nodes_backr), (_, trg_nodes_indices) = \
317 | separate_edges_nodes(
318 | trg_nodes_edges,
319 | trg_nodes_edges,
320 | trg_nodes_edges_backr,
321 | trg_nodes_edges_indices)
322 |
323 | for n, e, nb, ni in zip(trg_nodes, trg_edges, trg_nodes_backr, trg_nodes_indices):
324 |
325 | if isinstance(n, str) and n.startswith(':'):
326 | continue
327 | if isinstance(n, str) and n.startswith('<') and n.endswith('>'):
328 | continue
329 | if e == ':li':
330 | pass
331 | elif len(e) < 4 or (not e.startswith(':')):
332 | continue
333 |
334 | # same edge more than once
335 | num = cnt[src_var][e]
336 | # num = 0
337 | if num:
338 |
339 | if e.startswith(':op') or e.startswith(':snt'):
340 | continue
341 | #elif e.startswith(':ARG'):
342 | # continue
343 | elif num > 3:
344 | continue
345 |
346 | if n is None:
347 | if nb not in index2variable:
348 | continue
349 | trg_var = index2variable[nb]
350 | trg = trg_var
351 | elif e == ':mode':
352 | trg = n
353 | elif (not isinstance(n, str)) or re.match(r"^[+-]?\d+\.?\d*$", n) or (n == '-') or (n == '+'):
354 | trg = str(n)
355 | elif (n.startswith('"') and n.endswith('"') and len(n) > 2):
356 | trg = '"' + n.replace('"', '') + '"'
357 | elif ('/' in n) or (':' in n) or ('(' in n) or (')' in n) or ('=' in n):
358 | trg = f'"{n}"'
359 | elif n == '"':
360 | continue
361 | elif (n.startswith('"') and (not n.endswith('"'))) or (not n.startswith('"') and (n.endswith('"'))) or ('"' in n):
362 | trg = '"' + n.replace('"', '') + '"'
363 | else:
364 | trg_var = n[0].lower()
365 | if trg_var not in 'abcdefghijklmnopqrstuvwxyz':
366 | trg_var = 'x'
367 | #trg_var = f'{trg_var}_{len(variable2index)}'
368 | trg_var = f'{trg_var}{len(variable2index)}'
369 | trg_var_i = ni
370 | variable2index[trg_var] = trg_var_i
371 | index2variable[trg_var_i] = trg_var
372 | triple = penman.Triple(trg_var, ':instance', n)
373 | if triple not in triples_added:
374 | triples.append(triple)
375 | triples_added.add(triple)
376 | trg = trg_var
377 |
378 | triple = penman.Triple(src_var, e, trg)
379 | if triple not in triples_added:
380 | triples.append(triple)
381 | triples_added.add(triple)
382 |
383 | cnt[src_var][e] += 1
384 |
385 | return penman.Graph(triples)
386 |
387 | def build_graph(nodes, backreferences, restore_name_ops=False):
388 | graph = _reconstruct_graph_from_nodes(nodes, backreferences)
389 | if restore_name_ops:
390 | graph = _split_name_ops(graph)
391 | return graph
392 |
393 | class ParsedStatus(enum.Enum):
394 | OK = 0
395 | FIXED = 1
396 | BACKOFF = 2
397 |
398 | def connect_graph_if_not_connected(graph):
399 |
400 | try:
401 | encoded = encode(graph)
402 | return graph, ParsedStatus.OK
403 | except:
404 | pass
405 |
406 | nxgraph = nx.MultiGraph()
407 | variables = graph.variables()
408 | for v1, _, v2 in graph.triples:
409 | if v1 in variables and v2 in variables:
410 | nxgraph.add_edge(v1, v2)
411 | elif v1 in variables:
412 | nxgraph.add_edge(v1, v1)
413 |
414 | triples = graph.triples.copy()
415 | new_triples = []
416 | addition = f'a{len(variables) + 1}'
417 | triples.append(penman.Triple(addition, ':instance', 'and'))
418 | for i, conn_set in enumerate(nx.connected_components(nxgraph), start=1):
419 | edge = f':op{i}'
420 | conn_set = sorted(conn_set, key=lambda x: int(x[1:]))
421 | conn_set = [c for c in conn_set if c in variables]
422 | node = conn_set[0]
423 | new_triples.append(penman.Triple(addition, edge, node))
424 | triples = new_triples + triples
425 | metadata = graph.metadata
426 | graph = penman.Graph(triples)
427 | graph.metadata.update(metadata)
428 | encode(graph)
429 |
430 | return graph, ParsedStatus.FIXED
431 |
432 | def restore_backreferences_from_pointers(nodes):
433 | new_nodes, new_backreferences = [], []
434 | prev_pointer = None
435 | pointer2i = {}
436 | for n in nodes:
437 | is_pointer = isinstance(n, str) and n.startswith('')
438 |
439 | if not is_pointer:
440 | if prev_pointer is not None:
441 | if prev_pointer in pointer2i:
442 | new_nodes.append(None)
443 | new_backreferences.append(pointer2i[prev_pointer])
444 | new_nodes.append(n)
445 | new_backreferences.append(-1)
446 |
447 | else:
448 | pointer2i[prev_pointer] = len(new_nodes)
449 | new_nodes.append(n)
450 | new_backreferences.append(-1)
451 | else:
452 | new_nodes.append(n)
453 | new_backreferences.append(-1)
454 |
455 | prev_pointer = None
456 | else:
457 | prev_pointer = n
458 | return new_nodes, new_backreferences
--------------------------------------------------------------------------------
/spring_amr/tokenization_bart.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import sys
3 | from pathlib import Path
4 |
5 | import penman
6 | import regex as re
7 | import torch
8 | from transformers import BartTokenizer
9 |
10 | from spring_amr import ROOT, postprocessing
11 | from spring_amr.linearization import AMRTokens, AMRLinearizer
12 | from spring_amr.penman import encode
13 |
14 |
15 | class AMRBartTokenizer(BartTokenizer):
16 |
17 | INIT = 'Ġ'
18 |
19 | ADDITIONAL = [
20 | AMRTokens.PNTR_N,
21 | AMRTokens.STOP_N,
22 | AMRTokens.LIT_START,
23 | AMRTokens.LIT_END,
24 | AMRTokens.BACKR_SRC_N,
25 | AMRTokens.BACKR_TRG_N,]
26 |
27 | def __init__(self, *args, use_pointer_tokens=False, collapse_name_ops=False, **kwargs):
28 | super().__init__(*args, **kwargs)
29 | self.patterns = re.compile(
30 | r""" ?<[a-z]+:?\d*>| ?:[^\s]+|'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
31 | self.linearizer = AMRLinearizer(use_pointer_tokens=use_pointer_tokens, collapse_name_ops=collapse_name_ops)
32 | self.use_pointer_tokens = use_pointer_tokens
33 | self.collapse_name_ops = collapse_name_ops
34 | self.recategorizations = set()
35 | self.modified = 0
36 |
37 | @classmethod
38 | def from_pretrained(cls, pretrained_model_path, pred_min=5, *args, **kwargs):
39 | inst = super().from_pretrained(pretrained_model_path, *args, **kwargs)
40 | inst.init_amr_vocabulary(pred_min=pred_min)
41 | return inst
42 |
43 | def init_amr_vocabulary(self, pred_min=5):
44 | for tok in [self.bos_token, self.eos_token, self.pad_token, '', '']:
45 | ntok = self.INIT + tok
46 | i = self.encoder[tok]
47 | self.decoder[i] = ntok
48 | del self.encoder[tok]
49 | self.encoder[ntok] = i
50 |
51 | tokens = []
52 | for line in Path(ROOT/'data/vocab/predicates.txt').read_text().strip().splitlines():
53 | tok, count = line.split()
54 | if int(count) >= pred_min:
55 | tokens.append(tok)
56 |
57 | for tok in Path(ROOT/'data/vocab/additions.txt').read_text().strip().splitlines():
58 | tokens.append(tok)
59 |
60 | for tok in Path(ROOT/'data/vocab/recategorizations.txt').read_text().strip().splitlines():
61 | if not tok.startswith('_'):
62 | self.recategorizations.add(tok)
63 | tokens.append(tok)
64 |
65 | if self.use_pointer_tokens:
66 | for cnt in range(512):
67 | tokens.append(f"")
68 |
69 | tokens += self.ADDITIONAL
70 | tokens = [self.INIT + t if t[0] not in ('_', '-') else t for t in tokens]
71 | tokens = [t for t in tokens if t not in self.encoder]
72 | self.old_enc_size = old_enc_size = len(self.encoder)
73 | for i, t in enumerate(tokens, start= old_enc_size):
74 | self.encoder[t] = i
75 |
76 | self.encoder = {k: i for i, (k,v) in enumerate(sorted(self.encoder.items(), key=lambda x: x[1]))}
77 | self.decoder = {v: k for k, v in sorted(self.encoder.items(), key=lambda x: x[1])}
78 | self.modified = len(tokens)
79 |
80 | self.bos_token = self.INIT + ''
81 | self.pad_token = self.INIT + ''
82 | self.eos_token = self.INIT + ''
83 | self.unk_token = self.INIT + ''
84 |
85 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
86 | output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
87 | if token_ids_1 is None:
88 | return output
89 | return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
90 |
91 | def _tokenize(self, text):
92 | """ Tokenize a string. Modified in order to handle sentences with recategorization pointers"""
93 | bpe_tokens = []
94 | for tok_span in text.lstrip().split(' '):
95 | tok_span = tok_span.strip()
96 | recats = tok_span.rsplit('_', 1)
97 | if len(recats) == 2 and recats[0] in self.recategorizations and ('_' + recats[1]) in self.encoder:
98 | bpe_tokens.extend([self.INIT + recats[0], '_' + recats[1]])
99 | else:
100 | for token in re.findall(self.pat, ' ' + tok_span):
101 | token = "".join(
102 | self.byte_encoder[b] for b in token.encode("utf-8")
103 | ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
104 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
105 |
106 | return bpe_tokens
107 |
108 | def _tok_bpe(self, token, add_space=True):
109 | # if add_space:
110 | # token = ' ' + token.lstrip()
111 | tokk = []
112 | tok = token.strip()
113 | recats = tok.rsplit('_', 1)
114 | if len(recats) == 2 and recats[0] in self.recategorizations and ('_' + recats[1]) in self.encoder:
115 | tokk.extend([self.INIT + recats[0], '_' + recats[1]])
116 | else:
117 | for tok in self.patterns.findall(' ' + token):
118 | tok = "".join(
119 | self.byte_encoder[b] for b in tok.encode("utf-8"))
120 | toks = self.bpe(tok).split(' ')
121 | tokk.extend(toks)
122 | return tokk
123 |
124 | def _get_nodes_and_backreferences(self, graph):
125 | lin = self.linearizer.linearize(graph)
126 | linearized_nodes, backreferences = lin.nodes, lin.backreferences
127 | return linearized_nodes, backreferences
128 |
129 | def tokenize_amr(self, graph):
130 | linearized_nodes, backreferences = self._get_nodes_and_backreferences(graph)
131 |
132 | bpe_tokens = []
133 | bpe_backreferences = []
134 | counter = 0
135 |
136 | for i, (backr, tokk) in enumerate(zip(backreferences, linearized_nodes)):
137 | is_in_enc = self.INIT + tokk in self.encoder
138 | is_rel = tokk.startswith(':') and len(tokk) > 1
139 | is_spc = tokk.startswith('<') and tokk.endswith('>')
140 | is_of = tokk.startswith(':') and tokk.endswith('-of')
141 | is_frame = re.match(r'.+-\d\d', tokk) is not None
142 |
143 | if tokk.startswith('"') and tokk.endswith('"'):
144 | tokk = tokk[1:-1].replace('_', ' ')
145 | bpe_toks = [self.INIT + AMRTokens.LIT_START]
146 | bpe_toks += self._tok_bpe(tokk, add_space=True)
147 | bpe_toks.append(self.INIT + AMRTokens.LIT_END)
148 |
149 | elif (is_rel or is_spc or is_frame or is_of):
150 | if is_in_enc:
151 | bpe_toks = [self.INIT + tokk]
152 | elif is_frame:
153 | bpe_toks = self._tok_bpe(tokk[:-3], add_space=True) + [tokk[-3:]]
154 | elif is_of:
155 | rel = tokk[:-3]
156 | if self.INIT + rel in self.encoder:
157 | bpe_toks = [self.INIT + rel, '-of']
158 | else:
159 | bpe_toks = [self.INIT + ':'] + self._tok_bpe(rel[1:], add_space=True) + ['-of']
160 | elif is_rel:
161 | bpe_toks = [self.INIT + ':'] + self._tok_bpe(tokk[1:], add_space=True)
162 | else:
163 | raise
164 |
165 | else:
166 | if is_in_enc:
167 | bpe_toks = [self.INIT + tokk]
168 | else:
169 | bpe_toks = self._tok_bpe(tokk, add_space=True)
170 |
171 | bpe_tokens.append(bpe_toks)
172 |
173 | if i == backr:
174 | bpe_backr = list(range(counter, counter + len(bpe_toks)))
175 | counter += len(bpe_toks)
176 | bpe_backreferences.append(bpe_backr)
177 | else:
178 | bpe_backreferences.append(bpe_backreferences[backr][0:1])
179 | counter += 1
180 | bpe_tokens = [b for bb in bpe_tokens for b in bb]
181 | bpe_token_ids = [self.encoder.get(b, self.unk_token_id) for b in bpe_tokens]
182 | bpe_backreferences = [b for bb in bpe_backreferences for b in bb]
183 | return bpe_tokens, bpe_token_ids, bpe_backreferences
184 |
185 | def batch_encode_sentences(self, sentences, device=torch.device('cpu')):
186 | sentences = [s for s in sentences]
187 | extra = {'sentences': sentences}
188 | batch = super().batch_encode_plus(sentences, return_tensors='pt', pad_to_max_length=True)
189 | batch = {k: v.to(device) for k, v in batch.items()}
190 | return batch, extra
191 |
192 | def linearize(self, graph):
193 | shift = len(self.encoder)
194 | tokens, token_ids, backreferences = self.tokenize_amr(graph)
195 | extra = {'linearized_graphs': tokens, 'graphs': graph}
196 | token_uni_ids = \
197 | [idx if i == b else b + shift for i, (idx, b) in enumerate(zip(token_ids, backreferences))]
198 | if token_uni_ids[-1] != (self.INIT + AMRTokens.EOS_N):
199 | tokens.append(self.INIT + AMRTokens.EOS_N)
200 | token_ids.append(self.eos_token_id)
201 | token_uni_ids.append(self.eos_token_id)
202 | backreferences.append(len(backreferences))
203 | return token_uni_ids, extra
204 |
205 | def batch_encode_graphs(self, graphs, device=torch.device('cpu')):
206 | linearized, extras = zip(*[self.linearize(g) for g in graphs])
207 | return self.batch_encode_graphs_from_linearized(linearized, extras, device=device)
208 |
209 | def batch_encode_graphs_from_linearized(self, linearized, extras=None, device=torch.device('cpu')):
210 | if extras is not None:
211 | batch_extra = {'linearized_graphs': [], 'graphs': []}
212 | for extra in extras:
213 | batch_extra['graphs'].append(extra['graphs'])
214 | batch_extra['linearized_graphs'].append(extra['linearized_graphs'])
215 | else:
216 | batch_extra = {}
217 | maxlen = 0
218 | batch = []
219 | for token_uni_ids in linearized:
220 | maxlen = max(len(token_uni_ids), maxlen)
221 | batch.append(token_uni_ids)
222 | batch = [x + [self.pad_token_id] * (maxlen - len(x)) for x in batch]
223 | batch = torch.tensor(batch).to(device)
224 | batch = {'decoder_input_ids': batch[:, :-1], 'lm_labels': batch[:, 1:]}
225 | return batch, batch_extra
226 |
227 | def decode_amr(self, tokens, restore_name_ops=False):
228 | try:
229 | nodes, backreferences = postprocessing.decode_into_node_and_backreferences(tokens, self)
230 | except Exception as e:
231 | print('Decoding failure:', file=sys.stderr)
232 | print(e, file=sys.stderr)
233 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None)
234 | if self.use_pointer_tokens:
235 | nodes, backreferences = postprocessing.restore_backreferences_from_pointers(nodes)
236 | try:
237 | graph_ = graph = postprocessing.build_graph(nodes, backreferences, restore_name_ops=restore_name_ops)
238 | except Exception as e:
239 | print('Building failure:', file=sys.stderr)
240 | print(nodes, file=sys.stderr)
241 | print(backreferences, file=sys.stderr)
242 | print(e, file=sys.stderr)
243 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None)
244 | try:
245 | graph, status = postprocessing.connect_graph_if_not_connected(graph)
246 | if status == postprocessing.ParsedStatus.BACKOFF:
247 | print('Reconnection 1 failure:')
248 | print(nodes, file=sys.stderr)
249 | print(backreferences, file=sys.stderr)
250 | print(graph_, file=sys.stderr)
251 | return graph, status, (nodes, backreferences)
252 | except Exception as e:
253 | print('Reconnction 2 failure:', file=sys.stderr)
254 | print(e, file=sys.stderr)
255 | print(nodes, file=sys.stderr)
256 | print(backreferences, file=sys.stderr)
257 | print(graph_, file=sys.stderr)
258 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (nodes, backreferences)
259 |
260 | class PENMANBartTokenizer(AMRBartTokenizer):
261 |
262 | def __init__(self, *args, raw_graph=False, **kwargs):
263 | super().__init__(*args, **kwargs)
264 | self.linearizer = None
265 | self.remove_pars = False
266 | self.raw_graph = raw_graph
267 |
268 | def _tokenize_encoded_graph(self, encoded):
269 | linearized = re.sub(r"(\".+?\")", r' \1 ', encoded)
270 | pieces = []
271 | for piece in linearized.split():
272 | if piece.startswith('"') and piece.endswith('"'):
273 | pieces.append(piece)
274 | else:
275 | piece = piece.replace('(', ' ( ')
276 | piece = piece.replace(')', ' ) ')
277 | piece = piece.replace(':', ' :')
278 | piece = piece.replace('/', ' / ')
279 | piece = piece.strip()
280 | pieces.append(piece)
281 | linearized = re.sub(r'\s+', ' ', ' '.join(pieces)).strip()
282 | linearized_nodes = [AMRTokens.BOS_N] + linearized.split(' ')
283 | return linearized_nodes
284 |
285 | def tokenize_amr(self, graph):
286 | if self.raw_graph:
287 | graph_ = copy.deepcopy(graph)
288 | graph_.metadata = {}
289 | linearized = penman.encode(graph_)
290 | linearized = re.sub(r"\s+", ' ', linearized)
291 | bpe_tokens = [self.bos_token] + self._tokenize(linearized)[:1022]
292 | bpe_token_ids = [self.encoder.get(b, self.unk_token_id) for b in bpe_tokens]
293 | bpe_backreferences = list(range(len(bpe_token_ids)))
294 | return bpe_tokens, bpe_token_ids, bpe_backreferences
295 | else:
296 | return super().tokenize_amr(graph)
297 |
298 | def _get_nodes_and_backreferences(self, graph):
299 | graph_ = copy.deepcopy(graph)
300 | graph_.metadata = {}
301 | linearized = penman.encode(graph_)
302 | linearized_nodes = self._tokenize_encoded_graph(linearized)
303 |
304 | if self.use_pointer_tokens:
305 | remap = {}
306 | for i in range(1, len(linearized_nodes)):
307 | nxt = linearized_nodes[i]
308 | lst = linearized_nodes[i-1]
309 | if nxt == '/':
310 | remap[lst] = f''
311 | i = 1
312 | linearized_nodes_ = [linearized_nodes[0]]
313 | while i < (len(linearized_nodes)):
314 | nxt = linearized_nodes[i]
315 | lst = linearized_nodes_[-1]
316 | if nxt in remap:
317 | if lst == '(' and linearized_nodes[i+1] == '/':
318 | nxt = remap[nxt]
319 | i += 1
320 | elif lst.startswith(':'):
321 | nxt = remap[nxt]
322 | linearized_nodes_.append(nxt)
323 | i += 1
324 | linearized_nodes = linearized_nodes_
325 | if self.remove_pars:
326 | linearized_nodes = [n for n in linearized_nodes if n != '(']
327 | backreferences = list(range(len(linearized_nodes)))
328 | return linearized_nodes, backreferences
329 |
330 | def _classify(self, node):
331 | if not isinstance(node, str):
332 | return "CONST"
333 | elif node == 'i':
334 | return "I"
335 | elif re.match(r'^[a-z]\d*$', node) is not None:
336 | return "VAR"
337 | elif node[0].isdigit():
338 | return "CONST"
339 | elif node.startswith('"') and node.endswith('"'):
340 | return "CONST"
341 | elif node in ('+', '-'):
342 | return "CONST"
343 | elif node == ':mode':
344 | return 'MODE'
345 | elif node.startswith(':'):
346 | return "EDGE"
347 | elif node in ['/', '(', ')']:
348 | return node
349 | elif node[0].isalpha():
350 | for char in (',', ':', '/', '(', ')', '.', '!', '?', '\\'):
351 | if char in node:
352 | return "CONST"
353 | return "INST"
354 | else:
355 | return 'CONST'
356 |
357 | def _fix_and_make_graph(self, nodes):
358 |
359 | nodes_ = []
360 | for n in nodes:
361 | if isinstance(n, str):
362 | if n.startswith('<') and n.endswith('>') and (not n.startswith('')
379 | if e != len(nxt) -1:
380 | pst = nxt[e+1:]
381 | nxt = nxt[:e+1]
382 | nodes_.append(nxt)
383 | if pst is not None:
384 | nodes_.append(pst)
385 | else:
386 | nodes_.append(nxt)
387 | i += 1
388 | nodes = nodes_
389 |
390 | i = 1
391 | nodes_ = [nodes[0]]
392 | while i < len(nodes):
393 | nxt = nodes[i]
394 | if isinstance(nxt, str) and nxt.startswith(' 0:
570 | line = line[:i].strip()
571 | break
572 | old_line = line
573 | while True:
574 | open_count = len(re.findall(r'\(', line))
575 | close_count = len(re.findall(r'\)', line))
576 | if open_count > close_count:
577 | line += ')' * (open_count - close_count)
578 | elif close_count > open_count:
579 | for i in range(close_count - open_count):
580 | line = line.rstrip(')')
581 | line = line.rstrip(' ')
582 | if old_line == line:
583 | break
584 | old_line = line
585 | """
586 |
587 | graph = penman.decode(linearized + ' ')
588 | triples = []
589 | newvars = 2000
590 | for triple in graph.triples:
591 | x, rel, y = triple
592 | if x is None:
593 | pass
594 | elif rel == ':instance' and y is None:
595 | triples.append(penman.Triple(x, rel, 'thing'))
596 | elif y is None:
597 | var = f'z{newvars}'
598 | newvars += 1
599 | triples.append(penman.Triple(x, rel, var))
600 | triples.append(penman.Triple(var, ':instance', 'thing'))
601 | else:
602 | triples.append(triple)
603 | graph = penman.Graph(triples)
604 | linearized = encode(graph)
605 |
606 | def fix_text(linearized=linearized):
607 | n = 0
608 | def _repl1(match):
609 | nonlocal n
610 | out = match.group(1) + match.group(2) + str(3000 + n) + ' / ' + match.group(2) + match.group(3)
611 | n += 1
612 | return out
613 | linearized = re.sub(r'(\(\s?)([a-z])([^\/:\)]+[:\)])', _repl1, linearized,
614 | flags=re.IGNORECASE | re.MULTILINE)
615 |
616 | def _repl2(match):
617 | return match.group(1)
618 | linearized = re.sub(r'(\(\s*[a-z][\d+]\s*\/\s*[^\s\)\(:\/]+\s*)((?:/\s*[^\s\)\(:\/]+\s*)+)', _repl2,
619 | linearized,
620 | flags=re.IGNORECASE | re.MULTILINE)
621 |
622 | # adds a ':' to args w/o it
623 | linearized = re.sub(r'([^:])(ARG)', r'\1 :\2', linearized)
624 |
625 | # removes edges with no node
626 | # linearized = re.sub(r':[^\s\)\(:\/]+?\s*\)', ')', linearized, flags=re.MULTILINE)
627 |
628 | return linearized
629 |
630 | linearized = fix_text(linearized)
631 |
632 | g = penman.decode(linearized)
633 | return g
634 |
635 | def decode_amr(self, tokens, restore_name_ops=None):
636 | try:
637 | if self.raw_graph:
638 | nodes = self._tokenize_encoded_graph(self.decode(tokens))
639 | backreferences = list(range(len(nodes)))
640 | else:
641 | nodes, backreferences = postprocessing.decode_into_node_and_backreferences(tokens, self)
642 | nodes_ = nodes
643 | except Exception as e:
644 | print('Decoding failure:', file=sys.stderr)
645 | print(e, file=sys.stderr)
646 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None)
647 | try:
648 | graph_ = graph = self._fix_and_make_graph(nodes)
649 | if self.collapse_name_ops:
650 | graph_ = graph = postprocessing._split_name_ops(graph)
651 | except Exception as e:
652 | print('Building failure:', file=sys.stderr)
653 | print(nodes, file=sys.stderr)
654 | print(backreferences, file=sys.stderr)
655 | print(e, file=sys.stderr)
656 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None)
657 | try:
658 | graph, status = postprocessing.connect_graph_if_not_connected(graph)
659 | if status == postprocessing.ParsedStatus.BACKOFF:
660 | print('Reconnection 1 failure:')
661 | print(nodes, file=sys.stderr)
662 | print(backreferences, file=sys.stderr)
663 | print(graph_, file=sys.stderr)
664 | return graph, status, (nodes_, backreferences)
665 | except Exception as e:
666 | print('Reconnction 2 failure:', file=sys.stderr)
667 | print(e, file=sys.stderr)
668 | print(nodes, file=sys.stderr)
669 | print(backreferences, file=sys.stderr)
670 | print(graph_, file=sys.stderr)
671 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (nodes_, backreferences)
672 |
--------------------------------------------------------------------------------
/spring_amr/utils.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from pathlib import Path
3 |
4 | import torch
5 | from transformers import AutoConfig
6 |
7 | from spring_amr.dataset import AMRDataset, AMRDatasetTokenBatcherAndLoader
8 | from spring_amr.modeling_bart import AMRBartForConditionalGeneration
9 | from spring_amr.tokenization_bart import AMRBartTokenizer, PENMANBartTokenizer
10 |
11 |
12 | def instantiate_model_and_tokenizer(
13 | name=None,
14 | checkpoint=None,
15 | additional_tokens_smart_init=True,
16 | dropout = 0.15,
17 | attention_dropout = 0.15,
18 | from_pretrained = True,
19 | init_reverse = False,
20 | collapse_name_ops = False,
21 | penman_linearization = False,
22 | use_pointer_tokens = False,
23 | raw_graph = False,
24 | ):
25 | if raw_graph:
26 | assert penman_linearization
27 |
28 | skip_relations = False
29 |
30 | if name is None:
31 | name = 'facebook/bart-large'
32 |
33 | if name == 'facebook/bart-base':
34 | tokenizer_name = 'facebook/bart-large'
35 | else:
36 | tokenizer_name = name
37 |
38 | config = AutoConfig.from_pretrained(name)
39 | config.output_past = False
40 | config.no_repeat_ngram_size = 0
41 | config.prefix = " "
42 | config.output_attentions = True
43 | config.dropout = dropout
44 | config.attention_dropout = attention_dropout
45 |
46 | if penman_linearization:
47 | tokenizer = PENMANBartTokenizer.from_pretrained(
48 | tokenizer_name,
49 | collapse_name_ops=collapse_name_ops,
50 | use_pointer_tokens=use_pointer_tokens,
51 | raw_graph=raw_graph,
52 | config=config,
53 | )
54 | else:
55 | tokenizer = AMRBartTokenizer.from_pretrained(
56 | tokenizer_name,
57 | collapse_name_ops=collapse_name_ops,
58 | use_pointer_tokens=use_pointer_tokens,
59 | config=config,
60 | )
61 |
62 | if from_pretrained:
63 | model = AMRBartForConditionalGeneration.from_pretrained(name, config=config)
64 | else:
65 | model = AMRBartForConditionalGeneration(config)
66 |
67 | model.resize_token_embeddings(len(tokenizer.encoder))
68 |
69 | if additional_tokens_smart_init:
70 | modified = 0
71 | for tok, idx in tokenizer.encoder.items():
72 | tok = tok.lstrip(tokenizer.INIT)
73 |
74 | if idx < tokenizer.old_enc_size:
75 | continue
76 |
77 | elif tok.startswith(''):
78 | tok_split = ['pointer', str(tok.split(':')[1].strip('>'))]
79 |
80 | elif tok.startswith('<'):
81 | continue
82 |
83 | elif tok.startswith(':'):
84 |
85 | if skip_relations:
86 | continue
87 |
88 | elif tok.startswith(':op'):
89 | tok_split = ['relation', 'operator', str(int(tok[3:]))]
90 |
91 | elif tok.startswith(':snt'):
92 | tok_split = ['relation', 'sentence', str(int(tok[4:]))]
93 |
94 | elif tok.startswith(':ARG'):
95 | tok_split = ['relation', 'argument', str(int(tok[4:]))]
96 |
97 | else:
98 | tok_split = ['relation'] + tok.lstrip(':').split('-')
99 |
100 | else:
101 | tok_split = tok.split('-')
102 |
103 | tok_split_ = tok_split
104 | tok_split = []
105 | for s in tok_split_:
106 | s_ = s + tokenizer.INIT
107 | if s_ in tokenizer.encoder:
108 | tok_split.append(s_)
109 | else:
110 | tok_split.extend(tokenizer._tok_bpe(s))
111 |
112 | vecs = []
113 | for s in tok_split:
114 | idx_split = tokenizer.encoder.get(s, -1)
115 | if idx_split > -1:
116 | vec_split = model.model.shared.weight.data[idx_split].clone()
117 | vecs.append(vec_split)
118 |
119 | if vecs:
120 | vec = torch.stack(vecs, 0).mean(0)
121 | noise = torch.empty_like(vec)
122 | noise.uniform_(-0.1, +0.1)
123 | model.model.shared.weight.data[idx] = vec + noise
124 | modified += 1
125 |
126 | if init_reverse:
127 | model.init_reverse_model()
128 |
129 | if checkpoint is not None:
130 | model.load_state_dict(torch.load(checkpoint, map_location='cpu')['model'])
131 |
132 | return model, tokenizer
133 |
134 |
135 | def instantiate_loader(
136 | glob_pattn,
137 | tokenizer,
138 | batch_size=500,
139 | evaluation=True,
140 | out=None,
141 | use_recategorization=False,
142 | remove_longer_than=None,
143 | remove_wiki=False,
144 | dereify=True,
145 | ):
146 | paths = []
147 | if isinstance(glob_pattn, str) or isinstance(glob_pattn, Path):
148 | glob_pattn = [glob_pattn]
149 | for gpattn in glob_pattn:
150 | paths += [Path(p) for p in glob(gpattn)]
151 | if evaluation:
152 | assert out is not None
153 | Path(out).write_text(
154 | '\n\n'.join([p.read_text() for p in paths]))
155 | dataset = AMRDataset(
156 | paths,
157 | tokenizer,
158 | use_recategorization=use_recategorization,
159 | remove_longer_than=remove_longer_than,
160 | remove_wiki=remove_wiki,
161 | dereify=dereify,
162 | )
163 | loader = AMRDatasetTokenBatcherAndLoader(
164 | dataset,
165 | batch_size=batch_size,
166 | shuffle=not evaluation,
167 | )
168 | return loader
169 |
--------------------------------------------------------------------------------