├── LICENSE ├── README.md ├── bin ├── blinkify.py ├── eval_bleu.py ├── inspect_.py ├── patch_legacy_checkpoint.py ├── predict_amrs.py ├── predict_amrs_from_plaintext.py ├── predict_sentences.py └── train.py ├── configs └── config.yaml ├── data └── vocab │ ├── additions.txt │ ├── predicates.txt │ └── recategorizations.txt ├── docs ├── appendix.pdf ├── camera-ready.pdf └── preprint.pdf ├── requirements.txt ├── sample.txt ├── setup.py └── spring_amr ├── IO.py ├── __init__.py ├── dataset.py ├── entities.py ├── evaluation.py ├── linearization.py ├── modeling_bart.py ├── optim.py ├── penman.py ├── postprocessing.py ├── tokenization_bart.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | ======================================================================= 2 | 3 | Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | ======================================================================= 6 | 7 | Creative Commons Corporation ("Creative Commons") is not a law firm and 8 | does not provide legal services or legal advice. Distribution of 9 | Creative Commons public licenses does not create a lawyer-client or 10 | other relationship. Creative Commons makes its licenses and related 11 | information available on an "as-is" basis. Creative Commons gives no 12 | warranties regarding its licenses, any material licensed under their 13 | terms and conditions, or any related information. Creative Commons 14 | disclaims all liability for damages resulting from their use to the 15 | fullest extent possible. 16 | 17 | Using Creative Commons Public Licenses 18 | 19 | Creative Commons public licenses provide a standard set of terms and 20 | conditions that creators and other rights holders may use to share 21 | original works of authorship and other material subject to copyright 22 | and certain other rights specified in the public license below. The 23 | following considerations are for informational purposes only, are not 24 | exhaustive, and do not form part of our licenses. 25 | 26 | Considerations for licensors: Our public licenses are 27 | intended for use by those authorized to give the public 28 | permission to use material in ways otherwise restricted by 29 | copyright and certain other rights. Our licenses are 30 | irrevocable. Licensors should read and understand the terms 31 | and conditions of the license they choose before applying it. 32 | Licensors should also secure all rights necessary before 33 | applying our licenses so that the public can reuse the 34 | material as expected. Licensors should clearly mark any 35 | material not subject to the license. This includes other CC- 36 | licensed material, or material used under an exception or 37 | limitation to copyright. More considerations for licensors: 38 | wiki.creativecommons.org/Considerations_for_licensors 39 | 40 | Considerations for the public: By using one of our public 41 | licenses, a licensor grants the public permission to use the 42 | licensed material under specified terms and conditions. If 43 | the licensor's permission is not necessary for any reason--for 44 | example, because of any applicable exception or limitation to 45 | copyright--then that use is not regulated by the license. Our 46 | licenses grant only permissions under copyright and certain 47 | other rights that a licensor has authority to grant. Use of 48 | the licensed material may still be restricted for other 49 | reasons, including because others have copyright or other 50 | rights in the material. A licensor may make special requests, 51 | such as asking that all changes be marked or described. 52 | Although not required by our licenses, you are encouraged to 53 | respect those requests where reasonable. More considerations 54 | for the public: 55 | wiki.creativecommons.org/Considerations_for_licensees 56 | 57 | ======================================================================= 58 | 59 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 60 | Public License 61 | 62 | By exercising the Licensed Rights (defined below), You accept and agree 63 | to be bound by the terms and conditions of this Creative Commons 64 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 65 | ("Public License"). To the extent this Public License may be 66 | interpreted as a contract, You are granted the Licensed Rights in 67 | consideration of Your acceptance of these terms and conditions, and the 68 | Licensor grants You such rights in consideration of benefits the 69 | Licensor receives from making the Licensed Material available under 70 | these terms and conditions. 71 | 72 | 73 | Section 1 -- Definitions. 74 | 75 | a. Adapted Material means material subject to Copyright and Similar 76 | Rights that is derived from or based upon the Licensed Material 77 | and in which the Licensed Material is translated, altered, 78 | arranged, transformed, or otherwise modified in a manner requiring 79 | permission under the Copyright and Similar Rights held by the 80 | Licensor. For purposes of this Public License, where the Licensed 81 | Material is a musical work, performance, or sound recording, 82 | Adapted Material is always produced where the Licensed Material is 83 | synched in timed relation with a moving image. 84 | 85 | b. Adapter's License means the license You apply to Your Copyright 86 | and Similar Rights in Your contributions to Adapted Material in 87 | accordance with the terms and conditions of this Public License. 88 | 89 | c. BY-NC-SA Compatible License means a license listed at 90 | creativecommons.org/compatiblelicenses, approved by Creative 91 | Commons as essentially the equivalent of this Public License. 92 | 93 | d. Copyright and Similar Rights means copyright and/or similar rights 94 | closely related to copyright including, without limitation, 95 | performance, broadcast, sound recording, and Sui Generis Database 96 | Rights, without regard to how the rights are labeled or 97 | categorized. For purposes of this Public License, the rights 98 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 99 | Rights. 100 | 101 | e. Effective Technological Measures means those measures that, in the 102 | absence of proper authority, may not be circumvented under laws 103 | fulfilling obligations under Article 11 of the WIPO Copyright 104 | Treaty adopted on December 20, 1996, and/or similar international 105 | agreements. 106 | 107 | f. Exceptions and Limitations means fair use, fair dealing, and/or 108 | any other exception or limitation to Copyright and Similar Rights 109 | that applies to Your use of the Licensed Material. 110 | 111 | g. License Elements means the license attributes listed in the name 112 | of a Creative Commons Public License. The License Elements of this 113 | Public License are Attribution, NonCommercial, and ShareAlike. 114 | 115 | h. Licensed Material means the artistic or literary work, database, 116 | or other material to which the Licensor applied this Public 117 | License. 118 | 119 | i. Licensed Rights means the rights granted to You subject to the 120 | terms and conditions of this Public License, which are limited to 121 | all Copyright and Similar Rights that apply to Your use of the 122 | Licensed Material and that the Licensor has authority to license. 123 | 124 | j. Licensor means the individual(s) or entity(ies) granting rights 125 | under this Public License. 126 | 127 | k. NonCommercial means not primarily intended for or directed towards 128 | commercial advantage or monetary compensation. For purposes of 129 | this Public License, the exchange of the Licensed Material for 130 | other material subject to Copyright and Similar Rights by digital 131 | file-sharing or similar means is NonCommercial provided there is 132 | no payment of monetary compensation in connection with the 133 | exchange. 134 | 135 | l. Share means to provide material to the public by any means or 136 | process that requires permission under the Licensed Rights, such 137 | as reproduction, public display, public performance, distribution, 138 | dissemination, communication, or importation, and to make material 139 | available to the public including in ways that members of the 140 | public may access the material from a place and at a time 141 | individually chosen by them. 142 | 143 | m. Sui Generis Database Rights means rights other than copyright 144 | resulting from Directive 96/9/EC of the European Parliament and of 145 | the Council of 11 March 1996 on the legal protection of databases, 146 | as amended and/or succeeded, as well as other essentially 147 | equivalent rights anywhere in the world. 148 | 149 | n. You means the individual or entity exercising the Licensed Rights 150 | under this Public License. Your has a corresponding meaning. 151 | 152 | 153 | Section 2 -- Scope. 154 | 155 | a. License grant. 156 | 157 | 1. Subject to the terms and conditions of this Public License, 158 | the Licensor hereby grants You a worldwide, royalty-free, 159 | non-sublicensable, non-exclusive, irrevocable license to 160 | exercise the Licensed Rights in the Licensed Material to: 161 | 162 | a. reproduce and Share the Licensed Material, in whole or 163 | in part, for NonCommercial purposes only; and 164 | 165 | b. produce, reproduce, and Share Adapted Material for 166 | NonCommercial purposes only. 167 | 168 | 2. Exceptions and Limitations. For the avoidance of doubt, where 169 | Exceptions and Limitations apply to Your use, this Public 170 | License does not apply, and You do not need to comply with 171 | its terms and conditions. 172 | 173 | 3. Term. The term of this Public License is specified in Section 174 | 6(a). 175 | 176 | 4. Media and formats; technical modifications allowed. The 177 | Licensor authorizes You to exercise the Licensed Rights in 178 | all media and formats whether now known or hereafter created, 179 | and to make technical modifications necessary to do so. The 180 | Licensor waives and/or agrees not to assert any right or 181 | authority to forbid You from making technical modifications 182 | necessary to exercise the Licensed Rights, including 183 | technical modifications necessary to circumvent Effective 184 | Technological Measures. For purposes of this Public License, 185 | simply making modifications authorized by this Section 2(a) 186 | (4) never produces Adapted Material. 187 | 188 | 5. Downstream recipients. 189 | 190 | a. Offer from the Licensor -- Licensed Material. Every 191 | recipient of the Licensed Material automatically 192 | receives an offer from the Licensor to exercise the 193 | Licensed Rights under the terms and conditions of this 194 | Public License. 195 | 196 | b. Additional offer from the Licensor -- Adapted Material. 197 | Every recipient of Adapted Material from You 198 | automatically receives an offer from the Licensor to 199 | exercise the Licensed Rights in the Adapted Material 200 | under the conditions of the Adapter's License You apply. 201 | 202 | c. No downstream restrictions. You may not offer or impose 203 | any additional or different terms or conditions on, or 204 | apply any Effective Technological Measures to, the 205 | Licensed Material if doing so restricts exercise of the 206 | Licensed Rights by any recipient of the Licensed 207 | Material. 208 | 209 | 6. No endorsement. Nothing in this Public License constitutes or 210 | may be construed as permission to assert or imply that You 211 | are, or that Your use of the Licensed Material is, connected 212 | with, or sponsored, endorsed, or granted official status by, 213 | the Licensor or others designated to receive attribution as 214 | provided in Section 3(a)(1)(A)(i). 215 | 216 | b. Other rights. 217 | 218 | 1. Moral rights, such as the right of integrity, are not 219 | licensed under this Public License, nor are publicity, 220 | privacy, and/or other similar personality rights; however, to 221 | the extent possible, the Licensor waives and/or agrees not to 222 | assert any such rights held by the Licensor to the limited 223 | extent necessary to allow You to exercise the Licensed 224 | Rights, but not otherwise. 225 | 226 | 2. Patent and trademark rights are not licensed under this 227 | Public License. 228 | 229 | 3. To the extent possible, the Licensor waives any right to 230 | collect royalties from You for the exercise of the Licensed 231 | Rights, whether directly or through a collecting society 232 | under any voluntary or waivable statutory or compulsory 233 | licensing scheme. In all other cases the Licensor expressly 234 | reserves any right to collect such royalties, including when 235 | the Licensed Material is used other than for NonCommercial 236 | purposes. 237 | 238 | 239 | Section 3 -- License Conditions. 240 | 241 | Your exercise of the Licensed Rights is expressly made subject to the 242 | following conditions. 243 | 244 | a. Attribution. 245 | 246 | 1. If You Share the Licensed Material (including in modified 247 | form), You must: 248 | 249 | a. retain the following if it is supplied by the Licensor 250 | with the Licensed Material: 251 | 252 | i. identification of the creator(s) of the Licensed 253 | Material and any others designated to receive 254 | attribution, in any reasonable manner requested by 255 | the Licensor (including by pseudonym if 256 | designated); 257 | 258 | ii. a copyright notice; 259 | 260 | iii. a notice that refers to this Public License; 261 | 262 | iv. a notice that refers to the disclaimer of 263 | warranties; 264 | 265 | v. a URI or hyperlink to the Licensed Material to the 266 | extent reasonably practicable; 267 | 268 | b. indicate if You modified the Licensed Material and 269 | retain an indication of any previous modifications; and 270 | 271 | c. indicate the Licensed Material is licensed under this 272 | Public License, and include the text of, or the URI or 273 | hyperlink to, this Public License. 274 | 275 | 2. You may satisfy the conditions in Section 3(a)(1) in any 276 | reasonable manner based on the medium, means, and context in 277 | which You Share the Licensed Material. For example, it may be 278 | reasonable to satisfy the conditions by providing a URI or 279 | hyperlink to a resource that includes the required 280 | information. 281 | 3. If requested by the Licensor, You must remove any of the 282 | information required by Section 3(a)(1)(A) to the extent 283 | reasonably practicable. 284 | 285 | b. ShareAlike. 286 | 287 | In addition to the conditions in Section 3(a), if You Share 288 | Adapted Material You produce, the following conditions also apply. 289 | 290 | 1. The Adapter's License You apply must be a Creative Commons 291 | license with the same License Elements, this version or 292 | later, or a BY-NC-SA Compatible License. 293 | 294 | 2. You must include the text of, or the URI or hyperlink to, the 295 | Adapter's License You apply. You may satisfy this condition 296 | in any reasonable manner based on the medium, means, and 297 | context in which You Share Adapted Material. 298 | 299 | 3. You may not offer or impose any additional or different terms 300 | or conditions on, or apply any Effective Technological 301 | Measures to, Adapted Material that restrict exercise of the 302 | rights granted under the Adapter's License You apply. 303 | 304 | 305 | Section 4 -- Sui Generis Database Rights. 306 | 307 | Where the Licensed Rights include Sui Generis Database Rights that 308 | apply to Your use of the Licensed Material: 309 | 310 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 311 | to extract, reuse, reproduce, and Share all or a substantial 312 | portion of the contents of the database for NonCommercial purposes 313 | only; 314 | 315 | b. if You include all or a substantial portion of the database 316 | contents in a database in which You have Sui Generis Database 317 | Rights, then the database in which You have Sui Generis Database 318 | Rights (but not its individual contents) is Adapted Material, 319 | including for purposes of Section 3(b); and 320 | 321 | c. You must comply with the conditions in Section 3(a) if You Share 322 | all or a substantial portion of the contents of the database. 323 | 324 | For the avoidance of doubt, this Section 4 supplements and does not 325 | replace Your obligations under this Public License where the Licensed 326 | Rights include other Copyright and Similar Rights. 327 | 328 | 329 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 330 | 331 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 332 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 333 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 334 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 335 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 336 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 337 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 338 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 339 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 340 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 341 | 342 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 343 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 344 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 345 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 346 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 347 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 348 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 349 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 350 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 351 | 352 | c. The disclaimer of warranties and limitation of liability provided 353 | above shall be interpreted in a manner that, to the extent 354 | possible, most closely approximates an absolute disclaimer and 355 | waiver of all liability. 356 | 357 | 358 | Section 6 -- Term and Termination. 359 | 360 | a. This Public License applies for the term of the Copyright and 361 | Similar Rights licensed here. However, if You fail to comply with 362 | this Public License, then Your rights under this Public License 363 | terminate automatically. 364 | 365 | b. Where Your right to use the Licensed Material has terminated under 366 | Section 6(a), it reinstates: 367 | 368 | 1. automatically as of the date the violation is cured, provided 369 | it is cured within 30 days of Your discovery of the 370 | violation; or 371 | 372 | 2. upon express reinstatement by the Licensor. 373 | 374 | For the avoidance of doubt, this Section 6(b) does not affect any 375 | right the Licensor may have to seek remedies for Your violations 376 | of this Public License. 377 | 378 | c. For the avoidance of doubt, the Licensor may also offer the 379 | Licensed Material under separate terms or conditions or stop 380 | distributing the Licensed Material at any time; however, doing so 381 | will not terminate this Public License. 382 | 383 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 384 | License. 385 | 386 | 387 | Section 7 -- Other Terms and Conditions. 388 | 389 | a. The Licensor shall not be bound by any additional or different 390 | terms or conditions communicated by You unless expressly agreed. 391 | 392 | b. Any arrangements, understandings, or agreements regarding the 393 | Licensed Material not stated herein are separate from and 394 | independent of the terms and conditions of this Public License. 395 | 396 | 397 | Section 8 -- Interpretation. 398 | 399 | a. For the avoidance of doubt, this Public License does not, and 400 | shall not be interpreted to, reduce, limit, restrict, or impose 401 | conditions on any use of the Licensed Material that could lawfully 402 | be made without permission under this Public License. 403 | 404 | b. To the extent possible, if any provision of this Public License is 405 | deemed unenforceable, it shall be automatically reformed to the 406 | minimum extent necessary to make it enforceable. If the provision 407 | cannot be reformed, it shall be severed from this Public License 408 | without affecting the enforceability of the remaining terms and 409 | conditions. 410 | 411 | c. No term or condition of this Public License will be waived and no 412 | failure to comply consented to unless expressly agreed to by the 413 | Licensor. 414 | 415 | d. Nothing in this Public License constitutes or may be interpreted 416 | as a limitation upon, or waiver of, any privileges and immunities 417 | that apply to the Licensor or You, including from the legal 418 | processes of any jurisdiction or authority. 419 | 420 | ======================================================================= 421 | 422 | Creative Commons is not a party to its public 423 | licenses. Notwithstanding, Creative Commons may elect to apply one of 424 | its public licenses to material it publishes and in those instances 425 | will be considered the “Licensor.” The text of the Creative Commons 426 | public licenses is dedicated to the public domain under the CC0 Public 427 | Domain Dedication. Except for the limited purpose of indicating that 428 | material is shared under a Creative Commons public license or as 429 | otherwise permitted by the Creative Commons policies published at 430 | creativecommons.org/policies, Creative Commons does not authorize the 431 | use of the trademark "Creative Commons" or any other trademark or logo 432 | of Creative Commons without its prior written consent including, 433 | without limitation, in connection with any unauthorized modifications 434 | to any of its public licenses or any other arrangements, 435 | understandings, or agreements concerning use of licensed material. For 436 | the avoidance of doubt, this paragraph does not form part of the 437 | public licenses. 438 | 439 | Creative Commons may be contacted at creativecommons.org. 440 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPRING 2 | 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/one-spring-to-rule-them-both-symmetric-amr/amr-parsing-on-ldc2017t10)](https://paperswithcode.com/sota/amr-parsing-on-ldc2017t10?p=one-spring-to-rule-them-both-symmetric-amr) 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/one-spring-to-rule-them-both-symmetric-amr/amr-parsing-on-ldc2020t02)](https://paperswithcode.com/sota/amr-parsing-on-ldc2020t02?p=one-spring-to-rule-them-both-symmetric-amr) 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/one-spring-to-rule-them-both-symmetric-amr/amr-to-text-generation-on-ldc2017t10)](https://paperswithcode.com/sota/amr-to-text-generation-on-ldc2017t10?p=one-spring-to-rule-them-both-symmetric-amr) 9 | 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/one-spring-to-rule-them-both-symmetric-amr/amr-to-text-generation-on-ldc2020t02)](https://paperswithcode.com/sota/amr-to-text-generation-on-ldc2020t02?p=one-spring-to-rule-them-both-symmetric-amr) 11 | 12 | This is the repo for [SPRING (*Symmetric ParsIng aNd Generation*)](https://ojs.aaai.org/index.php/AAAI/article/view/17489), a novel approach to semantic parsing and generation, presented at AAAI 2021. 13 | 14 | With SPRING you can perform both state-of-the-art Text-to-AMR parsing and AMR-to-Text generation without many cumbersome external components. 15 | If you use the code, please reference this work in your paper: 16 | 17 | ``` 18 | @inproceedings{bevilacqua-etal-2021-one, 19 | title = {One {SPRING} to Rule Them Both: {S}ymmetric {AMR} Semantic Parsing and Generation without a Complex Pipeline}, 20 | author = {Bevilacqua, Michele and Blloshmi, Rexhina and Navigli, Roberto}, 21 | booktitle = {Proceedings of AAAI}, 22 | year = {2021} 23 | } 24 | ``` 25 | 26 | ## Pretrained Checkpoints 27 | 28 | Here we release our best SPRING models which are based on the DFS linearization. 29 | 30 | ### Text-to-AMR Parsing 31 | - Model trained in the AMR 2.0 training set: AMR2.parsing-1.0.tar.bz2 32 | 33 | - Model trained in the AMR 3.0 training set: [AMR3.parsing-1.0.tar.bz2](http://nlp.uniroma1.it/AMR/AMR3.parsing-1.0.tar.bz2) 34 | 35 | ### AMR-to-Text Generation 36 | - Model trained in the AMR 2.0 training set: [AMR2.generation-1.0.tar.bz2](http://nlp.uniroma1.it/AMR/AMR2.generation-1.0.tar.bz2) 37 | 38 | - Model trained in the AMR 3.0 training set: [AMR3.generation-1.0.tar.bz2](http://nlp.uniroma1.it/AMR/AMR3.generation-1.0.tar.bz2) 39 | 40 | 41 | If you need the checkpoints of other experiments in the paper, please send us an email. 42 | 43 | ## Installation 44 | ```shell script 45 | cd spring 46 | pip install -r requirements.txt 47 | pip install -e . 48 | ``` 49 | 50 | The code only works with `transformers` < 3.0 because of a disrupting change in positional embeddings. 51 | The code works fine with `torch` 1.5. We recommend the usage of a new `conda` env. 52 | 53 | ## Train 54 | Modify `config.yaml` in `configs`. Instructions in comments within the file. Also see the [appendix](docs/appendix.pdf). 55 | 56 | ### Text-to-AMR 57 | ```shell script 58 | python bin/train.py --config configs/config.yaml --direction amr 59 | ``` 60 | Results in `runs/` 61 | 62 | ### AMR-to-Text 63 | ```shell script 64 | python bin/train.py --config configs/config.yaml --direction text 65 | ``` 66 | Results in `runs/` 67 | 68 | ## Evaluate 69 | ### Text-to-AMR 70 | ```shell script 71 | python bin/predict_amrs.py \ 72 | --datasets /data/amrs/split/test/*.txt \ 73 | --gold-path data/tmp/amr2.0/gold.amr.txt \ 74 | --pred-path data/tmp/amr2.0/pred.amr.txt \ 75 | --checkpoint runs/.pt \ 76 | --beam-size 5 \ 77 | --batch-size 500 \ 78 | --device cuda \ 79 | --penman-linearization --use-pointer-tokens 80 | ``` 81 | `gold.amr.txt` and `pred.amr.txt` will contain, respectively, the concatenated gold and the predictions. 82 | 83 | To reproduce our paper's results, you will also need need to run the [BLINK](https://github.com/facebookresearch/BLINK) 84 | entity linking system on the prediction file (`data/tmp/amr2.0/pred.amr.txt` in the previous code snippet). 85 | To do so, you will need to install BLINK, and download their models: 86 | ```shell script 87 | git clone https://github.com/facebookresearch/BLINK.git 88 | cd BLINK 89 | pip install -r requirements.txt 90 | sh download_blink_models.sh 91 | cd models 92 | wget http://dl.fbaipublicfiles.com/BLINK//faiss_flat_index.pkl 93 | cd ../.. 94 | ``` 95 | Then, you will be able to launch the `blinkify.py` script: 96 | ```shell 97 | python bin/blinkify.py \ 98 | --datasets data/tmp/amr2.0/pred.amr.txt \ 99 | --out data/tmp/amr2.0/pred.amr.blinkified.txt \ 100 | --device cuda \ 101 | --blink-models-dir BLINK/models 102 | ``` 103 | To have comparable Smatch scores you will also need to use the scripts available at https://github.com/mdtux89/amr-evaluation, which provide 104 | results that are around ~0.3 Smatch points lower than those returned by `bin/predict_amrs.py`. 105 | 106 | ### AMR-to-Text 107 | ```shell script 108 | python bin/predict_sentences.py \ 109 | --datasets /data/amrs/split/test/*.txt \ 110 | --gold-path data/tmp/amr2.0/gold.text.txt \ 111 | --pred-path data/tmp/amr2.0/pred.text.txt \ 112 | --checkpoint runs/.pt \ 113 | --beam-size 5 \ 114 | --batch-size 500 \ 115 | --device cuda \ 116 | --penman-linearization --use-pointer-tokens 117 | ``` 118 | `gold.text.txt` and `pred.text.txt` will contain, respectively, the concatenated gold and the predictions. 119 | For BLEU, chrF++, and Meteor in order to be comparable you will need to tokenize both gold and predictions using [JAMR tokenizer](https://github.com/redpony/cdec/blob/master/corpus/tokenize-anything.sh). 120 | To compute BLEU and chrF++, please use `bin/eval_bleu.py`. For METEOR, use https://www.cs.cmu.edu/~alavie/METEOR/ . 121 | For BLEURT don't use tokenization and run the eval with `https://github.com/google-research/bleurt`. Also see the [appendix](docs/appendix.pdf). 122 | 123 | ## Linearizations 124 | The previously shown commands assume the use of the DFS-based linearization. To use BFS or PENMAN decomment the relevant lines in `configs/config.yaml` (for training). As for the evaluation scripts, substitute the `--penman-linearization --use-pointer-tokens` line with `--use-pointer-tokens` for BFS or with `--penman-linearization` for PENMAN. 125 | 126 | ## License 127 | This project is released under the CC-BY-NC-SA 4.0 license (see `LICENSE`). If you use SPRING, please put a link to this repo. 128 | 129 | ## Acknowledgements 130 | The authors gratefully acknowledge the support of the [ERC Consolidator Grant MOUSSE](http://mousse-project.org) No. 726487 and the [ELEXIS project](https://elex.is/) No. 731015 under the European Union’s Horizon 2020 research and innovation programme. 131 | 132 | This work was supported in part by the MIUR under the grant "Dipartimenti di eccellenza 2018-2022" of the Department of Computer Science of the Sapienza University of Rome. 133 | -------------------------------------------------------------------------------- /bin/blinkify.py: -------------------------------------------------------------------------------- 1 | import blink.main_dense as main_dense 2 | from logging import getLogger 3 | from penman import Triple, Graph 4 | from spring_amr.evaluation import write_predictions 5 | from spring_amr.tokenization_bart import AMRBartTokenizer 6 | import json 7 | from pathlib import Path 8 | from spring_amr.IO import read_raw_amr_data 9 | from spring_amr.entities import read_entities 10 | 11 | if __name__ == '__main__': 12 | 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--datasets', nargs='+', required=True) 17 | parser.add_argument('--blink-models-dir', type=str, required=True) 18 | parser.add_argument('--out', type=str, required=True) 19 | parser.add_argument('--device', type=str, default='cuda', 20 | help="Device. 'cpu', 'cuda', 'cuda:'.") 21 | parser.add_argument('--all', action='store_true') 22 | parser.add_argument('--fast', action='store_true') 23 | args = parser.parse_args() 24 | 25 | graphs = read_raw_amr_data(args.datasets) 26 | sentences = [g.metadata['snt'] for g in graphs] 27 | for_blink = [] 28 | sample_id = 0 29 | 30 | for sent, (i, with_wikis, name_to_entity, name_to_ops) in zip(sentences, read_entities(sentences, graphs, just_tagged=not args.all)): 31 | for name, parent in name_to_entity.items(): 32 | nt, wiki = with_wikis[parent] 33 | ops_triples = name_to_ops[name] 34 | ops_triples = sorted(ops_triples, key=lambda t: t[1]) 35 | ops_triples = [t[2].strip('"') for t in ops_triples] 36 | string = ' '.join(ops_triples) 37 | found = string.lower() in sent.lower() 38 | if found: 39 | left = sent.lower().find(string.lower()) 40 | right = left + len(string) 41 | 42 | sample = { 43 | "id": sample_id, 44 | "label": "unknown", 45 | "label_id": -1, 46 | "context_left": sent[:left].strip().lower(), 47 | "mention": string.lower(), 48 | "context_right": sent[right:].strip().lower(), 49 | "graph_n": i, 50 | "triple_n": nt, 51 | } 52 | sample_id += 1 53 | for_blink.append(sample) 54 | 55 | main_dense.logger = logger = getLogger('BLINK') 56 | models_path = args.blink_models_dir # the path where you stored the BLINK models 57 | 58 | config = { 59 | "test_entities": None, 60 | "test_mentions": None, 61 | "interactive": False, 62 | "biencoder_model": models_path+"biencoder_wiki_large.bin", 63 | "biencoder_config": models_path+"biencoder_wiki_large.json", 64 | "entity_catalogue": models_path+"entity.jsonl", 65 | "entity_encoding": models_path+"all_entities_large.t7", 66 | "crossencoder_model": models_path+"crossencoder_wiki_large.bin", 67 | "crossencoder_config": models_path+"crossencoder_wiki_large.json", 68 | "top_k": 10, 69 | "show_url": False, 70 | "fast": args.fast, # set this to be true if speed is a concern 71 | "output_path": models_path+"logs/", # logging directory 72 | "faiss_index": None,#"flat", 73 | "index_path": models_path+"faiss_flat_index.pkl", 74 | } 75 | 76 | args_blink = argparse.Namespace(**config) 77 | models = main_dense.load_models(args_blink, logger=logger) 78 | _, _, _, _, _, predictions, scores, = main_dense.run(args_blink, logger, *models, test_data=for_blink, device=args.device) 79 | 80 | for s, pp in zip(for_blink, predictions): 81 | pp = [p for p in pp if not p.startswith('List of')] 82 | p = f'"{pp[0]}"' if pp else '-' 83 | p = p.replace(' ', '_') 84 | graph_n = s['graph_n'] 85 | triple_n = s['triple_n'] 86 | triples = [g for g in graphs[graph_n].triples] 87 | n, rel, w = triples[triple_n] 88 | triples[triple_n] = Triple(n, rel, p) 89 | g = Graph(triples) 90 | g.metadata = graphs[graph_n].metadata 91 | graphs[graph_n] = g 92 | 93 | 94 | write_predictions(args.out, AMRBartTokenizer, graphs) 95 | -------------------------------------------------------------------------------- /bin/eval_bleu.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from typing import Iterable, Optional 4 | import sacrebleu 5 | import re 6 | 7 | 8 | def argument_parser(): 9 | 10 | parser = argparse.ArgumentParser(description='Preprocess AMR data') 11 | # Multiple input parameters 12 | parser.add_argument( 13 | "--in-tokens", 14 | help="input tokens", 15 | required=True, 16 | type=str 17 | ) 18 | parser.add_argument( 19 | "--in-reference-tokens", 20 | help="refrence tokens to compute metric", 21 | type=str 22 | ) 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | 28 | def tokenize_sentence(text, debug=False): 29 | text = re.sub(r"('ll|n't|'m|'s|'d|'re)", r" \1", text) 30 | text = re.sub(r"(\s+)", r" ", text) 31 | return text 32 | 33 | 34 | def raw_corpus_bleu(hypothesis: Iterable[str], reference: Iterable[str], 35 | offset: Optional[float] = 0.01) -> float: 36 | bleu = sacrebleu.corpus_bleu(hypothesis, reference, smooth_value=offset, 37 | force=True, use_effective_order=False, 38 | lowercase=True) 39 | return bleu.score 40 | 41 | 42 | def raw_corpus_chrf(hypotheses: Iterable[str], 43 | references: Iterable[str]) -> float: 44 | return sacrebleu.corpus_chrf(hypotheses, references, 45 | order=sacrebleu.CHRF_ORDER, 46 | beta=sacrebleu.CHRF_BETA, 47 | remove_whitespace=True) 48 | 49 | def read_tokens(in_tokens_file): 50 | with open(in_tokens_file) as fid: 51 | lines = fid.readlines() 52 | return lines 53 | 54 | 55 | if __name__ == '__main__': 56 | 57 | # Argument handlig 58 | args = argument_parser() 59 | 60 | # read files 61 | ref = read_tokens(args.in_reference_tokens) 62 | hyp = read_tokens(args.in_tokens) 63 | 64 | # Lower evaluation 65 | for i in range(len(ref)): 66 | ref[i] = ref[i].lower() 67 | 68 | # Lower case output 69 | for i in range(len(hyp)): 70 | if '' in hyp[i]: 71 | hyp[i] = hyp[i].split('')[-1] 72 | hyp[i] = tokenize_sentence(hyp[i].lower()) 73 | 74 | # results 75 | 76 | bleu = raw_corpus_bleu(hyp, [ref]) 77 | print('BLEU {:.2f}'.format(bleu)) 78 | chrFpp = raw_corpus_chrf(hyp, ref).score * 100 79 | print('chrF++ {:.2f}'.format(chrFpp)) -------------------------------------------------------------------------------- /bin/inspect_.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import penman 3 | from spring_amr.utils import instantiate_model_and_tokenizer 4 | 5 | if __name__ == '__main__': 6 | 7 | from argparse import ArgumentParser 8 | parser = ArgumentParser() 9 | parser.add_argument('--checkpoint', type=str, required=True) 10 | parser.add_argument('--beam-size', type=int, default=1) 11 | parser.add_argument('--device', type=str, default='cpu') 12 | parser.add_argument('--penman-linearization', action='store_true', 13 | help="Predict using PENMAN linearization instead of ours.") 14 | parser.add_argument('--use-pointer-tokens', action='store_true') 15 | parser.add_argument('--restore-name-ops', action='store_true') 16 | args = parser.parse_args() 17 | 18 | device = torch.device(args.device) 19 | model, tokenizer = instantiate_model_and_tokenizer( 20 | name='facebook/bart-large', 21 | checkpoint=args.checkpoint, 22 | dropout=0., attention_dropout=0., 23 | penman_linearization=args.penman_linearization, 24 | use_pointer_tokens=args.use_pointer_tokens, 25 | ) 26 | model.eval().to(device) 27 | 28 | while True: 29 | sentence = [input('Sentence to parse:\n')] 30 | x, extra = tokenizer.batch_encode_sentences(sentence, device) 31 | with torch.no_grad(): 32 | out = model.generate(**x, max_length=1024, decoder_start_token_id=0, num_beams=args.beam_size) 33 | out = out[0].tolist() 34 | graph, status, (lin, backr) = tokenizer.decode_amr(out, restore_name_ops=args.restore_name_ops) 35 | print('-' * 5) 36 | print('Status:', status) 37 | print('-' * 5) 38 | print('Graph:') 39 | print(penman.encode(graph)) 40 | print('-' * 5) 41 | print('Linearization:') 42 | print(lin) 43 | print('\n') 44 | -------------------------------------------------------------------------------- /bin/patch_legacy_checkpoint.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | 3 | from argparse import ArgumentParser 4 | import torch 5 | 6 | parser = ArgumentParser() 7 | parser.add_argument('legacy_checkpoint') 8 | parser.add_argument('patched_checkpoint') 9 | parser.parse_args() 10 | 11 | args = parser.parse_args() 12 | 13 | to_remove = [] 14 | 15 | fixed = False 16 | w = torch.load(args.legacy_checkpoint, map_location='cpu') 17 | for name in w['model']: 18 | if 'backreferences' in name: 19 | fixed = True 20 | to_remove.append(name) 21 | print('Deleting parameters:', name) 22 | 23 | if not fixed: 24 | print('The checkpoint was fine as it was!') 25 | else: 26 | for name in to_remove: 27 | del w['model'][name] 28 | torch.save(w, args.patched_checkpoint) 29 | -------------------------------------------------------------------------------- /bin/predict_amrs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import penman 4 | import torch 5 | 6 | from spring_amr import ROOT 7 | from spring_amr.evaluation import predict_amrs, compute_smatch 8 | from spring_amr.penman import encode 9 | from spring_amr.utils import instantiate_loader, instantiate_model_and_tokenizer 10 | 11 | if __name__ == '__main__': 12 | 13 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 14 | 15 | parser = ArgumentParser( 16 | description="Script to predict AMR graphs given sentences. LDC format as input.", 17 | formatter_class=ArgumentDefaultsHelpFormatter, 18 | ) 19 | parser.add_argument('--datasets', type=str, required=True, nargs='+', 20 | help="Required. One or more glob patterns to use to load amr files.") 21 | parser.add_argument('--checkpoint', type=str, required=True, 22 | help="Required. Checkpoint to restore.") 23 | parser.add_argument('--model', type=str, default='facebook/bart-large', 24 | help="Model config to use to load the model class.") 25 | parser.add_argument('--beam-size', type=int, default=1, 26 | help="Beam size.") 27 | parser.add_argument('--batch-size', type=int, default=1000, 28 | help="Batch size (as number of linearized graph tokens per batch).") 29 | parser.add_argument('--device', type=str, default='cuda', 30 | help="Device. 'cpu', 'cuda', 'cuda:'.") 31 | parser.add_argument('--pred-path', type=Path, default=ROOT / 'data/tmp/inf-pred.txt', 32 | help="Where to write predictions.") 33 | parser.add_argument('--gold-path', type=Path, default=ROOT / 'data/tmp/inf-gold.txt', 34 | help="Where to write the gold file.") 35 | parser.add_argument('--use-recategorization', action='store_true', 36 | help="Predict using Zhang recategorization on top of our linearization (requires recategorized sentences in input).") 37 | parser.add_argument('--penman-linearization', action='store_true', 38 | help="Predict using PENMAN linearization instead of ours.") 39 | parser.add_argument('--use-pointer-tokens', action='store_true') 40 | parser.add_argument('--raw-graph', action='store_true') 41 | parser.add_argument('--restore-name-ops', action='store_true') 42 | parser.add_argument('--return-all', action='store_true') 43 | 44 | args = parser.parse_args() 45 | 46 | device = torch.device(args.device) 47 | model, tokenizer = instantiate_model_and_tokenizer( 48 | args.model, 49 | dropout=0., 50 | attention_dropout=0., 51 | penman_linearization=args.penman_linearization, 52 | use_pointer_tokens=args.use_pointer_tokens, 53 | raw_graph=args.raw_graph, 54 | ) 55 | model.amr_mode = True 56 | model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model']) 57 | model.to(device) 58 | 59 | gold_path = args.gold_path 60 | pred_path = args.pred_path 61 | loader = instantiate_loader( 62 | args.datasets, 63 | tokenizer, 64 | batch_size=args.batch_size, 65 | evaluation=True, out=gold_path, 66 | use_recategorization=args.use_recategorization, 67 | ) 68 | loader.device = device 69 | 70 | graphs = predict_amrs( 71 | loader, 72 | model, 73 | tokenizer, 74 | beam_size=args.beam_size, 75 | restore_name_ops=args.restore_name_ops, 76 | return_all=args.return_all, 77 | ) 78 | if args.return_all: 79 | graphs = [g for gg in graphs for g in gg] 80 | 81 | pieces = [encode(g) for g in graphs] 82 | pred_path.write_text('\n\n'.join(pieces)) 83 | 84 | if not args.return_all: 85 | score = compute_smatch(gold_path, pred_path) 86 | print(f'Smatch: {score:.3f}') 87 | -------------------------------------------------------------------------------- /bin/predict_amrs_from_plaintext.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import penman 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from spring_amr.penman import encode 8 | from spring_amr.utils import instantiate_model_and_tokenizer 9 | 10 | def read_file_in_batches(path, batch_size=1000, max_length=100): 11 | 12 | data = [] 13 | idx = 0 14 | for line in Path(path).read_text().strip().splitlines(): 15 | line = line.strip() 16 | if not line: 17 | continue 18 | n = len(line.split()) 19 | if n > max_length: 20 | continue 21 | data.append((idx, line, n)) 22 | idx += 1 23 | 24 | def _iterator(data): 25 | 26 | data = sorted(data, key=lambda x: x[2], reverse=True) 27 | 28 | maxn = 0 29 | batch = [] 30 | 31 | for sample in data: 32 | idx, line, n = sample 33 | if n > batch_size: 34 | if batch: 35 | yield batch 36 | maxn = 0 37 | batch = [] 38 | yield [sample] 39 | else: 40 | curr_batch_size = maxn * len(batch) 41 | cand_batch_size = max(maxn, n) * (len(batch) + 1) 42 | 43 | if 0 < curr_batch_size <= batch_size and cand_batch_size > batch_size: 44 | yield batch 45 | maxn = 0 46 | batch = [] 47 | maxn = max(maxn, n) 48 | batch.append(sample) 49 | 50 | if batch: 51 | yield batch 52 | 53 | return _iterator(data), len(data) 54 | 55 | if __name__ == '__main__': 56 | 57 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 58 | 59 | parser = ArgumentParser( 60 | description="Script to predict AMR graphs given sentences. LDC format as input.", 61 | formatter_class=ArgumentDefaultsHelpFormatter, 62 | ) 63 | parser.add_argument('--texts', type=str, required=True, nargs='+', 64 | help="Required. One or more files containing \\n-separated sentences.") 65 | parser.add_argument('--checkpoint', type=str, required=True, 66 | help="Required. Checkpoint to restore.") 67 | parser.add_argument('--model', type=str, default='facebook/bart-large', 68 | help="Model config to use to load the model class.") 69 | parser.add_argument('--beam-size', type=int, default=1, 70 | help="Beam size.") 71 | parser.add_argument('--batch-size', type=int, default=1000, 72 | help="Batch size (as number of linearized graph tokens per batch).") 73 | parser.add_argument('--penman-linearization', action='store_true', 74 | help="Predict using PENMAN linearization instead of ours.") 75 | parser.add_argument('--use-pointer-tokens', action='store_true') 76 | parser.add_argument('--restore-name-ops', action='store_true') 77 | parser.add_argument('--device', type=str, default='cuda', 78 | help="Device. 'cpu', 'cuda', 'cuda:'.") 79 | parser.add_argument('--only-ok', action='store_true') 80 | args = parser.parse_args() 81 | 82 | device = torch.device(args.device) 83 | model, tokenizer = instantiate_model_and_tokenizer( 84 | args.model, 85 | dropout=0., 86 | attention_dropout=0, 87 | penman_linearization=args.penman_linearization, 88 | use_pointer_tokens=args.use_pointer_tokens, 89 | ) 90 | model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model']) 91 | model.to(device) 92 | model.eval() 93 | 94 | for path in tqdm(args.texts, desc='Files:'): 95 | 96 | iterator, nsent = read_file_in_batches(path, args.batch_size) 97 | 98 | with tqdm(desc=path, total=nsent) as bar: 99 | for batch in iterator: 100 | if not batch: 101 | continue 102 | ids, sentences, _ = zip(*batch) 103 | x, _ = tokenizer.batch_encode_sentences(sentences, device=device) 104 | with torch.no_grad(): 105 | model.amr_mode = True 106 | out = model.generate(**x, max_length=512, decoder_start_token_id=0, num_beams=args.beam_size) 107 | 108 | bgraphs = [] 109 | for idx, sent, tokk in zip(ids, sentences, out): 110 | graph, status, (lin, backr) = tokenizer.decode_amr(tokk.tolist(), restore_name_ops=args.restore_name_ops) 111 | if args.only_ok and ('OK' not in str(status)): 112 | continue 113 | graph.metadata['status'] = str(status) 114 | graph.metadata['source'] = path 115 | graph.metadata['nsent'] = str(idx) 116 | graph.metadata['snt'] = sent 117 | bgraphs.append((idx, graph)) 118 | 119 | for i, g in bgraphs: 120 | print(encode(g)) 121 | print() 122 | 123 | # if bgraphs and args.reverse: 124 | # bgraphs = [x[1] for x in bgraphs] 125 | # x, _ = tokenizer.batch_encode_graphs(bgraphs, device) 126 | # x = torch.cat([x['decoder_input_ids'], x['lm_labels'][:, -1:]], 1) 127 | # att = torch.ones_like(x) 128 | # att[att == tokenizer.pad_token_id] = 0 129 | # x = { 130 | # 'input_ids': x, 131 | # #'attention_mask': att, 132 | # } 133 | # with torch.no_grad(): 134 | # model.amr_mode = False 135 | # out = model.generate(**x, max_length=1024, decoder_start_token_id=0, num_beams=args.beam_size) 136 | # 137 | # for graph, tokk in zip(bgraphs, out): 138 | # tokk = [t for t in tokk.tolist() if t > 2] 139 | # graph.metadata['snt-pred'] = tokenizer.decode(tokk).strip() 140 | bar.update(len(sentences)) 141 | 142 | exit(0) 143 | 144 | ids, graphs = zip(*sorted(results, key=lambda x:x[0])) 145 | 146 | for g in graphs: 147 | print(encode(g)) 148 | print() 149 | -------------------------------------------------------------------------------- /bin/predict_sentences.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import penman 4 | import torch 5 | 6 | from spring_amr import ROOT 7 | from spring_amr.evaluation import predict_amrs, compute_smatch, predict_sentences, compute_bleu 8 | from spring_amr.penman import encode 9 | from spring_amr.utils import instantiate_loader, instantiate_model_and_tokenizer 10 | 11 | if __name__ == '__main__': 12 | 13 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 14 | 15 | parser = ArgumentParser( 16 | description="Script to predict AMR graphs given sentences. LDC format as input.", 17 | formatter_class=ArgumentDefaultsHelpFormatter, 18 | ) 19 | parser.add_argument('--datasets', type=str, required=True, nargs='+', 20 | help="Required. One or more glob patterns to use to load amr files.") 21 | parser.add_argument('--checkpoint', type=str, required=True, 22 | help="Required. Checkpoint to restore.") 23 | parser.add_argument('--model', type=str, default='facebook/bart-large', 24 | help="Model config to use to load the model class.") 25 | parser.add_argument('--beam-size', type=int, default=1, 26 | help="Beam size.") 27 | parser.add_argument('--batch-size', type=int, default=1000, 28 | help="Batch size (as number of linearized graph tokens per batch).") 29 | parser.add_argument('--device', type=str, default='cuda', 30 | help="Device. 'cpu', 'cuda', 'cuda:'.") 31 | parser.add_argument('--pred-path', type=Path, default=ROOT / 'data/tmp/inf-pred-sentences.txt', 32 | help="Where to write predictions.") 33 | parser.add_argument('--gold-path', type=Path, default=ROOT / 'data/tmp/inf-gold-sentences.txt', 34 | help="Where to write the gold file.") 35 | parser.add_argument('--add-to-graph-file', action='store_true') 36 | parser.add_argument('--use-reverse-decoder', action='store_true') 37 | parser.add_argument('--deinvert', action='store_true') 38 | parser.add_argument('--penman-linearization', action='store_true', 39 | help="Predict using PENMAN linearization instead of ours.") 40 | parser.add_argument('--collapse-name-ops', action='store_true') 41 | parser.add_argument('--use-pointer-tokens', action='store_true') 42 | parser.add_argument('--raw-graph', action='store_true') 43 | parser.add_argument('--return-all', action='store_true') 44 | args = parser.parse_args() 45 | 46 | device = torch.device(args.device) 47 | model, tokenizer = instantiate_model_and_tokenizer( 48 | args.model, 49 | dropout=0., 50 | attention_dropout=0., 51 | penman_linearization=args.penman_linearization, 52 | use_pointer_tokens=args.use_pointer_tokens, 53 | collapse_name_ops=args.collapse_name_ops, 54 | init_reverse=args.use_reverse_decoder, 55 | raw_graph=args.raw_graph, 56 | ) 57 | model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model']) 58 | model.to(device) 59 | model.rev.amr_mode = False 60 | 61 | loader = instantiate_loader( 62 | args.datasets, 63 | tokenizer, 64 | batch_size=args.batch_size, 65 | evaluation=True, out='/tmp/a.txt', 66 | dereify=args.deinvert) 67 | loader.device = device 68 | 69 | pred_sentences = predict_sentences(loader, model.rev, tokenizer, beam_size=args.beam_size, return_all=args.return_all) 70 | if args.add_to_graph_file: 71 | graphs = loader.dataset.graphs 72 | for ss, g in zip(pred_sentences, graphs): 73 | if args.return_all: 74 | g.metadata['snt-pred'] = '\t\t'.join(ss) 75 | else: 76 | g.metadata['snt-pred'] = ss 77 | args.pred_path.write_text('\n\n'.join([encode(g) for g in graphs])) 78 | else: 79 | if args.return_all: 80 | pred_sentences = [s for ss in pred_sentences for s in ss] 81 | args.gold_path.write_text('\n'.join(loader.dataset.sentences)) 82 | args.pred_path.write_text('\n'.join(pred_sentences)) 83 | if not args.return_all: 84 | score = compute_bleu(loader.dataset.sentences, pred_sentences) 85 | print(f'BLEU: {score.score:.2f}') 86 | -------------------------------------------------------------------------------- /bin/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | try: 5 | from torch.cuda.amp import autocast 6 | autocast_available = True 7 | except ImportError: 8 | class autocast: 9 | def __init__(self, enabled=True): pass 10 | def __enter__(self): return self 11 | def __exit__(self, exc_type, exc_value, exc_traceback): pass 12 | autocast_available = False 13 | 14 | from torch.cuda.amp.grad_scaler import GradScaler 15 | import transformers 16 | 17 | from spring_amr import ROOT 18 | from spring_amr.dataset import reverse_direction 19 | from spring_amr.optim import RAdam 20 | from spring_amr.evaluation import write_predictions, compute_smatch, predict_amrs, predict_sentences, compute_bleu 21 | from spring_amr.utils import instantiate_model_and_tokenizer, instantiate_loader 22 | 23 | from ignite.engine import Engine, Events 24 | from ignite.metrics import RunningAverage 25 | from ignite.handlers import ModelCheckpoint, global_step_from_engine 26 | 27 | def do_train(checkpoint=None, direction='amr', split_both_decoder=False, fp16=False): 28 | 29 | assert direction in ('amr', 'text', 'both') 30 | 31 | model, tokenizer = instantiate_model_and_tokenizer( 32 | config['model'], 33 | checkpoint=checkpoint, 34 | additional_tokens_smart_init=config['smart_init'], 35 | dropout=config['dropout'], 36 | attention_dropout=config['attention_dropout'], 37 | from_pretrained=config['warm_start'], 38 | init_reverse=split_both_decoder, 39 | penman_linearization=config['penman_linearization'], 40 | collapse_name_ops=config['collapse_name_ops'], 41 | use_pointer_tokens=config['use_pointer_tokens'], 42 | raw_graph=config.get('raw_graph', False) 43 | ) 44 | 45 | print(model) 46 | print(model.config) 47 | 48 | if checkpoint is not None: 49 | print(f'Checkpoint restored ({checkpoint})!') 50 | 51 | if direction == 'both' and split_both_decoder: 52 | params_dir_enc = list(model.model.encoder.parameters()) 53 | params_dir_enc_check = {id(p) for p in params_dir_enc} 54 | params_dir_dec = set() 55 | params_dir_dec |= {p for p in model.model.decoder.parameters() if id(p) not in params_dir_enc_check} 56 | params_dir_dec |= {p for p in model.rev.model.decoder.parameters() if id(p) not in params_dir_enc_check} 57 | params_dir_dec = list(params_dir_dec) 58 | optimizer = RAdam( 59 | [{'params': params_dir_enc, 'lr': config['learning_rate']}, 60 | {'params': params_dir_dec, 'lr': config['learning_rate'] * 2},], 61 | weight_decay=config['weight_decay']) 62 | else: 63 | optimizer = RAdam( 64 | model.parameters(), 65 | lr=config['learning_rate'], 66 | weight_decay=config['weight_decay']) 67 | if checkpoint is not None: 68 | optimizer.load_state_dict(torch.load(checkpoint)['optimizer']) 69 | 70 | if config['scheduler'] == 'cosine': 71 | scheduler = transformers.get_cosine_schedule_with_warmup( 72 | optimizer, 73 | num_warmup_steps=config['warmup_steps'], 74 | num_training_steps=config['training_steps']) 75 | elif config['scheduler'] == 'constant': 76 | scheduler = transformers.get_constant_schedule_with_warmup( 77 | optimizer, 78 | num_warmup_steps=config['warmup_steps']) 79 | else: 80 | raise ValueError 81 | 82 | scaler = GradScaler(enabled=fp16) 83 | 84 | train_loader = instantiate_loader( 85 | config['train'], 86 | tokenizer, 87 | batch_size=config['batch_size'], 88 | evaluation=False, 89 | use_recategorization=config['use_recategorization'], 90 | remove_longer_than=config['remove_longer_than'], 91 | remove_wiki=config['remove_wiki'], 92 | dereify=config['dereify'], 93 | ) 94 | 95 | dev_gold_path = ROOT / 'data/tmp/dev-gold.txt' 96 | dev_pred_path = ROOT / 'data/tmp/dev-pred.txt' 97 | dev_loader = instantiate_loader( 98 | config['dev'], 99 | tokenizer, 100 | batch_size=config['batch_size'], 101 | evaluation=True, out=dev_gold_path, 102 | use_recategorization=config['use_recategorization'], 103 | remove_wiki=config['remove_wiki'], 104 | dereify=config['dereify'], 105 | ) 106 | 107 | if direction == 'amr': 108 | 109 | def train_step(engine, batch): 110 | model.train() 111 | x, y, extra = batch 112 | model.amr_mode = True 113 | with autocast(enabled=fp16): 114 | loss, *_ = model(**x, **y) 115 | scaler.scale((loss / config['accum_steps'])).backward() 116 | return loss.item() 117 | 118 | @torch.no_grad() 119 | def eval_step(engine, batch): 120 | model.eval() 121 | x, y, extra = batch 122 | model.amr_mode = True 123 | loss, *_ = model(**x, **y) 124 | return loss.item() 125 | 126 | elif direction == 'text': 127 | 128 | def train_step(engine, batch): 129 | model.train() 130 | x, y, extra = batch 131 | x, y = reverse_direction(x, y) 132 | model.rev.amr_mode = False 133 | with autocast(enabled=fp16): 134 | loss, *_ = model.rev(**x, **y) 135 | scaler.scale((loss / config['accum_steps'])).backward() 136 | return loss.item() 137 | 138 | @torch.no_grad() 139 | def eval_step(engine, batch): 140 | model.eval() 141 | x, y, extra = batch 142 | x, y = reverse_direction(x, y) 143 | model.rev.amr_mode = False 144 | loss, *_ = model(**x, **y) 145 | return loss.item() 146 | 147 | elif direction == 'both': 148 | 149 | def train_step(engine, batch): 150 | model.train() 151 | x, y, extra = batch 152 | model.amr_mode = True 153 | with autocast(enabled=fp16): 154 | loss1, *_ = model(**x, **y) 155 | scaler.scale((loss1 / config['accum_steps'] * 0.5)).backward() 156 | loss1 = loss1.item() 157 | x, y = reverse_direction(x, y) 158 | model.rev.amr_mode = False 159 | with autocast(enabled=fp16): 160 | loss2, *_ = model.rev(**x, **y) 161 | scaler.scale((loss2 / config['accum_steps'] * 0.5)).backward() 162 | return loss1, loss2.item() 163 | 164 | @torch.no_grad() 165 | def eval_step(engine, batch): 166 | model.eval() 167 | x, y, extra = batch 168 | model.amr_mode = True 169 | loss1, *_ = model(**x, **y) 170 | x, y = reverse_direction(x, y) 171 | model.rev.amr_mode = False 172 | loss2, *_ = model.rev(**x, **y) 173 | return loss1.item(), loss2.item() 174 | 175 | else: 176 | raise ValueError 177 | 178 | trainer = Engine(train_step) 179 | evaluator = Engine(eval_step) 180 | 181 | @trainer.on(Events.STARTED) 182 | def update(engine): 183 | print('training started!') 184 | 185 | @trainer.on(Events.EPOCH_COMPLETED) 186 | @trainer.on(Events.ITERATION_COMPLETED(every=config['accum_steps'])) 187 | def update(engine): 188 | scaler.unscale_(optimizer) 189 | torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_norm']) 190 | scaler.step(optimizer) 191 | scaler.update() 192 | optimizer.zero_grad() 193 | scheduler.step() 194 | 195 | @trainer.on(Events.EPOCH_COMPLETED) 196 | def log_trn_loss(engine): 197 | log_msg = f"training epoch: {engine.state.epoch}" 198 | if direction in ('amr', 'both'): 199 | log_msg += f" | loss_amr: {engine.state.metrics['trn_amr_loss']:.3f}" 200 | if direction in ('text', 'both'): 201 | log_msg += f" | loss_text: {engine.state.metrics['trn_text_loss']:.3f}" 202 | print(log_msg) 203 | 204 | @trainer.on(Events.EPOCH_COMPLETED) 205 | def run_dev_eval(engine): 206 | dev_loader.batch_size = config['batch_size'] 207 | dev_loader.device = next(model.parameters()).device 208 | evaluator.run(dev_loader) 209 | 210 | if not config['best_loss']: 211 | if direction in ('amr', 'both'): 212 | @evaluator.on(Events.EPOCH_COMPLETED) 213 | def smatch_eval(engine): 214 | device = next(model.parameters()).device 215 | dev_loader.device = device 216 | graphs = predict_amrs(dev_loader, model, tokenizer, restore_name_ops=config['collapse_name_ops']) 217 | write_predictions(dev_pred_path, tokenizer, graphs) 218 | try: 219 | smatch = compute_smatch(dev_gold_path, dev_pred_path) 220 | except: 221 | smatch = 0. 222 | engine.state.metrics['dev_smatch'] = smatch 223 | 224 | if direction in ('text', 'both'): 225 | @evaluator.on(Events.EPOCH_COMPLETED) 226 | def smatch_eval(engine): 227 | device = next(model.parameters()).device 228 | dev_loader.device = device 229 | pred_sentences = predict_sentences(dev_loader, model.rev, tokenizer, beam_size=config['beam_size']) 230 | bleu = compute_bleu(dev_loader.dataset.sentences, pred_sentences) 231 | engine.state.metrics['dev_bleu'] = bleu.score 232 | 233 | @evaluator.on(Events.EPOCH_COMPLETED) 234 | def log_dev_loss(engine): 235 | log_msg = f"dev epoch: {trainer.state.epoch}" 236 | if direction in ('amr', 'both'): 237 | log_msg += f" | loss_amr: {engine.state.metrics['dev_amr_loss']:.3f}" 238 | if not config['best_loss']: 239 | log_msg += f" | smatch: {engine.state.metrics['dev_smatch']:.3f}" 240 | if direction in ('text', 'both'): 241 | log_msg += f" | loss_text: {engine.state.metrics['dev_text_loss']:.3f}" 242 | if not config['best_loss']: 243 | log_msg += f" | bleu: {engine.state.metrics['dev_bleu']:.3f}" 244 | print(log_msg) 245 | 246 | if direction == 'amr': 247 | RunningAverage(output_transform=lambda out: out).attach(trainer, 'trn_amr_loss') 248 | RunningAverage(output_transform=lambda out: out).attach(evaluator, 'dev_amr_loss') 249 | elif direction == 'text': 250 | RunningAverage(output_transform=lambda out: out).attach(trainer, 'trn_text_loss') 251 | RunningAverage(output_transform=lambda out: out).attach(evaluator, 'dev_text_loss') 252 | elif direction == 'both': 253 | RunningAverage(output_transform=lambda out: out[0]).attach(trainer, 'trn_amr_loss') 254 | RunningAverage(output_transform=lambda out: out[1]).attach(trainer, 'trn_text_loss') 255 | RunningAverage(output_transform=lambda out: out[0]).attach(evaluator, 'dev_amr_loss') 256 | RunningAverage(output_transform=lambda out: out[1]).attach(evaluator, 'dev_text_loss') 257 | 258 | 259 | if config['log_wandb']: 260 | from ignite.contrib.handlers.wandb_logger import WandBLogger 261 | wandb_logger = WandBLogger(init=False) 262 | 263 | if direction == 'amr': 264 | wandb_logger.attach_output_handler( 265 | trainer, 266 | event_name=Events.ITERATION_COMPLETED, 267 | tag="iterations/trn_amr_loss", 268 | output_transform=lambda loss: loss 269 | ) 270 | elif direction == 'text': 271 | wandb_logger.attach_output_handler( 272 | trainer, 273 | event_name=Events.ITERATION_COMPLETED, 274 | tag="iterations/trn_text_loss", 275 | output_transform=lambda loss: loss 276 | ) 277 | if direction == 'both': 278 | wandb_logger.attach_output_handler( 279 | trainer, 280 | event_name=Events.ITERATION_COMPLETED, 281 | tag="iterations/trn_amr_loss", 282 | output_transform=lambda loss: loss[0] 283 | ) 284 | wandb_logger.attach_output_handler( 285 | trainer, 286 | event_name=Events.ITERATION_COMPLETED, 287 | tag="iterations/trn_text_loss", 288 | output_transform=lambda loss: loss[1] 289 | ) 290 | 291 | if direction == 'amr': 292 | metric_names_trn = ['trn_amr_loss'] 293 | metric_names_dev = ['dev_amr_loss'] 294 | if not config['best_loss']: 295 | metric_names_dev.append('dev_smatch') 296 | elif direction == 'text': 297 | metric_names_trn = ['trn_text_loss'] 298 | metric_names_dev = ['dev_text_loss'] 299 | if not config['best_loss']: 300 | metric_names_dev.append('dev_bleu') 301 | elif direction == 'both': 302 | metric_names_trn = ['trn_amr_loss', 'trn_text_loss'] 303 | metric_names_dev = ['dev_amr_loss', 'dev_smatch'] 304 | if not config['best_loss']: 305 | metric_names_dev.extend(['dev_text_loss', 'dev_bleu']) 306 | 307 | wandb_logger.attach_output_handler( 308 | trainer, 309 | event_name=Events.EPOCH_COMPLETED, 310 | tag="epochs", 311 | metric_names=metric_names_trn, 312 | global_step_transform=lambda *_: trainer.state.iteration, 313 | ) 314 | 315 | wandb_logger.attach_output_handler( 316 | evaluator, 317 | event_name=Events.EPOCH_COMPLETED, 318 | tag="epochs", 319 | metric_names=metric_names_dev, 320 | global_step_transform=lambda *_: trainer.state.iteration, 321 | ) 322 | 323 | @trainer.on(Events.ITERATION_COMPLETED) 324 | def wandb_log_lr(engine): 325 | wandb.log({'lr': scheduler.get_last_lr()[0]}, step=engine.state.iteration) 326 | 327 | if config['save_checkpoints']: 328 | 329 | if direction in ('amr', 'both'): 330 | if config['best_loss']: 331 | prefix = 'best-loss-amr' 332 | score_function = lambda x: 1 / evaluator.state.metrics['dev_amr_loss'] 333 | else: 334 | prefix = 'best-smatch' 335 | score_function = lambda x: evaluator.state.metrics['dev_smatch'] 336 | else: 337 | if config['best_loss']: 338 | prefix = 'best-loss-text' 339 | score_function = lambda x: 1 / evaluator.state.metrics['dev_amr_loss'] 340 | else: 341 | prefix = 'best-bleu' 342 | score_function = lambda x: evaluator.state.metrics['dev_bleu'] 343 | 344 | to_save = {'model': model, 'optimizer': optimizer} 345 | if config['log_wandb']: 346 | where_checkpoints = str(wandb_logger.run.dir) 347 | else: 348 | root = ROOT/'runs' 349 | try: 350 | root.mkdir() 351 | except: 352 | pass 353 | where_checkpoints = root/str(len(list(root.iterdir()))) 354 | try: 355 | where_checkpoints.mkdir() 356 | except: 357 | pass 358 | where_checkpoints = str(where_checkpoints) 359 | 360 | print(where_checkpoints) 361 | handler = ModelCheckpoint( 362 | where_checkpoints, 363 | prefix, 364 | n_saved=1, 365 | create_dir=True, 366 | score_function=score_function, 367 | global_step_transform=global_step_from_engine(trainer), 368 | ) 369 | evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save) 370 | 371 | model.cuda() 372 | device = next(model.parameters()).device 373 | train_loader.device = device 374 | trainer.run(train_loader, max_epochs=config['max_epochs']) 375 | 376 | if __name__ == '__main__': 377 | 378 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 379 | import yaml 380 | 381 | import wandb 382 | 383 | parser = ArgumentParser( 384 | description="Trainer script", 385 | formatter_class=ArgumentDefaultsHelpFormatter, 386 | ) 387 | parser.add_argument('--direction', type=str, default='amr', choices=['amr', 'text', 'both'], 388 | help='Train a uni- (amr, text) or bidirectional (both).') 389 | parser.add_argument('--split-both-decoder', action='store_true') 390 | parser.add_argument('--config', type=Path, default=ROOT/'configs/sweeped.yaml', 391 | help='Use the following config for hparams.') 392 | parser.add_argument('--checkpoint', type=str, 393 | help='Warm-start from a previous fine-tuned checkpoint.') 394 | parser.add_argument('--fp16', action='store_true') 395 | args, unknown = parser.parse_known_args() 396 | 397 | if args.fp16 and autocast_available: 398 | raise ValueError('You\'ll need a newer PyTorch version to enable fp16 training.') 399 | 400 | with args.config.open() as y: 401 | config = yaml.load(y, Loader=yaml.FullLoader) 402 | 403 | if config['log_wandb']: 404 | wandb.init( 405 | entity="SOME-RUNS", 406 | project="SOME-PROJECT", 407 | config=config, 408 | dir=str(ROOT / 'runs/')) 409 | config = wandb.config 410 | 411 | print(config) 412 | 413 | if args.checkpoint: 414 | checkpoint = args.checkpoint 415 | else: 416 | checkpoint = None 417 | 418 | do_train( 419 | checkpoint=checkpoint, 420 | direction=args.direction, 421 | split_both_decoder=args.split_both_decoder, 422 | fp16=args.fp16, 423 | ) -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | name: baseline+smart_init 2 | model: facebook/bart-large 3 | 4 | # <-------------- 5 | # Linearizations 6 | # Comment DFS and uncomment the relevant block if you want to use a different linearization scheme 7 | 8 | # DFS 9 | penman_linearization: True 10 | use_pointer_tokens: True 11 | raw_graph: False 12 | 13 | # BFS 14 | # penman_linearization: False 15 | # use_pointer_tokens: True 16 | # raw_graph: False 17 | 18 | # PENMAN 19 | # penman_linearization: True 20 | # use_pointer_tokens: False 21 | # raw_graph: False 22 | 23 | # BART baseline 24 | # penman_linearization: True 25 | # use_pointer_tokens: False 26 | # raw_graph: True 27 | 28 | remove_wiki: False 29 | dereify: False 30 | collapse_name_ops: False 31 | 32 | # Hparams 33 | batch_size: 500 34 | beam_size: 1 35 | dropout: 0.25 36 | attention_dropout: 0.0 37 | smart_init: True 38 | accum_steps: 10 39 | warmup_steps: 1 40 | training_steps: 250000 41 | weight_decay: 0.004 42 | grad_norm: 2.5 43 | scheduler: constant 44 | learning_rate: 0.00005 45 | max_epochs: 30 46 | save_checkpoints: True 47 | log_wandb: False 48 | warm_start: True 49 | use_recategorization: False 50 | best_loss: False 51 | remove_longer_than: 1024 52 | 53 | # <------------------ 54 | # Data: replace DATA below with the root of your AMR 2/3 release folder 55 | train: DATA/data/amrs/split/training/*.txt 56 | dev: DATA/data/amrs/split/dev/*.txt 57 | test: DATA/data/amrs/split/test/*.txt 58 | -------------------------------------------------------------------------------- /data/vocab/additions.txt: -------------------------------------------------------------------------------- 1 | date-entity 2 | government-organization 3 | temporal-quantity 4 | amr-unknown 5 | multi-sentence 6 | political-party 7 | :compared-to 8 | monetary-quantity 9 | ordinal-entity 10 | religious-group 11 | percentage-entity 12 | world-region 13 | :consist 14 | url-entity 15 | political-movement 16 | et-cetera 17 | at-least 18 | mass-quantity 19 | have-org-role-91 20 | have-rel-role-91 21 | include-91 22 | have-concession-91 23 | have-condition-91 24 | be-located-at-91 25 | rate-entity-91 26 | instead-of-91 27 | hyperlink-91 28 | request-confirmation-91 29 | have-purpose-91 30 | be-temporally-at-91 31 | regardless-91 32 | have-polarity-91 33 | byline-91 34 | have-manner-91 35 | have-part-91 36 | have-quant-91 37 | publication-91 38 | be-from-91 39 | have-mod-91 40 | have-frequency-91 41 | score-on-scale-91 42 | have-li-91 43 | be-compared-to-91 44 | be-destined-for-91 45 | course-91 46 | have-subevent-91 47 | street-address-91 48 | have-extent-91 49 | statistical-test-91 50 | have-instrument-91 51 | have-name-91 52 | be-polite-91 53 | -00 54 | -01 55 | -02 56 | -03 57 | -04 58 | -05 59 | -06 60 | -07 61 | -08 62 | -09 63 | -10 64 | -11 65 | -12 66 | -13 67 | -14 68 | -15 69 | -16 70 | -17 71 | -18 72 | -19 73 | -20 74 | -21 75 | -22 76 | -23 77 | -24 78 | -25 79 | -26 80 | -27 81 | -28 82 | -29 83 | -20 84 | -31 85 | -32 86 | -33 87 | -34 88 | -35 89 | -36 90 | -37 91 | -38 92 | -39 93 | -40 94 | -41 95 | -42 96 | -43 97 | -44 98 | -45 99 | -46 100 | -47 101 | -48 102 | -49 103 | -50 104 | -51 105 | -52 106 | -53 107 | -54 108 | -55 109 | -56 110 | -57 111 | -58 112 | -59 113 | -60 114 | -61 115 | -62 116 | -63 117 | -64 118 | -65 119 | -66 120 | -67 121 | -68 122 | -69 123 | -70 124 | -71 125 | -72 126 | -73 127 | -74 128 | -75 129 | -76 130 | -77 131 | -78 132 | -79 133 | -80 134 | -81 135 | -82 136 | -83 137 | -84 138 | -85 139 | -86 140 | -87 141 | -88 142 | -89 143 | -90 144 | -91 145 | -92 146 | -93 147 | -94 148 | -95 149 | -96 150 | -97 151 | -98 152 | -of 153 | :op1 154 | :op2 155 | :op3 156 | :op4 157 | :op5 158 | :ARG0 159 | :ARG1 160 | :ARG2 161 | :ARG3 162 | :ARG4 163 | :ARG5 164 | :ARG6 165 | :ARG7 166 | :ARG8 167 | :ARG9 168 | :ARG10 169 | :ARG11 170 | :ARG12 171 | :ARG13 172 | :ARG14 173 | :ARG15 174 | :ARG16 175 | :ARG17 176 | :ARG18 177 | :ARG19 178 | :ARG20 179 | :accompanier 180 | :age 181 | :beneficiary 182 | :calendar 183 | :cause 184 | :century 185 | :concession 186 | :condition 187 | :conj-as-if 188 | :consist-of 189 | :cost 190 | :day 191 | :dayperiod 192 | :decade 193 | :degree 194 | :destination 195 | :direction 196 | :domain 197 | :duration 198 | :employed-by 199 | :era 200 | :example 201 | :extent 202 | :frequency 203 | :instrument 204 | :li 205 | :location 206 | :manner 207 | :meaning 208 | :medium 209 | :mod 210 | :mode 211 | :month 212 | :name 213 | :ord 214 | :part 215 | :path 216 | :polarity 217 | :polite 218 | :poss 219 | :purpose 220 | :quant 221 | :quarter 222 | :range 223 | :relation 224 | :role 225 | :scale 226 | :season 227 | :source 228 | :subevent 229 | :subset 230 | :superset 231 | :time 232 | :timezone 233 | :topic 234 | :unit 235 | :value 236 | :weekday 237 | :wiki 238 | :year 239 | :year2 240 | :snt0 241 | :snt1 242 | :snt2 243 | :snt3 244 | :snt4 245 | :snt5 246 | -------------------------------------------------------------------------------- /data/vocab/recategorizations.txt: -------------------------------------------------------------------------------- 1 | PERSON 2 | COUNTRY 3 | QUANTITY 4 | ORGANIZATION 5 | DATE_ATTRS 6 | NATIONALITY 7 | LOCATION 8 | ENTITY 9 | CITY 10 | MISC 11 | ORDINAL_ENTITY 12 | IDEOLOGY 13 | RELIGION 14 | STATE_OR_PROVINCE 15 | URL 16 | CAUSE_OF_DEATH 17 | O 18 | TITLE 19 | DATE 20 | NUMBER 21 | HANDLE 22 | SCORE_ENTITY 23 | DURATION 24 | ORDINAL 25 | MONEY 26 | SET 27 | CRIMINAL_CHARGE 28 | _1 29 | _2 30 | _3 31 | _4 32 | _2 33 | _5 34 | _6 35 | _7 36 | _8 37 | _9 38 | _10 39 | _11 40 | _12 41 | _13 42 | _14 43 | _15 -------------------------------------------------------------------------------- /docs/appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/spring/39079940d028ba0dde4c1af60432be49f67d76f8/docs/appendix.pdf -------------------------------------------------------------------------------- /docs/camera-ready.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/spring/39079940d028ba0dde4c1af60432be49f67d76f8/docs/camera-ready.pdf -------------------------------------------------------------------------------- /docs/preprint.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/spring/39079940d028ba0dde4c1af60432be49f67d76f8/docs/preprint.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cached_property 2 | networkx 3 | penman>=1.1.0 4 | pytorch-ignite 5 | regex 6 | sacrebleu 7 | smatch 8 | transformers==2.11.0 9 | wandb 10 | PyYAML>=5.1 -------------------------------------------------------------------------------- /sample.txt: -------------------------------------------------------------------------------- 1 | # ::status ParsedStatus.OK 2 | # ::source sample.txt 3 | # ::nsent 6 4 | # ::snt In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. 5 | # ::snt-pred Scientists were shocked to discover a herd of unicorns living in a remote valley inaccessible in the Andes Mountains. 6 | (z0 / discover-01 7 | :ARG0 (z1 / scientist) 8 | :ARG1 (z2 / herd 9 | :consist-of (z3 / unicorn) 10 | :ARG0-of (z4 / live-01 11 | :location (z5 / valley 12 | :mod (z6 / remote) 13 | :ARG1-of (z7 / explore-01 14 | :polarity - 15 | :time (z8 / previous)) 16 | :location (z9 / mountain 17 | :wiki "Andes" 18 | :name (z10 / name 19 | :op1 "Andes" 20 | :op2 "Mountains"))))) 21 | :ARG0-of (z11 / shock-01)) 22 | 23 | # ::status ParsedStatus.OK 24 | # ::source sample.txt 25 | # ::nsent 5 26 | # ::snt Emily loves mint chocolate cake, but she requires that it be paired with mini chocolate chips, so I threw some of those in between the layers. 27 | # ::snt-pred Emily loves chocolate cake, but it requires to be paired with mini chocolate chips, so I threw some of them in between the layers. 28 | (z0 / love-01 29 | :ARG0 (z1 / person 30 | :wiki - 31 | :name (z2 / name 32 | :op1 "Emily")) 33 | :ARG1 (z3 / cake 34 | :consist-of (z4 / chocolate 35 | :mod (z5 / mint))) 36 | :concession-of (z6 / require-01 37 | :ARG0 z1 38 | :ARG1 (z7 / pair-01 39 | :ARG1 z3 40 | :ARG2 (z8 / chip 41 | :consist-of (z9 / chocolate 42 | :mod (z10 / mini))))) 43 | :ARG0-of (z11 / cause-01 44 | :ARG1 (z12 / throw-01 45 | :ARG0 (z13 / i) 46 | :ARG1 (z14 / some 47 | :ARG1-of (z15 / include-91 48 | :ARG2 z3)) 49 | :ARG2 (z16 / between 50 | :op1 (z17 / layer))))) 51 | 52 | # ::status ParsedStatus.OK 53 | # ::source sample.txt 54 | # ::nsent 7 55 | # ::snt Prehistoric man sketched an incredible array of prehistoric beasts on the rough limestone walls of a cave in modern day France 36,000 years ago. 56 | # ::snt-pred 36,000 years ago, prehistoric men drew an incredible array of prehistoric beasts on a rough limestone wall of a cave in modern-day France. 57 | (z0 / draw-01 58 | :ARG0 (z1 / man 59 | :mod (z2 / prehistoric)) 60 | :ARG1 (z3 / array 61 | :mod (z4 / incredible) 62 | :consist-of (z5 / beast 63 | :mod (z6 / prehistoric))) 64 | :location (z7 / wall 65 | :consist-of (z8 / limestone) 66 | :ARG1-of (z9 / rough-04) 67 | :part-of (z10 / cave 68 | :location (z11 / country 69 | :wiki "France" 70 | :name (z12 / name 71 | :op1 "France") 72 | :time (z13 / day 73 | :ARG1-of (z14 / modern-02))))) 74 | :time (z15 / before 75 | :op1 (z16 / now) 76 | :quant (z17 / temporal-quantity 77 | :quant 36000 78 | :unit (z18 / year)))) 79 | 80 | # ::status ParsedStatus.OK 81 | # ::source sample.txt 82 | # ::nsent 3 83 | # ::snt Corporal Michael P. Goeldin was an unskilled laborer from Ireland when he enlisted in Company A in November 1860. 84 | # ::snt-pred When Michael P. Goeldin enlisted in Company A in November, 1860, he was an Irish labourer with no skills. 85 | (z0 / person 86 | :ARG0-of (z1 / labor-01 87 | :manner (z2 / skill 88 | :polarity -)) 89 | :domain (z3 / person 90 | :wiki - 91 | :name (z4 / name 92 | :op1 "Michael" 93 | :op2 "P." 94 | :op3 "Goeldin") 95 | :ARG0-of (z5 / have-org-role-91 96 | :ARG2 (z6 / corporal))) 97 | :mod (z7 / country 98 | :wiki "Ireland" 99 | :name (z8 / name 100 | :op1 "Ireland")) 101 | :time (z9 / enlist-01 102 | :ARG1 z3 103 | :ARG2 (z10 / military 104 | :wiki - 105 | :name (z11 / name 106 | :op1 "Company" 107 | :op2 "A")) 108 | :time (z12 / date-entity 109 | :year 1860 110 | :month 11))) 111 | 112 | # ::status ParsedStatus.OK 113 | # ::source sample.txt 114 | # ::nsent 0 115 | # ::snt This pairing was the first outfit I thought of when I bought the shoes. 116 | # ::snt-pred This pair is the first outfit I thought of when I bought shoes. 117 | (z0 / outfit 118 | :ord (z1 / ordinal-entity 119 | :value 1) 120 | :ARG1-of (z2 / think-01 121 | :ARG0 (z3 / i) 122 | :time (z4 / buy-01 123 | :ARG0 z3 124 | :ARG1 (z5 / shoe))) 125 | :domain (z6 / pair-01 126 | :mod (z7 / this))) 127 | 128 | # ::status ParsedStatus.OK 129 | # ::source sample.txt 130 | # ::nsent 2 131 | # ::snt The pink ghost’s AI is designed to ”feel” opposite of the red ghost’s behavior. 132 | # ::snt-pred The artificial system of the pink ghosts was designed to feel the opposite of the way the red ghosts behaved. 133 | (z0 / design-01 134 | :ARG1 (z1 / system 135 | :mod (z2 / artificial) 136 | :poss (z3 / ghost 137 | :ARG1-of (z4 / pink-04))) 138 | :ARG3 (z5 / feel-01 139 | :ARG0 z1 140 | :ARG1 (z6 / opposite-01 141 | :ARG2 (z7 / behave-01 142 | :ARG0 (z8 / ghost 143 | :mod (z9 / red)))))) 144 | 145 | # ::status ParsedStatus.OK 146 | # ::source sample.txt 147 | # ::nsent 4 148 | # ::snt Xresources can be an absolute pain (they were for me). 149 | # ::snt-pred The x-resoures could absolutely cause pain to me. 150 | (z0 / possible-01 151 | :ARG1 (z1 / pain-01 152 | :ARG0 (z2 / resource 153 | :mod (z3 / xresources)) 154 | :mod (z4 / absolute) 155 | :ARG1-of (z5 / cause-01 156 | :ARG0 (z6 / they 157 | :beneficiary (z7 / i))))) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='spring_amr', 5 | version='1.0', 6 | packages=['spring_amr'], 7 | url='https://github.com/SapienzaNLP/spring', 8 | license='CC BY-NC-SA 4.0', 9 | author='Michele Bevilacqua, Rexhina Blloshmi and Roberto Navigli', 10 | author_email='{bevilacqua,blloshmi,navigli}@di.uniroma1.it', 11 | description='Parse sentences into AMR graphs and generate sentences from AMR graphs without breaking a sweat!' 12 | ) 13 | -------------------------------------------------------------------------------- /spring_amr/IO.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from typing import List, Union, Iterable 3 | from pathlib import Path 4 | from spring_amr.penman import load as pm_load 5 | 6 | def read_raw_amr_data( 7 | paths: List[Union[str, Path]], 8 | use_recategorization=False, 9 | dereify=True, 10 | remove_wiki=False, 11 | ): 12 | assert paths 13 | 14 | if not isinstance(paths, Iterable): 15 | paths = [paths] 16 | 17 | graphs = [] 18 | for path_ in paths: 19 | for path in glob.glob(str(path_)): 20 | path = Path(path) 21 | graphs.extend(pm_load(path, dereify=dereify, remove_wiki=remove_wiki)) 22 | 23 | assert graphs 24 | 25 | if use_recategorization: 26 | for g in graphs: 27 | metadata = g.metadata 28 | metadata['snt_orig'] = metadata['snt'] 29 | tokens = eval(metadata['tokens']) 30 | metadata['snt'] = ' '.join([t for t in tokens if not ((t.startswith('-L') or t.startswith('-R')) and t.endswith('-'))]) 31 | 32 | return graphs -------------------------------------------------------------------------------- /spring_amr/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | 3 | from pathlib import Path 4 | 5 | ROOT = Path(__file__).parent.parent 6 | -------------------------------------------------------------------------------- /spring_amr/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import torch 4 | from cached_property import cached_property 5 | from torch.utils.data import Dataset 6 | from spring_amr.IO import read_raw_amr_data 7 | 8 | def reverse_direction(x, y, pad_token_id=1): 9 | input_ids = torch.cat([y['decoder_input_ids'], y['lm_labels'][:, -1:]], 1) 10 | attention_mask = torch.ones_like(input_ids) 11 | attention_mask[input_ids == pad_token_id] = 0 12 | decoder_input_ids = x['input_ids'][:,:-1] 13 | lm_labels = x['input_ids'][:,1:] 14 | x = {'input_ids': input_ids, 'attention_mask': attention_mask} 15 | y = {'decoder_input_ids': decoder_input_ids, 'lm_labels': lm_labels} 16 | return x, y 17 | 18 | class AMRDataset(Dataset): 19 | 20 | def __init__( 21 | self, 22 | paths, 23 | tokenizer, 24 | device=torch.device('cpu'), 25 | use_recategorization=False, 26 | remove_longer_than=None, 27 | remove_wiki=False, 28 | dereify=True, 29 | ): 30 | self.paths = paths 31 | self.tokenizer = tokenizer 32 | self.device = device 33 | graphs = read_raw_amr_data(paths, use_recategorization, remove_wiki=remove_wiki, dereify=dereify) 34 | self.graphs = [] 35 | self.sentences = [] 36 | self.linearized = [] 37 | self.linearized_extra = [] 38 | self.remove_longer_than = remove_longer_than 39 | for g in graphs: 40 | l, e = self.tokenizer.linearize(g) 41 | 42 | try: 43 | self.tokenizer.batch_encode_sentences([g.metadata['snt']]) 44 | except: 45 | logging.warning('Invalid sentence!') 46 | continue 47 | 48 | if remove_longer_than and len(l) > remove_longer_than: 49 | continue 50 | if len(l) > 1024: 51 | logging.warning('Sequence longer than 1024 included. BART does not support it!') 52 | 53 | self.sentences.append(g.metadata['snt']) 54 | self.graphs.append(g) 55 | self.linearized.append(l) 56 | self.linearized_extra.append(e) 57 | 58 | def __len__(self): 59 | return len(self.sentences) 60 | 61 | def __getitem__(self, idx): 62 | sample = {} 63 | sample['id'] = idx 64 | sample['sentences'] = self.sentences[idx] 65 | if self.linearized is not None: 66 | sample['linearized_graphs_ids'] = self.linearized[idx] 67 | sample.update(self.linearized_extra[idx]) 68 | return sample 69 | 70 | def size(self, sample): 71 | return len(sample['linearized_graphs_ids']) 72 | 73 | def collate_fn(self, samples, device=torch.device('cpu')): 74 | x = [s['sentences'] for s in samples] 75 | x, extra = self.tokenizer.batch_encode_sentences(x, device=device) 76 | if 'linearized_graphs_ids' in samples[0]: 77 | y = [s['linearized_graphs_ids'] for s in samples] 78 | y, extra_y = self.tokenizer.batch_encode_graphs_from_linearized(y, samples, device=device) 79 | extra.update(extra_y) 80 | else: 81 | y = None 82 | extra['ids'] = [s['id'] for s in samples] 83 | return x, y, extra 84 | 85 | class AMRDatasetTokenBatcherAndLoader: 86 | 87 | def __init__(self, dataset, batch_size=800 ,device=torch.device('cpu'), shuffle=False, sort=False): 88 | assert not (shuffle and sort) 89 | self.batch_size = batch_size 90 | self.tokenizer = dataset.tokenizer 91 | self.dataset = dataset 92 | self.device = device 93 | self.shuffle = shuffle 94 | self.sort = sort 95 | 96 | def __iter__(self): 97 | it = self.sampler() 98 | it = ([[self.dataset[s] for s in b] for b in it]) 99 | it = (self.dataset.collate_fn(b, device=self.device) for b in it) 100 | return it 101 | 102 | @cached_property 103 | def sort_ids(self): 104 | lengths = [len(s.split()) for s in self.dataset.sentences] 105 | ids, _ = zip(*sorted(enumerate(lengths), reverse=True)) 106 | ids = list(ids) 107 | return ids 108 | 109 | def sampler(self): 110 | ids = list(range(len(self.dataset)))[::-1] 111 | 112 | if self.shuffle: 113 | random.shuffle(ids) 114 | if self.sort: 115 | ids = self.sort_ids.copy() 116 | 117 | batch_longest = 0 118 | batch_nexamps = 0 119 | batch_ntokens = 0 120 | batch_ids = [] 121 | 122 | def discharge(): 123 | nonlocal batch_longest 124 | nonlocal batch_nexamps 125 | nonlocal batch_ntokens 126 | ret = batch_ids.copy() 127 | batch_longest *= 0 128 | batch_nexamps *= 0 129 | batch_ntokens *= 0 130 | batch_ids[:] = [] 131 | return ret 132 | 133 | while ids: 134 | idx = ids.pop() 135 | size = self.dataset.size(self.dataset[idx]) 136 | cand_batch_ntokens = max(size, batch_longest) * (batch_nexamps + 1) 137 | if cand_batch_ntokens > self.batch_size and batch_ids: 138 | yield discharge() 139 | batch_longest = max(batch_longest, size) 140 | batch_nexamps += 1 141 | batch_ntokens = batch_longest * batch_nexamps 142 | batch_ids.append(idx) 143 | 144 | if len(batch_ids) == 1 and batch_ntokens > self.batch_size: 145 | yield discharge() 146 | 147 | if batch_ids: 148 | yield discharge() 149 | -------------------------------------------------------------------------------- /spring_amr/entities.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | def read_entities(sentences, graphs, just_tagged=True): 4 | 5 | for i, (s, g) in enumerate(zip(sentences, graphs)): 6 | 7 | with_wikis = {} 8 | name_to_entity = {} 9 | name_to_ops = defaultdict(list) 10 | 11 | for nt, t in enumerate(g.triples): 12 | n1, rel, n2 = t 13 | 14 | if n2 == '-' and just_tagged: 15 | continue 16 | 17 | if rel == ':wiki': 18 | with_wikis[n1] = (nt, n2) 19 | 20 | for t in g.triples: 21 | n1, rel, n2 = t 22 | if (n1 in with_wikis) and (rel == ':name'): 23 | name_to_entity[n2] = n1 24 | 25 | for nt, t in enumerate(g.triples): 26 | n1, rel, n2 = t 27 | if (n1 in name_to_entity) and rel.startswith(':op'): 28 | name_to_ops[n1].append(t) 29 | 30 | yield (i, with_wikis, name_to_entity, name_to_ops) -------------------------------------------------------------------------------- /spring_amr/evaluation.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | 4 | import penman 5 | from sacrebleu import corpus_bleu 6 | import torch 7 | from tqdm import tqdm 8 | import smatch 9 | 10 | from spring_amr.dataset import reverse_direction 11 | 12 | def predict_amrs( 13 | loader, model, tokenizer, beam_size=1, tokens=None, restore_name_ops=False, return_all=False): 14 | 15 | shuffle_orig = loader.shuffle 16 | sort_orig = loader.sort 17 | 18 | loader.shuffle = False 19 | loader.sort = True 20 | 21 | total = len(loader.dataset) 22 | model.eval() 23 | model.amr_mode = True 24 | 25 | if tokens is None: 26 | ids = [] 27 | tokens = [] 28 | with tqdm(total=total) as bar: 29 | for x, y, extra in loader: 30 | ii = extra['ids'] 31 | ids.extend(ii) 32 | with torch.no_grad(): 33 | out = model.generate( 34 | **x, 35 | max_length=1024, 36 | decoder_start_token_id=0, 37 | num_beams=beam_size, 38 | num_return_sequences=beam_size) 39 | nseq = len(ii) 40 | for i1 in range(0, out.size(0), beam_size): 41 | tokens_same_source = [] 42 | tokens.append(tokens_same_source) 43 | for i2 in range(i1, i1+beam_size): 44 | tokk = out[i2].tolist() 45 | tokens_same_source.append(tokk) 46 | bar.update(nseq) 47 | # reorder 48 | tokens = [tokens[i] for i in ids] 49 | tokens = [t for tt in tokens for t in tt] 50 | 51 | graphs = [] 52 | for i1 in range(0, len(tokens), beam_size): 53 | graphs_same_source = [] 54 | graphs.append(graphs_same_source) 55 | for i2 in range(i1, i1+beam_size): 56 | tokk = tokens[i2] 57 | graph, status, (lin, backr) = tokenizer.decode_amr(tokk, restore_name_ops=restore_name_ops) 58 | graph.status = status 59 | graph.nodes = lin 60 | graph.backreferences = backr 61 | graph.tokens = tokk 62 | graphs_same_source.append(graph) 63 | graphs_same_source[:] = tuple(zip(*sorted(enumerate(graphs_same_source), key=lambda x: (x[1].status.value, x[0]))))[1] 64 | 65 | for gps, gg in zip(graphs, loader.dataset.graphs): 66 | for gp in gps: 67 | metadata = gg.metadata.copy() 68 | metadata['annotator'] = 'bart-amr' 69 | metadata['date'] = str(datetime.datetime.now()) 70 | if 'save-date' in metadata: 71 | del metadata['save-date'] 72 | gp.metadata = metadata 73 | 74 | loader.shuffle = shuffle_orig 75 | loader.sort = sort_orig 76 | 77 | if not return_all: 78 | graphs = [gg[0] for gg in graphs] 79 | 80 | return graphs 81 | 82 | def predict_sentences(loader, model, tokenizer, beam_size=1, tokens=None, return_all=False): 83 | 84 | shuffle_orig = loader.shuffle 85 | sort_orig = loader.sort 86 | 87 | loader.shuffle = False 88 | loader.sort = True 89 | 90 | total = len(loader.dataset) 91 | model.eval() 92 | model.amr_mode = False 93 | 94 | if tokens is None: 95 | ids = [] 96 | tokens = [] 97 | with tqdm(total=total) as bar: 98 | for x, y, extra in loader: 99 | ids.extend(extra['ids']) 100 | x, y = reverse_direction(x, y) 101 | x['input_ids'] = x['input_ids'][:, :1024] 102 | x['attention_mask'] = x['attention_mask'][:, :1024] 103 | with torch.no_grad(): 104 | out = model.generate( 105 | **x, 106 | max_length=350, 107 | decoder_start_token_id=0, 108 | num_beams=beam_size, 109 | num_return_sequences=beam_size) 110 | for i1 in range(0, len(out), beam_size): 111 | tokens_same_source = [] 112 | tokens.append(tokens_same_source) 113 | for i2 in range(i1, i1+beam_size): 114 | tokk = out[i2] 115 | tokk = [t for t in tokk.tolist() if t > 2] 116 | tokens_same_source.append(tokk) 117 | bar.update(out.size(0) // beam_size) 118 | #reorder 119 | tokens = [tokens[i] for i in ids] 120 | 121 | sentences = [] 122 | for tokens_same_source in tokens: 123 | if return_all: 124 | sentences.append([tokenizer.decode(tokk).strip() for tokk in tokens_same_source]) 125 | else: 126 | sentences.append(tokenizer.decode(tokens_same_source[0]).strip()) 127 | 128 | loader.shuffle = shuffle_orig 129 | loader.sort = sort_orig 130 | 131 | return sentences 132 | 133 | def write_predictions(predictions_path, tokenizer, graphs): 134 | pieces = [penman.encode(g) for g in graphs] 135 | Path(predictions_path).write_text('\n\n'.join(pieces).replace(tokenizer.INIT, '')) 136 | return predictions_path 137 | 138 | def compute_smatch(test_path, predictions_path): 139 | with Path(predictions_path).open() as p, Path(test_path).open() as g: 140 | score = next(smatch.score_amr_pairs(p, g)) 141 | return score[2] 142 | 143 | def compute_bleu(gold_sentences, pred_sentences): 144 | return corpus_bleu(pred_sentences, [gold_sentences]) 145 | -------------------------------------------------------------------------------- /spring_amr/linearization.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import itertools 3 | from collections import deque, defaultdict 4 | import re 5 | from typing import List, Optional, Dict, Any, Set, TypeVar 6 | 7 | from cached_property import cached_property 8 | from dataclasses import dataclass 9 | import networkx as nx 10 | import penman 11 | 12 | @dataclass 13 | class SemanticGraph: 14 | 15 | nodes_var: List[str] 16 | """ 17 | List of linearized nodes, with special tokens. 18 | """ 19 | edges: Optional[List[str]] 20 | """ 21 | List of linearized edges, with special tokens. 22 | """ 23 | backreferences: List[int] 24 | """ 25 | List of backpointers to handle rentrancies and cycles. 26 | """ 27 | var2instance: Dict[str, str] 28 | """ 29 | Dict from var ids to 'lemmatized' readable strings qualifying the node (collapsing the :instance edge for AMR). 30 | """ 31 | extra: Dict[str, Any] 32 | """ 33 | Holds extra stuff that might be useful, e.g. alignments, NER, EL. 34 | """ 35 | 36 | @cached_property 37 | def variables(self) -> Set[str]: 38 | """Set of variables in this semantic graph""" 39 | variables = {v for v in self.nodes_var if not v.startswith('<')} 40 | return variables 41 | 42 | @property 43 | def resolved_nodes_var(self) -> List[str]: 44 | return [self.nodes_var[b] for b in self.backreferences] 45 | 46 | @cached_property 47 | def nodes(self) -> List[str]: 48 | """Linearized nodes with varids replaced by instances""" 49 | return [self.var2instance.get(node, node) for node in self.nodes_var] 50 | 51 | @property 52 | def resolved_nodes(self) -> List[str]: 53 | return [self.nodes[b] for b in self.backreferences] 54 | 55 | def src_occurrence(self, var: str) -> int: 56 | pass 57 | 58 | 59 | class BaseLinearizer(metaclass=abc.ABCMeta): 60 | 61 | @abc.abstractmethod 62 | def linearize(self, *args, **kwargs) -> SemanticGraph: 63 | pass 64 | 65 | class AMRTokens: 66 | 67 | START, END = '<', '>' 68 | _TEMPL = START + '{}' + END 69 | 70 | BOS_N = _TEMPL.format('s') 71 | EOS_N = _TEMPL.format('/s') 72 | START_N = _TEMPL.format('start') 73 | STOP_N = _TEMPL.format('stop') 74 | PNTR_N = _TEMPL.format('pointer') 75 | 76 | LIT_START = _TEMPL.format( 'lit') 77 | LIT_END = _TEMPL.format('/lit') 78 | 79 | BACKR_SRC_N = _TEMPL.format('backr:src:XXX') 80 | BACKR_TRG_N = _TEMPL.format('backr:trg:XXX') 81 | 82 | BOS_E = _TEMPL.format('s') 83 | EOS_E = _TEMPL.format('/s') 84 | START_E = _TEMPL.format('start') 85 | STOP_E = _TEMPL.format('stop') 86 | 87 | _FIXED_SPECIAL_TOKENS_N = { 88 | BOS_N, EOS_N, START_N, STOP_N} 89 | _FIXED_SPECIAL_TOKENS_E = { 90 | BOS_E, EOS_E, START_E, STOP_E} 91 | _FIXED_SPECIAL_TOKENS = _FIXED_SPECIAL_TOKENS_N | _FIXED_SPECIAL_TOKENS_E 92 | 93 | # match and read backreferences 94 | _re_BACKR_SRC_N = re.compile(BACKR_SRC_N.replace('XXX', r'([0-9]+)')) 95 | _re_BACKR_TRG_N = re.compile(BACKR_TRG_N.replace('XXX', r'([0-9]+)')) 96 | 97 | @classmethod 98 | def is_node(cls, string: str) -> bool: 99 | if isinstance(string, str) and string.startswith(':'): 100 | return False 101 | elif string in cls._FIXED_SPECIAL_TOKENS_E: 102 | return False 103 | return True 104 | 105 | @classmethod 106 | def read_backr(cls, string: str) -> Optional: 107 | m_src = cls._re_BACKR_SRC_N.search(string) 108 | if m_src is not None: 109 | return m_src 110 | m_trg = cls._re_BACKR_TRG_N.search(string) 111 | if m_trg is not None: 112 | return m_trg 113 | return None 114 | 115 | 116 | T = TypeVar('T') 117 | 118 | 119 | def index_default( 120 | item: T, list_: List[T], 121 | start: Optional[int] = None, 122 | stop: Optional[int] = None, 123 | default: Optional[int] = None 124 | ): 125 | if start is None: 126 | start = 0 127 | if stop is None: 128 | stop = len(list_) 129 | return next((i for i, x in enumerate(list_[start:stop], start=start) if x == item), default) 130 | 131 | class AMRLinearizer(BaseLinearizer): 132 | 133 | def __init__( 134 | self, 135 | use_pointer_tokens: bool = True, 136 | collapse_name_ops: bool = False, 137 | ): 138 | self.collapse_name_ops = collapse_name_ops 139 | self.interleave_edges = False 140 | self.use_pointer_tokens = use_pointer_tokens 141 | 142 | def _collapse_name_ops(self, amr): 143 | # identify name triples 144 | name_vars = {} 145 | for i, (v1, rel, v2) in enumerate(amr.triples): 146 | if rel == ':instance' and v2 == 'name': 147 | name_vars[v1] = 1 148 | 149 | # check if they have ops 150 | name_vars_to_ops = defaultdict(list) 151 | for i, (v1, rel, v2) in enumerate(amr.triples): 152 | if v1 in name_vars and rel.startswith(':op'): 153 | name_vars_to_ops[v1].append((i, rel, v2.strip('"'))) 154 | 155 | triples = amr.triples.copy() 156 | for nv, ops in name_vars_to_ops.items(): 157 | ops = sorted(ops, key=lambda x: int(x[1][3:])) 158 | idx, _, lits = zip(*ops) 159 | for i in idx: 160 | triples[i] = None 161 | lit = '"' + '_'.join(lits) + '"' 162 | triples[min(idx)] = penman.Triple(nv, ':op1', lit) 163 | 164 | triples = [t for t in triples if t is not None] 165 | amr_ = penman.Graph(triples) 166 | amr_.metadata = amr.metadata 167 | return amr_ 168 | 169 | 170 | def linearize(self, amr: penman.Graph) -> SemanticGraph: 171 | if self.collapse_name_ops: 172 | amr = self._collapse_name_ops(amr) 173 | linearized = self._linearize(amr) 174 | linearized = self._interleave(linearized) 175 | if self.use_pointer_tokens: 176 | linearized = self._add_pointer_tokens(linearized) 177 | return linearized 178 | 179 | def _linearize(self, amr: penman.Graph) -> SemanticGraph: 180 | variables = set(amr.variables()) 181 | variables = {'var:' + v for v in variables} 182 | var2instance = {} 183 | 184 | graph = nx.MultiDiGraph() 185 | 186 | triples2order = {k: i for i, k in enumerate(amr.triples)} 187 | 188 | for triple in amr.triples: 189 | var, rel, instance = triple 190 | order = triples2order[triple] 191 | if rel != ':instance': 192 | continue 193 | for expansion_candidate in itertools.chain(range(order - 1, -1), range(order + 1, len(amr.triples))): 194 | if var == amr.triples[expansion_candidate][2]: 195 | expansion = expansion_candidate 196 | break 197 | else: 198 | expansion = 0 199 | var = 'var:' + var 200 | var2instance[var] = instance 201 | graph.add_node(var, instance=instance, order=order, expansion=expansion) 202 | 203 | for triple in amr.edges(): 204 | var1, rel, var2 = triple 205 | order = triples2order[triple] 206 | if rel == ':instance': 207 | continue 208 | var1 = 'var:' + var1 209 | var2 = 'var:' + var2 210 | graph.add_edge(var1, var2, rel=rel, order=order) 211 | 212 | for triple in amr.attributes(): 213 | var, rel, attr = triple 214 | order = triples2order[triple] 215 | if rel == ':instance': 216 | continue 217 | var = 'var:' + var 218 | graph.add_edge(var, attr, rel=rel, order=order) 219 | 220 | # nodes that are not reachable from the root (e.g. because of reification) 221 | # will be present in the not_explored queue 222 | # undirected_graph = graph.to_undirected() 223 | # print(amr.variables()) 224 | not_explored = deque(sorted(variables, key=lambda x: nx.get_node_attributes(graph, 'order')[x])) 225 | # ( 226 | # len(nx.shortest_path(undirected_graph, 'var:' + amr.top, x)), 227 | # -graph.out_degree(x), 228 | # ) 229 | 230 | first_index = {} 231 | explored = set() 232 | added_to_queue = set() 233 | nodes_visit = [AMRTokens.BOS_N] 234 | edges_visit = [AMRTokens.BOS_E] 235 | backreferences = [0] 236 | queue = deque() 237 | queue.append('var:' + amr.top) 238 | 239 | while queue or not_explored: 240 | 241 | if queue: 242 | node1 = queue.popleft() 243 | else: 244 | node1 = not_explored.popleft() 245 | if node1 in added_to_queue: 246 | continue 247 | if not list(graph.successors(node1)): 248 | continue 249 | 250 | if node1 in variables: 251 | if node1 in explored: 252 | continue 253 | if node1 in first_index: 254 | nodes_visit.append(AMRTokens.BACKR_TRG_N) 255 | backreferences.append(first_index[node1]) 256 | else: 257 | backreferences.append(len(nodes_visit)) 258 | first_index[node1] = len(nodes_visit) 259 | nodes_visit.append(node1) 260 | edges_visit.append(AMRTokens.START_E) 261 | 262 | successors = [] 263 | for node2 in graph.successors(node1): 264 | for edge_data in graph.get_edge_data(node1, node2).values(): 265 | rel = edge_data['rel'] 266 | order = edge_data['order'] 267 | successors.append((order, rel, node2)) 268 | successors = sorted(successors) 269 | 270 | for order, rel, node2 in successors: 271 | edges_visit.append(rel) 272 | 273 | # node2 is a variable 274 | if node2 in variables: 275 | # ... which was mentioned before 276 | if node2 in first_index: 277 | nodes_visit.append(AMRTokens.BACKR_TRG_N) 278 | backreferences.append(first_index[node2]) 279 | 280 | # .. which is mentioned for the first time 281 | else: 282 | backreferences.append(len(nodes_visit)) 283 | first_index[node2] = len(nodes_visit) 284 | nodes_visit.append(node2) 285 | 286 | # 1) not already in Q 287 | # 2) has children 288 | # 3) the edge right before its expansion has been encountered 289 | if (node2 not in added_to_queue) and list(graph.successors(node2)) and (nx.get_node_attributes(graph, 'expansion')[node2] <= order): 290 | queue.append(node2) 291 | added_to_queue.add(node2) 292 | 293 | # node2 is a constant 294 | else: 295 | backreferences.append(len(nodes_visit)) 296 | nodes_visit.append(node2) 297 | 298 | backreferences.append(len(nodes_visit)) 299 | nodes_visit.append(AMRTokens.STOP_N) 300 | edges_visit.append(AMRTokens.STOP_E) 301 | explored.add(node1) 302 | 303 | else: 304 | backreferences.append(len(nodes_visit)) 305 | nodes_visit.append(node1) 306 | explored.add(node1) 307 | 308 | backreferences.append(len(nodes_visit)) 309 | nodes_visit.append(AMRTokens.EOS_N) 310 | edges_visit.append(AMRTokens.EOS_E) 311 | assert len(nodes_visit) == len(edges_visit) == len(backreferences) 312 | return SemanticGraph( 313 | nodes_visit, 314 | edges_visit, 315 | backreferences, 316 | var2instance, 317 | extra={'graph': graph, 'amr': amr} 318 | ) 319 | 320 | def _interleave(self, graph: SemanticGraph) -> SemanticGraph: 321 | 322 | new_backreferences_map = [] 323 | new_nodes = [] 324 | new_edges = None 325 | new_backreferences = [] 326 | 327 | # to isolate sublist to the stop token 328 | start_i = 1 329 | end_i = index_default(AMRTokens.STOP_N, graph.nodes_var, start_i, -1, -1) 330 | 331 | def add_node(node, backr = None): 332 | old_n_node = len(new_backreferences_map) 333 | new_n_node = len(new_nodes) 334 | 335 | if backr is None: 336 | backr = old_n_node 337 | 338 | new_backreferences_map.append(new_n_node) 339 | new_nodes.append(node) 340 | if old_n_node == backr: 341 | new_backreferences.append(new_n_node) 342 | else: 343 | new_backreferences.append(new_backreferences_map[backr]) 344 | 345 | def add_edge(edge): 346 | new_nodes.append(edge) 347 | new_backreferences.append(len(new_backreferences)) 348 | 349 | add_node(AMRTokens.BOS_N) 350 | 351 | while end_i > -1: 352 | 353 | # src node 354 | add_node(graph.nodes_var[start_i], graph.backreferences[start_i]) 355 | 356 | # edges and trg nodes, interleaved 357 | nodes = graph.nodes_var[start_i+1:end_i] 358 | edges = graph.edges[start_i+1:end_i] 359 | backr = graph.backreferences[start_i+1:end_i] 360 | for n, e, b in zip(nodes, edges, backr): 361 | add_edge(e) 362 | add_node(n, b) 363 | 364 | # stop 365 | add_node(graph.nodes_var[end_i], graph.backreferences[end_i]) 366 | 367 | start_i = end_i + 1 368 | end_i = index_default(AMRTokens.STOP_N, graph.nodes_var, start_i, -1, -1) 369 | 370 | add_node(AMRTokens.EOS_N) 371 | 372 | new_graph = SemanticGraph( 373 | new_nodes, 374 | None, 375 | new_backreferences, 376 | graph.var2instance, 377 | extra=graph.extra, 378 | ) 379 | return new_graph 380 | 381 | def _add_pointer_tokens(self, graph: SemanticGraph) -> SemanticGraph: 382 | new_nodes = [] 383 | var2pointer = {} 384 | for node, backr in zip(graph.nodes_var, graph.backreferences): 385 | 386 | if node == AMRTokens.BACKR_TRG_N: 387 | node = graph.nodes_var[backr] 388 | pointer = var2pointer[node] 389 | new_nodes.append(pointer) 390 | elif node in graph.var2instance: 391 | pointer = var2pointer.setdefault(node, f"") 392 | new_nodes.append(pointer) 393 | new_nodes.append(node) 394 | else: 395 | new_nodes.append(node) 396 | 397 | new_backreferences = list(range(len(new_nodes))) 398 | new_graph = SemanticGraph( 399 | new_nodes, 400 | None, 401 | new_backreferences, 402 | graph.var2instance, 403 | extra=graph.extra, 404 | ) 405 | return new_graph -------------------------------------------------------------------------------- /spring_amr/optim.py: -------------------------------------------------------------------------------- 1 | # taken from 2 | 3 | import math 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | 8 | class RAdam(Optimizer): 9 | 10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 11 | if not 0.0 <= lr: 12 | raise ValueError("Invalid learning rate: {}".format(lr)) 13 | if not 0.0 <= eps: 14 | raise ValueError("Invalid epsilon value: {}".format(eps)) 15 | if not 0.0 <= betas[0] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 17 | if not 0.0 <= betas[1] < 1.0: 18 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 19 | 20 | self.degenerated_to_sgd = degenerated_to_sgd 21 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 22 | for param in params: 23 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 24 | param['buffer'] = [[None, None, None] for _ in range(10)] 25 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 26 | buffer=[[None, None, None] for _ in range(10)]) 27 | super(RAdam, self).__init__(params, defaults) 28 | 29 | def __setstate__(self, state): 30 | super(RAdam, self).__setstate__(state) 31 | 32 | def step(self, closure=None): 33 | 34 | loss = None 35 | if closure is not None: 36 | loss = closure() 37 | 38 | for group in self.param_groups: 39 | 40 | for p in group['params']: 41 | if p.grad is None: 42 | continue 43 | grad = p.grad.data.float() 44 | if grad.is_sparse: 45 | raise RuntimeError('RAdam does not support sparse gradients') 46 | 47 | p_data_fp32 = p.data.float() 48 | 49 | state = self.state[p] 50 | 51 | if len(state) == 0: 52 | state['step'] = 0 53 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 54 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 55 | else: 56 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 57 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 58 | 59 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 60 | beta1, beta2 = group['betas'] 61 | 62 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 63 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 64 | 65 | state['step'] += 1 66 | buffered = group['buffer'][int(state['step'] % 10)] 67 | if state['step'] == buffered[0]: 68 | N_sma, step_size = buffered[1], buffered[2] 69 | else: 70 | buffered[0] = state['step'] 71 | beta2_t = beta2 ** state['step'] 72 | N_sma_max = 2 / (1 - beta2) - 1 73 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 74 | buffered[1] = N_sma 75 | 76 | # more conservative since it's an approximated value 77 | if N_sma >= 5: 78 | step_size = math.sqrt( 79 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 80 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 81 | elif self.degenerated_to_sgd: 82 | step_size = 1.0 / (1 - beta1 ** state['step']) 83 | else: 84 | step_size = -1 85 | buffered[2] = step_size 86 | 87 | # more conservative since it's an approximated value 88 | if N_sma >= 5: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 91 | denom = exp_avg_sq.sqrt().add_(group['eps']) 92 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 93 | p.data.copy_(p_data_fp32) 94 | elif step_size > 0: 95 | if group['weight_decay'] != 0: 96 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 97 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 98 | p.data.copy_(p_data_fp32) 99 | 100 | return loss -------------------------------------------------------------------------------- /spring_amr/penman.py: -------------------------------------------------------------------------------- 1 | from penman import load as load_, Graph, Triple 2 | from penman import loads as loads_ 3 | from penman import encode as encode_ 4 | from penman.model import Model 5 | from penman.models.noop import NoOpModel 6 | from penman.models import amr 7 | 8 | op_model = Model() 9 | noop_model = NoOpModel() 10 | amr_model = amr.model 11 | DEFAULT = op_model 12 | 13 | def _get_model(dereify): 14 | if dereify is None: 15 | return DEFAULT 16 | 17 | 18 | elif dereify: 19 | return op_model 20 | 21 | else: 22 | return noop_model 23 | 24 | def _remove_wiki(graph): 25 | metadata = graph.metadata 26 | triples = [] 27 | for t in graph.triples: 28 | v1, rel, v2 = t 29 | if rel == ':wiki': 30 | t = Triple(v1, rel, '+') 31 | triples.append(t) 32 | graph = Graph(triples) 33 | graph.metadata = metadata 34 | return graph 35 | 36 | def load(source, dereify=None, remove_wiki=False): 37 | model = _get_model(dereify) 38 | out = load_(source=source, model=model) 39 | if remove_wiki: 40 | for i in range(len(out)): 41 | out[i] = _remove_wiki(out[i]) 42 | return out 43 | 44 | def loads(string, dereify=None, remove_wiki=False): 45 | model = _get_model(dereify) 46 | out = loads_(string=string, model=model) 47 | if remove_wiki: 48 | for i in range(len(out)): 49 | out[i] = _remove_wiki(out[i]) 50 | return out 51 | 52 | def encode(g, top=None, indent=-1, compact=False): 53 | model = amr_model 54 | return encode_(g=g, top=top, indent=indent, compact=compact, model=model) -------------------------------------------------------------------------------- /spring_amr/postprocessing.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter 2 | import enum 3 | import re 4 | 5 | import networkx as nx 6 | import penman 7 | 8 | from spring_amr.penman import encode 9 | 10 | from spring_amr.linearization import AMRTokens 11 | 12 | BACKOFF = penman.Graph([ 13 | penman.Triple('d2', ':instance', 'dog'), 14 | penman.Triple('b1', ':instance', 'bark-01'), 15 | penman.Triple('b1', ':ARG0', 'd2'),]) 16 | 17 | def token_processing(tok): 18 | if tok is None: 19 | return None 20 | elif tok.isdigit(): 21 | try: 22 | return eval(tok) 23 | except: 24 | return tok 25 | elif tok.startswith('"') and (not tok.endswith('"')): 26 | return tok + '"' 27 | elif tok.endswith('"') and (not tok.startswith('"')): 28 | return '"' + tok 29 | else: 30 | return tok 31 | 32 | def decode_into_node_and_backreferences(subtoken_ids, tokenizer): 33 | rex_arg = re.compile(f"^{tokenizer.INIT}(op|snt|conj|prep)") 34 | rex_spc = re.compile(r"<(s|/s|lit|/lit|stop|unk|pad|mask)>") 35 | 36 | # get strings 37 | subtokens = [tokenizer.decoder.get(t) for t in subtoken_ids] 38 | # fix backreferences 39 | subtoken_backreferences = [max(t - len(tokenizer.encoder), -1) for t in subtoken_ids] 40 | # strip padding 41 | subtokens, subtoken_backreferences = zip( 42 | *[(s, b) for s, b in zip(subtokens, subtoken_backreferences) if s != (tokenizer.INIT + '')]) 43 | 44 | # subword collapse 45 | tokens = [] 46 | backreferences = [] 47 | subword_to_token_map = {} 48 | current_token_i = 0 49 | for subw_i, (subw_backr, subtok) in enumerate(zip(subtoken_backreferences, subtokens)): 50 | subword_to_token_map[subw_i] = current_token_i 51 | 52 | # if empty you cannot do anything but add a new word 53 | if not tokens: 54 | tokens.append(subtok.lstrip(tokenizer.INIT)) 55 | backreferences.append(-1) 56 | current_token_i += 1 57 | 58 | # backref can't be splitted 59 | elif subw_backr > -1: 60 | tokens.append(None) 61 | backreferences.append(subword_to_token_map[subw_backr]) 62 | current_token_i += 1 63 | 64 | # after a special token release 65 | elif isinstance(tokens[-1], str) and rex_spc.match(tokens[-1]): 66 | tokens.append(subtok.lstrip(tokenizer.INIT)) 67 | backreferences.append(-1) 68 | current_token_i += 1 69 | 70 | # after a subtoken ':' (which should be followed by the rest of the edge) ignore tokenizer.INIT 71 | # TODO: this is an ugly patch due to the fact that BART tokenizer splits after ':' 72 | elif (tokens[-1] == ':') and rex_arg.match(subtok): 73 | tokens[-1] = tokens[-1] + subtok[1:] 74 | 75 | # leading tokenizer.INIT 76 | elif subtok.startswith(tokenizer.INIT): 77 | tokens.append(subtok.lstrip(tokenizer.INIT)) 78 | backreferences.append(-1) 79 | current_token_i += 1 80 | 81 | # very ugly patch for some cases in which tokenizer.INIT is not in the following token to the edge 82 | elif isinstance(tokens[-1], str) and tokens[-1].startswith(':') and tokens[-1][-1].isdigit() and (subtok != '-of'): 83 | tokens.append(subtok.lstrip(tokenizer.INIT)) 84 | backreferences.append(-1) 85 | current_token_i += 1 86 | 87 | # in any other case attach to the previous 88 | else: 89 | tokens[-1] = tokens[-1] + subtok 90 | 91 | # strip INIT and fix byte-level 92 | tokens = [tokenizer.convert_tokens_to_string(list(t)).lstrip() if isinstance(t, str) else t for t in tokens] 93 | # tokens = [t.replace(tokenizer.INIT, '') if isinstance(t, str) else t for t in tokens] 94 | 95 | # unks are substituted with thing 96 | tokens = [t if t != '' else 'thing' for t in tokens] 97 | 98 | old_tokens = tokens 99 | old_backreferences = backreferences 100 | 101 | # Barack Obama -> "Barack Obama" 102 | tokens = [] 103 | backreferences = [] 104 | token_to_token_map = {} 105 | start_search = 0 106 | removed = 0 107 | while True: 108 | try: 109 | 110 | lit_start = old_tokens.index('', start_search) 111 | token_addition = old_tokens[start_search:lit_start] 112 | for i, t in enumerate(token_addition, start=start_search): 113 | token_to_token_map[i] = i - removed 114 | tokens += token_addition 115 | 116 | backreferences_addition = [token_to_token_map[b] if b > -1 else -1 for b in 117 | old_backreferences[start_search:lit_start]] 118 | backreferences += backreferences_addition 119 | 120 | lit_end = min(lit_start + 2, len(old_tokens) - 1) 121 | 122 | while lit_end < len(old_tokens): 123 | old_tok = old_tokens[lit_end] 124 | 125 | if isinstance(old_tok, str) and ( 126 | (old_tok.startswith(':') and len(old_tok) > 3) or (old_tok == '')): 127 | res_tok = old_tokens[lit_start + 1:lit_end] 128 | for i in range(lit_start, lit_end): 129 | token_to_token_map[i] = len(tokens) 130 | 131 | # Remove possible wrong None 132 | res = old_tokens[lit_start+1:lit_end] 133 | res = [str(r) for r in res if r is not None] 134 | res = '"' + '_'.join(res) + '"' 135 | 136 | removed += len(res_tok) 137 | start_search = lit_end 138 | tokens += [res, old_tok] 139 | backreferences += [-1, -1] 140 | break 141 | 142 | elif old_tok == '': 143 | res_tok = old_tokens[lit_start + 1:lit_end] 144 | for i in range(lit_start, lit_end + 1): 145 | token_to_token_map[i] = len(tokens) 146 | 147 | # Remove possible wrong None 148 | res = old_tokens[lit_start+1:lit_end] 149 | res = [str(r) for r in res if r is not None] 150 | res = '"' + '_'.join(res) + '"' 151 | 152 | removed += len(res_tok) + 1 153 | start_search = lit_end + 1 154 | tokens.append(res) 155 | backreferences.append(-1) 156 | break 157 | 158 | else: 159 | lit_end += 1 160 | start_search = lit_end 161 | 162 | except ValueError: 163 | token_addition = old_tokens[start_search:] 164 | for i, t in enumerate(token_addition, start=start_search): 165 | token_to_token_map[i] = i - removed 166 | backreferences_addition = [token_to_token_map[b] if b > -1 else b for b in 167 | old_backreferences[start_search:]] 168 | tokens += token_addition 169 | backreferences += backreferences_addition 170 | break 171 | 172 | tokens = [token_processing(t) for t in tokens] 173 | 174 | shift = 1 175 | if tokens[1] == '': 176 | shift = 2 177 | 178 | tokens = tokens[shift:] 179 | backreferences = [b if b == -1 else b - shift for b in backreferences[shift:]] 180 | 181 | if tokens[-1] == '': 182 | tokens.pop() 183 | backreferences.pop() 184 | 185 | return tokens, backreferences 186 | 187 | 188 | def index_of(element, iterable, default=None, start=None, end=None): 189 | if not callable(element): 190 | def check(x): 191 | return element == x 192 | else: 193 | check = element 194 | if start is None: 195 | start = 0 196 | if end is None: 197 | end = len(iterable) 198 | item = start 199 | while item < end: 200 | if check(iterable[item]): 201 | return item 202 | item += 1 203 | return default 204 | 205 | 206 | def separate_edges_nodes(edges_nodes_slice, *other): 207 | is_arg = lambda x: isinstance(x, str) and x.startswith(':') 208 | start = 0 209 | edges = [] 210 | nodes = [] 211 | l = len(edges_nodes_slice) 212 | while start < l: 213 | edge_index = index_of( 214 | is_arg, 215 | edges_nodes_slice, 216 | start=start) 217 | if edge_index is None or edge_index == (l - 1): 218 | break 219 | if is_arg(edges_nodes_slice[edge_index + 1]): 220 | start = edge_index + 1 221 | continue 222 | edges.append(edge_index) 223 | nodes.append(edge_index + 1) 224 | start = edge_index + 2 225 | ret = [] 226 | for oth in other: 227 | edges_oth = [oth[i] for i in edges] 228 | nodes_oth = [oth[i] for i in nodes] 229 | ret.append((edges_oth, nodes_oth)) 230 | return ret 231 | 232 | def _split_name_ops(graph): 233 | # identify name triples 234 | name_vars = {} 235 | for i, (v1, rel, v2) in enumerate(graph.triples): 236 | if rel == ':instance' and v2 == 'name': 237 | name_vars[v1] = 1 238 | 239 | # check if they have ops 240 | name_vars_to_ops = defaultdict(list) 241 | for i, (v1, rel, v2) in enumerate(graph.triples): 242 | if v1 in name_vars and rel.startswith(':op'): 243 | name_vars_to_ops[v1].append((i, rel, v2.strip('"'))) 244 | 245 | triples = graph.triples.copy() 246 | for nv, ops in name_vars_to_ops.items(): 247 | ops = sorted(ops, key=lambda x: int(x[1][3:])) 248 | idx, _, lits = zip(*ops) 249 | for i in idx: 250 | triples[i] = None 251 | 252 | lits = ['"' + l + '"' for lit in lits for l in lit.split('_')] 253 | 254 | tt = [] 255 | for i, l in enumerate(lits, start=1): 256 | rel = ':op' + str(i) 257 | tt.append(penman.Triple(nv, rel, l)) 258 | 259 | triples[min(idx)] = tt 260 | 261 | triples = [t if isinstance(t, list) else [t] for t in triples if t is not None] 262 | triples = [t for tt in triples for t in tt] 263 | 264 | graph_ = penman.Graph(triples) 265 | graph_.metadata = graph.metadata 266 | return graph_ 267 | 268 | def _reconstruct_graph_from_nodes(nodes, backreferences): 269 | triples = [] 270 | triples_added = set() 271 | 272 | variable2index = {} 273 | index2variable = {} 274 | start_index = 0 275 | 276 | cnt = defaultdict(Counter) 277 | 278 | while start_index < len(nodes): 279 | stop_index = index_of('', nodes, default=len(nodes) + 1, start=start_index) 280 | old_start_index = start_index 281 | start_index = stop_index + 1 282 | 283 | src_node, src_backr = nodes[old_start_index], backreferences[old_start_index] 284 | 285 | if src_node == '': 286 | continue 287 | 288 | trg_nodes_edges = nodes[old_start_index:stop_index] 289 | trg_nodes_edges_backr = backreferences[old_start_index:stop_index] 290 | trg_nodes_edges_indices = list(range(old_start_index, stop_index)) 291 | 292 | if isinstance(src_node, str): 293 | if src_node in ('', '', ''): 294 | continue 295 | elif ('/' in src_node) or (':' in src_node) or ('(' in src_node) or (')' in src_node): 296 | src_node = 'thing' 297 | 298 | if src_node is not None: 299 | src_node = str(src_node) 300 | src_var = src_node[0].lower() 301 | if not src_var not in 'abcdefghijklmnopqrstuvwxyz': 302 | src_var = 'x' 303 | #src_var = f'{src_var}_{len(variable2index)}' 304 | src_var = f'{src_var}{len(variable2index)}' 305 | src_var_i = old_start_index 306 | variable2index[src_var] = src_var_i 307 | index2variable[src_var_i] = src_var 308 | triple = penman.Triple(src_var, ':instance', src_node) 309 | if triple not in triples_added: 310 | triples.append(triple) 311 | triples_added.add(triple) 312 | else: 313 | if src_backr in index2variable: 314 | src_var = index2variable[src_backr] 315 | # more resilient logic here 316 | (trg_edges, trg_nodes), (_, trg_nodes_backr), (_, trg_nodes_indices) = \ 317 | separate_edges_nodes( 318 | trg_nodes_edges, 319 | trg_nodes_edges, 320 | trg_nodes_edges_backr, 321 | trg_nodes_edges_indices) 322 | 323 | for n, e, nb, ni in zip(trg_nodes, trg_edges, trg_nodes_backr, trg_nodes_indices): 324 | 325 | if isinstance(n, str) and n.startswith(':'): 326 | continue 327 | if isinstance(n, str) and n.startswith('<') and n.endswith('>'): 328 | continue 329 | if e == ':li': 330 | pass 331 | elif len(e) < 4 or (not e.startswith(':')): 332 | continue 333 | 334 | # same edge more than once 335 | num = cnt[src_var][e] 336 | # num = 0 337 | if num: 338 | 339 | if e.startswith(':op') or e.startswith(':snt'): 340 | continue 341 | #elif e.startswith(':ARG'): 342 | # continue 343 | elif num > 3: 344 | continue 345 | 346 | if n is None: 347 | if nb not in index2variable: 348 | continue 349 | trg_var = index2variable[nb] 350 | trg = trg_var 351 | elif e == ':mode': 352 | trg = n 353 | elif (not isinstance(n, str)) or re.match(r"^[+-]?\d+\.?\d*$", n) or (n == '-') or (n == '+'): 354 | trg = str(n) 355 | elif (n.startswith('"') and n.endswith('"') and len(n) > 2): 356 | trg = '"' + n.replace('"', '') + '"' 357 | elif ('/' in n) or (':' in n) or ('(' in n) or (')' in n) or ('=' in n): 358 | trg = f'"{n}"' 359 | elif n == '"': 360 | continue 361 | elif (n.startswith('"') and (not n.endswith('"'))) or (not n.startswith('"') and (n.endswith('"'))) or ('"' in n): 362 | trg = '"' + n.replace('"', '') + '"' 363 | else: 364 | trg_var = n[0].lower() 365 | if trg_var not in 'abcdefghijklmnopqrstuvwxyz': 366 | trg_var = 'x' 367 | #trg_var = f'{trg_var}_{len(variable2index)}' 368 | trg_var = f'{trg_var}{len(variable2index)}' 369 | trg_var_i = ni 370 | variable2index[trg_var] = trg_var_i 371 | index2variable[trg_var_i] = trg_var 372 | triple = penman.Triple(trg_var, ':instance', n) 373 | if triple not in triples_added: 374 | triples.append(triple) 375 | triples_added.add(triple) 376 | trg = trg_var 377 | 378 | triple = penman.Triple(src_var, e, trg) 379 | if triple not in triples_added: 380 | triples.append(triple) 381 | triples_added.add(triple) 382 | 383 | cnt[src_var][e] += 1 384 | 385 | return penman.Graph(triples) 386 | 387 | def build_graph(nodes, backreferences, restore_name_ops=False): 388 | graph = _reconstruct_graph_from_nodes(nodes, backreferences) 389 | if restore_name_ops: 390 | graph = _split_name_ops(graph) 391 | return graph 392 | 393 | class ParsedStatus(enum.Enum): 394 | OK = 0 395 | FIXED = 1 396 | BACKOFF = 2 397 | 398 | def connect_graph_if_not_connected(graph): 399 | 400 | try: 401 | encoded = encode(graph) 402 | return graph, ParsedStatus.OK 403 | except: 404 | pass 405 | 406 | nxgraph = nx.MultiGraph() 407 | variables = graph.variables() 408 | for v1, _, v2 in graph.triples: 409 | if v1 in variables and v2 in variables: 410 | nxgraph.add_edge(v1, v2) 411 | elif v1 in variables: 412 | nxgraph.add_edge(v1, v1) 413 | 414 | triples = graph.triples.copy() 415 | new_triples = [] 416 | addition = f'a{len(variables) + 1}' 417 | triples.append(penman.Triple(addition, ':instance', 'and')) 418 | for i, conn_set in enumerate(nx.connected_components(nxgraph), start=1): 419 | edge = f':op{i}' 420 | conn_set = sorted(conn_set, key=lambda x: int(x[1:])) 421 | conn_set = [c for c in conn_set if c in variables] 422 | node = conn_set[0] 423 | new_triples.append(penman.Triple(addition, edge, node)) 424 | triples = new_triples + triples 425 | metadata = graph.metadata 426 | graph = penman.Graph(triples) 427 | graph.metadata.update(metadata) 428 | encode(graph) 429 | 430 | return graph, ParsedStatus.FIXED 431 | 432 | def restore_backreferences_from_pointers(nodes): 433 | new_nodes, new_backreferences = [], [] 434 | prev_pointer = None 435 | pointer2i = {} 436 | for n in nodes: 437 | is_pointer = isinstance(n, str) and n.startswith('') 438 | 439 | if not is_pointer: 440 | if prev_pointer is not None: 441 | if prev_pointer in pointer2i: 442 | new_nodes.append(None) 443 | new_backreferences.append(pointer2i[prev_pointer]) 444 | new_nodes.append(n) 445 | new_backreferences.append(-1) 446 | 447 | else: 448 | pointer2i[prev_pointer] = len(new_nodes) 449 | new_nodes.append(n) 450 | new_backreferences.append(-1) 451 | else: 452 | new_nodes.append(n) 453 | new_backreferences.append(-1) 454 | 455 | prev_pointer = None 456 | else: 457 | prev_pointer = n 458 | return new_nodes, new_backreferences -------------------------------------------------------------------------------- /spring_amr/tokenization_bart.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import sys 3 | from pathlib import Path 4 | 5 | import penman 6 | import regex as re 7 | import torch 8 | from transformers import BartTokenizer 9 | 10 | from spring_amr import ROOT, postprocessing 11 | from spring_amr.linearization import AMRTokens, AMRLinearizer 12 | from spring_amr.penman import encode 13 | 14 | 15 | class AMRBartTokenizer(BartTokenizer): 16 | 17 | INIT = 'Ġ' 18 | 19 | ADDITIONAL = [ 20 | AMRTokens.PNTR_N, 21 | AMRTokens.STOP_N, 22 | AMRTokens.LIT_START, 23 | AMRTokens.LIT_END, 24 | AMRTokens.BACKR_SRC_N, 25 | AMRTokens.BACKR_TRG_N,] 26 | 27 | def __init__(self, *args, use_pointer_tokens=False, collapse_name_ops=False, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | self.patterns = re.compile( 30 | r""" ?<[a-z]+:?\d*>| ?:[^\s]+|'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 31 | self.linearizer = AMRLinearizer(use_pointer_tokens=use_pointer_tokens, collapse_name_ops=collapse_name_ops) 32 | self.use_pointer_tokens = use_pointer_tokens 33 | self.collapse_name_ops = collapse_name_ops 34 | self.recategorizations = set() 35 | self.modified = 0 36 | 37 | @classmethod 38 | def from_pretrained(cls, pretrained_model_path, pred_min=5, *args, **kwargs): 39 | inst = super().from_pretrained(pretrained_model_path, *args, **kwargs) 40 | inst.init_amr_vocabulary(pred_min=pred_min) 41 | return inst 42 | 43 | def init_amr_vocabulary(self, pred_min=5): 44 | for tok in [self.bos_token, self.eos_token, self.pad_token, '', '']: 45 | ntok = self.INIT + tok 46 | i = self.encoder[tok] 47 | self.decoder[i] = ntok 48 | del self.encoder[tok] 49 | self.encoder[ntok] = i 50 | 51 | tokens = [] 52 | for line in Path(ROOT/'data/vocab/predicates.txt').read_text().strip().splitlines(): 53 | tok, count = line.split() 54 | if int(count) >= pred_min: 55 | tokens.append(tok) 56 | 57 | for tok in Path(ROOT/'data/vocab/additions.txt').read_text().strip().splitlines(): 58 | tokens.append(tok) 59 | 60 | for tok in Path(ROOT/'data/vocab/recategorizations.txt').read_text().strip().splitlines(): 61 | if not tok.startswith('_'): 62 | self.recategorizations.add(tok) 63 | tokens.append(tok) 64 | 65 | if self.use_pointer_tokens: 66 | for cnt in range(512): 67 | tokens.append(f"") 68 | 69 | tokens += self.ADDITIONAL 70 | tokens = [self.INIT + t if t[0] not in ('_', '-') else t for t in tokens] 71 | tokens = [t for t in tokens if t not in self.encoder] 72 | self.old_enc_size = old_enc_size = len(self.encoder) 73 | for i, t in enumerate(tokens, start= old_enc_size): 74 | self.encoder[t] = i 75 | 76 | self.encoder = {k: i for i, (k,v) in enumerate(sorted(self.encoder.items(), key=lambda x: x[1]))} 77 | self.decoder = {v: k for k, v in sorted(self.encoder.items(), key=lambda x: x[1])} 78 | self.modified = len(tokens) 79 | 80 | self.bos_token = self.INIT + '' 81 | self.pad_token = self.INIT + '' 82 | self.eos_token = self.INIT + '' 83 | self.unk_token = self.INIT + '' 84 | 85 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 86 | output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] 87 | if token_ids_1 is None: 88 | return output 89 | return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] 90 | 91 | def _tokenize(self, text): 92 | """ Tokenize a string. Modified in order to handle sentences with recategorization pointers""" 93 | bpe_tokens = [] 94 | for tok_span in text.lstrip().split(' '): 95 | tok_span = tok_span.strip() 96 | recats = tok_span.rsplit('_', 1) 97 | if len(recats) == 2 and recats[0] in self.recategorizations and ('_' + recats[1]) in self.encoder: 98 | bpe_tokens.extend([self.INIT + recats[0], '_' + recats[1]]) 99 | else: 100 | for token in re.findall(self.pat, ' ' + tok_span): 101 | token = "".join( 102 | self.byte_encoder[b] for b in token.encode("utf-8") 103 | ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) 104 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) 105 | 106 | return bpe_tokens 107 | 108 | def _tok_bpe(self, token, add_space=True): 109 | # if add_space: 110 | # token = ' ' + token.lstrip() 111 | tokk = [] 112 | tok = token.strip() 113 | recats = tok.rsplit('_', 1) 114 | if len(recats) == 2 and recats[0] in self.recategorizations and ('_' + recats[1]) in self.encoder: 115 | tokk.extend([self.INIT + recats[0], '_' + recats[1]]) 116 | else: 117 | for tok in self.patterns.findall(' ' + token): 118 | tok = "".join( 119 | self.byte_encoder[b] for b in tok.encode("utf-8")) 120 | toks = self.bpe(tok).split(' ') 121 | tokk.extend(toks) 122 | return tokk 123 | 124 | def _get_nodes_and_backreferences(self, graph): 125 | lin = self.linearizer.linearize(graph) 126 | linearized_nodes, backreferences = lin.nodes, lin.backreferences 127 | return linearized_nodes, backreferences 128 | 129 | def tokenize_amr(self, graph): 130 | linearized_nodes, backreferences = self._get_nodes_and_backreferences(graph) 131 | 132 | bpe_tokens = [] 133 | bpe_backreferences = [] 134 | counter = 0 135 | 136 | for i, (backr, tokk) in enumerate(zip(backreferences, linearized_nodes)): 137 | is_in_enc = self.INIT + tokk in self.encoder 138 | is_rel = tokk.startswith(':') and len(tokk) > 1 139 | is_spc = tokk.startswith('<') and tokk.endswith('>') 140 | is_of = tokk.startswith(':') and tokk.endswith('-of') 141 | is_frame = re.match(r'.+-\d\d', tokk) is not None 142 | 143 | if tokk.startswith('"') and tokk.endswith('"'): 144 | tokk = tokk[1:-1].replace('_', ' ') 145 | bpe_toks = [self.INIT + AMRTokens.LIT_START] 146 | bpe_toks += self._tok_bpe(tokk, add_space=True) 147 | bpe_toks.append(self.INIT + AMRTokens.LIT_END) 148 | 149 | elif (is_rel or is_spc or is_frame or is_of): 150 | if is_in_enc: 151 | bpe_toks = [self.INIT + tokk] 152 | elif is_frame: 153 | bpe_toks = self._tok_bpe(tokk[:-3], add_space=True) + [tokk[-3:]] 154 | elif is_of: 155 | rel = tokk[:-3] 156 | if self.INIT + rel in self.encoder: 157 | bpe_toks = [self.INIT + rel, '-of'] 158 | else: 159 | bpe_toks = [self.INIT + ':'] + self._tok_bpe(rel[1:], add_space=True) + ['-of'] 160 | elif is_rel: 161 | bpe_toks = [self.INIT + ':'] + self._tok_bpe(tokk[1:], add_space=True) 162 | else: 163 | raise 164 | 165 | else: 166 | if is_in_enc: 167 | bpe_toks = [self.INIT + tokk] 168 | else: 169 | bpe_toks = self._tok_bpe(tokk, add_space=True) 170 | 171 | bpe_tokens.append(bpe_toks) 172 | 173 | if i == backr: 174 | bpe_backr = list(range(counter, counter + len(bpe_toks))) 175 | counter += len(bpe_toks) 176 | bpe_backreferences.append(bpe_backr) 177 | else: 178 | bpe_backreferences.append(bpe_backreferences[backr][0:1]) 179 | counter += 1 180 | bpe_tokens = [b for bb in bpe_tokens for b in bb] 181 | bpe_token_ids = [self.encoder.get(b, self.unk_token_id) for b in bpe_tokens] 182 | bpe_backreferences = [b for bb in bpe_backreferences for b in bb] 183 | return bpe_tokens, bpe_token_ids, bpe_backreferences 184 | 185 | def batch_encode_sentences(self, sentences, device=torch.device('cpu')): 186 | sentences = [s for s in sentences] 187 | extra = {'sentences': sentences} 188 | batch = super().batch_encode_plus(sentences, return_tensors='pt', pad_to_max_length=True) 189 | batch = {k: v.to(device) for k, v in batch.items()} 190 | return batch, extra 191 | 192 | def linearize(self, graph): 193 | shift = len(self.encoder) 194 | tokens, token_ids, backreferences = self.tokenize_amr(graph) 195 | extra = {'linearized_graphs': tokens, 'graphs': graph} 196 | token_uni_ids = \ 197 | [idx if i == b else b + shift for i, (idx, b) in enumerate(zip(token_ids, backreferences))] 198 | if token_uni_ids[-1] != (self.INIT + AMRTokens.EOS_N): 199 | tokens.append(self.INIT + AMRTokens.EOS_N) 200 | token_ids.append(self.eos_token_id) 201 | token_uni_ids.append(self.eos_token_id) 202 | backreferences.append(len(backreferences)) 203 | return token_uni_ids, extra 204 | 205 | def batch_encode_graphs(self, graphs, device=torch.device('cpu')): 206 | linearized, extras = zip(*[self.linearize(g) for g in graphs]) 207 | return self.batch_encode_graphs_from_linearized(linearized, extras, device=device) 208 | 209 | def batch_encode_graphs_from_linearized(self, linearized, extras=None, device=torch.device('cpu')): 210 | if extras is not None: 211 | batch_extra = {'linearized_graphs': [], 'graphs': []} 212 | for extra in extras: 213 | batch_extra['graphs'].append(extra['graphs']) 214 | batch_extra['linearized_graphs'].append(extra['linearized_graphs']) 215 | else: 216 | batch_extra = {} 217 | maxlen = 0 218 | batch = [] 219 | for token_uni_ids in linearized: 220 | maxlen = max(len(token_uni_ids), maxlen) 221 | batch.append(token_uni_ids) 222 | batch = [x + [self.pad_token_id] * (maxlen - len(x)) for x in batch] 223 | batch = torch.tensor(batch).to(device) 224 | batch = {'decoder_input_ids': batch[:, :-1], 'lm_labels': batch[:, 1:]} 225 | return batch, batch_extra 226 | 227 | def decode_amr(self, tokens, restore_name_ops=False): 228 | try: 229 | nodes, backreferences = postprocessing.decode_into_node_and_backreferences(tokens, self) 230 | except Exception as e: 231 | print('Decoding failure:', file=sys.stderr) 232 | print(e, file=sys.stderr) 233 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None) 234 | if self.use_pointer_tokens: 235 | nodes, backreferences = postprocessing.restore_backreferences_from_pointers(nodes) 236 | try: 237 | graph_ = graph = postprocessing.build_graph(nodes, backreferences, restore_name_ops=restore_name_ops) 238 | except Exception as e: 239 | print('Building failure:', file=sys.stderr) 240 | print(nodes, file=sys.stderr) 241 | print(backreferences, file=sys.stderr) 242 | print(e, file=sys.stderr) 243 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None) 244 | try: 245 | graph, status = postprocessing.connect_graph_if_not_connected(graph) 246 | if status == postprocessing.ParsedStatus.BACKOFF: 247 | print('Reconnection 1 failure:') 248 | print(nodes, file=sys.stderr) 249 | print(backreferences, file=sys.stderr) 250 | print(graph_, file=sys.stderr) 251 | return graph, status, (nodes, backreferences) 252 | except Exception as e: 253 | print('Reconnction 2 failure:', file=sys.stderr) 254 | print(e, file=sys.stderr) 255 | print(nodes, file=sys.stderr) 256 | print(backreferences, file=sys.stderr) 257 | print(graph_, file=sys.stderr) 258 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (nodes, backreferences) 259 | 260 | class PENMANBartTokenizer(AMRBartTokenizer): 261 | 262 | def __init__(self, *args, raw_graph=False, **kwargs): 263 | super().__init__(*args, **kwargs) 264 | self.linearizer = None 265 | self.remove_pars = False 266 | self.raw_graph = raw_graph 267 | 268 | def _tokenize_encoded_graph(self, encoded): 269 | linearized = re.sub(r"(\".+?\")", r' \1 ', encoded) 270 | pieces = [] 271 | for piece in linearized.split(): 272 | if piece.startswith('"') and piece.endswith('"'): 273 | pieces.append(piece) 274 | else: 275 | piece = piece.replace('(', ' ( ') 276 | piece = piece.replace(')', ' ) ') 277 | piece = piece.replace(':', ' :') 278 | piece = piece.replace('/', ' / ') 279 | piece = piece.strip() 280 | pieces.append(piece) 281 | linearized = re.sub(r'\s+', ' ', ' '.join(pieces)).strip() 282 | linearized_nodes = [AMRTokens.BOS_N] + linearized.split(' ') 283 | return linearized_nodes 284 | 285 | def tokenize_amr(self, graph): 286 | if self.raw_graph: 287 | graph_ = copy.deepcopy(graph) 288 | graph_.metadata = {} 289 | linearized = penman.encode(graph_) 290 | linearized = re.sub(r"\s+", ' ', linearized) 291 | bpe_tokens = [self.bos_token] + self._tokenize(linearized)[:1022] 292 | bpe_token_ids = [self.encoder.get(b, self.unk_token_id) for b in bpe_tokens] 293 | bpe_backreferences = list(range(len(bpe_token_ids))) 294 | return bpe_tokens, bpe_token_ids, bpe_backreferences 295 | else: 296 | return super().tokenize_amr(graph) 297 | 298 | def _get_nodes_and_backreferences(self, graph): 299 | graph_ = copy.deepcopy(graph) 300 | graph_.metadata = {} 301 | linearized = penman.encode(graph_) 302 | linearized_nodes = self._tokenize_encoded_graph(linearized) 303 | 304 | if self.use_pointer_tokens: 305 | remap = {} 306 | for i in range(1, len(linearized_nodes)): 307 | nxt = linearized_nodes[i] 308 | lst = linearized_nodes[i-1] 309 | if nxt == '/': 310 | remap[lst] = f'' 311 | i = 1 312 | linearized_nodes_ = [linearized_nodes[0]] 313 | while i < (len(linearized_nodes)): 314 | nxt = linearized_nodes[i] 315 | lst = linearized_nodes_[-1] 316 | if nxt in remap: 317 | if lst == '(' and linearized_nodes[i+1] == '/': 318 | nxt = remap[nxt] 319 | i += 1 320 | elif lst.startswith(':'): 321 | nxt = remap[nxt] 322 | linearized_nodes_.append(nxt) 323 | i += 1 324 | linearized_nodes = linearized_nodes_ 325 | if self.remove_pars: 326 | linearized_nodes = [n for n in linearized_nodes if n != '('] 327 | backreferences = list(range(len(linearized_nodes))) 328 | return linearized_nodes, backreferences 329 | 330 | def _classify(self, node): 331 | if not isinstance(node, str): 332 | return "CONST" 333 | elif node == 'i': 334 | return "I" 335 | elif re.match(r'^[a-z]\d*$', node) is not None: 336 | return "VAR" 337 | elif node[0].isdigit(): 338 | return "CONST" 339 | elif node.startswith('"') and node.endswith('"'): 340 | return "CONST" 341 | elif node in ('+', '-'): 342 | return "CONST" 343 | elif node == ':mode': 344 | return 'MODE' 345 | elif node.startswith(':'): 346 | return "EDGE" 347 | elif node in ['/', '(', ')']: 348 | return node 349 | elif node[0].isalpha(): 350 | for char in (',', ':', '/', '(', ')', '.', '!', '?', '\\'): 351 | if char in node: 352 | return "CONST" 353 | return "INST" 354 | else: 355 | return 'CONST' 356 | 357 | def _fix_and_make_graph(self, nodes): 358 | 359 | nodes_ = [] 360 | for n in nodes: 361 | if isinstance(n, str): 362 | if n.startswith('<') and n.endswith('>') and (not n.startswith('') 379 | if e != len(nxt) -1: 380 | pst = nxt[e+1:] 381 | nxt = nxt[:e+1] 382 | nodes_.append(nxt) 383 | if pst is not None: 384 | nodes_.append(pst) 385 | else: 386 | nodes_.append(nxt) 387 | i += 1 388 | nodes = nodes_ 389 | 390 | i = 1 391 | nodes_ = [nodes[0]] 392 | while i < len(nodes): 393 | nxt = nodes[i] 394 | if isinstance(nxt, str) and nxt.startswith(' 0: 570 | line = line[:i].strip() 571 | break 572 | old_line = line 573 | while True: 574 | open_count = len(re.findall(r'\(', line)) 575 | close_count = len(re.findall(r'\)', line)) 576 | if open_count > close_count: 577 | line += ')' * (open_count - close_count) 578 | elif close_count > open_count: 579 | for i in range(close_count - open_count): 580 | line = line.rstrip(')') 581 | line = line.rstrip(' ') 582 | if old_line == line: 583 | break 584 | old_line = line 585 | """ 586 | 587 | graph = penman.decode(linearized + ' ') 588 | triples = [] 589 | newvars = 2000 590 | for triple in graph.triples: 591 | x, rel, y = triple 592 | if x is None: 593 | pass 594 | elif rel == ':instance' and y is None: 595 | triples.append(penman.Triple(x, rel, 'thing')) 596 | elif y is None: 597 | var = f'z{newvars}' 598 | newvars += 1 599 | triples.append(penman.Triple(x, rel, var)) 600 | triples.append(penman.Triple(var, ':instance', 'thing')) 601 | else: 602 | triples.append(triple) 603 | graph = penman.Graph(triples) 604 | linearized = encode(graph) 605 | 606 | def fix_text(linearized=linearized): 607 | n = 0 608 | def _repl1(match): 609 | nonlocal n 610 | out = match.group(1) + match.group(2) + str(3000 + n) + ' / ' + match.group(2) + match.group(3) 611 | n += 1 612 | return out 613 | linearized = re.sub(r'(\(\s?)([a-z])([^\/:\)]+[:\)])', _repl1, linearized, 614 | flags=re.IGNORECASE | re.MULTILINE) 615 | 616 | def _repl2(match): 617 | return match.group(1) 618 | linearized = re.sub(r'(\(\s*[a-z][\d+]\s*\/\s*[^\s\)\(:\/]+\s*)((?:/\s*[^\s\)\(:\/]+\s*)+)', _repl2, 619 | linearized, 620 | flags=re.IGNORECASE | re.MULTILINE) 621 | 622 | # adds a ':' to args w/o it 623 | linearized = re.sub(r'([^:])(ARG)', r'\1 :\2', linearized) 624 | 625 | # removes edges with no node 626 | # linearized = re.sub(r':[^\s\)\(:\/]+?\s*\)', ')', linearized, flags=re.MULTILINE) 627 | 628 | return linearized 629 | 630 | linearized = fix_text(linearized) 631 | 632 | g = penman.decode(linearized) 633 | return g 634 | 635 | def decode_amr(self, tokens, restore_name_ops=None): 636 | try: 637 | if self.raw_graph: 638 | nodes = self._tokenize_encoded_graph(self.decode(tokens)) 639 | backreferences = list(range(len(nodes))) 640 | else: 641 | nodes, backreferences = postprocessing.decode_into_node_and_backreferences(tokens, self) 642 | nodes_ = nodes 643 | except Exception as e: 644 | print('Decoding failure:', file=sys.stderr) 645 | print(e, file=sys.stderr) 646 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None) 647 | try: 648 | graph_ = graph = self._fix_and_make_graph(nodes) 649 | if self.collapse_name_ops: 650 | graph_ = graph = postprocessing._split_name_ops(graph) 651 | except Exception as e: 652 | print('Building failure:', file=sys.stderr) 653 | print(nodes, file=sys.stderr) 654 | print(backreferences, file=sys.stderr) 655 | print(e, file=sys.stderr) 656 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None) 657 | try: 658 | graph, status = postprocessing.connect_graph_if_not_connected(graph) 659 | if status == postprocessing.ParsedStatus.BACKOFF: 660 | print('Reconnection 1 failure:') 661 | print(nodes, file=sys.stderr) 662 | print(backreferences, file=sys.stderr) 663 | print(graph_, file=sys.stderr) 664 | return graph, status, (nodes_, backreferences) 665 | except Exception as e: 666 | print('Reconnction 2 failure:', file=sys.stderr) 667 | print(e, file=sys.stderr) 668 | print(nodes, file=sys.stderr) 669 | print(backreferences, file=sys.stderr) 670 | print(graph_, file=sys.stderr) 671 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (nodes_, backreferences) 672 | -------------------------------------------------------------------------------- /spring_amr/utils.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from pathlib import Path 3 | 4 | import torch 5 | from transformers import AutoConfig 6 | 7 | from spring_amr.dataset import AMRDataset, AMRDatasetTokenBatcherAndLoader 8 | from spring_amr.modeling_bart import AMRBartForConditionalGeneration 9 | from spring_amr.tokenization_bart import AMRBartTokenizer, PENMANBartTokenizer 10 | 11 | 12 | def instantiate_model_and_tokenizer( 13 | name=None, 14 | checkpoint=None, 15 | additional_tokens_smart_init=True, 16 | dropout = 0.15, 17 | attention_dropout = 0.15, 18 | from_pretrained = True, 19 | init_reverse = False, 20 | collapse_name_ops = False, 21 | penman_linearization = False, 22 | use_pointer_tokens = False, 23 | raw_graph = False, 24 | ): 25 | if raw_graph: 26 | assert penman_linearization 27 | 28 | skip_relations = False 29 | 30 | if name is None: 31 | name = 'facebook/bart-large' 32 | 33 | if name == 'facebook/bart-base': 34 | tokenizer_name = 'facebook/bart-large' 35 | else: 36 | tokenizer_name = name 37 | 38 | config = AutoConfig.from_pretrained(name) 39 | config.output_past = False 40 | config.no_repeat_ngram_size = 0 41 | config.prefix = " " 42 | config.output_attentions = True 43 | config.dropout = dropout 44 | config.attention_dropout = attention_dropout 45 | 46 | if penman_linearization: 47 | tokenizer = PENMANBartTokenizer.from_pretrained( 48 | tokenizer_name, 49 | collapse_name_ops=collapse_name_ops, 50 | use_pointer_tokens=use_pointer_tokens, 51 | raw_graph=raw_graph, 52 | config=config, 53 | ) 54 | else: 55 | tokenizer = AMRBartTokenizer.from_pretrained( 56 | tokenizer_name, 57 | collapse_name_ops=collapse_name_ops, 58 | use_pointer_tokens=use_pointer_tokens, 59 | config=config, 60 | ) 61 | 62 | if from_pretrained: 63 | model = AMRBartForConditionalGeneration.from_pretrained(name, config=config) 64 | else: 65 | model = AMRBartForConditionalGeneration(config) 66 | 67 | model.resize_token_embeddings(len(tokenizer.encoder)) 68 | 69 | if additional_tokens_smart_init: 70 | modified = 0 71 | for tok, idx in tokenizer.encoder.items(): 72 | tok = tok.lstrip(tokenizer.INIT) 73 | 74 | if idx < tokenizer.old_enc_size: 75 | continue 76 | 77 | elif tok.startswith(''): 78 | tok_split = ['pointer', str(tok.split(':')[1].strip('>'))] 79 | 80 | elif tok.startswith('<'): 81 | continue 82 | 83 | elif tok.startswith(':'): 84 | 85 | if skip_relations: 86 | continue 87 | 88 | elif tok.startswith(':op'): 89 | tok_split = ['relation', 'operator', str(int(tok[3:]))] 90 | 91 | elif tok.startswith(':snt'): 92 | tok_split = ['relation', 'sentence', str(int(tok[4:]))] 93 | 94 | elif tok.startswith(':ARG'): 95 | tok_split = ['relation', 'argument', str(int(tok[4:]))] 96 | 97 | else: 98 | tok_split = ['relation'] + tok.lstrip(':').split('-') 99 | 100 | else: 101 | tok_split = tok.split('-') 102 | 103 | tok_split_ = tok_split 104 | tok_split = [] 105 | for s in tok_split_: 106 | s_ = s + tokenizer.INIT 107 | if s_ in tokenizer.encoder: 108 | tok_split.append(s_) 109 | else: 110 | tok_split.extend(tokenizer._tok_bpe(s)) 111 | 112 | vecs = [] 113 | for s in tok_split: 114 | idx_split = tokenizer.encoder.get(s, -1) 115 | if idx_split > -1: 116 | vec_split = model.model.shared.weight.data[idx_split].clone() 117 | vecs.append(vec_split) 118 | 119 | if vecs: 120 | vec = torch.stack(vecs, 0).mean(0) 121 | noise = torch.empty_like(vec) 122 | noise.uniform_(-0.1, +0.1) 123 | model.model.shared.weight.data[idx] = vec + noise 124 | modified += 1 125 | 126 | if init_reverse: 127 | model.init_reverse_model() 128 | 129 | if checkpoint is not None: 130 | model.load_state_dict(torch.load(checkpoint, map_location='cpu')['model']) 131 | 132 | return model, tokenizer 133 | 134 | 135 | def instantiate_loader( 136 | glob_pattn, 137 | tokenizer, 138 | batch_size=500, 139 | evaluation=True, 140 | out=None, 141 | use_recategorization=False, 142 | remove_longer_than=None, 143 | remove_wiki=False, 144 | dereify=True, 145 | ): 146 | paths = [] 147 | if isinstance(glob_pattn, str) or isinstance(glob_pattn, Path): 148 | glob_pattn = [glob_pattn] 149 | for gpattn in glob_pattn: 150 | paths += [Path(p) for p in glob(gpattn)] 151 | if evaluation: 152 | assert out is not None 153 | Path(out).write_text( 154 | '\n\n'.join([p.read_text() for p in paths])) 155 | dataset = AMRDataset( 156 | paths, 157 | tokenizer, 158 | use_recategorization=use_recategorization, 159 | remove_longer_than=remove_longer_than, 160 | remove_wiki=remove_wiki, 161 | dereify=dereify, 162 | ) 163 | loader = AMRDatasetTokenBatcherAndLoader( 164 | dataset, 165 | batch_size=batch_size, 166 | shuffle=not evaluation, 167 | ) 168 | return loader 169 | --------------------------------------------------------------------------------