├── .Rhistory ├── .idea ├── AMR_AS_GRAPH.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── AMR_FEATURE ├── .classpath ├── .gitignore ├── .project ├── .settings │ └── org.eclipse.jdt.core.prefs ├── bin │ ├── convertingAMR.class │ └── json-20170516.jar ├── joints.txt └── src │ ├── convertingAMR.java │ └── json-20170516.jar ├── README.md ├── data ├── aux_dict ├── category_dict ├── graph_to_node_dict.txt ├── graph_to_node_dict_extended_without_jamr ├── graph_to_node_dict_extended_without_jamr.txt ├── high_dict ├── joints.txt ├── lemma_dict ├── ner_dict ├── non_rule_set ├── pos_dict ├── rel_dict ├── rule_f_without_jamr ├── sensed_dict └── word_dict ├── np_sents.txt ├── np_sents.txt_parsed ├── parser ├── AMRProcessors.py ├── DataIterator.py ├── Dict.py ├── Optim.py ├── __init__.py ├── __pycache__ │ ├── AMRProcessors.cpython-36.pyc │ ├── DataIterator.cpython-36.pyc │ ├── Dict.cpython-36.pyc │ ├── Optim.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── models │ ├── ConceptModel.py │ ├── MultiPassRelModel.py │ ├── __init__.py │ └── __pycache__ │ │ ├── ConceptModel.cpython-36.pyc │ │ ├── MultiPassRelModel.cpython-36.pyc │ │ └── __init__.cpython-36.pyc └── modules │ ├── GumbelSoftMax.py │ ├── __initial__.py │ ├── __pycache__ │ ├── GumbelSoftMax.cpython-36.pyc │ └── helper_module.cpython-36.pyc │ └── helper_module.py ├── src ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── train.cpython-36.pyc ├── data_build.py ├── generate.py ├── parse.py ├── preprocessing.py ├── rule_system_build.py └── train.py └── utility ├── AMRGraph.py ├── Naive_Scores.py ├── PropbankReader.py ├── ReCategorization.py ├── StringCopyRules.py ├── __init__.py ├── __init__.pyc ├── __pycache__ ├── AMRGraph.cpython-36.pyc ├── Naive_Scores.cpython-36.pyc ├── PropbankReader.cpython-36.pyc ├── ReCategorization.cpython-36.pyc ├── StringCopyRules.cpython-36.pyc ├── __init__.cpython-36.pyc ├── amr.cpython-36.pyc ├── constants.cpython-36.pyc └── data_helper.cpython-36.pyc ├── amr.peg ├── amr.py ├── amr.pyc ├── constants.py ├── constants.pyc └── data_helper.py /.Rhistory: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/.Rhistory -------------------------------------------------------------------------------- /.idea/AMR_AS_GRAPH.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /AMR_FEATURE/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /AMR_FEATURE/.gitignore: -------------------------------------------------------------------------------- 1 | /.metadata/ 2 | -------------------------------------------------------------------------------- /AMR_FEATURE/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | AMR_FEATURE 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | 15 | org.eclipse.jdt.core.javanature 16 | 17 | 18 | -------------------------------------------------------------------------------- /AMR_FEATURE/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=1.8 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 11 | org.eclipse.jdt.core.compiler.source=1.8 12 | -------------------------------------------------------------------------------- /AMR_FEATURE/bin/convertingAMR.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/AMR_FEATURE/bin/convertingAMR.class -------------------------------------------------------------------------------- /AMR_FEATURE/bin/json-20170516.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/AMR_FEATURE/bin/json-20170516.jar -------------------------------------------------------------------------------- /AMR_FEATURE/joints.txt: -------------------------------------------------------------------------------- 1 | have to 2 | at all 3 | so far 4 | more than 5 | less than 6 | no one 7 | as well 8 | at least 9 | right wing 10 | left wing 11 | as long as 12 | all over 13 | of course 14 | kind of 15 | after all 16 | by oneself 17 | by the way 18 | in fact 19 | be all 20 | head up 21 | come out 22 | coop up 23 | seize up 24 | bust up 25 | hang out 26 | limber up 27 | quieten down 28 | crack up 29 | fuck up 30 | get out 31 | clear out 32 | rip up 33 | rock on 34 | shout down 35 | bundle up 36 | pump up 37 | smooth out 38 | set down 39 | drop off 40 | think over 41 | core out 42 | tidy up 43 | make off 44 | fight on 45 | set out 46 | think up 47 | try out 48 | sign in 49 | take out 50 | top off 51 | nail down 52 | block up 53 | cash in 54 | fork out 55 | mark down 56 | rattle off 57 | bandage up 58 | sleep over 59 | patch up 60 | freeze over 61 | seal off 62 | free up 63 | clown around 64 | tear down 65 | dust off 66 | live up 67 | cut loose 68 | louse up 69 | sit down 70 | stand by 71 | take up 72 | steal away 73 | lay off 74 | turn in 75 | meet up 76 | check up 77 | taper off 78 | dole out 79 | catch up 80 | shape up 81 | tax away 82 | pass off 83 | give in 84 | speak up 85 | call upon 86 | stall out 87 | butt in 88 | carve out 89 | step up 90 | trigger off 91 | prop up 92 | scoop up 93 | summon forth 94 | boss around 95 | cool down 96 | give back 97 | cut down 98 | jot down 99 | doze off 100 | drum up 101 | bog down 102 | throw out 103 | shy away 104 | frost over 105 | rack up 106 | even out 107 | light up 108 | shack up 109 | bone up 110 | cut out 111 | sum up 112 | shut up 113 | send out 114 | pine away 115 | take over 116 | gobble up 117 | shoot back 118 | lay on 119 | swear off 120 | spread out 121 | pin down 122 | find out 123 | drag on 124 | thaw out 125 | bump off 126 | fatten up 127 | get back 128 | arm up 129 | load up 130 | give vent 131 | top up 132 | bounce back 133 | bad off 134 | come by 135 | single out 136 | call out 137 | slow down 138 | ask out 139 | slice up 140 | roll up 141 | divide up 142 | hold over 143 | touch off 144 | pass out 145 | have mod 146 | screw up 147 | iron out 148 | tell on 149 | dry out 150 | zero out 151 | rev up 152 | request confirmation 153 | scrawl out 154 | tie in 155 | pass up 156 | scratch out 157 | miss out 158 | root out 159 | frighten off 160 | have subevent 161 | go on 162 | follow through 163 | lighten up 164 | trade off 165 | carry over 166 | pay out 167 | mellow out 168 | fool around 169 | get down 170 | stretch out 171 | run down 172 | scrub up 173 | splash out 174 | stop by 175 | touch upon 176 | dig out 177 | stick around 178 | act out 179 | pass by 180 | watch out 181 | share out 182 | shut out 183 | get along 184 | go through 185 | tease out 186 | kill off 187 | slug out 188 | bottom out 189 | tie down 190 | neaten up 191 | dress down 192 | turn off 193 | bandy around 194 | yammer away 195 | gulp down 196 | cut back 197 | chatter away 198 | glaze over 199 | drop by 200 | slack off 201 | fess up 202 | seek out 203 | creep out 204 | hold up 205 | knock up 206 | shine through 207 | fence off 208 | zero in 209 | flip out 210 | rein in 211 | screen out 212 | cheer up 213 | saw up 214 | sign off 215 | flatten out 216 | heat up 217 | add on 218 | clip off 219 | doll up 220 | touch on 221 | fall off 222 | suit up 223 | palm off 224 | mist over 225 | flesh out 226 | burn up 227 | sweat out 228 | work up 229 | brazen out 230 | peel off 231 | pay up 232 | get even 233 | fill out 234 | whip up 235 | shout out 236 | kick in 237 | draw up 238 | thrash out 239 | head off 240 | come in 241 | break up 242 | speed up 243 | spout off 244 | type up 245 | polish off 246 | trot out 247 | puke up 248 | bank up 249 | rip off 250 | dry up 251 | settle down 252 | cry out 253 | go out 254 | face off 255 | ride up 256 | buckle up 257 | pair up 258 | come off 259 | auction off 260 | roll back 261 | throw in 262 | eat up 263 | suck up 264 | shut down 265 | wipe out 266 | nod off 267 | choke off 268 | sleep off 269 | stand up 270 | frost up 271 | join in 272 | mix up 273 | crisp up 274 | knock out 275 | talk out 276 | set off 277 | sit in 278 | bang on 279 | flake out 280 | take off 281 | queue up 282 | square off 283 | make over 284 | ramp up 285 | let down 286 | toss out 287 | finish up 288 | blow over 289 | sound off 290 | cut up 291 | rough in 292 | blot out 293 | stave off 294 | stop off 295 | act up 296 | scout out 297 | pay off 298 | beat out 299 | copy out 300 | wolf down 301 | have manner 302 | get through 303 | break off 304 | drug up 305 | pump out 306 | take hold 307 | polish up 308 | pucker up 309 | write off 310 | shell out 311 | come over 312 | color in 313 | tamp down 314 | shut off 315 | have mode 316 | strike up 317 | beat up 318 | sweep up 319 | come up 320 | blast off 321 | lie in 322 | warm over 323 | ratchet up 324 | bump up 325 | play out 326 | look out 327 | tip over 328 | fudge over 329 | warm up 330 | throw away 331 | crank up 332 | tip off 333 | have quant 334 | go back 335 | roll out 336 | trim down 337 | set up 338 | rake in 339 | piss off 340 | give over 341 | buoy up 342 | pen up 343 | touch up 344 | parcel out 345 | boom out 346 | give off 347 | jump up 348 | leave over 349 | tone down 350 | dream on 351 | lock in 352 | win over 353 | stop over 354 | turn over 355 | play on 356 | edge out 357 | get up 358 | leave off 359 | finish off 360 | slim down 361 | wall off 362 | puff up 363 | plug up 364 | write out 365 | let out 366 | stop up 367 | calm down 368 | bring about 369 | phase out 370 | belly up 371 | break down 372 | stick up 373 | lock up 374 | pull out 375 | set upon 376 | jet off 377 | pay down 378 | fart around 379 | zone out 380 | bear out 381 | take away 382 | bleed off 383 | write up 384 | lash out 385 | lam out 386 | tie up 387 | siphon off 388 | dress up 389 | stamp out 390 | black out 391 | snuff out 392 | whip out 393 | go off 394 | ease up 395 | tune out 396 | gun down 397 | freak out 398 | chop down 399 | strip away 400 | step down 401 | hit up 402 | read up 403 | chew up 404 | start out 405 | own up 406 | close down 407 | come upon 408 | cone down 409 | yield up 410 | get away 411 | gear up 412 | bring on 413 | figure out 414 | turn up 415 | check out 416 | bead up 417 | ship out 418 | crank out 419 | flush out 420 | let on 421 | put on 422 | usher in 423 | spin off 424 | knock off 425 | skim off 426 | pass on 427 | finish out 428 | instead of 429 | leave out 430 | frighten away 431 | buy up 432 | knock over 433 | straighten out 434 | wear off 435 | whiz away 436 | call on 437 | put out 438 | totter around 439 | salt away 440 | spell out 441 | creep up 442 | hold out 443 | sign up 444 | branch out 445 | mark up 446 | hail down 447 | pick out 448 | shoot off 449 | din out 450 | beef up 451 | get off 452 | break through 453 | smarten up 454 | help out 455 | buy out 456 | stake out 457 | take in 458 | do in 459 | come to 460 | sell out 461 | shore up 462 | hem in 463 | hang up 464 | boil over 465 | sort out 466 | wipe up 467 | curl up 468 | whack off 469 | track down 470 | dig up 471 | run out 472 | haul out 473 | plot out 474 | loan out 475 | coil up 476 | die off 477 | pipe down 478 | kick off 479 | come through 480 | print out 481 | pick away 482 | gloss over 483 | ring up 484 | go down 485 | read off 486 | pitch in 487 | choke up 488 | break in 489 | crack down 490 | boot up 491 | blurt out 492 | sluice down 493 | fill up 494 | spring up 495 | lock out 496 | pack up 497 | look over 498 | whittle down 499 | chicken out 500 | bandy about 501 | cart off 502 | plug in 503 | buy off 504 | pick on 505 | crash out 506 | total up 507 | pile on 508 | pan out 509 | prick up 510 | dish up 511 | stash away 512 | round up 513 | shoot up 514 | balance out 515 | bring along 516 | quiet down 517 | cut off 518 | vamp up 519 | run off 520 | pull down 521 | team up 522 | hold back 523 | hammer out 524 | stack up 525 | think through 526 | match up 527 | rise up 528 | have concession 529 | wipe off 530 | hash out 531 | come down 532 | sock away 533 | jump in 534 | hang on 535 | ferret out 536 | wake up 537 | brick over 538 | burst out 539 | tack down 540 | spike out 541 | use up 542 | carry on 543 | bottle up 544 | tighten up 545 | start up 546 | carry off 547 | speak out 548 | set about 549 | tag along 550 | hook up 551 | oil up 552 | fend off 553 | start over 554 | sit up 555 | sign on 556 | take down 557 | study up 558 | while away 559 | fold up 560 | cheer on 561 | bust out 562 | rate entity 563 | play down 564 | book up 565 | bind up 566 | stay on 567 | come about 568 | put up 569 | dine out 570 | have frequency 571 | store up 572 | give up 573 | vote down 574 | bring up 575 | tape up 576 | leave behind 577 | turn on 578 | save up 579 | break out 580 | wash up 581 | fork over 582 | hollow out 583 | freshen up 584 | screw over 585 | dash off 586 | have part 587 | mess up 588 | buy into 589 | burn out 590 | cave in 591 | lead up 592 | clear up 593 | cry down 594 | stand out 595 | turn away 596 | drown out 597 | run in 598 | cover up 599 | spill over 600 | die out 601 | farm out 602 | hand over 603 | poke around 604 | ride out 605 | come across 606 | give away 607 | tack on 608 | bow out 609 | squeeze out 610 | write in 611 | show up 612 | come on 613 | fix up 614 | sew up 615 | fort up 616 | do away 617 | liven up 618 | scrunch up 619 | log on 620 | ham up 621 | look down 622 | firm up 623 | tally up 624 | tool up 625 | weigh in 626 | flare up 627 | strike down 628 | thin out 629 | blast away 630 | reel off 631 | feed up 632 | camp out 633 | well off 634 | crop up 635 | be like 636 | open up 637 | link up 638 | lick up 639 | look up 640 | statistical test 641 | charge off 642 | drop out 643 | keep up 644 | tick off 645 | tune in 646 | write down 647 | bat in 648 | stay over 649 | gas up 650 | pick up 651 | cook up 652 | boil down 653 | pull through 654 | call off 655 | pop off 656 | hand out 657 | push up 658 | fritter away 659 | trail off 660 | chop up 661 | rear end 662 | fuck around 663 | rattle on 664 | tire out 665 | street address 666 | keep on 667 | pack away 668 | keg stand 669 | close off 670 | lose out 671 | wring out 672 | make believe 673 | soak up 674 | tee off 675 | shake up 676 | scent out 677 | steer clear 678 | have instrument 679 | tear up 680 | feel up 681 | live down 682 | bowl over 683 | step in 684 | hobnob around 685 | bow down 686 | buzz off 687 | tangle up 688 | catch on 689 | price out 690 | snap up 691 | live out 692 | touch base 693 | be done 694 | have li 695 | vomit up 696 | clean out 697 | laid back 698 | buckle down 699 | slip in 700 | swear in 701 | stall off 702 | shoot down 703 | be from 704 | serve up 705 | join up 706 | back up 707 | well up 708 | pull up 709 | put down 710 | wash down 711 | dish out 712 | age out 713 | fight back 714 | bring down 715 | run up 716 | zip up 717 | switch over 718 | spend down 719 | call up 720 | be polite 721 | pop up 722 | fall apart 723 | net out 724 | jut out 725 | wind up 726 | rent out 727 | cross out 728 | rough up 729 | broke ass 730 | dredge up 731 | wait out 732 | shuffle off 733 | build up 734 | box in 735 | shake off 736 | cool off 737 | get on 738 | hit on 739 | straighten up 740 | start off 741 | belch out 742 | lie down 743 | play up 744 | give out 745 | haul in 746 | hard put 747 | make up 748 | snap off 749 | follow suit 750 | pass away 751 | smooth over 752 | hole up 753 | turn out 754 | clog up 755 | sober up 756 | smash up 757 | contract out 758 | go over 759 | dope up 760 | bed down 761 | sit out 762 | hype up 763 | drop in 764 | put off 765 | ward off 766 | get together 767 | turn down 768 | back off 769 | swoop up 770 | out trade 771 | size up 772 | pull off 773 | conjure up 774 | stock up 775 | sleep away 776 | monkey around 777 | break away 778 | pile up 779 | put in 780 | dream up 781 | wrap up 782 | gum up 783 | bound up 784 | tuck away 785 | board up 786 | have purpose 787 | stick out 788 | fall out 789 | take aback 790 | chart out 791 | latch on 792 | belt out 793 | wear on 794 | muck up 795 | step aside 796 | lead off 797 | point out 798 | line up 799 | check in 800 | start in 801 | bunch up 802 | watch over 803 | fill in 804 | work out 805 | joke around 806 | hum along 807 | lock down 808 | wear out 809 | rip out 810 | bleed out 811 | come along 812 | play off 813 | show off 814 | have extent 815 | concrete over 816 | narrow down 817 | jack up 818 | stare down 819 | pipe up 820 | loosen up 821 | wear down 822 | bear up 823 | cover over 824 | have polarity 825 | mic up 826 | make do 827 | close over 828 | deck out 829 | blow out 830 | play to 831 | hammer away 832 | ration out 833 | sell off 834 | have name 835 | strike out 836 | shuttle off 837 | call in 838 | shrug off 839 | chalk up 840 | perk up 841 | knock down 842 | follow up 843 | pass over 844 | brush off 845 | drink up 846 | fly out 847 | close in 848 | grow up 849 | eat away 850 | have condition 851 | snatch away 852 | pick off 853 | stress out 854 | take on 855 | muddle up 856 | tuck in 857 | live on 858 | skip off 859 | look forward 860 | stir up 861 | bail out 862 | stand down 863 | close up 864 | run over 865 | throw up 866 | fuck off 867 | swallow up 868 | spill out 869 | fall back 870 | fight off 871 | rig up 872 | sweat off 873 | hide out 874 | divvy up 875 | flash back 876 | end up 877 | make it 878 | toss in 879 | round out 880 | sniff out 881 | grind up 882 | chip in 883 | cough up 884 | phase in 885 | let up 886 | water down 887 | hold on 888 | level off 889 | have value 890 | fit in 891 | yammer on 892 | key in 893 | hold off 894 | silt up 895 | get by 896 | split up 897 | make out 898 | look after 899 | rubber stamp 900 | sketch out 901 | pull over 902 | spruce up 903 | glass over 904 | add up 905 | mist up 906 | brush up 907 | wind down 908 | clutch on 909 | knock back 910 | pare down 911 | rule out 912 | fall through 913 | hack away 914 | asphalt over 915 | clean up 916 | pound out 917 | die down 918 | carry out 919 | fall over 920 | blow up 921 | weasel out 922 | break even 923 | -------------------------------------------------------------------------------- /AMR_FEATURE/src/json-20170516.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/AMR_FEATURE/src/json-20170516.jar -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMR AS GRAPH PREDICTION 2 | 3 | This repository contains code for training and using the Abstract Meaning Representation model described in: 4 | [AMR Parsing as Graph Prediction with Latent Alignment](https://arxiv.org/pdf/1805.05286.pdf) 5 | 6 | If you use our code, please cite our paper as follows: 7 | > @inproceedings{Lyu2018AMRPA, 8 | >     title={AMR Parsing as Graph Prediction with Latent Alignment}, 9 | >     author={Chunchuan Lyu and Ivan Titov}, 10 | >     booktitle={Proceedings of the Annual Meeting of the Association for Computational Linguistics}, 11 | >     year={2018} 12 | > } 13 | 14 | ## Prerequisites 15 | 16 | * Python 3.6 17 | * Stanford Corenlp 3.9.1 (the python wrapper is not compatible with the new one) 18 | * pytorch 0.20 19 | * [GloVe](https://nlp.stanford.edu/projects/glove/) embeddings 20 | * [AMR dataset and resources files](https://amr.isi.edu/download.html) 21 | 22 | ## Configuration 23 | 24 | * Set up [Stanford Corenlp server](https://stanfordnlp.github.io/CoreNLP/corenlp-server.html), which feature extraction relies on. 25 | * Change file paths in utility/constants.py accordingly. 26 | 27 | 28 | ## Preprocessing 29 | 30 | Either a) combine all `*.txt` files into a single one, and use Stanford CoreNLP to extract ner, pos and lemma. 31 | Processed file saved in the same folder. 32 | 33 | python src/preprocessing.py 34 | 35 | or b) process from [AMR-to-English aligner](https://www.isi.edu/natural-language/mt/amr_eng_align.pdf) using java script in AMR_FEATURE (I used Eclipse to run it). 36 | 37 | Build the copying dictionary and recategorization system (can skip as they are in data/). 38 | 39 | python src/rule_system_build.py 40 | 41 | Build data into tensor. 42 | 43 | python src/data_build.py 44 | 45 | ## Training 46 | 47 | Default model is saved in [save_to]/gpus_0valid_best.pt . (save_to is defined in constants.py) 48 | 49 | python src/train.py 50 | 51 | ## Testing 52 | 53 | Load model to parse from pre-build data. 54 | 55 | python src/generate.py -train_from [gpus_0valid_best.pt] 56 | 57 | ## Evaluation 58 | 59 | Please use [amr-evaluation-tool-enhanced](https://github.com/ChunchuanLv/amr-evaluation-tool-enhanced). 60 | This is based on Marco Damonte's [amr-evaluation-tool](https://github.com/mdtux89/amr-evaluation) 61 | But with correction concerning unlabeled edge score. 62 | 63 | ## Parsing 64 | 65 | Either a) parse a file where each line consists of a single sentence, output saved at `[file]_parsed` 66 | 67 | python src/parse.py -train_from [gpus_0valid_best.pt] -input [file] 68 | 69 | or b) parse a sentence where each line consists of a single sentence, output saved at `[file]_parsed` 70 | 71 | python src/parse.py -train_from [gpus_0valid_best.pt] -text [type sentence here] 72 | 73 | ## Pretrained models 74 | 75 | Keeping the files under data/ folder unchanged, download [model](https://drive.google.com/open?id=1jNTG3tuIfS-WoUpqGQydRgYWst51kjHx) 76 | Should allow one to run parsing. 77 | 78 | ## Notes 79 | 80 | This "python src/preprocessing.py" starts with sentence original AMR files, while the paper version is trained on tokenized version provided by [AMR-to-English aligner](https://www.isi.edu/natural-language/mt/amr_eng_align.pdf) 81 | So the results could be slightly different. Also, to build a parser for out of domain data, please start preprocessing with "python src/preprocessing.py" to make everything consistent. 82 | 83 | ## Contact 84 | 85 | Contact if you have any questions! 86 | -------------------------------------------------------------------------------- /data/aux_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/aux_dict -------------------------------------------------------------------------------- /data/category_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/category_dict -------------------------------------------------------------------------------- /data/graph_to_node_dict_extended_without_jamr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/graph_to_node_dict_extended_without_jamr -------------------------------------------------------------------------------- /data/high_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/high_dict -------------------------------------------------------------------------------- /data/joints.txt: -------------------------------------------------------------------------------- 1 | have to 2 | at all 3 | so far 4 | more than 5 | less than 6 | no one 7 | as well 8 | at least 9 | right wing 10 | left wing 11 | as long as 12 | all over 13 | of course 14 | kind of 15 | after all 16 | by oneself 17 | by the way 18 | in fact 19 | be all 20 | head up 21 | come out 22 | coop up 23 | seize up 24 | bust up 25 | hang out 26 | limber up 27 | quieten down 28 | crack up 29 | fuck up 30 | get out 31 | clear out 32 | rip up 33 | rock on 34 | shout down 35 | bundle up 36 | pump up 37 | smooth out 38 | set down 39 | drop off 40 | think over 41 | core out 42 | tidy up 43 | make off 44 | fight on 45 | set out 46 | think up 47 | try out 48 | sign in 49 | take out 50 | top off 51 | nail down 52 | block up 53 | cash in 54 | fork out 55 | mark down 56 | rattle off 57 | bandage up 58 | sleep over 59 | patch up 60 | freeze over 61 | seal off 62 | free up 63 | clown around 64 | tear down 65 | dust off 66 | live up 67 | cut loose 68 | louse up 69 | sit down 70 | stand by 71 | take up 72 | steal away 73 | lay off 74 | turn in 75 | meet up 76 | check up 77 | taper off 78 | dole out 79 | catch up 80 | shape up 81 | tax away 82 | pass off 83 | give in 84 | speak up 85 | call upon 86 | stall out 87 | butt in 88 | carve out 89 | step up 90 | trigger off 91 | prop up 92 | scoop up 93 | summon forth 94 | boss around 95 | cool down 96 | give back 97 | cut down 98 | jot down 99 | doze off 100 | drum up 101 | bog down 102 | throw out 103 | shy away 104 | frost over 105 | rack up 106 | even out 107 | light up 108 | shack up 109 | bone up 110 | cut out 111 | sum up 112 | shut up 113 | send out 114 | pine away 115 | take over 116 | gobble up 117 | shoot back 118 | lay on 119 | swear off 120 | spread out 121 | pin down 122 | find out 123 | drag on 124 | thaw out 125 | bump off 126 | fatten up 127 | get back 128 | arm up 129 | load up 130 | give vent 131 | top up 132 | bounce back 133 | bad off 134 | come by 135 | single out 136 | call out 137 | slow down 138 | ask out 139 | slice up 140 | roll up 141 | divide up 142 | hold over 143 | touch off 144 | pass out 145 | have mod 146 | screw up 147 | iron out 148 | tell on 149 | dry out 150 | zero out 151 | rev up 152 | request confirmation 153 | scrawl out 154 | tie in 155 | pass up 156 | scratch out 157 | miss out 158 | root out 159 | frighten off 160 | have subevent 161 | go on 162 | follow through 163 | lighten up 164 | trade off 165 | carry over 166 | pay out 167 | mellow out 168 | fool around 169 | get down 170 | stretch out 171 | run down 172 | scrub up 173 | splash out 174 | stop by 175 | touch upon 176 | dig out 177 | stick around 178 | act out 179 | pass by 180 | watch out 181 | share out 182 | shut out 183 | get along 184 | go through 185 | tease out 186 | kill off 187 | slug out 188 | bottom out 189 | tie down 190 | neaten up 191 | dress down 192 | turn off 193 | bandy around 194 | yammer away 195 | gulp down 196 | cut back 197 | chatter away 198 | glaze over 199 | drop by 200 | slack off 201 | fess up 202 | seek out 203 | creep out 204 | hold up 205 | knock up 206 | shine through 207 | fence off 208 | zero in 209 | flip out 210 | rein in 211 | screen out 212 | cheer up 213 | saw up 214 | sign off 215 | flatten out 216 | heat up 217 | add on 218 | clip off 219 | doll up 220 | touch on 221 | fall off 222 | suit up 223 | palm off 224 | mist over 225 | flesh out 226 | burn up 227 | sweat out 228 | work up 229 | brazen out 230 | peel off 231 | pay up 232 | get even 233 | fill out 234 | whip up 235 | shout out 236 | kick in 237 | draw up 238 | thrash out 239 | head off 240 | come in 241 | break up 242 | speed up 243 | spout off 244 | type up 245 | polish off 246 | trot out 247 | puke up 248 | bank up 249 | rip off 250 | dry up 251 | settle down 252 | cry out 253 | go out 254 | face off 255 | ride up 256 | buckle up 257 | pair up 258 | come off 259 | auction off 260 | roll back 261 | throw in 262 | eat up 263 | suck up 264 | shut down 265 | wipe out 266 | nod off 267 | choke off 268 | sleep off 269 | stand up 270 | frost up 271 | join in 272 | mix up 273 | crisp up 274 | knock out 275 | talk out 276 | set off 277 | sit in 278 | bang on 279 | flake out 280 | take off 281 | queue up 282 | square off 283 | make over 284 | ramp up 285 | let down 286 | toss out 287 | finish up 288 | blow over 289 | sound off 290 | cut up 291 | rough in 292 | blot out 293 | stave off 294 | stop off 295 | act up 296 | scout out 297 | pay off 298 | beat out 299 | copy out 300 | wolf down 301 | have manner 302 | get through 303 | break off 304 | drug up 305 | pump out 306 | take hold 307 | polish up 308 | pucker up 309 | write off 310 | shell out 311 | come over 312 | color in 313 | tamp down 314 | shut off 315 | have mode 316 | strike up 317 | beat up 318 | sweep up 319 | come up 320 | blast off 321 | lie in 322 | warm over 323 | ratchet up 324 | bump up 325 | play out 326 | look out 327 | tip over 328 | fudge over 329 | warm up 330 | throw away 331 | crank up 332 | tip off 333 | have quant 334 | go back 335 | roll out 336 | trim down 337 | set up 338 | rake in 339 | piss off 340 | give over 341 | buoy up 342 | pen up 343 | touch up 344 | parcel out 345 | boom out 346 | give off 347 | jump up 348 | leave over 349 | tone down 350 | dream on 351 | lock in 352 | win over 353 | stop over 354 | turn over 355 | play on 356 | edge out 357 | get up 358 | leave off 359 | finish off 360 | slim down 361 | wall off 362 | puff up 363 | plug up 364 | write out 365 | let out 366 | stop up 367 | calm down 368 | bring about 369 | phase out 370 | belly up 371 | break down 372 | stick up 373 | lock up 374 | pull out 375 | set upon 376 | jet off 377 | pay down 378 | fart around 379 | zone out 380 | bear out 381 | take away 382 | bleed off 383 | write up 384 | lash out 385 | lam out 386 | tie up 387 | siphon off 388 | dress up 389 | stamp out 390 | black out 391 | snuff out 392 | whip out 393 | go off 394 | ease up 395 | tune out 396 | gun down 397 | freak out 398 | chop down 399 | strip away 400 | step down 401 | hit up 402 | read up 403 | chew up 404 | start out 405 | own up 406 | close down 407 | come upon 408 | cone down 409 | yield up 410 | get away 411 | gear up 412 | bring on 413 | figure out 414 | turn up 415 | check out 416 | bead up 417 | ship out 418 | crank out 419 | flush out 420 | let on 421 | put on 422 | usher in 423 | spin off 424 | knock off 425 | skim off 426 | pass on 427 | finish out 428 | instead of 429 | leave out 430 | frighten away 431 | buy up 432 | knock over 433 | straighten out 434 | wear off 435 | whiz away 436 | call on 437 | put out 438 | totter around 439 | salt away 440 | spell out 441 | creep up 442 | hold out 443 | sign up 444 | branch out 445 | mark up 446 | hail down 447 | pick out 448 | shoot off 449 | din out 450 | beef up 451 | get off 452 | break through 453 | smarten up 454 | help out 455 | buy out 456 | stake out 457 | take in 458 | do in 459 | come to 460 | sell out 461 | shore up 462 | hem in 463 | hang up 464 | boil over 465 | sort out 466 | wipe up 467 | curl up 468 | whack off 469 | track down 470 | dig up 471 | run out 472 | haul out 473 | plot out 474 | loan out 475 | coil up 476 | die off 477 | pipe down 478 | kick off 479 | come through 480 | print out 481 | pick away 482 | gloss over 483 | ring up 484 | go down 485 | read off 486 | pitch in 487 | choke up 488 | break in 489 | crack down 490 | boot up 491 | blurt out 492 | sluice down 493 | fill up 494 | spring up 495 | lock out 496 | pack up 497 | look over 498 | whittle down 499 | chicken out 500 | bandy about 501 | cart off 502 | plug in 503 | buy off 504 | pick on 505 | crash out 506 | total up 507 | pile on 508 | pan out 509 | prick up 510 | dish up 511 | stash away 512 | round up 513 | shoot up 514 | balance out 515 | bring along 516 | quiet down 517 | cut off 518 | vamp up 519 | run off 520 | pull down 521 | team up 522 | hold back 523 | hammer out 524 | stack up 525 | think through 526 | match up 527 | rise up 528 | have concession 529 | wipe off 530 | hash out 531 | come down 532 | sock away 533 | jump in 534 | hang on 535 | ferret out 536 | wake up 537 | brick over 538 | burst out 539 | tack down 540 | spike out 541 | use up 542 | carry on 543 | bottle up 544 | tighten up 545 | start up 546 | carry off 547 | speak out 548 | set about 549 | tag along 550 | hook up 551 | oil up 552 | fend off 553 | start over 554 | sit up 555 | sign on 556 | take down 557 | study up 558 | while away 559 | fold up 560 | cheer on 561 | bust out 562 | rate entity 563 | play down 564 | book up 565 | bind up 566 | stay on 567 | come about 568 | put up 569 | dine out 570 | have frequency 571 | store up 572 | give up 573 | vote down 574 | bring up 575 | tape up 576 | leave behind 577 | turn on 578 | save up 579 | break out 580 | wash up 581 | fork over 582 | hollow out 583 | freshen up 584 | screw over 585 | dash off 586 | have part 587 | mess up 588 | buy into 589 | burn out 590 | cave in 591 | lead up 592 | clear up 593 | cry down 594 | stand out 595 | turn away 596 | drown out 597 | run in 598 | cover up 599 | spill over 600 | die out 601 | farm out 602 | hand over 603 | poke around 604 | ride out 605 | come across 606 | give away 607 | tack on 608 | bow out 609 | squeeze out 610 | write in 611 | show up 612 | come on 613 | fix up 614 | sew up 615 | fort up 616 | do away 617 | liven up 618 | scrunch up 619 | log on 620 | ham up 621 | look down 622 | firm up 623 | tally up 624 | tool up 625 | weigh in 626 | flare up 627 | strike down 628 | thin out 629 | blast away 630 | reel off 631 | feed up 632 | camp out 633 | well off 634 | crop up 635 | be like 636 | open up 637 | link up 638 | lick up 639 | look up 640 | statistical test 641 | charge off 642 | drop out 643 | keep up 644 | tick off 645 | tune in 646 | write down 647 | bat in 648 | stay over 649 | gas up 650 | pick up 651 | cook up 652 | boil down 653 | pull through 654 | call off 655 | pop off 656 | hand out 657 | push up 658 | fritter away 659 | trail off 660 | chop up 661 | rear end 662 | fuck around 663 | rattle on 664 | tire out 665 | street address 666 | keep on 667 | pack away 668 | keg stand 669 | close off 670 | lose out 671 | wring out 672 | make believe 673 | soak up 674 | tee off 675 | shake up 676 | scent out 677 | steer clear 678 | have instrument 679 | tear up 680 | feel up 681 | live down 682 | bowl over 683 | step in 684 | hobnob around 685 | bow down 686 | buzz off 687 | tangle up 688 | catch on 689 | price out 690 | snap up 691 | live out 692 | touch base 693 | be done 694 | have li 695 | vomit up 696 | clean out 697 | laid back 698 | buckle down 699 | slip in 700 | swear in 701 | stall off 702 | shoot down 703 | be from 704 | serve up 705 | join up 706 | back up 707 | well up 708 | pull up 709 | put down 710 | wash down 711 | dish out 712 | age out 713 | fight back 714 | bring down 715 | run up 716 | zip up 717 | switch over 718 | spend down 719 | call up 720 | be polite 721 | pop up 722 | fall apart 723 | net out 724 | jut out 725 | wind up 726 | rent out 727 | cross out 728 | rough up 729 | broke ass 730 | dredge up 731 | wait out 732 | shuffle off 733 | build up 734 | box in 735 | shake off 736 | cool off 737 | get on 738 | hit on 739 | straighten up 740 | start off 741 | belch out 742 | lie down 743 | play up 744 | give out 745 | haul in 746 | hard put 747 | make up 748 | snap off 749 | follow suit 750 | pass away 751 | smooth over 752 | hole up 753 | turn out 754 | clog up 755 | sober up 756 | smash up 757 | contract out 758 | go over 759 | dope up 760 | bed down 761 | sit out 762 | hype up 763 | drop in 764 | put off 765 | ward off 766 | get together 767 | turn down 768 | back off 769 | swoop up 770 | out trade 771 | size up 772 | pull off 773 | conjure up 774 | stock up 775 | sleep away 776 | monkey around 777 | break away 778 | pile up 779 | put in 780 | dream up 781 | wrap up 782 | gum up 783 | bound up 784 | tuck away 785 | board up 786 | have purpose 787 | stick out 788 | fall out 789 | take aback 790 | chart out 791 | latch on 792 | belt out 793 | wear on 794 | muck up 795 | step aside 796 | lead off 797 | point out 798 | line up 799 | check in 800 | start in 801 | bunch up 802 | watch over 803 | fill in 804 | work out 805 | joke around 806 | hum along 807 | lock down 808 | wear out 809 | rip out 810 | bleed out 811 | come along 812 | play off 813 | show off 814 | have extent 815 | concrete over 816 | narrow down 817 | jack up 818 | stare down 819 | pipe up 820 | loosen up 821 | wear down 822 | bear up 823 | cover over 824 | have polarity 825 | mic up 826 | make do 827 | close over 828 | deck out 829 | blow out 830 | play to 831 | hammer away 832 | ration out 833 | sell off 834 | have name 835 | strike out 836 | shuttle off 837 | call in 838 | shrug off 839 | chalk up 840 | perk up 841 | knock down 842 | follow up 843 | pass over 844 | brush off 845 | drink up 846 | fly out 847 | close in 848 | grow up 849 | eat away 850 | have condition 851 | snatch away 852 | pick off 853 | stress out 854 | take on 855 | muddle up 856 | tuck in 857 | live on 858 | skip off 859 | look forward 860 | stir up 861 | bail out 862 | stand down 863 | close up 864 | run over 865 | throw up 866 | fuck off 867 | swallow up 868 | spill out 869 | fall back 870 | fight off 871 | rig up 872 | sweat off 873 | hide out 874 | divvy up 875 | flash back 876 | end up 877 | make it 878 | toss in 879 | round out 880 | sniff out 881 | grind up 882 | chip in 883 | cough up 884 | phase in 885 | let up 886 | water down 887 | hold on 888 | level off 889 | have value 890 | fit in 891 | yammer on 892 | key in 893 | hold off 894 | silt up 895 | get by 896 | split up 897 | make out 898 | look after 899 | rubber stamp 900 | sketch out 901 | pull over 902 | spruce up 903 | glass over 904 | add up 905 | mist up 906 | brush up 907 | wind down 908 | clutch on 909 | knock back 910 | pare down 911 | rule out 912 | fall through 913 | hack away 914 | asphalt over 915 | clean up 916 | pound out 917 | die down 918 | carry out 919 | fall over 920 | blow up 921 | weasel out 922 | break even 923 | -------------------------------------------------------------------------------- /data/lemma_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/lemma_dict -------------------------------------------------------------------------------- /data/ner_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/ner_dict -------------------------------------------------------------------------------- /data/non_rule_set: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/non_rule_set -------------------------------------------------------------------------------- /data/pos_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/pos_dict -------------------------------------------------------------------------------- /data/rel_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/rel_dict -------------------------------------------------------------------------------- /data/rule_f_without_jamr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/rule_f_without_jamr -------------------------------------------------------------------------------- /data/sensed_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/sensed_dict -------------------------------------------------------------------------------- /data/word_dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/data/word_dict -------------------------------------------------------------------------------- /parser/DataIterator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Iterating over data set 6 | 7 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 8 | @since: 2018-05-30 9 | ''' 10 | from utility.constants import * 11 | from utility.data_helper import * 12 | import torch 13 | from torch.autograd import Variable 14 | import math 15 | from torch.nn.utils.rnn import PackedSequence 16 | from parser.modules.helper_module import MyPackedSequence 17 | from torch.nn.utils.rnn import pack_padded_sequence as pack 18 | import re 19 | end= re.compile(".txt\_[a-z]*") 20 | def rel_to_batch(rel_batch_p,rel_index_batch_p,data_iterator,dicts): 21 | lemma_dict,category_dict = dicts["lemma_dict"], dicts["category_dict"] 22 | 23 | data = [torch.LongTensor([[category_dict[uni.cat],lemma_dict[uni.le],0] for uni in uni_seq]) for uni_seq in rel_batch_p ] 24 | rel_index = [torch.LongTensor(index) for index in rel_index_batch_p] 25 | 26 | rel_batch,rel_index_batch,rel_lengths = data_iterator._batchify_rel_concept(data,rel_index) 27 | return MyPackedSequence(rel_batch,rel_lengths),rel_index_batch 28 | 29 | class DataIterator(object): 30 | 31 | def __init__(self, filePathes,opt,rel_dict,volatile = False ,all_data = None): 32 | self.cuda = opt.gpus[0] != -1 33 | self.volatile = volatile 34 | self.rel_dict = rel_dict 35 | self.all = [] 36 | self.opt = opt 37 | # break 38 | 39 | # self.all = sorted(self.all, key=lambda x: x[0]) 40 | self.src = [] 41 | self.tgt = [] 42 | self.align_index = [] 43 | self.rel_seq = [] 44 | self.rel_index = [] 45 | self.rel_mat = [] 46 | self.root = [] 47 | self.src_source = [] 48 | self.tgt_source = [] 49 | self.rel_tgt = [] 50 | if all_data: 51 | for data in all_data: 52 | self.read_sentence(data) 53 | self.batchSize = len(all_data) 54 | self.numBatches = 1 55 | else: 56 | 57 | for filepath in filePathes: 58 | n = self.readFile(filepath) 59 | self.batchSize = opt.batch_size 60 | self.numBatches = math.ceil(len(self.src)/self.batchSize) 61 | 62 | self.source_only = len(self.root) == 0 63 | 64 | def read_sentence(self,data): 65 | def role_mat_to_sparse(role_mat,rel_dict): 66 | index =[] 67 | value = [] 68 | for i,role_list in enumerate(role_mat): 69 | for role_index in role_list: 70 | if role_index[0] in rel_dict: 71 | index.append([i,role_index[1]]) 72 | value.append(rel_dict[role_index[0]]) 73 | size = torch.Size([len(role_mat),len(role_mat)]) 74 | v = torch.LongTensor(value) 75 | if len(v) == 0: 76 | i = torch.LongTensor([[0,0]]).t() 77 | v = torch.LongTensor([0]) 78 | return torch.sparse.LongTensor(i,v,size) 79 | 80 | i = torch.LongTensor(index).t() 81 | return torch.sparse.LongTensor(i,v,size) 82 | 83 | #src: length x n_feature 84 | 85 | self.src.append(torch.LongTensor([data["snt_id"],data["lemma_id"],data["pos_id"],data["ner_id"]]).t().contiguous()) 86 | 87 | #source 88 | 89 | self.src_source.append([data["tok"],data["lem"],data["pos"],data["ner"]]) 90 | 91 | #tgt: length x n_feature 92 | # print (data["amr_id"]) 93 | if "amr_id" in data: 94 | self.tgt.append(torch.LongTensor(data["amr_id"])) # lemma,cat, lemma_sense,ner,is_high 95 | self.align_index.append(data["index"]) 96 | 97 | amrl = len(data["amr_id"]) 98 | for i in data["amr_rel_index"]: 99 | assert i 0, (data,rel_index) 151 | second = max([x.size(1) for x in data]) 152 | total = sum(lengths) 153 | out = data[0].new(total, second) 154 | out_index = [] 155 | current = 0 156 | for i in range(len(data)): 157 | data_t = data[i].clone() 158 | out.narrow(0, current, lengths[i]).copy_(data_t) 159 | index_t = rel_index[i].clone() 160 | if self.cuda: 161 | index_t = index_t.cuda() 162 | out_index.append(Variable(index_t,volatile=self.volatile,requires_grad = False)) 163 | # out_index.append(index_t) 164 | current += lengths[i] 165 | out = Variable(out,volatile=self.volatile,requires_grad = False) 166 | 167 | if self.cuda: 168 | out = out.cuda() 169 | return out,out_index,lengths 170 | 171 | 172 | #rel_mat: batch_size x var(len) x var(len) 173 | #rel_index: batch_size x var(len) 174 | 175 | #out : (batch_size x var(len) x var(len)) 176 | def _batchify_rel_roles(self, all_data ): 177 | length_squares = [x.size(0)**2 for x in all_data] 178 | total = sum(length_squares) 179 | out = torch.LongTensor(total) 180 | current = 0 181 | for i in range(len(all_data)): 182 | data_t = all_data[i].to_dense().clone().view(-1) 183 | out.narrow(0, current, length_squares[i]).copy_(data_t) 184 | current += length_squares[i] 185 | 186 | out = Variable(out,volatile=self.volatile,requires_grad = False) 187 | if self.cuda: 188 | out = out.cuda() 189 | 190 | return out,length_squares 191 | 192 | 193 | #data: batch_size x var(len) x n_feature 194 | #out : batch_size x tgt_len x n_feature 195 | def _batchify_tgt(self, data,max_src ): 196 | lengths = [x.size(0) for x in data] 197 | max_length = max(max(x.size(0) for x in data),max_src) #if y, we need max_x 198 | out = data[0].new(len(data), max_length,data[0].size(1)).fill_(PAD) 199 | for i in range(len(data)): 200 | data_t = data[i].clone() 201 | data_length = data[i].size(0) 202 | out[i].narrow(0, 0, data_length).copy_(data_t) 203 | return out 204 | 205 | #data: batch_size x var(len) x n_feature 206 | #out : batch_size x src_len x n_feature 207 | def _batchify_src(self, data,max_length ): 208 | out = data[0].new(len(data), max_length,data[0].size(1)).fill_(PAD) 209 | 210 | for i in range(len(data)): 211 | data_t = data[i].clone() 212 | data_length = data[i].size(0) 213 | out[i].narrow(0, 0, data_length).copy_(data_t) 214 | 215 | return out 216 | 217 | def getLengths(self,index): 218 | src_data = self.src[index*self.batchSize:(index+1)*self.batchSize] 219 | src_lengths = [x.size(0) for x in src_data] 220 | if self.source_only: 221 | return src_lengths,max(src_lengths) 222 | 223 | tgt_data = self.tgt[index*self.batchSize:(index+1)*self.batchSize] 224 | tgt_lengths = [x.size(0) for x in tgt_data] 225 | lengths = [] 226 | for i,l in enumerate(src_lengths): 227 | lengths.append(max(l,tgt_lengths[i])) 228 | return lengths,max(lengths) 229 | 230 | def __getitem__(self, index): 231 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches) 232 | lengths,max_len = self.getLengths(index ) 233 | def wrap(b,l ): 234 | #batch, len, feature 235 | if b is None: 236 | return b 237 | b = torch.stack(b, 0).transpose(0,1).contiguous() 238 | if self.cuda: 239 | b = b.cuda() 240 | packed = pack(b,list(l)) 241 | return PackedSequence(Variable(packed[0], volatile=self.volatile,requires_grad = False),packed[1]) 242 | 243 | def wrap_align(b,l ): 244 | #batch, len_tgt, len_src 245 | if b is None: 246 | return b 247 | b = torch.stack(b, 0).transpose(0,1).contiguous().float() 248 | if self.cuda: 249 | b = b.cuda() 250 | packed = pack(b,list(l)) 251 | return PackedSequence(Variable(packed[0], volatile=self.volatile,requires_grad = False),packed[1]) 252 | 253 | srcBatch = self._batchify_src( 254 | self.src[index*self.batchSize:(index+1)*self.batchSize],max_len) 255 | 256 | if self.source_only: 257 | src_sourceBatch = self.src_source[index*self.batchSize:(index+1)*self.batchSize] 258 | 259 | batch = zip( srcBatch,src_sourceBatch) 260 | lengths,max_len = self.getLengths(index ) 261 | order_data = sorted(list(enumerate(list(zip(batch, lengths)))),key = lambda x:-x[1][1]) 262 | order,data = zip(*order_data) 263 | batch, lengths = zip(*data) 264 | srcBatch,src_sourceBatch = zip(*batch) 265 | return order,wrap(srcBatch,lengths),src_sourceBatch 266 | 267 | else: 268 | tgtBatch = self._batchify_tgt( 269 | self.tgt[index*self.batchSize:(index+1)*self.batchSize],max_len) 270 | alignBatch = self._batchify_align( 271 | self.align_index[index*self.batchSize:(index+1)*self.batchSize],max_len) 272 | 273 | rel_seq_pre = self.rel_seq[index*self.batchSize:(index+1)*self.batchSize] 274 | rel_index_pre = self.rel_index[index*self.batchSize:(index+1)*self.batchSize] 275 | rel_role_pre = self.rel_mat[index*self.batchSize:(index+1)*self.batchSize] 276 | 277 | # roots = Variable(torch.IntTensor(self.root[index*self.batchSize:(index+1)*self.batchSize]),volatile = True) 278 | roots =self.root[index*self.batchSize:(index+1)*self.batchSize] 279 | 280 | src_sourceBatch = self.src_source[index*self.batchSize:(index+1)*self.batchSize] 281 | tgt_sourceBatch = self.tgt_source[index*self.batchSize:(index+1)*self.batchSize] 282 | sourceBatch = [ src_s +tgt_s for src_s,tgt_s in zip(src_sourceBatch,tgt_sourceBatch)] 283 | # within batch sorting by decreasing length for variable length rnns 284 | indices = range(len(srcBatch)) 285 | 286 | batch = zip(indices, srcBatch ,tgtBatch,alignBatch,rel_seq_pre,rel_index_pre,rel_role_pre,sourceBatch,roots) 287 | order_data = sorted(list(enumerate(list(zip(batch, lengths)))),key = lambda x:-x[1][1]) 288 | order,data = zip(*order_data) 289 | batch, lengths = zip(*data) 290 | indices, srcBatch,tgtBatch,alignBatch ,rel_seq_pre,rel_index_pre,rel_role_pre,sourceBatch,roots= zip(*batch) 291 | 292 | rel_batch,rel_index_batch,rel_lengths = self._batchify_rel_concept(rel_seq_pre,rel_index_pre) 293 | rel_roles,length_squares = self._batchify_rel_roles(rel_role_pre) 294 | 295 | 296 | #,wrap(charBatch)) 297 | return order,wrap(srcBatch,lengths), wrap(tgtBatch,lengths), wrap_align(alignBatch,lengths),\ 298 | MyPackedSequence(rel_batch,rel_lengths),rel_index_batch,MyPackedSequence(rel_roles,length_squares),roots,sourceBatch 299 | 300 | def __len__(self): 301 | return self.numBatches 302 | 303 | 304 | def shuffle(self): 305 | # if True: return 306 | if self.source_only: #if data set if for testing 307 | data = list(zip(self.src,self.src_source)) 308 | self.src,self.src_source = zip(*[data[i] for i in torch.randperm(len(data))]) 309 | else: 310 | data = list(zip(self.src, self.tgt,self.align_index,self.rel_seq,self.rel_index,self.rel_mat,self.root,self.src_source,self.tgt_source)) 311 | self.src, self.tgt,self.align_index,self.rel_seq,self.rel_index,self.rel_mat,self.root,self.src_source,self.tgt_source = zip(*[data[i] for i in torch.randperm(len(data))]) 312 | 313 | -------------------------------------------------------------------------------- /parser/Dict.py: -------------------------------------------------------------------------------- 1 | from utility.amr import * 2 | from utility.data_helper import * 3 | import torch 4 | 5 | def seq_to_id(dictionary,seq): 6 | id_seq = [] 7 | freq_seq = [] 8 | for i in seq: 9 | id_seq.append(dictionary[i]) 10 | freq_seq.append(dictionary.frequencies[dictionary[i]]) 11 | return id_seq,freq_seq 12 | 13 | 14 | 15 | def read_dicts(): 16 | 17 | word_dict = Dict("data/word_dict") 18 | lemma_dict = Dict("data/lemma_dict") 19 | aux_dict = Dict("data/aux_dict") 20 | high_dict = Dict("data/high_dict") 21 | pos_dict = Dict("data/pos_dict") 22 | ner_dict = Dict("data/ner_dict") 23 | rel_dict = Dict("data/rel_dict") 24 | category_dict = Dict("data/category_dict") 25 | 26 | word_dict.load() 27 | lemma_dict.load() 28 | pos_dict.load() 29 | ner_dict.load() 30 | rel_dict.load() 31 | category_dict.load() 32 | high_dict.load() 33 | aux_dict.load() 34 | dicts = dict() 35 | 36 | dicts["rel_dict"] = rel_dict 37 | dicts["word_dict"] = word_dict 38 | dicts["pos_dict"] = pos_dict 39 | dicts["ner_dict"] = ner_dict 40 | dicts["lemma_dict"] = lemma_dict 41 | dicts["category_dict"] = category_dict 42 | dicts["aux_dict"] = aux_dict 43 | dicts["high_dict"] = high_dict 44 | return dicts 45 | 46 | class Dict(object): 47 | def __init__(self, fileName,dictionary=None): 48 | self.idxToLabel = {} 49 | self.labelToIdx = {} 50 | self.frequencies = {} 51 | 52 | # Special entries will not be pruned. 53 | self.special = [] 54 | 55 | if dictionary : 56 | for label in dictionary: 57 | self.labelToIdx[label] = dictionary[label][0] 58 | self.idxToLabel[dictionary[label][0]] = label 59 | self.frequencies[dictionary[label][0]] = dictionary[label][1] 60 | self.fileName = fileName 61 | 62 | 63 | 64 | def size(self): 65 | return len(self.idxToLabel) 66 | 67 | def __len__(self): 68 | return len(self.idxToLabel) 69 | 70 | # Load entries from a file. 71 | def load(self, filename =None): 72 | if filename: 73 | self.fileName = filename 74 | else: 75 | filename = self.fileName 76 | f = Pickle_Helper(filename) 77 | data = f.load() 78 | self.idxToLabel=data["idxToLabel"] 79 | self.labelToIdx=data["labelToIdx"] 80 | self.frequencies=data["frequencies"] 81 | 82 | # Write entries to a file. 83 | def save(self, filename =None): 84 | if filename: 85 | self.fileName = filename 86 | else: 87 | filename = self.fileName 88 | f = Pickle_Helper(filename) 89 | f.dump( self.idxToLabel,"idxToLabel") 90 | f.dump( self.labelToIdx,"labelToIdx") 91 | f.dump( self.frequencies,"frequencies") 92 | f.save() 93 | 94 | def lookup(self, key, default=None): 95 | try: 96 | return self.labelToIdx[key] 97 | except KeyError: 98 | if default: return default 99 | 100 | return self.labelToIdx[UNK_WORD] 101 | def __str__(self): 102 | out_str = [] 103 | for k in self.frequencies: 104 | if k not in self.special: 105 | out_str.append(self.idxToLabel[k]+": "+str(self.frequencies[k])) 106 | return " \n".join(out_str) 107 | def __getitem__(self, label,default=None): 108 | try: 109 | return self.labelToIdx[label] 110 | except KeyError: 111 | if default: return default 112 | 113 | return self.labelToIdx[UNK_WORD] 114 | 115 | def getLabel(self, idx, default=UNK_WORD): 116 | try: 117 | return self.idxToLabel[idx] 118 | except KeyError: 119 | return default 120 | 121 | def __iter__(self): return self.labelToIdx.__iter__() 122 | def __next__(self): return self.labelToIdx.__next__() 123 | # Mark this `label` and `idx` as special (i.e. will not be pruned). 124 | def addSpecial(self, label, idx=None): 125 | idx = self.add(label, idx) 126 | self.special += [idx] 127 | 128 | # Mark all labels in `labels` as specials (i.e. will not be pruned). 129 | def addSpecials(self, labels): 130 | for label in labels: 131 | self.addSpecial(label) 132 | 133 | # Add `label` in the dictionary. Use `idx` as its index if given. 134 | def add(self, label, idx=None): 135 | if idx is not None: 136 | self.idxToLabel[idx] = label 137 | self.labelToIdx[label] = idx 138 | else: 139 | if label in self.labelToIdx: 140 | idx = self.labelToIdx[label] 141 | else: 142 | idx = len(self.idxToLabel) 143 | self.idxToLabel[idx] = label 144 | self.labelToIdx[label] = idx 145 | 146 | if idx not in self.frequencies: 147 | self.frequencies[idx] = 1 148 | else: 149 | self.frequencies[idx] += 1 150 | 151 | return idx 152 | 153 | def __setitem__(self, label, idx): 154 | self.add(label,idx) 155 | 156 | 157 | # Return a new dictionary with the `size` most frequent entries. 158 | def prune(self, size): 159 | if size >= self.size(): 160 | return self 161 | 162 | # Only keep the `size` most frequent entries. 163 | freq = torch.Tensor( 164 | [self.frequencies[i] for i in range(len(self.frequencies))]) 165 | _, idx = torch.sort(freq, 0, False) 166 | 167 | newDict = Dict(self.fileName) 168 | 169 | # Add special entries in all cases. 170 | for i in self.special: 171 | newDict.addSpecial(self.idxToLabel[i]) 172 | 173 | for i in idx[:size]: 174 | newDict.add(self.idxToLabel[i]) 175 | 176 | return newDict 177 | # Return a new dictionary with the `size` most frequent entries. 178 | def pruneByThreshold(self, threshold): 179 | # Only keep the `size` most frequent entries. 180 | high_freq = [ (self.frequencies[i],i) for i in range(len(self.frequencies)) if self.frequencies[i]>threshold] 181 | 182 | newDict = Dict(self.fileName) 183 | 184 | # Add special entries in all cases. 185 | for i in self.special: 186 | newDict.addSpecial(self.idxToLabel[i]) 187 | 188 | for freq,i in high_freq: 189 | newDict.add(self.idxToLabel[i]) 190 | newDict.frequencies[newDict.labelToIdx[self.idxToLabel[i]]] = freq 191 | 192 | return newDict 193 | # Convert `labels` to indices. Use `unkWord` if not found. 194 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 195 | def convertToIdx(self, labels, unkWord = UNK_WORD, bosWord=BOS_WORD, eosWord=EOS_WORD): 196 | vec = [] 197 | 198 | if bosWord is not None: 199 | vec += [self.lookup(bosWord)] 200 | 201 | unk = self.lookup(unkWord) 202 | vec += [self.lookup(label, default=unk) for label in labels] 203 | 204 | if eosWord is not None: 205 | vec += [self.lookup(eosWord)] 206 | 207 | return torch.LongTensor(vec) 208 | 209 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 210 | def convertToLabels(self, idx, stop=[]): 211 | labels = [] 212 | 213 | for i in idx: 214 | if i in stop: 215 | break 216 | labels += [self.getLabel(i)] 217 | 218 | return labels -------------------------------------------------------------------------------- /parser/Optim.py: -------------------------------------------------------------------------------- 1 | import math,torch 2 | import torch.optim as optim 3 | import numpy as np 4 | class Optim(object): 5 | 6 | def _makeOptimizer(self): 7 | if self.method == 'sgd': 8 | self.optimizer = optim.SGD(self.params, lr=self.lr,weight_decay = 0) 9 | elif self.method == 'adagrad': 10 | self.optimizer = optim.Adagrad(self.params, lr=self.lr,weight_decay = 0) 11 | elif self.method == 'adadelta': 12 | self.optimizer = optim.Adadelta(self.params, lr=self.lr,weight_decay = 0) 13 | elif self.method == 'adam': 14 | self.optimizer = optim.Adam(self.params, betas=[0.9,0.9],lr=self.lr,weight_decay = 0) 15 | elif self.method == "RMSprop": 16 | self.optimizer = optim.RMSprop(self.params, lr=self.lr, weight_decay=0) 17 | 18 | else: 19 | raise RuntimeError("Invalid optim method: " + self.method) 20 | 21 | def __init__(self, params, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None, weight_decay=0,perturb = 0): 22 | self.params = list(params) # careful: params may be a generator 23 | self.last_ppl = None 24 | self.lr = lr 25 | self.max_grad_norm = max_grad_norm 26 | self.method = method 27 | self.lr_decay = lr_decay 28 | self.start_decay_at = start_decay_at 29 | self.start_decay = False 30 | self.weight_decay = weight_decay 31 | self.weight_shirnk = 1.0 -weight_decay 32 | self._makeOptimizer() 33 | 34 | def step(self): 35 | # Compute gradients norm. 36 | grad_norm = 0 37 | for param in self.params: 38 | grad_norm += math.pow(param.grad.data.norm(), 2) 39 | 40 | grad_norm = math.sqrt(grad_norm) 41 | shrinkage = self.max_grad_norm / grad_norm 42 | nan_size = [] 43 | fine = [] 44 | for param in self.params: 45 | if shrinkage < 1: 46 | param.grad.data.mul_(shrinkage) 47 | # assert not np.isnan(np.sum(param.data.cpu().numpy())),("befotr optim\n",param) 48 | # if np.isnan(np.sum(param.grad.data.cpu().numpy())): 49 | # nan_size.append(param.grad.size()) 50 | # else: fine.append(param.grad.size()) 51 | if len(nan_size) > 0: 52 | print ("befotr optim grad explodes, abandon update, still weight_decay\n",fine) 53 | self.optimizer.step() 54 | for param in self.params: 55 | assert not np.isnan(np.sum(param.data.cpu().numpy())),("befotr shrink\n",param) 56 | param.data.mul_(self.weight_shirnk) #+ torch.normal(0,1e-3*torch.ones(param.data.size()).cuda()) 57 | assert not np.isnan(np.sum(param.data.cpu().numpy())),("after shrink\n",param) 58 | return grad_norm 59 | 60 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit 61 | def updateLearningRate(self, ppl, epoch): 62 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 63 | self.start_decay = True 64 | if self.last_ppl is not None and ppl > self.last_ppl: 65 | self.start_decay = True 66 | 67 | if self.start_decay: 68 | self.lr = self.lr * self.lr_decay 69 | print("Decaying learning rate to %g" % self.lr) 70 | 71 | self.last_ppl = ppl 72 | 73 | self._makeOptimizer() 74 | 75 | -------------------------------------------------------------------------------- /parser/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import parser.models 3 | import parser.Optim 4 | import parser.AMRProcessors -------------------------------------------------------------------------------- /parser/__pycache__/AMRProcessors.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/__pycache__/AMRProcessors.cpython-36.pyc -------------------------------------------------------------------------------- /parser/__pycache__/DataIterator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/__pycache__/DataIterator.cpython-36.pyc -------------------------------------------------------------------------------- /parser/__pycache__/Dict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/__pycache__/Dict.cpython-36.pyc -------------------------------------------------------------------------------- /parser/__pycache__/Optim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/__pycache__/Optim.cpython-36.pyc -------------------------------------------------------------------------------- /parser/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /parser/models/ConceptModel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Deep Learning Models for concept identification 6 | 7 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 8 | @since: 2018-05-30 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | from parser.modules.helper_module import data_dropout 14 | from torch.nn.utils.rnn import PackedSequence 15 | from utility.constants import * 16 | 17 | 18 | class SentenceEncoder(nn.Module): 19 | def __init__(self, opt, embs): 20 | self.layers = opt.txt_enlayers 21 | self.num_directions = 2 if opt.brnn else 1 22 | assert opt.txt_rnn_size % self.num_directions == 0 23 | self.hidden_size = opt.txt_rnn_size // self.num_directions 24 | # inputSize = opt.word_dim*2 + opt.lemma_dim + opt.pos_dim +opt.ner_dim 25 | inputSize = embs["word_fix_lut"].embedding_dim + embs["lemma_lut"].embedding_dim\ 26 | +embs["pos_lut"].embedding_dim + embs["ner_lut"].embedding_dim 27 | 28 | super(SentenceEncoder, self).__init__() 29 | self.rnn = nn.LSTM(inputSize, self.hidden_size, 30 | num_layers=self.layers, 31 | dropout=opt.dropout, 32 | bidirectional=opt.brnn) 33 | 34 | 35 | self.lemma_lut = embs["lemma_lut"] 36 | 37 | self.word_fix_lut = embs["word_fix_lut"] 38 | 39 | 40 | self.pos_lut = embs["pos_lut"] 41 | 42 | self.ner_lut = embs["ner_lut"] 43 | 44 | self.drop_emb = nn.Dropout(opt.dropout) 45 | self.alpha = opt.alpha 46 | 47 | if opt.cuda: 48 | self.rnn.cuda() 49 | 50 | def forward(self, packed_input: PackedSequence,hidden=None): 51 | #input: pack(data x n_feature ,batch_size) 52 | input = packed_input.data 53 | if self.alpha and self.training: 54 | input = data_dropout(input,self.alpha) 55 | 56 | word_fix_embed = self.word_fix_lut(input[:,TXT_WORD]) 57 | lemma_emb = self.lemma_lut(input[:,TXT_LEMMA]) 58 | pos_emb = self.pos_lut(input[:,TXT_POS]) 59 | ner_emb = self.ner_lut(input[:,TXT_NER]) 60 | 61 | 62 | emb = self.drop_emb(torch.cat([lemma_emb,pos_emb,ner_emb],1))# data,embed 63 | emb = torch.cat([word_fix_embed,emb],1)# data,embed 64 | emb = PackedSequence(emb, packed_input.batch_sizes) 65 | outputs, hidden_t = self.rnn(emb, hidden) 66 | return outputs 67 | 68 | class Concept_Classifier(nn.Module): 69 | 70 | def __init__(self, opt, embs): 71 | super(Concept_Classifier, self).__init__() 72 | self.txt_rnn_size = opt.txt_rnn_size 73 | 74 | self.n_cat = embs["cat_lut"].num_embeddings 75 | self.n_high = embs["high_lut"].num_embeddings 76 | self.n_aux = embs["aux_lut"].num_embeddings 77 | 78 | self.cat_score =nn.Sequential( 79 | nn.Dropout(opt.dropout), 80 | nn.Linear(self.txt_rnn_size,self.n_cat,bias = opt.cat_bias)) 81 | 82 | self.le_score =nn.Sequential( 83 | nn.Dropout(opt.dropout), 84 | nn.Linear(self.txt_rnn_size,self.n_high+1,bias = opt.lemma_bias)) 85 | 86 | self.ner_score =nn.Sequential( 87 | nn.Dropout(opt.dropout), 88 | nn.Linear(self.txt_rnn_size,self.n_aux,bias = opt.cat_bias)) 89 | 90 | self.t = 1 91 | self.sm = nn.Softmax() 92 | if opt.cuda: 93 | self.cuda() 94 | 95 | 96 | 97 | def forward(self, src_enc ): 98 | ''' 99 | src_enc: pack(data x txt_rnn_size ,batch_size) 100 | src_le: pack(data x 1 ,batch_size) 101 | 102 | out: (datax n_cat, batch_size), (data x n_high+1,batch_size) 103 | ''' 104 | 105 | assert isinstance(src_enc,PackedSequence) 106 | 107 | 108 | # high_embs = self.high_lut.weight.expand(le_score.size(0),self.n_high,self.dim) 109 | # le_self_embs = self.lemma_lut(src_le.data).unsqueeze(1) 110 | # le_emb = torch.cat([high_embs,le_self_embs],dim=1) #data x high+1 x dim 111 | 112 | pre_enc =src_enc.data 113 | 114 | cat_score = self.cat_score(pre_enc) # n_data x n_cat 115 | ner_score = self.ner_score(pre_enc)# n_data x n_cat 116 | le_score = self.le_score (src_enc.data) 117 | le_prob = self.sm(le_score) 118 | cat_prob = self.sm(cat_score) 119 | ner_prob = self.sm(ner_score) 120 | batch_sizes = src_enc.batch_sizes 121 | return PackedSequence(cat_prob,batch_sizes),PackedSequence(le_prob,batch_sizes),PackedSequence(ner_prob,batch_sizes) 122 | 123 | class ConceptIdentifier(nn.Module): 124 | #could share encoder with other model 125 | def __init__(self, opt,embs,encoder = None): 126 | super(ConceptIdentifier, self).__init__() 127 | if encoder: 128 | self.encoder = encoder 129 | else: 130 | self.encoder = SentenceEncoder( opt, embs) 131 | self.generator = Concept_Classifier( opt, embs) 132 | 133 | 134 | def forward(self, srcBatch): 135 | src_enc = self.encoder(srcBatch) 136 | probBatch = self.generator(src_enc) 137 | return probBatch,src_enc 138 | -------------------------------------------------------------------------------- /parser/models/MultiPassRelModel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Deep Learning Models for relation identification 6 | 7 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 8 | @since: 2018-05-30 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | from parser.modules.helper_module import mypack ,myunpack,MyPackedSequence,MyDoublePackedSequence,mydoubleunpack,mydoublepack,DoublePackedSequence,doubleunpack,data_dropout 15 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 16 | from torch.nn.utils.rnn import pack_padded_sequence as pack 17 | from torch.nn.utils.rnn import PackedSequence 18 | import torch.nn.functional as F 19 | from utility.constants import * 20 | 21 | 22 | 23 | #sentence encoder for root identification 24 | class RootSentenceEncoder(nn.Module): 25 | 26 | def __init__(self, opt, embs): 27 | self.layers = opt.root_enlayers 28 | self.num_directions = 2 if opt.brnn else 1 29 | assert opt.txt_rnn_size % self.num_directions == 0 30 | self.hidden_size = opt.rel_rnn_size // self.num_directions 31 | inputSize = embs["word_fix_lut"].embedding_dim + embs["lemma_lut"].embedding_dim\ 32 | +embs["pos_lut"].embedding_dim+embs["ner_lut"].embedding_dim 33 | 34 | super(RootSentenceEncoder, self).__init__() 35 | 36 | 37 | self.rnn =nn.LSTM(inputSize, self.hidden_size, 38 | num_layers=self.layers, 39 | dropout=opt.dropout, 40 | bidirectional=opt.brnn, 41 | batch_first=True) 42 | 43 | self.lemma_lut = embs["lemma_lut"] 44 | 45 | self.word_fix_lut = embs["word_fix_lut"] 46 | 47 | self.pos_lut = embs["pos_lut"] 48 | 49 | 50 | self.ner_lut = embs["ner_lut"] 51 | 52 | self.alpha = opt.alpha 53 | if opt.cuda: 54 | self.rnn.cuda() 55 | 56 | 57 | 58 | def forward(self, packed_input,hidden=None): 59 | #input: pack(data x n_feature ,batch_size) 60 | #posterior: pack(data x src_len ,batch_size) 61 | assert isinstance(packed_input,PackedSequence) 62 | input = packed_input.data 63 | 64 | if self.alpha and self.training: 65 | input = data_dropout(input,self.alpha) 66 | 67 | word_fix_embed = self.word_fix_lut(input[:,TXT_WORD]) 68 | lemma_emb = self.lemma_lut(input[:,TXT_LEMMA]) 69 | pos_emb = self.pos_lut(input[:,TXT_POS]) 70 | ner_emb = self.ner_lut(input[:,TXT_NER]) 71 | 72 | emb = torch.cat([word_fix_embed,lemma_emb,pos_emb,ner_emb],1)# data,embed 73 | 74 | emb = PackedSequence(emb, packed_input.batch_sizes) 75 | 76 | outputs = self.rnn(emb, hidden)[0] 77 | 78 | return outputs 79 | 80 | #combine amr node embedding and aligned sentence token embedding 81 | class RootEncoder(nn.Module): 82 | 83 | def __init__(self, opt, embs): 84 | self.layers = opt.amr_enlayers 85 | #share hyper parameter with relation model 86 | self.size = opt.rel_dim 87 | inputSize = embs["cat_lut"].embedding_dim + embs["lemma_lut"].embedding_dim+opt.rel_rnn_size 88 | super(RootEncoder, self).__init__() 89 | 90 | self.cat_lut = embs["cat_lut"] 91 | 92 | self.lemma_lut = embs["lemma_lut"] 93 | 94 | self.root = nn.Sequential( 95 | nn.Dropout(opt.dropout), 96 | nn.Linear(inputSize,self.size ), 97 | nn.ReLU() 98 | ) 99 | 100 | 101 | self.alpha = opt.alpha 102 | if opt.cuda: 103 | self.cuda() 104 | 105 | def getEmb(self,indexes,src_enc): 106 | head_emb,lengths = [],[] 107 | src_enc = myunpack(*src_enc) # pre_amr_l/src_l x batch x dim 108 | for i, index in enumerate(indexes): 109 | enc = src_enc[i] #src_l x dim 110 | head_emb.append(enc[index]) #var(amr_l x dim) 111 | lengths.append(len(index)) 112 | return mypack(head_emb,lengths) 113 | 114 | #input: all_data x n_feature, lengths 115 | #index: batch_size x var(amr_len) 116 | #src_enc (batch x amr_len) x src_len x txt_rnn_size 117 | 118 | #head: batch x var( amr_len x txt_rnn_size ) 119 | 120 | #dep : batch x var( amr_len x amr_len x txt_rnn_size ) 121 | 122 | #heads: [var(len),rel_dim] 123 | #deps: [var(len)**2,rel_dim] 124 | def forward(self, input, index,src_enc): 125 | assert isinstance(input, MyPackedSequence),input 126 | input,lengths = input 127 | if self.alpha and self.training: 128 | input = data_dropout(input,self.alpha) 129 | cat_embed = self.cat_lut(input[:,AMR_CAT]) 130 | lemma_embed = self.lemma_lut(input[:,AMR_LE]) 131 | 132 | amr_emb = torch.cat([cat_embed,lemma_embed],1) 133 | # print (input,lengths) 134 | 135 | head_emb = self.getEmb(index,src_enc) #packed, mydoublepacked 136 | 137 | 138 | root_emb = torch.cat([amr_emb,head_emb.data],1) 139 | root_emb = self.root(root_emb) 140 | 141 | return MyPackedSequence(root_emb,lengths) 142 | 143 | #multi pass sentence encoder for relation identification 144 | class RelSentenceEncoder(nn.Module): 145 | 146 | def __init__(self, opt, embs): 147 | self.layers = opt.rel_enlayers 148 | self.num_directions = 2 if opt.brnn else 1 149 | assert opt.txt_rnn_size % self.num_directions == 0 150 | self.hidden_size = opt.rel_rnn_size // self.num_directions 151 | inputSize = embs["word_fix_lut"].embedding_dim + embs["lemma_lut"].embedding_dim\ 152 | +embs["pos_lut"].embedding_dim+embs["ner_lut"].embedding_dim+1 153 | super(RelSentenceEncoder, self).__init__() 154 | 155 | 156 | self.rnn =nn.LSTM(inputSize, self.hidden_size, 157 | num_layers=self.layers, 158 | dropout=opt.dropout, 159 | bidirectional=opt.brnn, 160 | batch_first=True) #first is for root 161 | 162 | self.lemma_lut = embs["lemma_lut"] 163 | 164 | self.word_fix_lut = embs["word_fix_lut"] 165 | 166 | self.pos_lut = embs["pos_lut"] 167 | 168 | 169 | self.ner_lut = embs["ner_lut"] 170 | 171 | self.alpha = opt.alpha 172 | if opt.cuda: 173 | self.rnn.cuda() 174 | 175 | def posteriorIndictedEmb(self,embs,posterior): 176 | #real alignment is sent in as list of index 177 | #variational relaxed posterior is sent in as MyPackedSequence 178 | 179 | #out (batch x amr_len) x src_len x (dim+1) 180 | embs,src_len = unpack(embs) 181 | 182 | if isinstance(posterior,MyPackedSequence): 183 | # print ("posterior is packed") 184 | posterior = myunpack(*posterior) 185 | embs = embs.transpose(0,1) 186 | out = [] 187 | lengths = [] 188 | amr_len = [len(p) for p in posterior] 189 | for i,emb in enumerate(embs): 190 | expanded_emb = emb.unsqueeze(0).expand([amr_len[i]]+[i for i in emb.size()]) # amr_len x src_len x dim 191 | indicator = posterior[i].unsqueeze(2) # amr_len x src_len x 1 192 | out.append(torch.cat([expanded_emb,indicator],2)) # amr_len x src_len x (dim+1) 193 | lengths = lengths + [src_len[i]]*amr_len[i] 194 | data = torch.cat(out,dim=0) 195 | 196 | return pack(data,lengths,batch_first=True),amr_len 197 | elif isinstance(posterior,list): 198 | embs = embs.transpose(0,1) 199 | src_l = embs.size(1) 200 | amr_len = [len(i) for i in posterior] 201 | out = [] 202 | lengths = [] 203 | for i,emb in enumerate(embs): 204 | amr_l = len(posterior[i]) 205 | expanded_emb = emb.unsqueeze(0).expand([amr_l]+[i for i in emb.size()]) # amr_len x src_len x dim 206 | indicator = emb.data.new(amr_l,src_l).zero_() 207 | indicator.scatter_(1, posterior[i].data.unsqueeze(1), 1.0) # amr_len x src_len x 1 208 | indicator = Variable(indicator.unsqueeze(2)) 209 | out.append(torch.cat([expanded_emb,indicator],2)) # amr_len x src_len x (dim+1) 210 | lengths = lengths + [src_len[i]]*amr_l 211 | data = torch.cat(out,dim=0) 212 | 213 | return pack(data,lengths,batch_first=True),amr_len 214 | 215 | 216 | def forward(self, packed_input, packed_posterior,hidden=None): 217 | #input: pack(data x n_feature ,batch_size) 218 | #posterior: pack(data x src_len ,batch_size) 219 | assert isinstance(packed_input,PackedSequence) 220 | input = packed_input.data 221 | 222 | if self.alpha and self.training: 223 | input = data_dropout(input,self.alpha) 224 | word_fix_embed = self.word_fix_lut(input[:,TXT_WORD]) 225 | lemma_emb = self.lemma_lut(input[:,TXT_LEMMA]) 226 | pos_emb = self.pos_lut(input[:,TXT_POS]) 227 | ner_emb = self.ner_lut(input[:,TXT_NER]) 228 | 229 | emb = torch.cat([word_fix_embed,lemma_emb,pos_emb,ner_emb],1)# data,embed 230 | 231 | emb = PackedSequence(emb, packed_input.batch_sizes) 232 | poster_emb,amr_len = self.posteriorIndictedEmb(emb,packed_posterior) 233 | 234 | Outputs = self.rnn(poster_emb, hidden)[0] 235 | 236 | return DoublePackedSequence(Outputs,amr_len,Outputs.data) 237 | 238 | 239 | #combine amr node embedding and aligned sentence token embedding 240 | class RelEncoder(nn.Module): 241 | 242 | def __init__(self, opt, embs): 243 | super(RelEncoder, self).__init__() 244 | 245 | self.layers = opt.amr_enlayers 246 | 247 | self.size = opt.rel_dim 248 | inputSize = embs["cat_lut"].embedding_dim + embs["lemma_lut"].embedding_dim+opt.rel_rnn_size 249 | 250 | self.head = nn.Sequential( 251 | nn.Dropout(opt.dropout), 252 | nn.Linear(inputSize,self.size ) 253 | ) 254 | 255 | self.dep = nn.Sequential( 256 | nn.Dropout(opt.dropout), 257 | nn.Linear(inputSize,self.size ) 258 | ) 259 | 260 | self.cat_lut = embs["cat_lut"] 261 | 262 | self.lemma_lut = embs["lemma_lut"] 263 | self.alpha = opt.alpha 264 | 265 | if opt.cuda: 266 | self.cuda() 267 | 268 | def getEmb(self,indexes,src_enc): 269 | head_emb,dep_emb = [],[] 270 | src_enc,src_l = doubleunpack(src_enc) # batch x var(amr_l x src_l x dim) 271 | length_pairs = [] 272 | for i, index in enumerate(indexes): 273 | enc = src_enc[i] #amr_l src_l dim 274 | dep_emb.append(enc.index_select(1,index)) #var(amr_l x amr_l x dim) 275 | head_index = index.unsqueeze(1).unsqueeze(2).expand(enc.size(0),1,enc.size(-1)) 276 | # print ("getEmb",enc.size(),dep_index.size(),head_index.size()) 277 | head_emb.append(enc.gather(1,head_index).squeeze(1)) #var(amr_l x dim) 278 | length_pairs.append([len(index),len(index)]) 279 | return mypack(head_emb,[ls[0] for ls in length_pairs]),mydoublepack(dep_emb,length_pairs),length_pairs 280 | 281 | #input: all_data x n_feature, lengths 282 | #index: batch_size x var(amr_len) 283 | #src_enc (batch x amr_len) x src_len x txt_rnn_size 284 | 285 | #head: batch x var( amr_len x txt_rnn_size ) 286 | 287 | #dep : batch x var( amr_len x amr_len x txt_rnn_size ) 288 | 289 | #heads: [var(len),rel_dim] 290 | #deps: [var(len)**2,rel_dim] 291 | def forward(self, input, index,src_enc): 292 | assert isinstance(input, MyPackedSequence),input 293 | input,lengths = input 294 | if self.alpha and self.training: 295 | input = data_dropout(input,self.alpha) 296 | cat_embed = self.cat_lut(input[:,AMR_CAT]) 297 | lemma_embed = self.lemma_lut(input[:,AMR_LE]) 298 | 299 | amr_emb = torch.cat([cat_embed,lemma_embed],1) 300 | # print (input,lengths) 301 | 302 | head_emb_t,dep_emb_t,length_pairs = self.getEmb(index,src_enc) #packed, mydoublepacked 303 | 304 | 305 | head_emb = torch.cat([amr_emb,head_emb_t.data],1) 306 | 307 | dep_amr_emb_t = myunpack(*MyPackedSequence(amr_emb,lengths)) 308 | dep_amr_emb = [ emb.unsqueeze(0).expand(emb.size(0),emb.size(0),emb.size(-1)) for emb in dep_amr_emb_t] 309 | 310 | mydouble_amr_emb = mydoublepack(dep_amr_emb,length_pairs) 311 | 312 | # print ("rel_encoder",mydouble_amr_emb.data.size(),dep_emb_t.data.size()) 313 | dep_emb = torch.cat([mydouble_amr_emb.data,dep_emb_t.data],-1) 314 | 315 | # emb_unpacked = myunpack(emb,lengths) 316 | 317 | head_packed = MyPackedSequence(self.head(head_emb),lengths) # total,rel_dim 318 | head_amr_packed = MyPackedSequence(amr_emb,lengths) # total,rel_dim 319 | 320 | # print ("dep_emb",dep_emb.size()) 321 | size = dep_emb.size() 322 | dep = self.dep(dep_emb.view(-1,size[-1])).view(size[0],size[1],-1) 323 | 324 | dep_packed = MyDoublePackedSequence(MyPackedSequence(dep,mydouble_amr_emb[0][1]),mydouble_amr_emb[1],dep) 325 | 326 | return head_amr_packed,head_packed,dep_packed #,MyPackedSequence(emb,lengths) 327 | 328 | 329 | class RelModel(nn.Module): 330 | def __init__(self, opt,embs): 331 | super(RelModel, self).__init__() 332 | self.root_encoder = RootEncoder(opt,embs) 333 | self.encoder = RelEncoder( opt, embs) 334 | self.generator = RelCalssifierBiLinear( opt, embs,embs["rel_lut"].num_embeddings) 335 | 336 | self.root = nn.Linear(opt.rel_dim,1) 337 | self.LogSoftmax = nn.LogSoftmax() 338 | 339 | 340 | def root_score(self,mypackedhead): 341 | heads = myunpack(*mypackedhead) 342 | output = [] 343 | for head in heads: 344 | score = self.root(head).squeeze(1) 345 | output.append(self.LogSoftmax(score)) 346 | return output 347 | 348 | def forward(self, srlBatch, index,src_enc,root_enc): 349 | mypacked_root_enc = self.root_encoder(srlBatch, index,root_enc) #with information from le cat enc 350 | roots = self.root_score(mypacked_root_enc) 351 | 352 | encoded= self.encoder(srlBatch, index,src_enc) 353 | score_packed = self.generator(*encoded) 354 | 355 | return score_packed,roots #,arg_logit_packed 356 | 357 | 358 | class RelCalssifierBiLinear(nn.Module): 359 | 360 | def __init__(self, opt, embs,n_rel): 361 | super(RelCalssifierBiLinear, self).__init__() 362 | self.n_rel = n_rel 363 | self.cat_lut = embs["cat_lut"] 364 | self.inputSize = opt.rel_dim 365 | 366 | 367 | self.bilinear = nn.Sequential(nn.Dropout(opt.dropout), 368 | nn.Linear(self.inputSize,self.inputSize* self.n_rel)) 369 | self.head_bias = nn.Sequential(nn.Dropout(opt.dropout), 370 | nn.Linear(self.inputSize,self.n_rel)) 371 | self.dep_bias = nn.Sequential(nn.Dropout(opt.dropout), 372 | nn.Linear(self.inputSize,self.n_rel)) 373 | self.bias = nn.Parameter(torch.normal(torch.zeros(self.n_rel)).cuda()) 374 | 375 | 376 | # self.lsm = nn.LogSoftmax() 377 | self.cat_lut = embs["cat_lut"] 378 | self.lemma_lut = embs["lemma_lut"] 379 | if opt.cuda: 380 | self.cuda() 381 | 382 | def bilinearForParallel(self,inputs,length_pairs): 383 | output = [] 384 | ls = [] 385 | for i,input in enumerate(inputs): 386 | 387 | #head_t : amr_l x ( rel_dim x n_rel) 388 | #dep_t : amr_l x amr_l x rel_dim 389 | #head_bias : amr_l x n_rel 390 | #dep_bias : amr_l x amr_l x n_rel 391 | head_t,dep_t,head_bias,dep_bias = input 392 | l = len(head_t) 393 | ls.append(l) 394 | head_t = head_t.view(l,-1,self.n_rel) 395 | score =dep_t[:,:length_pairs[i][1]].bmm( head_t.view(l,-1,self.n_rel)).view(l,l,self.n_rel).transpose(0,1) 396 | 397 | dep_bias = dep_bias[:,:length_pairs[i][1]] 398 | score = score + dep_bias 399 | 400 | score = score + head_bias.unsqueeze(1).expand_as(score) 401 | score = score+self.bias.unsqueeze(0).unsqueeze(1).expand_as(score) 402 | score = F.log_softmax(score.view(ls[-1]*ls[-1],self.n_rel)) # - score.exp().sum(2,keepdim=True).log().expand_as(score) 403 | 404 | output.append(score.view(ls[-1]*ls[-1],self.n_rel)) 405 | return output,[l**2 for l in ls] 406 | 407 | 408 | def forward(self, _,heads,deps): 409 | '''heads.data: mypacked amr_l x rel_dim 410 | deps.data: mydoublepacked amr_l x amr_l x rel_dim 411 | ''' 412 | heads_data = heads.data 413 | deps_data = deps.data 414 | 415 | head_bilinear_transformed = self.bilinear (heads_data) #all_data x ( n_rel x inputsize) 416 | 417 | head_bias_unpacked = myunpack(self.head_bias(heads_data),heads.lengths) #[len x n_rel] 418 | 419 | size = deps_data.size() 420 | dep_bias = self.dep_bias(deps_data.view(-1,size[-1])).view(size[0],size[1],-1) 421 | 422 | dep_bias_unpacked,length_pairs = mydoubleunpack(MyDoublePackedSequence(MyPackedSequence( dep_bias,deps[0][1]),deps[1],dep_bias) ) #[len x n_rel] 423 | 424 | bilinear_unpacked = myunpack(head_bilinear_transformed,heads.lengths) 425 | 426 | deps_unpacked,length_pairs = mydoubleunpack(deps) 427 | output,l = self.bilinearForParallel( zip(bilinear_unpacked,deps_unpacked,head_bias_unpacked,dep_bias_unpacked),length_pairs) 428 | myscore_packed = mypack(output,l) 429 | 430 | # prob_packed = MyPackedSequence(myscore_packed.data,l) 431 | return myscore_packed -------------------------------------------------------------------------------- /parser/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Deep Learning Models for variational inference of alignment. 6 | Posterior , LikeliHood helps computing posterior weighted likelihood regarding relaxation. 7 | 8 | Also the whole AMR model is combined here. 9 | 10 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 11 | @since: 2018-05-30 12 | ''' 13 | 14 | import numpy as np 15 | from parser.models.ConceptModel import * 16 | from parser.models.MultiPassRelModel import * 17 | 18 | from parser.modules.GumbelSoftMax import renormalize,sink_horn,gumbel_noise_sample 19 | from parser.modules.helper_module import doublepack 20 | 21 | from copy import deepcopy 22 | 23 | #Encoding linearized AMR concepts for vartiaonal alignment model 24 | class AmrEncoder(nn.Module): 25 | 26 | def __init__(self, opt, embs): 27 | self.layers = opt.amr_enlayers 28 | self.num_directions = 2 if opt.brnn else 1 29 | assert opt.amr_rnn_size % self.num_directions == 0 30 | self.hidden_size = opt.amr_rnn_size // self.num_directions 31 | inputSize = embs["cat_lut"].embedding_dim + embs["lemma_lut"].embedding_dim 32 | super(AmrEncoder, self).__init__() 33 | 34 | self.rnn = nn.LSTM(inputSize, self.hidden_size, 35 | num_layers=opt.amr_enlayers, 36 | dropout=opt.dropout, 37 | bidirectional=opt.brnn) 38 | self.cat_lut = embs["cat_lut"] 39 | 40 | self.lemma_lut = embs["lemma_lut"] 41 | 42 | 43 | 44 | self.alpha = opt.alpha #unk with alpha 45 | if opt.cuda: 46 | self.cuda() 47 | 48 | #input:len, batch, n_feature 49 | #output: len, batch, hidden_size * num_directions 50 | def forward(self, packed_input, hidden=None): 51 | assert isinstance(packed_input,PackedSequence) 52 | input = packed_input.data 53 | 54 | if self.alpha and self.training: 55 | input = data_dropout(input,self.alpha) 56 | 57 | cat_embed = self.cat_lut(input[:,AMR_CAT]) 58 | lemma_embed = self.lemma_lut(input[:,AMR_LE]) 59 | 60 | emb = torch.cat([cat_embed,lemma_embed],1) # len,batch,embed 61 | emb = PackedSequence(emb, packed_input.batch_sizes) 62 | outputs, hidden_t = self.rnn(emb, hidden) 63 | return outputs, hidden_t 64 | 65 | #Model to compute relaxed posteior 66 | # we constraint alignment if copying mechanism can be used 67 | class Posterior(nn.Module): 68 | def __init__(self,opt): 69 | super(Posterior, self).__init__() 70 | self.txt_rnn_size = opt.txt_rnn_size 71 | self.amr_rnn_size = opt.amr_rnn_size 72 | self.jamr = opt.jamr 73 | if self.jamr : #if use fixed alignment, then no need for variational model 74 | return 75 | self.transform = nn.Sequential( 76 | nn.Dropout(opt.dropout), 77 | nn.Linear(self.txt_rnn_size,self.amr_rnn_size,bias = opt.lemma_bias)) 78 | self.sm = nn.Softmax() 79 | self.sink = opt.sink 80 | self.sink_t = opt.sink_t 81 | if opt.cuda: 82 | self.cuda() 83 | 84 | def forward(self,src_enc,amr_enc,aligns): 85 | 86 | '''src_enc: src_len x batch x txt_rnn_size, src_l 87 | amr_enc: amr_len x batch x amr_rnn_size, amr_l 88 | aligns: amr_len x batch x src_len , amr_l 89 | 90 | 91 | posterior: amr_len x batch x src_len , amr_l 92 | ''' 93 | if self.jamr : 94 | return aligns,aligns,0 95 | src_enc,amr_enc,aligns =unpack(src_enc),unpack(amr_enc),unpack(aligns) 96 | 97 | src_enc = src_enc[0] 98 | amr_enc = amr_enc[0] 99 | lengths = aligns[1] 100 | aligns = aligns[0] 101 | assert not np.isnan(np.sum(src_enc.data.cpu().numpy())),("src_enc \n",src_enc) 102 | assert not np.isnan(np.sum(amr_enc.data.cpu().numpy())),("amr_enc \n",amr_enc) 103 | src_len , batch , src_rnn_size = src_enc.size() 104 | src_transformed = self.transform(src_enc.view(-1,src_rnn_size)).view(src_len,batch,-1).transpose(0,1) #batch x src_len x amr_rnn_size 105 | amr_enc = amr_enc.transpose(0,1).transpose(1,2) #batch x amr_rnn_size x amr_len 106 | score = src_transformed.bmm(amr_enc).transpose(1,2).transpose(0,1) #/ self.amr_rnn_size #amr_len x batch x src_len 107 | assert not np.isnan(np.sum(score.data.cpu().numpy())),("score \n",score) 108 | final_score = gumbel_noise_sample(score) if self.training else score 109 | assert not np.isnan(np.sum(final_score.data.cpu().numpy())),("final_score \n",final_score) 110 | if self.sink: 111 | posterior = sink_horn((final_score- (1-aligns)*1e6 ,lengths),k=self.sink,t=self.sink_t ) 112 | else: 113 | final_score = final_score- (1-aligns)*1e6 114 | dim = final_score.size() 115 | final_score = final_score.view(-1, final_score.size(-1)) 116 | posterior =self.sm(final_score).view(dim) 117 | return pack(posterior, lengths),pack(score,lengths) #amr_len x batch x src_len 118 | 119 | #directly compute likelihood of concept being generated at words (a matrix for each training example) 120 | def LikeliHood(tgtBatch,probBatch): 121 | '''tgtBatch: data x [n_feature + 1 (AMR_CAN_COPY)], batch_sizes 122 | probaBatch: (data x n_out, lengths ) * 123 | aligns: amr_len x batch x src_len , amr_l 124 | 125 | likelihood: data (amr) x src_len , batch_sizes 126 | ''' 127 | 128 | batch_sizes = tgtBatch.batch_sizes 129 | likelihoods = [] 130 | for i,prob in enumerate(probBatch): 131 | assert isinstance(prob, PackedSequence),"only support packed" 132 | if i == AMR_LE: 133 | prob_batch,lengths = unpack(prob) 134 | prob_batch = prob_batch.transpose(0,1) # batch x src_len x n_out 135 | n_out = prob_batch.size(-1) 136 | src_len = prob_batch.size(1) 137 | packed_index_data = tgtBatch.data[:,i].clamp(max=n_out-1) #so lemma not in high maps to last index ,data x 1 138 | 139 | copy_data = (packed_index_data re-categorized_id 281 | # posterior: re-categorized_id -> alignment_soft_posterior 282 | rel_batch,rel_index,srcBatch,posterior = input 283 | assert not np.isnan(np.sum(posterior.data.data.cpu().numpy())),("posterior.data \n",posterior.data) 284 | posterior_data = renormalize(posterior.data+epsilon) 285 | assert not np.isnan(np.sum(posterior_data.data.cpu().numpy())),("posterior_data \n",posterior_data) 286 | posterior = PackedSequence(posterior_data,posterior.batch_sizes) 287 | indexed_posterior = self.index_posterior(posterior,rel_index) 288 | 289 | src_enc = self.rel_encoder(srcBatch,indexed_posterior) 290 | root_enc = self.root_encoder(srcBatch) 291 | 292 | weighted_root_enc = self.root_posterior_enc(posterior,root_enc) 293 | weighted_enc= self.weight_posterior_enc(posterior,src_enc) #src_enc MyDoublePackedSequence, amr_len 294 | 295 | # self_rel_index = [ Variable(index.data.new(list(range(len(index))))) for index in rel_index] 296 | rel_prob = self.relModel(rel_batch,rel_index,weighted_enc,weighted_root_enc) 297 | # assert not np.isnan(np.sum(rel_prob[0].data.data.cpu().numpy())),("inside srl\n",rel_prob[0].data.data) 298 | return rel_prob 299 | if len(input)==3 and rel: 300 | # relation identification evaluation 301 | rel_batch,srcBatch,alginBatch = input # 302 | src_enc = self.rel_encoder(srcBatch,alginBatch) 303 | root_enc = self.root_encoder(srcBatch) 304 | root_data,lengths = unpack(root_enc) 305 | mypacked_root_enc = mypack(root_data.transpose(0,1).contiguous(),lengths) 306 | rel_prob = self.relModel(rel_batch,alginBatch,src_enc,mypacked_root_enc) 307 | return rel_prob 308 | else: 309 | # concept identification evaluation 310 | srcBatch = input 311 | probBatch,src_enc= self.concept_decoder(srcBatch) 312 | return probBatch 313 | 314 | 315 | #encoding relaxation for root identification 316 | def root_posterior_enc(self,posterior,src_enc): 317 | '''src_enc: # batch x var( src_l x dim) 318 | posterior = pre_amr_len x batch x src_len , amr_l 319 | 320 | out: batch x amr_len x txt_rnn_size 321 | ''' 322 | posterior,lengths = unpack(posterior) 323 | enc,length_src = unpack(src_enc) 324 | # print ("length_pairs",length_pairs) 325 | # print ("lengths",lengths) 326 | weighted = [] 327 | for i, src_l in enumerate(length_src): #src_len x dim 328 | p = posterior[:,i,:src_l] #pre_full_amr_len x src_len 329 | enc_t = enc[:src_l,i,:] 330 | weighted_enc = p.mm(enc_t) #pre_amr_len x dim 331 | weighted.append(weighted_enc) #pre_amr_len x dim 332 | # print ("length_pairs",length_pairs) 333 | return mypack(weighted,lengths) 334 | 335 | #encoding relaxation for relation identification 336 | def weight_posterior_enc(self,posterior,src_enc): 337 | '''src_enc: # batch x var(pre_amr_len x src_l x dim) 338 | posterior = pre_amr_len x batch x src_len , amr_l 339 | 340 | out: batch x amr_len x txt_rnn_size 341 | ''' 342 | posterior,lengths = unpack(posterior) 343 | def handle_enc(enc): 344 | enc,length_pairs = doubleunpack(enc) 345 | # print ("length_pairs",length_pairs) 346 | # print ("lengths",lengths) 347 | dim = enc[0].size(-1) 348 | weighted = [] 349 | new_length_pairs = [] 350 | for i, src_enc_t in enumerate(enc): 351 | p = posterior[:lengths[i],i,:] #pre_amr_len x src_len 352 | enc_trans = src_enc_t.transpose(0,1).contiguous().view(p.size(-1),-1) #src_len x (pre_amr_len x dim) 353 | weighted_enc = p.mm(enc_trans) #pre_amr_len x (pre_amr_len x dim) 354 | weighted.append(weighted_enc.view(lengths[i],length_pairs[i][0],dim).transpose(0,1).contiguous()) #pre_amr_len x pre_amr_len x dim 355 | new_length_pairs.append([length_pairs[i][0],lengths[i]]) 356 | # print ("length_pairs",length_pairs) 357 | return doublepack(weighted,length_pairs) 358 | 359 | return handle_enc(src_enc) 360 | 361 | -------------------------------------------------------------------------------- /parser/models/__pycache__/ConceptModel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/models/__pycache__/ConceptModel.cpython-36.pyc -------------------------------------------------------------------------------- /parser/models/__pycache__/MultiPassRelModel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/models/__pycache__/MultiPassRelModel.cpython-36.pyc -------------------------------------------------------------------------------- /parser/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /parser/modules/GumbelSoftMax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Helper functions regarding gumbel noise 6 | 7 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 8 | @since: 2018-05-30 9 | ''' 10 | 11 | import torch 12 | from torch.autograd import Variable 13 | import torch.nn.functional as F 14 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 15 | from torch.nn.utils.rnn import pack_padded_sequence as pack 16 | from torch.nn.utils.rnn import PackedSequence 17 | 18 | eps = 1e-8 19 | def sample_gumbel(input): 20 | noise = torch.rand(input.size()).type_as(input.data) 21 | noise.add_(eps).log_().neg_() 22 | noise.add_(eps).log_().neg_() 23 | return Variable(noise,requires_grad=False) 24 | 25 | 26 | def gumbel_noise_sample(input,temperature = 1): 27 | noise = sample_gumbel(input) 28 | x = (input + noise) / temperature 29 | return x.view_as(input) 30 | 31 | 32 | import numpy as np 33 | 34 | def sink_horn(input,k = 5,t = 1,batch_first = False): 35 | def sink_horn_data(x,lengths): 36 | assert not np.isnan(np.sum(x.data.cpu().numpy())),("start x\n",x.data) 37 | over_flow = x-80*t 38 | x = x.clamp(max=80*t)+F.tanh(over_flow)*(over_flow>0).float() 39 | x = torch.exp(x/t) 40 | assert not np.isnan(np.sum(x.data.cpu().numpy())),("exp x\n",x.data) 41 | musks = torch.zeros(x.size()) 42 | for i,l in enumerate(lengths): 43 | musks[:l,i,:l] = 1 44 | musks = Variable(musks,requires_grad=False).type_as(x) 45 | x = x*musks+eps 46 | for i in range(0,k): 47 | x = x / x.sum(0,keepdim=True).expand_as(x) 48 | x = x*musks+eps 49 | x = x / x.sum(2,keepdim=True).expand_as(x) 50 | x = x*musks+eps 51 | 52 | assert not np.isnan(np.sum(x.data.cpu().numpy())),("end x\n",x.data) 53 | return x 54 | if isinstance(input,PackedSequence): 55 | data,l = unpack(input,batch_first=batch_first) 56 | data = sink_horn_data(data,l) 57 | return pack(data,l,batch_first) 58 | else: 59 | return sink_horn_data(*input) 60 | 61 | 62 | def renormalize(input,t=1): 63 | 64 | x = ((input+eps).log() ) / t 65 | x = F.softmax(x) 66 | return x.view_as(input) 67 | 68 | -------------------------------------------------------------------------------- /parser/modules/__initial__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/modules/__initial__.py -------------------------------------------------------------------------------- /parser/modules/__pycache__/GumbelSoftMax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/modules/__pycache__/GumbelSoftMax.cpython-36.pyc -------------------------------------------------------------------------------- /parser/modules/__pycache__/helper_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/parser/modules/__pycache__/helper_module.cpython-36.pyc -------------------------------------------------------------------------------- /parser/modules/helper_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Some data structure to save memory for packing variable lengthed data into batch, 6 | Not actually sure whether it's better (time or space) than zero padding, 7 | 8 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 9 | @since: 2018-05-30 10 | ''' 11 | import torch 12 | from torch.autograd import Variable 13 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 14 | from torch.nn.utils.rnn import pack_padded_sequence as pack 15 | from collections import namedtuple 16 | MyPackedSequence = namedtuple('MyPackedSequence', ['data', 'lengths']) 17 | MyDoublePackedSequence = namedtuple('MyDoublePackedSequence', ['PackedSequence', 'length_pairs','data']) #packed sequence must be batch_first, inner length 18 | DoublePackedSequence = namedtuple('DoublePackedSequence', ['PackedSequence', 'outer_lengths','data']) #packed sequence must be batch_first, inner length 19 | 20 | def sort_index(seq): 21 | sorted([(v, i) for (i, v) in enumerate(seq)],reverse = True) 22 | 23 | def mypack(data,lengths): 24 | if isinstance(data,list): 25 | return MyPackedSequence(torch.cat(data,0),lengths) 26 | else: 27 | data_list = [] 28 | for i, l in enumerate(lengths): 29 | data_list.append(data[i][:l]) 30 | return mypack(data_list,lengths) 31 | 32 | 33 | def myunpack(*mypacked): 34 | data,lengths = mypacked 35 | data_list = [] 36 | current = 0 37 | for i, l in enumerate(lengths): 38 | data_list.append(data[current:l+current]) 39 | current += l 40 | return data_list 41 | 42 | def mydoubleunpack(mydoublepacked): 43 | packeddata,length_pairs,data = mydoublepacked 44 | data = myunpack(*packeddata) 45 | data_list = [] 46 | for i, ls in enumerate(length_pairs): 47 | out_l,in_l = ls 48 | data_list.append(data[i][:,:in_l]) #outl x max_l x dim 49 | return data_list,length_pairs 50 | 51 | 52 | def mydoublepack(data_list,length_pairs): #batch x var(amr_l x src_l x dim) 53 | data = [] 54 | max_in_l = max([ls[1] for ls in length_pairs]) 55 | outer_l = [] 56 | for d, ls in list(zip(data_list,length_pairs)): 57 | outl,inl = ls 58 | size = [i for i in d.size()] 59 | if size[1] == max_in_l: 60 | tdata = d 61 | else: 62 | size[1] = max_in_l 63 | tdata = Variable(d.data.new(*size).fill_(0)) 64 | # print (tdata) 65 | tdata[:,:inl] = d 66 | data.append( tdata) #amr_l x src_l x dim 67 | outer_l.append(outl) 68 | 69 | packed = mypack(data,outer_l) 70 | 71 | return MyDoublePackedSequence(packed,length_pairs,packed.data) 72 | 73 | def doubleunpack(doublepacked): 74 | assert isinstance(doublepacked,DoublePackedSequence) 75 | packeddata,outer_lengths,data = doublepacked 76 | data,in_l = unpack(packeddata,batch_first=True) 77 | data_list = [] 78 | length_pairs = [] 79 | current = 0 80 | for i, l in enumerate(outer_lengths): 81 | data_list.append(data[current:l+current]) #outl x max_l x dim 82 | length_pairs.append((l,in_l[current])) 83 | current += l 84 | return data_list,length_pairs 85 | 86 | 87 | def doublepack(data_list,length_pairs): #batch x var(amr_l x src_l x dim) 88 | data = [] 89 | lengths = [] 90 | max_in_l = max([ls[1] for ls in length_pairs]) 91 | outer_l = [] 92 | for d, ls in list(zip(data_list,length_pairs)): 93 | outl,inl = ls 94 | size = [i for i in d.size()] 95 | if size[1] == max_in_l: 96 | tdata = d 97 | else: 98 | size[1] = max_in_l 99 | tdata = Variable(d.data.new(*size).fill_(0)) 100 | # print (tdata) 101 | tdata[:,:inl] = d 102 | data.append( tdata) #amr_l x src_l x dim 103 | lengths = lengths + [inl]*outl 104 | outer_l.append(outl) 105 | 106 | packed = pack(torch.cat(data,0),lengths,batch_first=True) 107 | 108 | return DoublePackedSequence(packed,outer_l,packed.data) 109 | 110 | 111 | 112 | def data_dropout(data:Variable,frequency,UNK = 1)->Variable: 113 | if frequency == 0: return data 114 | if isinstance(frequency,Variable): 115 | f = frequency 116 | unk_mask = Variable(torch.bernoulli(f.data),requires_grad = False).cuda() 117 | data = data*(1-unk_mask).long()+(unk_mask*Variable(torch.ones(data.size()).cuda()*UNK,requires_grad = False)).long() 118 | else: 119 | f = torch.ones(data.size()).cuda()*frequency 120 | unk_mask = Variable(torch.bernoulli(f),requires_grad = False) 121 | data = data*(1-unk_mask).long()+(unk_mask*Variable(torch.ones(data.size()).cuda()*UNK,requires_grad = False)).long() 122 | return data 123 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | def freeze(m,t=0): 6 | if isinstance(m,nn.Dropout): 7 | m.p = t 8 | m.dropout =t 9 | 10 | 11 | from copy import deepcopy 12 | def load_old_model(dicts,opt,generate=False): 13 | model_from = opt.train_from 14 | print('Loading from checkpoint at %s' % model_from) 15 | if opt.gpus[0] != -1: 16 | print ('from model in gpus:'+str(opt.from_gpus[0]),' to gpu:'+str(opt.gpus[0])) 17 | checkpoint = torch.load(model_from, map_location={'cuda:'+str(opt.from_gpus[0]): 'cuda:'+str(opt.gpus[0])}) 18 | else: 19 | print ('from model in gpus:'+str(opt.from_gpus[0]),'to cpu ') 20 | checkpoint = torch.load(model_from, map_location={'cuda:'+str(opt.from_gpus[0]): 'cpu'}) 21 | print("Model loaded") 22 | optt = checkpoint["opt"] 23 | rel = optt.rel 24 | AmrModel = checkpoint['model'] 25 | if optt.rel == 1: 26 | if not opt.train_all: 27 | AmrModel.concept_decoder = deepcopy(AmrModel.concept_decoder) 28 | for name, param in AmrModel.concept_decoder.named_parameters(): 29 | param.requires_grad = False 30 | AmrModel.concept_decoder.apply(freeze) 31 | 32 | parameters_to_train = [] 33 | for name, param in AmrModel.named_parameters(): 34 | if name == "word_fix_lut" or param.size(0) == len(dicts["word_dict"]): 35 | param.requires_grad = False 36 | if param.requires_grad: 37 | parameters_to_train.append(param) 38 | print (AmrModel) 39 | print ("training parameters: "+str(len(parameters_to_train))) 40 | return AmrModel,parameters_to_train,optt 41 | 42 | optt.rel = opt.rel 43 | if opt.rel and not rel : 44 | if opt.jamr == 0: 45 | AmrModel.poserior_m.align_weight = 1 46 | AmrModel.concept_decoder.apply(freeze) 47 | opt.independent = True 48 | AmrModel.start_rel(opt) 49 | embs = AmrModel.embs 50 | embs["lemma_lut"].requires_grad = False ##need load 51 | embs["pos_lut"].requires_grad = False 52 | embs["ner_lut"].requires_grad = False 53 | embs["word_fix_lut"].requires_grad = False 54 | embs["rel_lut"] = nn.Embedding(dicts["rel_dict"].size(), 55 | opt.rel_dim) 56 | for param in AmrModel.concept_decoder.parameters(): 57 | param.requires_grad = False 58 | if not generate and opt.jamr == 0: 59 | AmrModel.poserior_m.posterior.ST = opt.ST 60 | AmrModel.poserior_m.posterior.sink = opt.sink 61 | AmrModel.poserior_m.posterior.sink_t = opt.sink_t 62 | 63 | if opt.cuda: 64 | AmrModel.cuda() 65 | else: 66 | AmrModel.cpu() 67 | 68 | if not generate and opt.jamr == 0: 69 | if opt.train_posterior: 70 | for param in AmrModel.poserior_m.parameters(): 71 | param.requires_grad = True 72 | AmrModel.poserior_m.apply(lambda x: freeze(x,opt.dropout)) 73 | else: 74 | opt.prior_t = 0 75 | opt.sink_re = 0 76 | for param in AmrModel.poserior_m.parameters(): 77 | param.requires_grad = False 78 | parameters_to_train = [] 79 | if opt.train_all: 80 | for name, param in AmrModel.named_parameters(): 81 | if name != "word_fix_lut": 82 | param.requires_grad = True 83 | parameters_to_train.append(param) 84 | else: 85 | print ("not updating "+name) 86 | 87 | else: 88 | if opt.rel: 89 | for param in AmrModel.concept_decoder.parameters(): 90 | if param.requires_grad: 91 | param.requires_grad = False 92 | print("turing off concept model: ",param) 93 | for name,p in AmrModel.named_parameters(): 94 | if name == "word_fix_lut" or p.size(0) == len(dicts["word_dict"]): 95 | p.requires_grad = False 96 | if p.requires_grad: 97 | parameters_to_train.append(p) 98 | else: 99 | print ([p.size() for p in AmrModel.concept_decoder.parameters()]) 100 | AmrModel.apply(freeze) 101 | for p in AmrModel.concept_decoder.parameters(): 102 | p.requires_grad = True 103 | parameters_to_train.append(p) 104 | print (AmrModel) 105 | print ("training parameters: "+str(len(parameters_to_train))) 106 | return AmrModel,parameters_to_train,optt 107 | -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/src/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /src/data_build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Scripts build dictionary and data into numbers 6 | 7 | Data path information should also be specified here for 8 | trainFolderPath, devFolderPath and testFolderPath 9 | as we allow option to choose from two version of data. 10 | 11 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 12 | @since: 2018-05-30 13 | ''' 14 | 15 | from utility.StringCopyRules import * 16 | from utility.ReCategorization import * 17 | from parser.Dict import * 18 | 19 | import argparse 20 | 21 | 22 | def data_build_parser(): 23 | parser = argparse.ArgumentParser(description='data_build.py') 24 | 25 | ## Data options 26 | parser.add_argument('-threshold', default=10, type=int, 27 | help="""threshold for high frequency concepts""") 28 | 29 | parser.add_argument('-jamr', default=0, type=int, 30 | help="""wheather to add .jamr at the end""") 31 | parser.add_argument('-skip', default=0, type=int, 32 | help="""skip dict build if dictionary already built""") 33 | parser.add_argument('-suffix', default=".txt_pre_processed", type=str, 34 | help="""suffix of files to combine""") 35 | parser.add_argument('-folder', default=allFolderPath, type=str, 36 | help="""the folder""") 37 | return parser 38 | 39 | 40 | parser = data_build_parser() 41 | 42 | opt = parser.parse_args() 43 | 44 | suffix = opt.suffix + "_jamr" if opt.jamr else opt.suffix 45 | with_jamr = "_with_jamr" if opt.jamr else "_without_jamr" 46 | trainFolderPath = opt.folder + "/training/" 47 | trainingFilesPath = folder_to_files_path(trainFolderPath, suffix) 48 | 49 | devFolderPath = opt.folder + "/dev/" 50 | devFilesPath = folder_to_files_path(devFolderPath, suffix) 51 | 52 | testFolderPath = opt.folder + "/test/" 53 | testFilesPath = folder_to_files_path(testFolderPath, suffix) 54 | 55 | 56 | def myamr_to_seq(amr, snt_token, lemma_token, pos, rl, fragment_to_node_converter, 57 | high_freq): # high_freq should be a dict() 58 | 59 | def uni_to_list(uni, can_copy=0): 60 | # if can_copy: print (uni) 61 | le = uni.le 62 | cat = uni.cat # use right category anyway 63 | ner = uni.aux 64 | data = [0, 0, 0, 0, 0] 65 | data[AMR_AUX] = ner 66 | data[AMR_LE_SENSE] = uni.sense 67 | data[AMR_LE] = le 68 | data[AMR_CAT] = cat 69 | data[AMR_CAN_COPY] = 1 if can_copy else 0 70 | return data 71 | 72 | output_concepts = [] 73 | lemma_str = " ".join(lemma_token) 74 | fragment_to_node_converter.convert(amr, rl, snt_token, lemma_token, lemma_str) 75 | concepts, rel, rel_prefix, root_id = amr.node_value(keys=["value", "align"], all=True) 76 | 77 | results = rl.get_matched_concepts(snt_token, concepts, lemma_token, pos, jamr=opt.jamr) 78 | aligned_index = [] 79 | n_amr = len(results) 80 | n_snt = len(snt_token) 81 | l = len(lemma_token) if lemma_token[-1] != "." else len(lemma_token) - 1 82 | 83 | # hello, linguistic prior here 84 | old_unaligned_index = [i for i in range(l) if not ( 85 | pos[i] in ["IN", "POS"] or lemma_token[i] == "would" or lemma_token[i] == "will" and pos[i] == "MD" 86 | or lemma_token[i] == "have" and pos[i] not in ["VB", "VBG"]) 87 | or lemma_token[i] in ["although", "while", "of", "if", "in", "per", "like", "by", "for"]] 88 | 89 | for i, n_c_a in enumerate(results): 90 | uni = n_c_a[1] 91 | align = [a[0] for a in n_c_a[2]] if len(n_c_a[2]) > 0 else old_unaligned_index 92 | aligned_index += align 93 | 94 | data = uni_to_list(uni, len(n_c_a[2]) > 0) 95 | data.append(align) 96 | output_concepts.append(data) 97 | if len(aligned_index) == 0: 98 | output_concepts[0][-1] = [int((len(lemma_token) - 1) / 2)] 99 | aligned_index = [int((len(lemma_token) - 1) / 2)] 100 | assert len(aligned_index) > 0, (results, amr._anno, " ".join(lemma_token)) 101 | unaligned_index = [i for i in range(n_snt) if i not in aligned_index] # or [-1 n_snt] for all 102 | if len(unaligned_index) == 0: unaligned_index = [-1, n_snt] 103 | # assert n_snt <= n_amr or unaligned_index != [],(n_amr,n_snt,concepts,snt_token,amr 104 | for i in range(n_amr, n_snt): 105 | output_concepts.append([NULL_WORD, NULL_WORD, NULL_WORD, NULL_WORD, 0, [-1, n_snt]]) # len(amr) >= len(snt) 106 | printed = False 107 | for i in range(len(output_concepts)): 108 | if output_concepts[i][-1] == []: 109 | if not printed: 110 | print(output_concepts[i]) 111 | print (list(zip(snt_token, lemma_token, pos))) 112 | print(concepts, amr) 113 | printed = True 114 | output_concepts[i][-1] = [-1, n_snt] 115 | 116 | rel_feature = [] 117 | rel_tgt = [] 118 | for i, (amr_index, role_list) in enumerate(rel): 119 | amr_concept = uni_to_list(amr_index[ 120 | 0]) # if align else uni_to_list(AMRUniversal(UNK_WORD,output_concepts[amr_index[1]][AMR_CAT],NULL_WORD)) 121 | rel_feature.append(amr_concept[:4] + [amr_index[1]]+[rel_prefix[i]]) 122 | # assert amr_index[1] < len(results), (concepts, rel) 123 | rel_tgt.append(role_list) # [role,rel_index] 124 | return output_concepts, [rel_feature, rel_tgt, root_id], unaligned_index # [[[lemma1,lemma2],category,relation]] 125 | 126 | 127 | def filter_non_aligned(input_concepts, rel, unaligned_index): 128 | rel_feature, rel_tgt, root_id = rel 129 | 130 | filtered_index = {} # original -> filtered 131 | 132 | output_concepts = [] 133 | for i, data in enumerate(input_concepts): 134 | if len(data[-1]) == 0: 135 | output_concepts.append( 136 | [NULL_WORD, NULL_WORD, NULL_WORD, NULL_WORD, 0, unaligned_index]) # len(amr) >= len(snt) 137 | elif len(data[-1]) == 1 or data[AMR_CAT] == NULL_WORD: 138 | output_concepts.append(data) 139 | filtered_index[i] = len(output_concepts) - 1 140 | else: 141 | assert False, (i, data, input_concepts, rel) 142 | out_rel_feature, out_rel_tgt = [], [] 143 | filtered_rel_index = {} # original -> filtered for dependency indexing 144 | for i, data in enumerate(rel_feature): 145 | index = data[-1] 146 | if index in filtered_index: 147 | new_index = filtered_index[index] 148 | out_rel_feature.append(data[:-1] + [new_index]) 149 | filtered_rel_index[i] = len(out_rel_feature) - 1 150 | 151 | for i, roles in enumerate(rel_tgt): 152 | if i in filtered_rel_index: 153 | new_roles = [[role, filtered_rel_index[j]] for role, j in roles if j in filtered_rel_index] 154 | out_rel_tgt.append(new_roles) 155 | 156 | if root_id not in filtered_rel_index: 157 | root_id = 0 158 | 159 | assert len(output_concepts) > 0, (input_concepts, rel, unaligned_index) 160 | 161 | return output_concepts, [out_rel_feature, out_rel_tgt, root_id] 162 | 163 | 164 | def add_seq_to_dict(dictionary, seq): 165 | for i in seq: 166 | dictionary.add(i) 167 | 168 | 169 | def aligned(align_list): 170 | return align_list[0] == -1 171 | 172 | 173 | # id_seq : [(lemma,cat,lemma_sensed,ner])] 174 | def amr_seq_to_id(lemma_dict, category_dict, lemma_sensed_dict, aux_dict, amr_seq): 175 | id_seq = [] 176 | for l in amr_seq: 177 | data = [0] * 5 178 | data[AMR_CAT] = category_dict[l[AMR_CAT]] 179 | data[AMR_LE] = lemma_dict[l[AMR_LE]] 180 | data[AMR_AUX] = aux_dict[l[AMR_AUX]] 181 | data[AMR_SENSE] = sensed_dict[l[AMR_SENSE]] 182 | data[AMR_CAN_COPY] = l[AMR_CAN_COPY] 183 | id_seq.append(data) 184 | return id_seq 185 | 186 | 187 | def amr_seq_to_dict(lemma_dict, category_dict, sensed_dict, aux_dict, amr_seq): # le,cat,le_sense,ner,align 188 | for i in amr_seq: 189 | category_dict.add(i[AMR_CAT]) 190 | lemma_dict.add(i[AMR_LE]) 191 | aux_dict.add(i[AMR_NER]) 192 | sensed_dict.add(i[AMR_SENSE]) 193 | 194 | 195 | def rel_seq_to_dict(lemma_dict, category_dict, sensed_dict, rel_dict, rel): # (amr,index,[[role,amr,index]]) 196 | rel_feature, rel_tgt, root_id = rel 197 | for i in rel_feature: 198 | category_dict.add(i[AMR_CAT]) 199 | lemma_dict.add(i[AMR_LE]) 200 | # sensed_dict.add(i[AMR_SENSE]) 201 | for role_list in rel_tgt: 202 | for role_index in role_list: 203 | # assert (role_index[0]==":top"),rel_tgt 204 | rel_dict.add(role_index[0]) 205 | 206 | 207 | def rel_seq_to_id(lemma_dict, category_dict, sensed_dict, rel_dict, rel): 208 | rel_feature, rel_tgt, root_id = rel 209 | feature_seq = [] 210 | index_seq = [] 211 | prefix_seq = [] 212 | roles_mat = [] 213 | for l in rel_feature: 214 | data = [0] * 3 215 | data[0] = category_dict[l[AMR_CAT]] 216 | data[1] = lemma_dict[l[AMR_LE]] 217 | data[2] = sensed_dict[l[AMR_SENSE]] 218 | feature_seq.append(data) 219 | index_seq.append(l[-2]) 220 | prefix_seq.append(l[-1]) 221 | for role_list in rel_tgt: 222 | roles_id = [] 223 | for role_index in role_list: 224 | roles_id.append([role_index[0], role_index[1]]) 225 | roles_mat.append(roles_id) 226 | 227 | return feature_seq, index_seq, roles_mat, root_id,prefix_seq 228 | 229 | 230 | def handle_sentence(data, filepath, build_dict, n, word_only): 231 | if n % 1000 == 0: 232 | print (n) 233 | 234 | ner = data["ner"] 235 | snt_token = data["tok"] 236 | pos = data["pos"] 237 | lemma_token = data["lem"] 238 | amr_t = data["amr_t"] 239 | 240 | if build_dict: 241 | if word_only: 242 | add_seq_to_dict(word_dict, snt_token) 243 | else: 244 | add_seq_to_dict(word_dict, snt_token) 245 | add_seq_to_dict(lemma_dict, lemma_token) 246 | add_seq_to_dict(pos_dict, pos) 247 | add_seq_to_dict(ner_dict, ner) 248 | amr = AMRGraph(amr_t) 249 | amr_seq, rel, unaligned_index = myamr_to_seq(amr, snt_token, lemma_token, pos, rl, 250 | fragment_to_node_converter, high_freq) 251 | amr_seq_to_dict(lemma_dict, category_dict, sensed_dict, aux_dict, amr_seq) 252 | rel_seq_to_dict(lemma_dict, category_dict, sensed_dict, rel_dict, rel) 253 | else: 254 | amr = AMRGraph(amr_t) 255 | amr_seq, rel, unaligned_index = myamr_to_seq(amr, snt_token, lemma_token, pos, rl, fragment_to_node_converter, 256 | high_freq) 257 | if opt.jamr: 258 | amr_seq, rel = filter_non_aligned(amr_seq, rel, unaligned_index) 259 | data["snt_id"] = seq_to_id(word_dict, snt_token)[0] 260 | data["lemma_id"] = seq_to_id(lemma_dict, lemma_token)[0] 261 | data["pos_id"] = seq_to_id(pos_dict, pos)[0] 262 | data["ner_id"] = seq_to_id(ner_dict, ner)[0] 263 | 264 | l = len(data["pos_id"]) 265 | if not (l == len(data["snt_id"]) and l == len(data["lemma_id"]) and l == len(data["ner_id"])): 266 | print (l, len(data["snt_id"]), len(data["lemma_id"]), len(data["ner_id"])) 267 | print (data["pos_id"]) 268 | print (data["snt_id"]) 269 | print (data["lemma_id"]) 270 | print (data["ner_id"]) 271 | print (pos) 272 | print (snt_token) 273 | print (lemma_token) 274 | print (ner) 275 | print (data["snt"]) 276 | assert (False) 277 | data["amr_seq"] = amr_seq 278 | data["convertedl_seq"] = amr.node_value() 279 | data["rel_seq"], data["rel_triples"] = amr.get_gold() 280 | data["amr_id"] = amr_seq_to_id(lemma_dict, category_dict, sensed_dict, aux_dict, amr_seq) 281 | data["amr_rel_id"], data["amr_rel_index"], data["roles_mat"], data["root"],data["prefix"] = rel_seq_to_id(lemma_dict, 282 | category_dict, 283 | sensed_dict, 284 | rel_dict, rel) 285 | 286 | for i in data["amr_rel_index"]: 287 | assert i < len(data["amr_id"]), (data["amr_rel_index"], amr_seq, data["amr_id"]) 288 | data["index"] = [all[-1] for all in amr_seq] 289 | 290 | 291 | def readFile(filepath, build_dict=False, word_only=False): 292 | all_data = load_text_jamr(filepath) 293 | 294 | n = 0 295 | for data in all_data: 296 | n = n + 1 297 | handle_sentence(data, filepath, build_dict, n, word_only) 298 | if not build_dict: 299 | outfile = Pickle_Helper(re.sub(end, ".pickle" + with_jamr, filepath)) 300 | outfile.dump(all_data, "data") 301 | outfile.save() 302 | return len(all_data) 303 | 304 | 305 | # Creating ReUsable Object 306 | rl = rules() 307 | rl.load("data/rule_f" + with_jamr) 308 | # initializer = lasagne.init.Uniform() 309 | fragment_to_node_converter = ReCategorizor(from_file=True, path="data/graph_to_node_dict_extended" + with_jamr, 310 | training=False, auto_convert_threshold=opt.threshold) 311 | non_rule_set_f = Pickle_Helper("data/non_rule_set") 312 | non_rule_set = non_rule_set_f.load()["non_rule_set"] 313 | threshold = opt.threshold 314 | high_text_num, high_frequency, low_frequency, low_text_num = unmixe(non_rule_set, threshold) 315 | print ( 316 | "initial converted,threshold,len(non_rule_set),high_text_num,high_frequency,low_frequency,low_text_num,high_freq") 317 | high_freq = {**high_text_num, **high_frequency} 318 | 319 | # high_freq =high_frequency 320 | 321 | print ("initial converted", threshold, len(non_rule_set), len(high_text_num), len(high_frequency), len(low_frequency), 322 | len(low_text_num), len(high_freq)) 323 | 324 | 325 | def initial_dict(filename, with_unk=False): 326 | d = Dict(filename) 327 | d.addSpecial(NULL_WORD) 328 | if with_unk: 329 | d.addSpecial(UNK_WORD) 330 | # d.addSpecial(BOS_WORD) 331 | return d 332 | 333 | 334 | if not opt.skip: 335 | word_dict = initial_dict("data/word_dict", with_unk=True) 336 | pos_dict = initial_dict("data/pos_dict", with_unk=True) 337 | 338 | ner_dict = initial_dict("data/ner_dict", with_unk=True) # from stanford 339 | 340 | high_dict = initial_dict("data/high_dict", with_unk=True) 341 | 342 | lemma_dict = initial_dict("data/lemma_dict", with_unk=True) 343 | 344 | aux_dict = initial_dict("data/aux_dict", with_unk=True) 345 | 346 | rel_dict = initial_dict("data/rel_dict", with_unk=True) 347 | 348 | category_dict = initial_dict("data/category_dict", with_unk=True) 349 | sensed_dict = initial_dict("data/sensed_dict", with_unk=True) 350 | 351 | # print ("high freq") 352 | for uni in high_freq: 353 | le = uni.le 354 | lemma_dict.add(le) 355 | high_dict.add(le) 356 | # print (le,high_freq[uni][0]) 357 | 358 | for filepath in trainingFilesPath: 359 | print(("reading " + filepath.split("/")[-1] + "......")) 360 | n = readFile(filepath, build_dict=True) 361 | print(("done reading " + filepath.split("/")[-1] + ", " + str(n) + " sentences processed")) 362 | 363 | # only to allow fixed word embedding to be used for those data, alternatively we can build a huge word_embedding for all words from GLOVE... 364 | for filepath in devFilesPath: 365 | print(("reading " + filepath.split("/")[-1] + "......")) 366 | n = readFile(filepath, build_dict=True, word_only=True) 367 | print(("done reading " + filepath.split("/")[-1] + ", " + str(n) + " sentences processed")) 368 | 369 | for filepath in testFilesPath: 370 | print(("reading " + filepath.split("/")[-1] + "......")) 371 | n = readFile(filepath, build_dict=True, word_only=True) 372 | print(("done reading " + filepath.split("/")[-1] + ", " + str(n) + " sentences processed")) 373 | 374 | print ("len(aux_dict),len(rel_dict),threshold", len(aux_dict), len(rel_dict), threshold) 375 | 376 | rel_dict = rel_dict.pruneByThreshold(threshold) 377 | aux_dict = aux_dict.pruneByThreshold(threshold) 378 | category_dict = category_dict.pruneByThreshold(threshold) 379 | # print (rel_dict) 380 | word_dict.save() 381 | lemma_dict.save() 382 | pos_dict.save() 383 | aux_dict.save() 384 | ner_dict.save() 385 | high_dict.save() 386 | category_dict.save() 387 | rel_dict.save() 388 | sensed_dict.save() 389 | else: 390 | 391 | word_dict = Dict("data/word_dict") 392 | lemma_dict = Dict("data/lemma_dict") 393 | aux_dict = Dict("data/aux_dict") 394 | high_dict = Dict("data/high_dict") 395 | pos_dict = Dict("data/pos_dict") 396 | ner_dict = Dict("data/ner_dict") 397 | rel_dict = Dict("data/rel_dict") 398 | category_dict = Dict("data/category_dict") 399 | sensed_dict = Dict("data/sensed_dict") 400 | 401 | word_dict.load() 402 | lemma_dict.load() 403 | pos_dict.load() 404 | ner_dict.load() 405 | rel_dict.load() 406 | category_dict.load() 407 | high_dict.load() 408 | aux_dict.load() 409 | sensed_dict.save() 410 | 411 | fragment_to_node_converter = ReCategorizor(from_file=True, path="data/graph_to_node_dict_extended" + with_jamr, 412 | training=False, ner_cat_dict=aux_dict) 413 | print("dictionary building done") 414 | print("word_dict \t lemma_dict \tpos_dict \tner_dict \thigh_dict\tsensed_dict \tcategory_dict \taux_dict\trel_dict") 415 | print( 416 | len(word_dict), len(lemma_dict), len(pos_dict), len(ner_dict), len(high_dict), len(sensed_dict), len(category_dict), 417 | len(aux_dict), len(rel_dict)) 418 | 419 | print(("processing development set")) 420 | for filepath in devFilesPath: 421 | print(("reading " + filepath.split("/")[-1] + "......")) 422 | n = readFile(filepath, build_dict=False) 423 | print(("done reading " + filepath.split("/")[-1] + ", " + str(n) + " sentences processed")) 424 | 425 | print(("processing test set")) 426 | for filepath in testFilesPath: 427 | print(("reading " + filepath.split("/")[-1] + "......")) 428 | n = readFile(filepath, build_dict=False) 429 | 430 | print(("processing training set")) 431 | for filepath in trainingFilesPath: 432 | print(("reading " + filepath.split("/")[-1] + "......")) 433 | n = readFile(filepath, build_dict=False) 434 | print(("done reading " + filepath.split("/")[-1] + ", " + str(n) + " sentences processed")) 435 | 436 | print ("initial converted,threshold,len(non_rule_set),high_text_num,high_frequency,low_frequency,low_text_num") 437 | print ("initial converted", threshold, len(non_rule_set), len(high_text_num), len(high_frequency), len(low_frequency), 438 | len(low_text_num)) 439 | 440 | -------------------------------------------------------------------------------- /src/generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Scripts to run the model over preprocessed data to generate evaluatable results 6 | 7 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 8 | @since: 2018-05-30 9 | ''' 10 | 11 | from parser.DataIterator import DataIterator,rel_to_batch 12 | import parser 13 | import torch 14 | from torch import cuda 15 | from utility.Naive_Scores import * 16 | from parser.AMRProcessors import graph_to_amr 17 | from utility.data_helper import folder_to_files_path 18 | 19 | from src.train import read_dicts,load_old_model,train_parser 20 | 21 | def generate_parser(): 22 | parser = train_parser() 23 | parser.add_argument('-output', default="_generate") 24 | parser.add_argument('-with_graphs', type=int,default=1) 25 | return parser 26 | 27 | 28 | 29 | def generate_graph(model,AmrDecoder, data_set,dicts,file): 30 | 31 | concept_scores = concept_score_initial(dicts) 32 | 33 | rel_scores = rel_scores_initial() 34 | 35 | model.eval() 36 | AmrDecoder.eval() 37 | output = [] 38 | gold_file = [] 39 | for batchIdx in range(len(data_set)): 40 | order,srcBatch,_,_,_,_,_,gold_roots,sourceBatch =data_set[batchIdx] 41 | 42 | probBatch = model(srcBatch ) 43 | 44 | 45 | 46 | amr_pred_seq,concept_batches,aligns_raw,dependent_mark_batch = AmrDecoder.probAndSourceToAmr(sourceBatch,srcBatch,probBatch,getsense = opt.get_sense ) 47 | 48 | amr_pred_seq = [ [(uni.cat,uni.le,uni.aux,uni.sense,uni) for uni in seq ] for seq in amr_pred_seq ] 49 | 50 | 51 | rel_batch,aligns = rel_to_batch(concept_batches,aligns_raw,data_set,dicts) 52 | rel_prob,roots = model((rel_batch,srcBatch,aligns),rel=True) 53 | graphs,rel_triples = AmrDecoder.relProbAndConToGraph(concept_batches,rel_prob,roots,(dependent_mark_batch,aligns_raw),opt.get_sense,opt.get_wiki) 54 | batch_out = [0]*len(graphs) 55 | for score_h in rel_scores: 56 | if score_h.second_filter: 57 | t,p,tp = score_h.T_P_TP_Batch(rel_triples,list(zip(*sourceBatch))[5],second_filter_material = (concept_batches,list(zip(*sourceBatch))[4])) 58 | else: 59 | t,p,tp = score_h.T_P_TP_Batch(rel_triples,list(zip(*sourceBatch))[5]) 60 | for score_h in concept_scores: 61 | t,p,tp = score_h.T_P_TP_Batch(concept_batches,list(zip(*sourceBatch))[4]) 62 | for i,data in enumerate(zip( sourceBatch,amr_pred_seq,concept_batches,rel_triples,graphs)): 63 | source,amr_pred,concept, rel_triple,graph= data 64 | predicated_graph = graph_to_amr(graph) 65 | 66 | out = [] 67 | out.append( "# ::tok "+" ".join(source[0])+"\n") 68 | out.append( "# ::lem "+" ".join(source[1])+"\n") 69 | out.append( "# ::pos "+" ".join(source[2])+"\n") 70 | out.append( "# ::ner "+" ".join(source[3])+"\n") 71 | out.append( "# ::predicated "+" ".join([str(re_cat[-1]) for re_cat in amr_pred])+"\n") 72 | out.append( "# ::transformed final predication "+" ".join([str(c) for c in concept])+"\n") 73 | out.append( AmrDecoder.nodes_jamr(graph)) 74 | out.append( AmrDecoder.edges_jamr(graph)) 75 | out.append( predicated_graph) 76 | batch_out[order[i]] = "".join(out)+"\n" 77 | output += batch_out 78 | t_p_tp = list(map(lambda a,b:a+b, concept_scores[1].t_p_tp,rel_scores[1].t_p_tp)) 79 | total_out = "Smatch"+"\nT,P,TP: "+ " ".join([str(i) for i in t_p_tp])+"\nPrecesion,Recall,F1: "+ " ".join([str(i)for i in P_R_F1(*t_p_tp)]) 80 | print(total_out) 81 | for score_h in rel_scores: 82 | print("") 83 | print(score_h) 84 | file = file.replace(".pickle",".txt") 85 | with open(file+ opt.output, 'w+') as the_file: 86 | for data in output: 87 | the_file.write(data+'\n') 88 | print(file+ opt.output+" written.") 89 | return concept_scores,rel_scores,output 90 | 91 | 92 | def main(opt): 93 | dicts = read_dicts() 94 | assert opt.train_from 95 | with_jamr = "_with_jamr" if opt.jamr else "_without_jamr" 96 | suffix = ".pickle"+with_jamr+"_processed" 97 | trainFolderPath = opt.folder+"/training/" 98 | trainingFilesPath = folder_to_files_path(trainFolderPath,suffix) 99 | 100 | devFolderPath = opt.folder+"/dev/" 101 | devFilesPath = folder_to_files_path(devFolderPath,suffix) 102 | 103 | testFolderPath = opt.folder+"/test/" 104 | testFilesPath = folder_to_files_path(testFolderPath,suffix) 105 | 106 | 107 | 108 | AmrDecoder = parser.AMRProcessors.AMRDecoder(opt,dicts) 109 | AmrDecoder.eval() 110 | AmrModel,parameters,optt = load_old_model(dicts,opt,True) 111 | opt.start_epoch = 1 112 | 113 | out = "/".join(testFilesPath[0].split("/")[:-2])+ "/model" 114 | with open(out, 'w') as outfile: 115 | outfile.write(opt.train_from+"\n") 116 | outfile.write(str(AmrModel)+"\n") 117 | outfile.write(str(optt)+"\n") 118 | outfile.write(str(opt)) 119 | 120 | print('processing testing') 121 | for file in testFilesPath: 122 | dev_data = DataIterator([file],opt,dicts["rel_dict"],volatile = True) 123 | concept_scores,rel_scores,output =generate_graph(AmrModel,AmrDecoder,dev_data,dicts,file) 124 | 125 | print('processing validation') 126 | for file in devFilesPath: 127 | dev_data = DataIterator([file],opt,dicts["rel_dict"],volatile = True) 128 | concept_scores,rel_scores,output =generate_graph(AmrModel,AmrDecoder,dev_data,dicts,file) 129 | 130 | 131 | 132 | print('processing training') 133 | for file in trainingFilesPath: 134 | dev_data = DataIterator([file],opt,dicts["rel_dict"],volatile = True) 135 | concept_scores,rel_scores,output =generate_graph(AmrModel,AmrDecoder,dev_data,dicts,file) 136 | 137 | 138 | if __name__ == "__main__": 139 | print (" ") 140 | print (" ") 141 | global opt 142 | opt = generate_parser().parse_args() 143 | opt.lemma_dim = opt.dim 144 | opt.high_dim = opt.dim 145 | 146 | opt.cuda = len(opt.gpus) 147 | 148 | print(opt) 149 | 150 | if torch.cuda.is_available() and not opt.cuda: 151 | print("WARNING: You have a CUDA device, so you should probably run with -cuda") 152 | 153 | if opt.cuda and opt.gpus[0] != -1: 154 | cuda.set_device(opt.gpus[0]) 155 | main(opt) -------------------------------------------------------------------------------- /src/parse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | 6 | data["ner"] = [] 7 | data["tok"] = [] 8 | data["lem"] = [] 9 | data["pos"] = [] 10 | for snt_tok in snt: 11 | data["ner"].append(snt_tok['ner']) 12 | data["tok"].append(snt_tok['word']) 13 | data["lem"].append(snt_tok['lemma']) 14 | data["pos"].append(snt_tok['pos']) 15 | data["ner"].append(snt_tok['ner']) 16 | data["tok"].append(snt_tok['word']) 17 | data["lem"].append(snt_tok['lemma']) 18 | data["pos"].append(snt_tok['pos']) 19 | 20 | Scripts to run the model to parse a file. Input file should contain each sentence per line 21 | A file containing output will be generated at the same folder unless output is specified. 22 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 23 | @since: 2018-05-30 24 | ''' 25 | 26 | from torch import cuda 27 | from parser.AMRProcessors import * 28 | from src.train import read_dicts,train_parser 29 | 30 | def generate_parser(): 31 | parser = train_parser() 32 | parser.add_argument('-output', default=None) 33 | parser.add_argument('-with_graphs', type=int,default=1) 34 | parser.add_argument("-input",default=None,type=str, 35 | help="""input file path""") 36 | parser.add_argument("-text",default=None,type=str, 37 | help="""a single sentence to parse""") 38 | parser.add_argument("-processed",default=0,type=int, 39 | help="""a single sentence to parse""") 40 | return parser 41 | 42 | if __name__ == "__main__": 43 | global opt 44 | opt = generate_parser().parse_args() 45 | opt.lemma_dim = opt.dim 46 | opt.high_dim = opt.dim 47 | 48 | opt.cuda = len(opt.gpus) 49 | 50 | if opt.cuda and opt.gpus[0] != -1: 51 | cuda.set_device(opt.gpus[0]) 52 | dicts = read_dicts() 53 | processed = opt.processed==1 54 | Parser = AMRParser(opt,dicts,parse_from_processed= processed) 55 | 56 | if opt.input: 57 | 58 | filepath = opt.input 59 | out = opt.output if opt.output else filepath+"_parsed" 60 | print ("processing "+filepath) 61 | n = 0 62 | processed_sentences = 0 63 | with open(out,'w') as out_f: 64 | with open(filepath,'r') as f: 65 | line = f.readline() 66 | batch = [] 67 | while line and line.strip() != "": 68 | while line and line.strip() != "" and len(batch) < opt.batch_size: 69 | batch.append(line.strip()) 70 | line = f.readline() 71 | 72 | output = Parser.parse_batch(batch) 73 | for snt, others in zip(batch,output): 74 | out_f.write("# ::snt "+snt+"\n") 75 | out_f.write(others) 76 | out_f.write("\n") 77 | processed_sentences = processed_sentences + len(batch) 78 | print ("processed_sentences" , processed_sentences) 79 | batch = [] 80 | print ("done processing "+filepath) 81 | print (out +" is generated") 82 | 83 | elif opt.input: 84 | filepath = opt.input 85 | out = opt.output if opt.output else filepath+"_parsed" 86 | print ("processing "+filepath) 87 | n = 0 88 | with open(out,'w') as out_f: 89 | with open(filepath,'r') as f: 90 | line = f.readline() 91 | while line != '' : 92 | if line.strip() != "": 93 | output = Parser.parse_batch([line.strip()]) 94 | out_f.write("# ::snt "+line) 95 | out_f.write(output[0]) 96 | out_f.write("\n") 97 | line = f.readline() 98 | print ("done processing "+filepath) 99 | print (out +" is generated") 100 | elif opt.text: 101 | output = Parser.parse_one(opt.text) 102 | print ("# ::snt "+opt.text) 103 | for i in output: 104 | print (i) 105 | else: 106 | print ("option -input [file] or -text [sentence] is required.") -------------------------------------------------------------------------------- /src/preprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Combine multiple AMR data files in the same directory into a single one 6 | Need to specify folder containing all subfolders of training, dev and test 7 | 8 | Then extract features for futher process based on stanford core nlp tools 9 | 10 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 11 | @since: 2018-06-01 12 | ''' 13 | 14 | from parser.AMRProcessors import * 15 | 16 | import argparse 17 | 18 | 19 | def combine_files(files): 20 | out = "/".join(files[0].split("/")[:-1]) 21 | out = out + "/combined.txt_" 22 | with open(out, 'w+') as outfile: 23 | for fname in files: 24 | with open(fname) as infile: 25 | line = infile.readline() 26 | line = infile.readline() 27 | while line != '' : 28 | line = infile.readline() 29 | outfile.write(line) 30 | outfile.write("\n") 31 | 32 | def write_features(filepath,feature_extractor:AMRInputPreprocessor): 33 | out = filepath + "pre_processed" 34 | print ("processing "+filepath) 35 | n = 0 36 | with open(out,'w') as out_f: 37 | with open(filepath,'r') as f: 38 | line = f.readline() 39 | while line != '' : 40 | if line.startswith("# ::snt") or line.startswith("# ::tok"): 41 | text = line[7:] 42 | data = feature_extractor.preprocess(text) 43 | out_f.write(line.replace("# ::tok","# ::snt")) 44 | for key in ["tok","lem","pos","ner"]: 45 | out_f.write("# ::"+key+"\t"+"\t".join(data[key])+"\n") 46 | n = n+1 47 | if n % 500 ==0: 48 | print (str(n)+" sentences processed") 49 | elif not line.startswith("# AMR release; "): 50 | out_f.write(line) 51 | line = f.readline() 52 | print ("done processing "+filepath) 53 | print (out +" is generated") 54 | 55 | def combine_arg(): 56 | parser = argparse.ArgumentParser(description='preprocessing.py') 57 | 58 | ## Data options 59 | parser.add_argument('-suffix', default="txt", type=str, 60 | help="""suffix of files to combine""") 61 | parser.add_argument('-folder', default=allFolderPath, type=str , 62 | help="""the folder""") 63 | return parser 64 | 65 | 66 | parser = combine_arg() 67 | 68 | 69 | opt = parser.parse_args() 70 | feature_extractor = AMRInputPreprocessor() 71 | 72 | trainFolderPath = opt.folder+"/training/" 73 | trainingFilesPath = folder_to_files_path(trainFolderPath,opt.suffix) 74 | combine_files(trainingFilesPath) 75 | write_features(trainFolderPath+"/combined.txt_",feature_extractor) 76 | 77 | devFolderPath = opt.folder+"/dev/" 78 | devFilesPath = folder_to_files_path(devFolderPath,opt.suffix) 79 | combine_files(devFilesPath) 80 | write_features(devFolderPath+"/combined.txt_",feature_extractor) 81 | 82 | testFolderPath = opt.folder+"/test/" 83 | testFilesPath = folder_to_files_path(testFolderPath,opt.suffix) 84 | combine_files(testFilesPath) 85 | write_features(testFolderPath+"/combined.txt_",feature_extractor) 86 | 87 | -------------------------------------------------------------------------------- /src/rule_system_build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Scripts to build StringCopyRules and ReCategorizor 6 | 7 | Data path information should also be specified here for 8 | trainFolderPath, devFolderPath and testFolderPath 9 | as we allow option to choose from two version of data. 10 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 11 | @since: 2018-05-30 12 | ''' 13 | 14 | from utility.StringCopyRules import * 15 | from utility.ReCategorization import * 16 | from utility.data_helper import * 17 | 18 | 19 | import argparse 20 | def arg_parser(): 21 | parser = argparse.ArgumentParser(description='rule_system_build.py') 22 | 23 | ## Data options 24 | parser.add_argument('-threshold', default=5, type=int, 25 | help="""threshold for non-aligned high frequency concepts""") 26 | 27 | parser.add_argument('-jamr', default=0, type=int, 28 | help="""wheather to enhance string matching with additional jamr alignment""") 29 | parser.add_argument('-suffix', default=".txt_pre_processed", type=str, 30 | help="""suffix of files to combine""") 31 | parser.add_argument('-folder', default=allFolderPath, type=str , 32 | help="""the folder""") 33 | return parser 34 | parser = arg_parser() 35 | opt = parser.parse_args() 36 | threshold = opt.threshold 37 | suffix = opt.suffix + "_jamr" if opt.jamr else opt.suffix 38 | with_jamr = "_with_jamr" if opt.jamr else "_without_jamr" 39 | trainFolderPath = opt.folder+"/training/" 40 | trainingFilesPath = folder_to_files_path(trainFolderPath,suffix) 41 | 42 | devFolderPath = opt.folder+"/dev/" 43 | devFilesPath = folder_to_files_path(devFolderPath,suffix) 44 | 45 | testFolderPath = opt.folder+"/test/" 46 | testFilesPath = folder_to_files_path(testFolderPath,suffix) 47 | 48 | 49 | lock = threading.Lock() 50 | def add_count(store,new,additional=None): 51 | lock.acquire() 52 | 53 | for i in new: 54 | if not i in store: 55 | store[i] = [1,[additional]] 56 | else: 57 | store[i][0] = store[i][0] + 1 58 | store[i][1].append(additional) 59 | lock.release() 60 | 61 | def handle_sentence(data,n,update_freq,use_template,jamr = False): 62 | 63 | if n % 500 == 0: 64 | print (n) 65 | snt_token = data["tok"] 66 | pos_token = data["pos"] 67 | lemma_token = data["lem"] 68 | amr_t = data["amr_t"] 69 | aligns = data["align"] 70 | v2c = data["node"] 71 | amr = AMRGraph(amr_t,aligns=aligns) 72 | amr.check_consistency(v2c) 73 | lemma_str =" ".join(lemma_token) 74 | if use_template: 75 | fragment_to_node_converter.match(amr,rl ,snt_token,lemma_token,pos_token,lemma_str,jamr=jamr ) 76 | fragment_to_node_converter.convert(amr,rl ,snt_token,lemma_token,pos_token,lemma_str ) 77 | results = rl.get_matched_concepts(snt_token,amr,lemma_token,pos_token,with_target=update_freq,jamr=jamr) 78 | if update_freq: 79 | for n_c_a in results : 80 | for i_le in n_c_a[2]: 81 | rl.add_lemma_freq(i_le[1],n_c_a[1].le,n_c_a[1].cat,sense = n_c_a[1].sense) 82 | 83 | snt_str = " ".join(snt_token) 84 | none_rule = [n_c_a[1] for n_c_a in results if len(n_c_a[2])==0] 85 | add_count(non_rule_set,none_rule,snt_str) 86 | 87 | 88 | def readFile(filepath,update_freq=False,use_template=True): 89 | all_data = load_text_jamr(filepath) 90 | 91 | with open(filepath.replace(".txt",".tok"),'w') as output_file: 92 | n = 0 93 | for data in all_data: 94 | n=n+1 95 | snt_token = data["tok"] 96 | output_file.writelines("\t".join(snt_token)) 97 | if opt.jamr: 98 | handle_sentence(data,n,update_freq,use_template,jamr=True) 99 | else: 100 | handle_sentence(data,n,update_freq,use_template,jamr=False) 101 | return n 102 | 103 | 104 | 105 | rl = rules() 106 | non_rule_set = dict() 107 | fragment_to_node_converter = ReCategorizor(training=True) 108 | # 109 | non_rule_set_last = non_rule_set 110 | rl.build_lemma_cheat() 111 | # 112 | non_rule_set = dict() 113 | #lemmas_to_concept = read_resource_files( f_r.get_frames()) 114 | for filepath in trainingFilesPath: #actually already combined into one 115 | print(("reading "+filepath.split("/")[-1]+"......")) 116 | n = readFile(filepath,update_freq=True,use_template = True) 117 | print(("done reading "+filepath.split("/")[-1]+", "+str(n)+" sentences processed")) 118 | #non_rule_set = non_rule_set_last 119 | high_text_num,high_frequency,low_frequency,low_text_num=unmixe(non_rule_set,threshold ) 120 | print ("initial converted,threshold,len(non_rule_set),high_text_num,high_frequency,low_frequency,low_text_num") 121 | print ("initial converted",threshold,len(non_rule_set),len(high_text_num),len(high_frequency),len(low_frequency),len(low_text_num)) 122 | #print (len(concept_embedding)) 123 | # 124 | # 125 | # 126 | non_rule_set_initial_converted = non_rule_set 127 | rl.build_lemma_cheat() 128 | fragment_to_node_converter.save(path="data/graph_to_node_dict_extended"+with_jamr) 129 | fragment_to_node_converter = ReCategorizor(from_file=False, path="data/graph_to_node_dict_extended"+with_jamr,training=False) 130 | rl.save("data/rule_f"+with_jamr) 131 | non_rule_set = dict() 132 | NERS = {} 133 | 134 | #need to rebuild copying dictionary again based on recategorized graph 135 | for filepath in trainingFilesPath: 136 | print(("reading "+filepath.split("/")[-1]+"......")) 137 | n = readFile(filepath,update_freq=False,use_template=False) 138 | print(("done reading "+filepath.split("/")[-1]+", "+str(n)+" sentences processed")) 139 | 140 | non_rule_set_f = Pickle_Helper("data/non_rule_set") 141 | non_rule_set_f.dump(non_rule_set,"non_rule_set") 142 | non_rule_set_f.save() 143 | 144 | 145 | 146 | #only intermediate data, won't be useful for final parser 147 | non_rule_set_f = Pickle_Helper("data/non_rule_set") 148 | non_rule_set_f.dump(non_rule_set_last,"initial_non_rule_set") 149 | non_rule_set_f.dump(non_rule_set_initial_converted,"initial_converted_non_rule_set") 150 | non_rule_set_f.dump(non_rule_set,"non_rule_set") 151 | non_rule_set_f.save() 152 | 153 | high_text_num,high_frequency,low_frequency,low_text_num=unmixe(non_rule_set,threshold ) 154 | print ("final converted,threshold,len(non_rule_set),high_text_num,high_frequency,low_frequency,low_text_num") 155 | print ("final converted",threshold,len(non_rule_set),len(high_text_num),len(high_frequency),len(low_frequency),len(low_text_num)) -------------------------------------------------------------------------------- /utility/AMRGraph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | AMRGraph builds on top of AMR from amr.py 6 | representing AMR graph as graph, 7 | and extract named entity (t1,..,tn, ner type, wiki) tuple. (we use model predicting for deciding ner type though) 8 | Being able to apply recategorization to original graph, 9 | which involves collapsing nodes for concept identification and unpacking for relation identification. 10 | 11 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 12 | @since: 2018-05-28 13 | ''' 14 | from utility.amr import * 15 | from utility.constants import * 16 | import networkx as nx 17 | 18 | class AMRGraph(AMR): 19 | def __init__(self, anno, normalize_inverses=True, 20 | normalize_mod=False, tokens=None,aligns={}): 21 | ''' 22 | create AMR from text, and convert AMR to AMRGraph of standard representation 23 | ''' 24 | super().__init__(anno, tokens) 25 | self.ners = [] 26 | self.gold_concept = [] 27 | self.gold_triple = [] 28 | self.graph = nx.DiGraph() 29 | self.wikis = [] 30 | for h, r, d in [(h, r, d) for h, r, d in self.triples(normalize_inverses=normalize_inverses, 31 | normalize_mod=normalize_mod) if 32 | (r != ":instance" )]: 33 | if r == ':wiki': 34 | h, h_v = self.var_get_uni(h, True,(h, r, d )) 35 | d, d_v = self.var_get_uni(d) 36 | self.wikis.append(d) 37 | self.ners.append((h,d_v)) 38 | continue 39 | elif r == ':top': 40 | d, d_v = self.var_get_uni(d) 41 | self.root = d 42 | self.graph.add_node(d, value=d_v, align=None,gold=True,prefix = self._index_inv[d]) 43 | else: 44 | h_prefix = self._index_inv[h] 45 | d_prefix = self._index_inv[d] if d in self._index_inv else d #d will be the prefix if it is constant 46 | assert isinstance(h_prefix,str) and isinstance(d_prefix,str) 47 | h, h_v = self.var_get_uni(h, True,(h, r, d )) 48 | d, d_v = self.var_get_uni(d) 49 | self.graph.add_node(h, value=h_v, align=None,gold=True,prefix = h_prefix) 50 | self.graph.add_node(d, value=d_v, align=None,gold=True,prefix = d_prefix) 51 | self.graph.add_edge(h, d, role=r) 52 | self.graph.add_edge(d, h, role=r + "-of") 53 | 54 | # for i in self._triples: 55 | # print(i) 56 | self.read_align(aligns) 57 | 58 | #alignment from copying mechanism 59 | def read_align(self, aligns): 60 | for prefix in aligns: 61 | i = self._index[prefix] 62 | if isinstance(i,Var): 63 | assert i in self.graph.node,(self.graph.nodes(True),self.triples(normalize_inverses=True, 64 | normalize_mod=False),self._anno) 65 | self.graph.node[i]["align"] = aligns[prefix] 66 | else: 67 | if Var(prefix) in self.wikis: continue 68 | assert Var(prefix) in self.graph.node,(prefix,aligns,self._index,self.graph.nodes(True),self._anno) 69 | self.graph.node[Var(prefix)]["align"] = aligns[prefix] 70 | 71 | 72 | def check_consistency(self,pre2c): 73 | for prefix in pre2c: 74 | var = self._index[prefix] 75 | if not isinstance(var,Var): var = Var(prefix) 76 | if var in self.wikis: continue 77 | assert var in self.graph.node,(prefix, "\n",pre2c,"\n",self.graph.node,"\n",self._anno) 78 | amr_c = self.graph.node[var]["value"] 79 | 80 | assert amr_c.gold_str() == pre2c[prefix],(prefix, var,amr_c.gold_str() ,pre2c[prefix],"\n",pre2c,"\n",self.graph.nodes(True)) 81 | 82 | def get_gold(self): 83 | cons = [] 84 | roles = [] 85 | for n, d in self.graph.nodes(True): 86 | if "gold" in d: 87 | v = d["value"] 88 | cons.append(v) 89 | 90 | for h, d, rel in self.graph.edges(data=True): 91 | r = rel["role"] 92 | if self.cannonical(r): 93 | assert "gold" in self.graph.node[h] and "gold" in self.graph.node[d] 94 | h = self.graph.node[h]["value"] 95 | d = self.graph.node[d]["value"] 96 | roles.append([h,d,r]) 97 | 98 | root = self.graph.node[self.root]["value"] 99 | roles.append([AMRUniversal(BOS_WORD,BOS_WORD,NULL_WORD),root,':top']) 100 | return cons,roles 101 | 102 | def get_ners(self): 103 | ners = [] 104 | for v,wiki in self.ners: #v is name variable 105 | name = None 106 | names = [] 107 | for nearb in self.graph[v]: 108 | if self.graph[v][nearb]["role"] == ":name": 109 | name = nearb 110 | break 111 | if name is None: 112 | print (self.graph[v],self._anno) 113 | continue 114 | ner_type = self.graph.node[v]["value"] 115 | for node in self.graph[name]: 116 | if self.graph.node[node]["value"].cat == Rule_String and ":op" in self.graph[name][node]["role"]: 117 | names.append(( self.graph.node[node]["value"],int(self.graph[name][node]["role"][-1]))) # (role, con,node) 118 | 119 | names = [t[0] for t in sorted(names,key = lambda t: t[1])] 120 | ners.append([names,wiki,ner_type]) 121 | return ners 122 | 123 | 124 | 125 | def rely(self,o_node,n_node): 126 | if "rely" in self.graph.node[o_node]: 127 | return 128 | self.graph.node[o_node].setdefault("rely",n_node) 129 | 130 | #link old node to new node 131 | def link(self,o_node,n_node,rel): 132 | self.graph.node[o_node].setdefault("original-of",[]).append( n_node ) # for storing order of replacement 133 | if n_node: 134 | self.graph.node[n_node]["has-original"] = o_node # for storing order of replacement 135 | self.graph.node[n_node]["align"] = self.graph.node[o_node]["align"] 136 | if rel: self.rely(o_node,n_node) 137 | 138 | def replace(self,node,cat_or_uni,aux=None,rel=False): 139 | 140 | aux_le = self.graph.node[aux]['value'].le if aux else None 141 | 142 | if isinstance(cat_or_uni,AMRUniversal): 143 | universal = cat_or_uni 144 | else: 145 | le = self.graph.node[node]['value'].le 146 | universal = AMRUniversal(le, cat_or_uni, None, aux_le) #aux_le is usually named entity type 147 | # create a new recategorized node 148 | # gold is not marked, so new recategorized node won't be used for relation identification 149 | var = Var(node._name+"_"+universal.cat) 150 | self.graph.add_node(var, value=universal, align=None) 151 | self.link(node,var,rel) 152 | 153 | return var 154 | 155 | 156 | #get a amr universal node from a variable in AMR or a constant in AMR 157 | def var_get_uni(self, a, head=False,tri=None): 158 | if isinstance(a,Var): 159 | return a, AMRUniversal(concept=self._v2c[a]) 160 | else: 161 | if head: 162 | assert False, "constant as head" + "\n" + a + self._anno+"\n"+str(tri) 163 | return Var(a), AMRUniversal(concept=self._index[a]) 164 | 165 | 166 | 167 | def __getitem__(self, item): 168 | return self.graph.node[item] 169 | 170 | #check whether the relation is in the cannonical direction 171 | def cannonical(self,r): 172 | return "-of" in r and not self.is_core(r) or "-of" not in r and self.is_core(r) 173 | 174 | def getRoles(self,node,index_dict,rel_index,relyed = None): 175 | # (amr,index,[[role,rel_index]]) 176 | if relyed and relyed not in index_dict: 177 | print ("rely",node,relyed,self.graph.node[relyed]["value"],index_dict,self._anno) 178 | elif relyed is None and node not in index_dict: print (self.graph.node[node]["value"]) 179 | index = index_dict[node] if relyed is None else index_dict[relyed] 180 | out = [] 181 | # if self.graph.node[node]["value"].le != "name": 182 | for n2 in self.graph[node]: 183 | r = self.graph[node][n2]["role"] 184 | if self.cannonical(r): 185 | if n2 not in rel_index: 186 | print(node,n2) 187 | print(self._anno) 188 | out.append([r,rel_index[n2]]) 189 | return [[self.graph.node[node]["value"],index], out] 190 | 191 | #return data for training concept identification or relation identification 192 | def node_value(self, keys=["value"], all=False): 193 | def concept_concept(): 194 | out = [] 195 | index = 0 196 | index_dict ={} 197 | for n, d in self.graph.nodes(True): 198 | if "original-of"in d: 199 | comps = d["original-of"] 200 | for comp in comps: 201 | if comp is None: 202 | continue 203 | comp_d = self.graph.node[comp] 204 | out.append([comp] + [comp_d[k] for k in keys]) 205 | index_dict[comp] = index 206 | index += 1 207 | elif not ("has-original" in d or "rely" in d): 208 | out.append([n] + [d[k] for k in keys]) 209 | index_dict[n] = index 210 | index += 1 211 | return out,index_dict 212 | def rel_concept(): 213 | index = 0 214 | rel_index ={} 215 | rel_prefix = [] 216 | rel_out = [] 217 | for n, d in self.graph.nodes(True): 218 | if "gold" in d: 219 | rel_out.append([n,d]) 220 | rel_index[n] = index 221 | rel_prefix.append( d["prefix"]) 222 | index += 1 223 | return rel_out,rel_index,rel_prefix 224 | 225 | out,index_dict = concept_concept() 226 | if all: 227 | rel_out, rel_index, rel_prefix = rel_concept() 228 | for i, n_d in enumerate( rel_out): 229 | n,d = n_d 230 | if "rely" in d: 231 | rel_out[i] =self.getRoles(n,index_dict,rel_index,d["rely"]) 232 | elif not ("has-original" in d or "original-of" in d): 233 | rel_out[i] = self.getRoles(n,index_dict,rel_index) 234 | else: 235 | assert False , (self._anno,n,d["value"]) 236 | assert (self.root in rel_index),(self.graph.nodes[self.root],rel_index,self._anno) 237 | return out,rel_out,rel_prefix, rel_index[self.root] 238 | else: 239 | return out -------------------------------------------------------------------------------- /utility/Naive_Scores.py: -------------------------------------------------------------------------------- 1 | __author__ = 's1544871' 2 | from utility.constants import * 3 | from utility.amr import AMRUniversal 4 | 5 | class ScoreHelper: 6 | 7 | def __init__(self,name, filter ,second_filter=None): 8 | self.t_p_tp = [0,0,0] 9 | self.name = name 10 | self.f = filter 11 | self.second_filter = second_filter 12 | self.false_positive = {} 13 | self.false_negative = {} 14 | 15 | def T_P_TP_Batch(self,hypos,golds,accumulate=True,second_filter_material =None): 16 | if self.second_filter: 17 | T,P,TP,fp,fn = T_P_TP_Batch(hypos,golds,self.f,self.second_filter,second_filter_material) 18 | else: 19 | # assert self.name != "Unlabled SRL Triple",(hypos[-20],"STOP!",golds[-20]) 20 | T,P,TP,fp,fn = T_P_TP_Batch(hypos,golds,self.f) 21 | if accumulate: 22 | self.add_t_p_tp(T,P,TP) 23 | self.add_content(fp,fn) 24 | return T,P,TP 25 | 26 | def add_t_p_tp(self,T,P,TP): 27 | self.t_p_tp[0] += T 28 | self.t_p_tp[1] += P 29 | self.t_p_tp[2] += TP 30 | 31 | def add_content(self,fp,fn ): 32 | for i in fp: 33 | self.false_positive[i] = self.false_positive.setdefault(i,0)+1 34 | for i in fn: 35 | self.false_negative[i] = self.false_negative.setdefault(i,0)+1 36 | 37 | def show_error(self,t = 5): 38 | print ("false_positive",[(k,self.false_positive[k]) for k in sorted(self.false_positive,key=self.false_positive.get) if self.false_positive[k]> t]) 39 | print ("") 40 | print ("false_negative",[(k,self.false_negative[k]) for k in sorted(self.false_negative,key=self.false_negative.get) if self.false_negative[k]>t]) 41 | def __str__(self): 42 | s = self.name+"\nT,P,TP: "+ " ".join([str(i) for i in self.t_p_tp])+"\nPrecesion,Recall,F1: "+ " ".join([str(i)for i in P_R_F1(*self.t_p_tp)]) 43 | return s 44 | 45 | 46 | 47 | def filter_mutual(hypo,gold,mutual_filter): 48 | filtered_hypo = [item for sublist in filter_seq(mutual_filter,hypo) for item in sublist] 49 | out_hypo = [] 50 | filtered_gold = [item for sublist in filter_seq(mutual_filter,gold) for item in sublist] 51 | out_gold = [] 52 | 53 | for data in hypo: 54 | d1,d2 = mutual_filter(data) 55 | if d1 in filtered_gold and d2 in filtered_gold: 56 | out_hypo.append(data) 57 | 58 | 59 | for data in gold: 60 | d1,d2 = mutual_filter(data) 61 | if d1 in filtered_hypo and d2 in filtered_hypo: 62 | out_gold.append(data) 63 | 64 | return out_hypo,out_gold 65 | 66 | def list_to_mulset(l): 67 | s = dict() 68 | for i in l: 69 | if isinstance(i,AMRUniversal) and i.le == "i"and i.cat == Rule_Concept : 70 | s[i] = 1 71 | else: 72 | s[i] = s.setdefault(i,0)+1 73 | return s 74 | 75 | def legal_concept(uni): 76 | if isinstance(uni,AMRUniversal): 77 | return (uni.cat,uni.le,uni.sense) if not uni.le in Special and not uni.cat in Special else None 78 | else: 79 | return uni 80 | 81 | def nonsense_concept(uni): 82 | return (uni.cat,uni.le) if not uni.le in Special and not uni.cat in Special else None 83 | 84 | def dynamics_filter(triple,concept_seq): 85 | if triple[0] in concept_seq and triple[1] in concept_seq or BOS_WORD in triple[0]: 86 | return triple[:3] 87 | 88 | # print (triple,concept_seq[0]) 89 | return None 90 | 91 | def filter_seq(filter,seq): 92 | out = [] 93 | for t in seq: 94 | filtered = filter(t) 95 | if filtered and filtered[0] != BOS_WORD and filtered != BOS_WORD: 96 | out.append(filtered) 97 | return out 98 | 99 | def remove_sense(uni): 100 | return (uni.cat,uni.le) 101 | 102 | def T_TP_Seq(hypo,gold,filter,second_filter = None,second_filter_material = None): 103 | gold = filter_seq(filter,gold) 104 | hypo = filter_seq(filter,hypo) 105 | fp = [] 106 | fn = [] 107 | if second_filter: #only for triple given concept 108 | 109 | 110 | second_filter_predicated = filter_seq(legal_concept,second_filter_material[0]) 111 | second_filter_with_material = lambda x: second_filter(x,second_filter_predicated) 112 | gold = filter_seq(second_filter_with_material,gold) 113 | 114 | 115 | second_filter_gold = filter_seq(legal_concept,second_filter_material[1]) 116 | second_filter_with_material = lambda x: second_filter(x,second_filter_gold) 117 | 118 | hypo = filter_seq(second_filter_with_material,hypo) 119 | 120 | if len(gold)>0 and isinstance(gold[0],tuple) and len(gold[0])==3 and False: 121 | print ("") 122 | print ("source based prediction") 123 | for t in hypo: 124 | print (t) 125 | print ("") 126 | print ("source gold seq") 127 | for t in gold: 128 | print (t) 129 | print ("") 130 | TP = 0 131 | T = len(gold) 132 | P = len(hypo) 133 | gold = list_to_mulset(gold) 134 | hypo = list_to_mulset(hypo) 135 | for d_g in gold: 136 | if d_g in hypo : 137 | TP += min(gold[d_g],hypo[d_g]) 138 | fn = fn + [d_g] *min(gold[d_g]-hypo[d_g],0) 139 | else: 140 | fn = fn + [d_g] *gold[d_g] 141 | for d_g in hypo: 142 | if d_g in gold : 143 | fp = fp + [d_g] *min(hypo[d_g]-gold[d_g],0) 144 | else: 145 | fp = fp + [d_g] *hypo[d_g] 146 | return T,P,TP,fp,fn 147 | 148 | def T_P_TP_Batch(hypos,golds,filter=legal_concept,second_filter=None,second_filter_material_batch = None): 149 | TP,T,P = 0,0,0 150 | FP,FN = [],[] 151 | assert hypos, golds 152 | for i in range(len(hypos)): 153 | if second_filter: 154 | t,p,tp,fp,fn = T_TP_Seq(hypos[i],golds[i],filter,second_filter,(second_filter_material_batch[0][i],second_filter_material_batch[1][i])) 155 | else: 156 | t,p,tp,fp,fn = T_TP_Seq(hypos[i],golds[i],filter) 157 | T += t 158 | P +=p 159 | TP += tp 160 | FP += fp 161 | FN += fn 162 | return T,P,TP,FP,FN 163 | 164 | 165 | def P_R_F1(T,P,TP): 166 | if TP == 0: 167 | return 0,0,0 168 | P = TP/P 169 | R = TP/T 170 | F1 = 2.0/(1.0/P+1.0/R) 171 | return P,R,F1 172 | 173 | 174 | 175 | #naive set overlapping for different kinds of relations 176 | def rel_scores_initial(): 177 | 178 | 179 | root_filter = lambda t:(legal_concept(t[0]),legal_concept(t[1]),t[2]) if legal_concept(t[0]) and legal_concept(t[1]) and nonsense_concept(t[0]) == (BOS_WORD,BOS_WORD) else None 180 | 181 | root_score = ScoreHelper("Root",filter=root_filter) 182 | 183 | rel_filter = lambda t:(legal_concept(t[0]),legal_concept(t[1]),t[2]) if legal_concept(t[0]) and legal_concept(t[1]) else None 184 | rel_score = ScoreHelper("REL Triple",filter=rel_filter) 185 | 186 | non_sense_rel_filter = lambda t:(nonsense_concept(t[0]),nonsense_concept(t[1]),t[2]) if legal_concept(t[0]) and legal_concept(t[1]) else None 187 | nonsense_rel_score = ScoreHelper("Nonsense REL Triple",filter=non_sense_rel_filter) 188 | 189 | unlabeled_filter =lambda t:(legal_concept(t[0]),legal_concept(t[1])) if legal_concept(t[0]) and legal_concept(t[1]) else None 190 | 191 | unlabeled_rel_score = ScoreHelper("Unlabeled Rel Triple",filter=unlabeled_filter) 192 | 193 | labeled_rel_score_given_concept = ScoreHelper("REL Triple given concept",filter = rel_filter, second_filter=dynamics_filter) 194 | 195 | 196 | un_srl_filter =lambda t:(legal_concept(t[0]),legal_concept(t[1])) if legal_concept(t[0]) and legal_concept(t[1]) and t[2].startswith(':ARG') else None 197 | 198 | un_frame_score = ScoreHelper("Unlabled SRL Triple",filter=un_srl_filter) 199 | 200 | srl_filter = lambda t:(legal_concept(t[0]),legal_concept(t[1]),t[2]) if legal_concept(t[0]) and legal_concept(t[1]) and t[2].startswith(':ARG') else None 201 | frame_score = ScoreHelper("SRL Triple",filter=srl_filter) 202 | 203 | labeled_srl_score_given_concept = ScoreHelper("SRL Triple given concept",filter = srl_filter, second_filter=dynamics_filter) 204 | 205 | unlabeled_srl_score_given_concept = ScoreHelper("Unlabeled SRL Triple given concept",filter = un_srl_filter, second_filter=dynamics_filter) 206 | 207 | return [nonsense_rel_score,rel_score,root_score,unlabeled_rel_score,labeled_rel_score_given_concept,frame_score,un_frame_score,labeled_srl_score_given_concept,unlabeled_srl_score_given_concept] 208 | 209 | 210 | #naive set overlapping for different kinds of concepts 211 | def concept_score_initial(dicts): 212 | 213 | Non_Sense = ScoreHelper("Non_Sense",filter=nonsense_concept) 214 | concept_score = ScoreHelper("Full Concept",filter=legal_concept) 215 | category_score = ScoreHelper("Category Only",filter=lambda uni:(uni.cat) 216 | if legal_concept(uni) else None) 217 | lemma_score = ScoreHelper("Lemma Only",filter=lambda uni: (uni.le) 218 | if legal_concept(uni) else None) 219 | frame_score = ScoreHelper("Frame Only",filter=lambda uni: (uni.le) 220 | if legal_concept(uni) and uni.cat==Rule_Frame else None) 221 | frame_sense_score = ScoreHelper("Frame Sensed Only",filter=lambda uni: (uni.le,uni.sense) 222 | if legal_concept(uni) and uni.cat==Rule_Frame else None) 223 | frame_non_91_score = ScoreHelper("Frame non 91 Only",filter=lambda uni: (uni.le,uni.sense) 224 | if legal_concept(uni) and uni.cat==Rule_Frame and "91" not in uni.sense else None) 225 | high_score = ScoreHelper("High Freq Only",filter=lambda uni: (uni.le,uni.cat) 226 | if uni.le in dicts["high_dict"] and legal_concept(uni) else None) 227 | default_score = ScoreHelper("Copy Only",filter=lambda uni: (uni.le,uni.cat) 228 | if uni.le not in dicts["high_dict"] and legal_concept(uni) else None) 229 | return [Non_Sense,concept_score,category_score,frame_score,frame_sense_score,frame_non_91_score,lemma_score,high_score,default_score] 230 | -------------------------------------------------------------------------------- /utility/PropbankReader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | This reader reads all amr propbank file, 6 | and add possible cannonical amr lemma 7 | to the corresponding copying dictionary of a word and aliases of the word 8 | 9 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 10 | @since: 2018-05-28 11 | ''' 12 | 13 | import xml.etree.ElementTree as ET 14 | from nltk.stem import WordNetLemmatizer 15 | from utility.amr import * 16 | from utility.data_helper import folder_to_files_path 17 | wordnet_lemmatizer = WordNetLemmatizer() 18 | 19 | def add_concept(lemmas_to_concept,le,con): 20 | 21 | if not le in lemmas_to_concept: 22 | lemmas_to_concept[le]= set([con]) 23 | else: 24 | lemmas_to_concept[le].add(con) 25 | 26 | 27 | 28 | class PropbankReader: 29 | def parse(self): 30 | self.frames = dict() 31 | self.non_sense_frames = dict() 32 | self.frame_lemmas = set() 33 | self.joints = set() 34 | for f in self.frame_files_path: 35 | self.parse_file(f) 36 | 37 | def __init__(self, folder_path=frame_folder_path): 38 | self.frame_files_path = folder_to_files_path(folder_path,".xml") 39 | self.parse() 40 | 41 | def parse_file(self,f): 42 | tree = ET.parse(f) 43 | root = tree.getroot() 44 | for child in root: 45 | if child.tag == "predicate": 46 | self.add_lemma(child) 47 | 48 | #add cannonical amr lemma to possible set of words including for aliases of the words 49 | def add_lemma(self,node): 50 | lemma = node.attrib["lemma"].replace("_","-") 51 | self.frames.setdefault(lemma,set()) 52 | self.non_sense_frames.setdefault(lemma,set()) 53 | # self.frames[lemma] = set() 54 | for child in node: 55 | if child.tag == "roleset": 56 | if "." not in child.attrib["id"]: 57 | if len(child.attrib["id"].split("-")) == 1: 58 | le,sense = child.attrib["id"],NULL_WORD 59 | else: 60 | le,sense = child.attrib["id"].split("-") 61 | # print (child.attrib["id"],lemma) 62 | else: 63 | le,sense = child.attrib["id"].replace("_","-").split(".") 64 | self.frame_lemmas.add(le) 65 | role = AMRUniversal(le,Rule_Frame,"-"+sense) 66 | if len(role.le.split("-")) == 2: 67 | k,v = role.le.split("-") 68 | self.joints.add((k,v)) 69 | no_sense_con = AMRUniversal(role.le,role.cat,None) 70 | add_concept(self.frames,lemma,role) 71 | add_concept(self.non_sense_frames,lemma,no_sense_con) 72 | aliases = child.find('aliases') 73 | if aliases: 74 | for alias in aliases.findall('alias'): 75 | if alias.text != le and alias.text not in self.frames: 76 | alias_t = alias.text.replace("_","-") 77 | add_concept(self.frames,alias_t,role) 78 | add_concept(self.non_sense_frames,alias_t,no_sense_con) 79 | #print (le, self.frames[le]) 80 | def get_frames(self): 81 | return self.frames 82 | def main(): 83 | f_r = PropbankReader() 84 | for k,v in f_r.joints: 85 | print (k+" "+v) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /utility/StringCopyRules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | ''' 4 | Building and hanlding category based dictionary for copying mechanism 5 | Also used by ReCategorization to preduce training set, and templates (which partially rely on string matching). 6 | 7 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 8 | @since: 2018-05-28 9 | ''' 10 | 11 | import threading 12 | from utility.data_helper import * 13 | from utility.AMRGraph import * 14 | from utility.constants import * 15 | from utility.PropbankReader import PropbankReader 16 | 17 | from nltk.metrics.distance import edit_distance 18 | 19 | 20 | def de_polarity(lemma): 21 | if len(lemma) == 0: return None 22 | if lemma[0] == "a" and len(lemma)> 5: 23 | return lemma[1:] 24 | if (lemma[:2]) in ["in","un","il","ir","im"] and len(lemma)>5: 25 | return lemma[2:] 26 | if lemma[:3] in ["dis","non"] and len(lemma)>6: 27 | return lemma[3:] 28 | if lemma[-4:] in ["less"] and len(lemma)>6: 29 | return lemma[:-4] 30 | return None 31 | def polarity_match(con_lemma,lemma): 32 | lemma = de_polarity(lemma) 33 | if lemma is not None: 34 | if disMatch(lemma,con_lemma )<1: 35 | return True 36 | return False 37 | 38 | 39 | #computing string dissimilarity (e.g. 0 means perfect match) 40 | def disMatch(lemma,con_lemma,t=0.5): 41 | # if (con_lemma == "and" and lemma == ";" ): return True 42 | # if (con_lemma == "multi-sentence" and lemma in [".",";"]): return True 43 | if lemma == con_lemma: return 0 44 | if de_polarity(lemma) == con_lemma: return 1 #not match if depolaritized matched 45 | if (con_lemma in lemma or lemma in con_lemma and "-role" not in con_lemma) and len(lemma)>2 and len(con_lemma)>2 : 46 | return 0 47 | if lemma.endswith("ily") and lemma[:-3]+"y"==con_lemma: 48 | return 0 49 | if lemma.endswith("ing") and (lemma[:-3]+"e"==con_lemma or lemma[:-3]==con_lemma): 50 | return 0 51 | if lemma.endswith("ical") and lemma[:-4]+"y"==con_lemma: 52 | return 0 53 | if lemma.endswith("ially") and lemma[:-5] in con_lemma: 54 | return 0 55 | if lemma.endswith("ion") and (lemma[:-3]+"e"==con_lemma or lemma[:-3]==con_lemma): 56 | return 0 57 | if lemma in con_lemma and len(lemma)>3 and len(con_lemma)-len(lemma)<5: 58 | return 0 59 | if lemma.endswith("y") and lemma[:-1]+"ize"==con_lemma: 60 | return 0 61 | if lemma.endswith("er") and (lemma[:-2]==con_lemma or lemma[:-3]==con_lemma or lemma[:-1]==con_lemma): 62 | return 0 63 | dis = 1.0*edit_distance(lemma,con_lemma)/min(12,max(len(lemma),len(con_lemma))) 64 | if (dis < t ) : 65 | return dis 66 | return 1 67 | 68 | import calendar 69 | month_abbr = {name: num for num, name in enumerate(calendar.month_abbr) if num} 70 | 71 | _float_regexp = re.compile(r"^[-+]?(?:\b[0-9]+(?:\.[0-9]*)?|\.[0-9]+\b)(?:[eE][-+]?[0-9]+\b)?$") 72 | def is_float_re(str): 73 | return re.match(_float_regexp, str) 74 | super_scripts = '⁰¹²³⁴⁵⁶⁷⁸⁹' 75 | def parseStr(x): 76 | if x.isdigit(): 77 | if x in super_scripts: 78 | return super_scripts.find(x) 79 | return int(x) 80 | elif is_float_re(x): 81 | return float(x) 82 | return None 83 | 84 | units = [ 85 | "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", 86 | "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", 87 | "sixteen", "seventeen", "eighteen", "nineteen", 88 | ] 89 | tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] 90 | months = ["","january","february","march","april","may","june","july","august","september","october","november","december"] 91 | scales = ["hundred", "thousand", "million", "billion", "trillion"] 92 | scaless = ["hundreds", "thousands", "millions", "billions", "trillions"] 93 | numwords = {} 94 | numwords["and"] = (1, 0) 95 | for idx, word in enumerate(units): numwords[word] = (1, idx) 96 | for idx, word in enumerate(months): numwords[word] = (1, idx) 97 | for idx, word in enumerate(tens): numwords[word] = (1, idx * 10) 98 | for idx, word in enumerate(scales): numwords[word] = (10 ** (idx * 3 or 2), 0) 99 | for idx, word in enumerate(scaless): numwords[word] = (10 ** (idx * 3 or 2), 0) 100 | ordinal_words = {'first':1, 'second':2, 'third':3, 'fifth':5, 'eighth':8, 'ninth':9, 'twelfth':12} 101 | ordinal_endings = [('ieth', 'y'), ('th', ''), ('st', ''), ('nd', ''), ('rd', '')] 102 | 103 | def text2int(textnum): 104 | for k, v in month_abbr.items(): 105 | if textnum == k.lower(): 106 | return v 107 | if " and " in textnum : 108 | textnums = textnum.split(" and ") 109 | out = [ j for j in [text2int(i) for i in textnums ] if j ] 110 | if len(out) > 1: return sum(out) 111 | else: return None 112 | textnum = textnum.replace(',', ' ') 113 | textnum = textnum.replace('-', ' ') 114 | textnum = textnum.replace('@-@', ' ') 115 | current = result = 0 116 | for word in textnum.split(): 117 | w_num = parseStr(word) 118 | if word in ordinal_words: 119 | scale, increment = (1, ordinal_words[word]) 120 | elif w_num: 121 | scale, increment = (1,w_num) 122 | else: 123 | for ending, replacement in ordinal_endings: 124 | if word.endswith(ending): 125 | word = "%s%s" % (word[:-len(ending)], replacement) 126 | if word not in numwords: 127 | return None 128 | scale, increment = numwords[word] 129 | current = current * scale + increment 130 | if scale > 100: 131 | result += current 132 | current = 0 133 | return int(result + current) 134 | 135 | 136 | 137 | 138 | 139 | 140 | def unmixe(mixed,threshold = 50): 141 | high_frequency = dict() 142 | low_frequency = dict() 143 | low_text_num = dict() 144 | high_text_num = dict() 145 | for i in mixed: 146 | cat = i.cat 147 | if mixed[i][0] >= threshold and ( cat in Rule_All_Constants) : 148 | high_text_num[i] = mixed[i] 149 | elif mixed[i][0] >= threshold: 150 | high_frequency[i] = mixed[i] 151 | elif (cat in Rule_All_Constants): 152 | low_text_num[i] = mixed[i] 153 | else: 154 | low_frequency[i] = mixed[i] 155 | 156 | return high_text_num,high_frequency,low_frequency,low_text_num 157 | 158 | class rules: 159 | RE_FRAME_NUM = re.compile(r'-\d\d$') 160 | NUM = re.compile(r'[-]?[1-9][0-9]*[:.]?[0-9]*') 161 | Pure_NUM = re.compile(r'[-]?[1-9][0-9]*[,]?[0-9]*') 162 | 163 | def save(self,filepath="data/rule_f"): 164 | pickle_helper= Pickle_Helper(filepath) 165 | pickle_helper.dump(self.lemma_back,"lemma_back") 166 | pickle_helper.dump([k for k in self.lemma_freq_cat.keys()],"keys") 167 | for cat in self.lemma_freq_cat: 168 | pickle_helper.dump(self.lemma_freq_cat[cat] ,cat) 169 | pickle_helper.save() 170 | 171 | self.load(filepath) 172 | 173 | lock = threading.Lock() 174 | def load(self,filepath="data/rule_f"): 175 | pickle_helper= Pickle_Helper(filepath) 176 | data = pickle_helper.load() 177 | keys = data["keys"] 178 | self.lemma_freq_cat = {} 179 | self.lemma_back = data["lemma_back"] 180 | for key in keys: 181 | self.lemma_freq_cat[key] = data[key] 182 | self.build_lemma_cheat() 183 | return self 184 | 185 | def set_rules(self): 186 | self.rules = {} 187 | self.rules[Rule_Frame]= lambda _,l,con_l ,sense: self.standard_rule(l,Rule_Frame,con_l,sense) 188 | self.rules[Rule_String]= lambda x,_,con_l ,__: self.standard_rule(x,Rule_String,con_l) 189 | self.rules[Rule_Ner]= lambda x,_ ,con_l,__:self.standard_rule(x,Rule_Ner,con_l) 190 | self.rules[Rule_B_Ner]= lambda x,_ ,con_l,__:self.standard_rule(x,Rule_B_Ner,con_l) 191 | self.rules[Rule_Constant]= lambda _,l,con_l,__: self.standard_rule(l,Rule_Constant,con_l) 192 | self.rules[Rule_Concept]= lambda _,l,con_l,__: self.standard_rule(l,Rule_Concept,con_l) 193 | self.rules[Rule_Num]= lambda _,l ,con_l,__: self.num(l) 194 | 195 | def entity(self,lemma,cat,con_lemma = None): 196 | num = self.num(lemma) 197 | if num is not None and num.le != NULL_WORD: 198 | num.cat = cat 199 | if cat == "date-entity" and self.Pure_NUM.search(lemma) and len(lemma) == 6: 200 | num.le = lemma 201 | return num 202 | return self.standard_rule(lemma,cat,con_lemma) 203 | 204 | def read_veb(self): 205 | RE_FRAME_NUM = re.compile(r'-\d\d$') 206 | f = open(verbalization,"r") 207 | f.readline() 208 | line = f.readline() 209 | while line != '' : 210 | tokens = line.replace("\n","").split(" ") 211 | if len(tokens)<= 4 and (tokens[0] =="VERBALIZE" or tokens[0] =="MAYBE-VERBALIZE"): 212 | old_lemma = tokens[1] 213 | amr_lemma = re.sub(RE_FRAME_NUM, '', tokens[3]) 214 | if tokens[0] =="MAYBE-VERBALIZE": 215 | self.add_lemma_freq(old_lemma,amr_lemma,Rule_Frame,freq=1,sense = tokens[3][-3:]) 216 | else: 217 | self.add_lemma_freq(old_lemma,amr_lemma,Rule_Frame,freq=100,sense = tokens[3][-3:]) 218 | 219 | line = f.readline() 220 | f.close() 221 | 222 | 223 | def read_frame(self): 224 | f_r = PropbankReader() 225 | f_r = f_r.frames 226 | for le,concepts in f_r.items(): 227 | i=0 228 | for c in concepts: 229 | self.add_lemma_freq(le,c.le,Rule_Frame,freq=10,sense = c.sense) 230 | i = i+1 231 | 232 | def __init__(self): 233 | self.lemma_freq_cat = {} 234 | self.lemmatize_cheat = {} 235 | self.lemma_back = {} 236 | self.read_frame() 237 | self.read_veb() 238 | self.frame_lemmas = PropbankReader().frame_lemmas 239 | self.build_lemma_cheat() 240 | self.set_rules() 241 | # self.rules[Rule_Re]= lambda _,l = wordnet.NOUN: self.re(l) 242 | 243 | def standard_rule(self,lemma,cat,con_lemma=None,sense=NULL_WORD): 244 | if con_lemma is None: #testing 245 | if cat in [Rule_Ner,Rule_B_Ner ]and len(lemma)>3: 246 | lemma = lemma.capitalize() 247 | if (lemma,cat) in self.lemmatize_cheat: 248 | # if lemma == "cooperation" and cat == Rule_Frame: 249 | # print ("before cooperation",self.lemmatize_cheat[(lemma,cat)],AMRUniversal(lemma,cat,sense)) 250 | lemma = self.lemmatize_cheat[(lemma,cat)] 251 | # if lemma == "cooperate" and cat == Rule_Frame: 252 | # print ("after cooperate",AMRUniversal(lemma,cat,sense)) 253 | # elif lemma == "cooperation" and cat == Rule_Frame: 254 | # print ("after cooperation",AMRUniversal(lemma,cat,sense)) 255 | return AMRUniversal(lemma,cat,sense) 256 | return AMRUniversal(lemma,cat,sense) 257 | else: #training 258 | if cat in [Rule_Ner,Rule_B_Ner ] and len(lemma)>3: 259 | lemma = lemma.capitalize() 260 | if cat not in self.lemma_freq_cat or lemma not in self.lemma_freq_cat[cat]: 261 | return AMRUniversal(lemma,cat,sense) 262 | candidates = self.lemma_freq_cat[cat][lemma] 263 | if con_lemma in candidates.keys(): 264 | return AMRUniversal(con_lemma,cat,sense) 265 | return AMRUniversal(lemma,cat,sense) 266 | 267 | def clear_freq(self): 268 | self.lemma_freq_cat = {} 269 | self.lemmatize_cheat = {} 270 | 271 | def add_lemma_freq(self,old_lemma,amr_lemma,cat,freq=1,sense=NULL_WORD): 272 | # if cat in Rule_All_Constants: 273 | # return 274 | self.lock.acquire() 275 | if old_lemma == amr_lemma: freq *= 10 276 | amr_con = amr_lemma 277 | self.lemma_back[amr_con][old_lemma] = self.lemma_back.setdefault(amr_con,{}).setdefault(old_lemma,0)+freq 278 | lemma_freq = self.lemma_freq_cat.setdefault(cat,{}).setdefault(old_lemma,{}) 279 | lemma_freq[amr_con] = lemma_freq.setdefault(amr_con,0)+freq 280 | self.lock.release() 281 | 282 | 283 | def build_lemma_cheat(self): 284 | for cat in self.lemma_freq_cat: 285 | lemma_freqs = self.lemma_freq_cat[cat] 286 | for word in lemma_freqs: 287 | max_score = 0 288 | max_lemma = word 289 | for arm_le in lemma_freqs[word]: 290 | score = 1.0*lemma_freqs[word][arm_le] 291 | assert (score > 0) 292 | if score >max_score: 293 | max_score = score 294 | max_lemma = arm_le 295 | self.lemmatize_cheat[(word,cat)] = max_lemma 296 | 297 | # print (self.lemmatize_cheat) 298 | 299 | # fragments_to_break = set(["up","down","make"]) 300 | 301 | def num(self,lemma): 302 | r = text2int(lemma) 303 | if r is None and self.Pure_NUM.search(lemma) is not None: 304 | lemma = lemma.replace(",","") 305 | return AMRUniversal(lemma,Rule_Num,None) 306 | if r is not None: 307 | return AMRUniversal(str(r),Rule_Num,None) 308 | return AMRUniversal(NULL_WORD,Rule_Num,None) 309 | 310 | #old_ids : batch x (cat,le,lemma,word) only cat is id 311 | def toAmrSeq(self,cats,snt,lemma,high,auxs,senses = None,ners = None): 312 | out = [] 313 | for i in range(len(snt)): 314 | sense = senses[i] if senses else None 315 | txt, le,cat,h,aux = snt[i],lemma[i],cats[i],high[i],auxs[i] 316 | assert h is None or isinstance(h,str) or isinstance(h,tuple)and isinstance(cat,str) ,(txt, le,cat,h) 317 | if h and h != UNK_WORD: 318 | if cat == Rule_Num: 319 | uni = self.to_concept(txt,h,Rule_Num,sense) 320 | if uni.le == NULL_WORD: 321 | uni = AMRUniversal(h,Rule_Concept,sense) 322 | else: 323 | uni = AMRUniversal(h,cat,sense) 324 | else: 325 | try_num = self.to_concept(txt,le,Rule_Num,sense) 326 | if " " in le and try_num.le != NULL_WORD and cat not in [Rule_String,Rule_B_Ner,Rule_Ner] and "entity" not in cat: 327 | uni = try_num 328 | else: 329 | uni = self.to_concept(txt,le,cat,sense) 330 | 331 | if cat == Rule_B_Ner: 332 | if not aux in [UNK_WORD,NULL_WORD]: 333 | uni.aux = aux 334 | elif ners[i] == "PERSON": 335 | uni.aux = "person" 336 | elif ners[i] == "LOCATION": 337 | uni.aux = "location" 338 | elif ners[i] == "ORGANIZATION": 339 | uni.aux = "organization" 340 | else: 341 | uni.aux = UNK_WORD 342 | assert isinstance(uni.le,str) and isinstance(uni.cat,str ),(txt, le,cat,h,uni.le,uni,cat) 343 | 344 | 345 | if ners[i] == "URL": #over write ML decision, otherwise creating bug 346 | uni = AMRUniversal(le,"url-entity",None) 347 | 348 | out.append(uni) 349 | 350 | return out 351 | 352 | 353 | def to_concept(self,txt,le,cat,con_lemma=None,sense=NULL_WORD): 354 | if cat in self.rules: 355 | return self.rules[cat](txt,le,con_lemma,sense) 356 | elif cat.endswith("-entity"): # entity 357 | return self.entity(le,cat,con_lemma) 358 | else: 359 | return self.standard_rule(le,cat,con_lemma) 360 | 361 | 362 | #amr is myamr 363 | def get_matched_concepts(self,snt_token,amr_or_node_value,lemma,pos,with_target = False,jamr=False,full=1): 364 | results = [] 365 | node_value = amr_or_node_value.node_value(keys=["value","align"]) if isinstance(amr_or_node_value,AMRGraph) else amr_or_node_value 366 | for n,c,a in node_value: 367 | if full == 1: 368 | align = self.match_concept(snt_token,c,lemma,pos,with_target) 369 | if jamr and a is not None: 370 | exist = False 371 | for i,l,p in align: 372 | if i == a: 373 | exist = True 374 | if not exist: 375 | align += [(a,lemma[a],pos[a])] 376 | results.append([n,c,align]) 377 | else: 378 | if jamr and a is not None: 379 | align = [(a,lemma[a],pos[a])] 380 | else: 381 | align = self.match_concept(snt_token,c,lemma,pos,with_target) 382 | results.append([n,c,align]) 383 | return results 384 | 385 | def match_concept(self,snt_token,concept,lemma,pos,with_target = False,candidate = None): 386 | if len(lemma) == 1: return [[0,lemma[0],pos[0]]] 387 | le,cat,sense = decompose(concept) 388 | align = [] 389 | if candidate is None: 390 | candidate = range(len(snt_token)) 391 | for i in candidate: 392 | if with_target and disMatch(lemma[i],le) <1: # and pos[i] not in ["IN"]: 393 | align.append((i,lemma[i],pos[i])) 394 | continue 395 | if with_target: 396 | amr_c = self.to_concept(snt_token[i],lemma[i],cat,le,sense) 397 | else: 398 | amr_c = self.to_concept(snt_token[i],lemma[i],cat) 399 | if amr_c is None: 400 | continue 401 | le_i,cat_i,sen_i = decompose(amr_c) 402 | 403 | assert cat == cat_i, "cat mismatch "+ snt_token[i]+" "+lemma[i]+" "+cat+" "+le+" "+cat_i+" "+le_i+"\n"+" ".join(snt_token) 404 | if amr_c.non_sense_equal(concept): # and pos[i] not in ["IN"]: 405 | align.append((i,lemma[i],pos[i])) 406 | 407 | if le == "and" and len(align) == 0: 408 | for i in range(len(lemma)): 409 | if lemma[i] == ";" or lemma[i] == "and": 410 | align.append((i,lemma[i],pos[i])) 411 | if len(align)>0: return [align[-1]] 412 | 413 | # if len(align) > 0 : print (le,align,lemma) 414 | 415 | if le == "multi-sentence" and len(align) == 0 and False: 416 | for i in range(len(lemma)): 417 | if lemma[i] in [".",";","?","!"]: 418 | align.append((i,lemma[i],pos[i])) 419 | return align 420 | # if len(align) > 0 : print (le,align,lemma) 421 | return align 422 | 423 | -------------------------------------------------------------------------------- /utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__init__.py -------------------------------------------------------------------------------- /utility/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__init__.pyc -------------------------------------------------------------------------------- /utility/__pycache__/AMRGraph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/AMRGraph.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/Naive_Scores.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/Naive_Scores.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/PropbankReader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/PropbankReader.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/ReCategorization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/ReCategorization.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/StringCopyRules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/StringCopyRules.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/amr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/amr.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/constants.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/constants.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/data_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/__pycache__/data_helper.cpython-36.pyc -------------------------------------------------------------------------------- /utility/amr.peg: -------------------------------------------------------------------------------- 1 | # PEG (parsing expression grammar) for a single AMR annotation.MEDCONST = ~r"[a-z]{2,}" ALIGNMENT? # aside from + and -, named constants must have at least 2 letters (to distinguish from variable names) 2 | # Designed for Parsimonious library (the https://github.com/erikrose/parsimonious), 3 | # though a bit of automatic cleanup is required when loading this file. 4 | # Nathan Schneider, 2015-05-05 5 | 6 | ALL = ~r"\s*" X ~r"\s*$" 7 | 8 | X = "(" ` VAR _ "/" _ CONCEPT (_ REL _ Y)* ` ")" 9 | Y = X / NAMEDCONST / VAR / STR / NUM 10 | VAR = ~r"[a-z]+[0-9]*" ALIGNMENT? 11 | NAMEDCONST = ~r"[a-z]{2,}\b" ALIGNMENT? # aside from + and -, named constants must have at least 2 letters (to distinguish from variable names) 12 | STR = "\"" ~r"[^\"\s]([^\"\n\r]*[^\"\s])?" "\"" ALIGNMENT? # quoted string literal. nonempty; may not start or end with whitespace 13 | CONCEPT = ~r"[A-Za-z0-9.\!\?,:;'][A-Za-z0-9.i\!\?.;:'-]*" ALIGNMENT? # seen in data: :li (x3 / 3) and :quant (x / 355.02) and :mod (x / friggin') 14 | REL = ~r":[A-Za-z][A-Za-z0-9-]*" ALIGNMENT? 15 | NUM = ~r"[+-]?\d*(\.\d+)?" ALIGNMENT? 16 | ALIGNMENT = "~" ~r"[A-Za-z0-9\.]+(\,[0-9]+)*" 17 | # TODO: the regexes, especially NUM, need checking 18 | 19 | _ = ~r"([ \t]*[\n\r][ \t]*)|[ \t]+" 20 | ` = ~r"[ \t]*[\n\r]?[ \t]*" 21 | -------------------------------------------------------------------------------- /utility/amr.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/amr.pyc -------------------------------------------------------------------------------- /utility/constants.py: -------------------------------------------------------------------------------- 1 | '''Global constants, and file paths''' 2 | import os,re 3 | 4 | # Change the path according to your system 5 | 6 | save_to = '/disk/scratch/s1544871/model/' #the folder amr model will be saved to (model name is parameterized by some hyper parameter) 7 | train_from = '/disk/scratch/s1544871/model/gpus_0valid_best.pt' #default model loading 8 | embed_path = "/disk/scratch/s1544871/glove.840B.300d.txt" #file containing glove embedding 9 | core_nlp_url = 'http://localhost:9000' #local host url of standford corenlp server 10 | root_path = "/disk/scratch/s1544871" 11 | allFolderPath = root_path + "/amr_annotation_r2/data/alignments/split" 12 | resource_folder_path = root_path +"/amr_annotation_r2/" 13 | frame_folder_path = resource_folder_path+"data/frames/propbank-frames-xml-2016-03-08/" 14 | have_org_role = resource_folder_path+"have-org-role-91-roles-v1.06.txt" #not used 15 | have_rel_role = resource_folder_path+"have-rel-role-91-roles-v1.06.txt" #not used 16 | morph_verbalization = resource_folder_path+"morph-verbalization-v1.01.txt" #not used 17 | verbalization = resource_folder_path+"verbalization-list-v1.06.txt" 18 | 19 | 20 | PAD = 0 21 | UNK = 1 22 | 23 | PAD_WORD = '' 24 | UNK_WORD = '' 25 | BOS_WORD = '' 26 | EOS_WORD = '' 27 | NULL_WORD = "" 28 | UNK_WIKI = '' 29 | Special = [NULL_WORD,UNK_WORD,PAD_WORD] 30 | #Categories 31 | Rule_Frame = "Frame" 32 | Rule_Constant = "Constant" 33 | Rule_String = "String" 34 | Rule_Concept = "Concept" 35 | Rule_Comp = "COMPO" 36 | Rule_Num = "Num" 37 | Rule_Re = "Re" #corenference 38 | Rule_Ner = "Ner" 39 | Rule_B_Ner = "B_Ner" 40 | Rule_Other = "Entity" 41 | Other_Cats = {"person","thing",} 42 | COMP = "0" 43 | Rule_All_Constants = [Rule_Num,Rule_Constant,Rule_String,Rule_Ner] 44 | Splish = "$£%%££%£%£%£%" 45 | Rule_Basics = Rule_All_Constants + [Rule_Frame,Rule_Concept,UNK_WORD,BOS_WORD,EOS_WORD,NULL_WORD,PAD_WORD] 46 | 47 | RULE = 0 48 | HIGH = 1 49 | LOW = 2 50 | 51 | RE_FRAME_NUM = re.compile(r'-\d\d$') 52 | RE_COMP = re.compile(r'_\d$') 53 | end= re.compile(".txt\_[a-z]*") 54 | epsilon = 1e-8 55 | 56 | TXT_WORD = 0 57 | TXT_LEMMA = 1 58 | TXT_POS = 2 59 | TXT_NER = 3 60 | 61 | 62 | AMR_CAT = 0 63 | AMR_LE = 1 64 | AMR_NER = 2 65 | AMR_AUX = 2 66 | AMR_LE_SENSE = 3 67 | AMR_SENSE = 3 68 | AMR_CAN_COPY = 4 69 | 70 | threshold = 5 71 | 72 | 73 | -------------------------------------------------------------------------------- /utility/constants.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/3375123c6b00bdfbe3395706769175073716b699/utility/constants.pyc -------------------------------------------------------------------------------- /utility/data_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | # coding=utf-8 3 | ''' 4 | 5 | Some helper functions for storing and reading data 6 | 7 | @author: Chunchuan Lyu (chunchuan.lv@gmail.com) 8 | @since: 2018-05-29 9 | ''' 10 | import json,os,re 11 | import pickle 12 | 13 | 14 | class Pickle_Helper: 15 | 16 | def __init__(self, filePath): 17 | self.path = filePath 18 | self.objects = dict() 19 | 20 | def dump(self,obj,name): 21 | self.objects[name] = obj 22 | 23 | def save(self): 24 | f = open(self.path , "wb") 25 | pickle.dump(self.objects ,f,protocol=pickle.HIGHEST_PROTOCOL) 26 | f.close() 27 | 28 | def load(self): 29 | f = open(self.path , "rb") 30 | self.objects = pickle.load(f) 31 | f.close() 32 | return self.objects 33 | 34 | def get_path(self): 35 | return self.path 36 | 37 | class Json_Helper: 38 | 39 | def __init__(self, filePath): 40 | self.path = filePath 41 | self.objects = dict() 42 | 43 | def dump(self,obj,name): 44 | self.objects[name] = obj 45 | 46 | def save(self): 47 | if not os.path.exists(self.path): 48 | os.makedirs(self.path) 49 | for name in self.objects: 50 | with open(self.path+"/"+name+".json", 'w+') as fp: 51 | json.dump(self.objects[name], fp) 52 | 53 | def load(self): 54 | files_path = folder_to_files_path(self.path,ends =".json") 55 | for f in files_path: 56 | name = f.split("/") 57 | with open(f) as data_file: 58 | data = json.load(data_file) 59 | self.objects[name] = data 60 | return self.objects 61 | 62 | def get_path(self): 63 | return self.path 64 | 65 | def folder_to_files_path(folder,ends =".txt"): 66 | files = os.listdir(folder ) 67 | files_path = [] 68 | for f in files: 69 | if f.endswith(ends): 70 | files_path.append(folder+f) 71 | # break 72 | return files_path 73 | def load_line(line,data): 74 | if "\t" in line: 75 | tokens = line[4:].split("\t") 76 | else: 77 | tokens = line[4:].split() 78 | if tokens[0] == "root": return 79 | 80 | if tokens[0] == "node": 81 | data["node"][tokens[1]] = tokens[2] 82 | if tokens.__len__() > 3: 83 | data["align"][tokens[1]] = int(tokens[3].split("-")[0]) 84 | return 85 | if tokens[0] == "edge": 86 | data["edge"][tokens[4],tokens[5]] = tokens[2] 87 | return 88 | data[tokens[0]] = tokens[1:] 89 | def asserting_equal_length(data): 90 | assert len(data["tok"]) ==len(data["lem"]) , ( len(data["tok"]) ,len(data["lem"]),"\n",list(zip(data["tok"],data["lem"])) ,data["tok"],data["lem"]) 91 | assert len(data["tok"]) ==len(data["ner"]) , ( len(data["tok"]) ,len(data["ner"]),"\n",list(zip(data["tok"],data["ner"])) ,data["tok"],data["ner"]) 92 | assert len(data["tok"]) ==len(data["pos"]) , ( len(data["tok"]) ,len(data["pos"]),"\n",list(zip(data["tok"],data["pos"])) ,data["tok"],data["pos"]) 93 | 94 | def load_text_jamr(filepath): 95 | all_data = [] 96 | with open(filepath,'r') as f: 97 | line = f.readline() 98 | while line != '' : 99 | while line != '' and not line.startswith("# ::") : 100 | line = f.readline() 101 | 102 | if line == "": return all_data 103 | 104 | data = {} 105 | data.setdefault("align",{}) 106 | data.setdefault("node",{}) 107 | data.setdefault("edge",{}) 108 | while line.startswith("# ::"): 109 | load_line(line.replace("\n","").strip(),data) 110 | line = f.readline() 111 | amr_t = "" 112 | while line.strip() != '' and not line.startswith("# AMR release"): 113 | amr_t = amr_t+line 114 | line = f.readline() 115 | data["amr_t"] = amr_t 116 | asserting_equal_length(data) 117 | all_data.append(data) 118 | line = f.readline() 119 | return all_data 120 | 121 | 122 | def load_text_input(filepath): 123 | all_data = [] 124 | with open(filepath,'r') as f: 125 | line = f.readline() 126 | while line != '' : 127 | while line != '' and not line.startswith("# ::"): 128 | line = f.readline() 129 | 130 | if line == "": return all_data 131 | 132 | data = {} 133 | while line.startswith("# ::"): 134 | load_line(line.replace("\n","").strip(),data) 135 | line = f.readline() 136 | all_data.append(data) 137 | line = f.readline() 138 | return all_data --------------------------------------------------------------------------------