├── LICENSE ├── README.md ├── analysis ├── __init__.py ├── analyze_chinese.py ├── analyzer.py ├── comparator.py ├── compare_heatmap.py ├── evaluate_result.py ├── heatmap.py ├── intro_examples.py ├── length.py ├── length_analysis.py ├── significant.py └── stator.py ├── common ├── instance.py └── sentence.py ├── config ├── __init__.py ├── config.py ├── eval.py ├── reader.py └── utils.py ├── data ├── catalan │ ├── dev.sd.conllx │ ├── test.sd.conllx │ └── train.sd.conllx ├── readme.txt └── spanish │ ├── dev.sd.conllx │ ├── test.sd.conllx │ └── train.sd.conllx ├── main.py ├── model ├── charbilstm.py ├── deplabel_gcn.py └── lstmcrf.py ├── preprocess ├── convert_sem_eng.py ├── convert_sem_other.py ├── elmo_others.py ├── prebert.py ├── preelmo.py └── preflair.py └── scripts ├── run.bash ├── run_pytorch.bash └── run_pytorch_all.bash /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Dependency-Guided LSTM-CRF Model for Named Entity Recognition 2 | 3 | Codebase for the upcoming paper "[Dependency-Guided LSTM-CRF for Named Entity Recognition](https://www.aclweb.org/anthology/D19-1399.pdf)" in EMNLP 2019. 4 | The usage code below make sure you can reproduce almost same results as shown in the paper. 5 | 6 | ### Requirements 7 | * PyTorch 1.1 (Also tested on PyTorch 1.3) 8 | * Python 3.6 9 | 10 | ### Dataset Format 11 | 12 | I have uploaded the preprocessed `Catalan` and `Spanish` datasets. (Please contact me with your license if you need the preprocessed OntoNotes dataset.) 13 | If you have a new dataset, please make sure we follow the CoNLL-X format and we put the entity label at the end. 14 | The sentence below is an example. 15 | Note that we only use the columns for *word*, *dependency head index*, *dependency relation label* and the last *entity label*. 16 | ``` 17 | 1 Brasil _ n n _ 2 suj _ _ B-org 18 | 2 buscará _ v v _ 0 root _ _ O 19 | 3 a_partir_de _ s s _ 2 cc _ _ O 20 | 4 mañana _ n n _ 3 sn _ _ O 21 | 5 , _ f f _ 6 f _ _ B-misc 22 | 6 viernes _ w w _ 4 sn _ _ I-misc 23 | 7 , _ f f _ 6 f _ _ I-misc 24 | 8 el _ d d _ 9 spec _ _ O 25 | 9 pase _ n n _ 2 cd _ _ O 26 | ``` 27 | Entity labels follow the `IOB` tagging scheme and will be converted to `IOBES` in this codebase. 28 | 29 | ### Usage 30 | 31 | Baseline **BiLSTM-CRF**: 32 | ```bash 33 | python main.py --dataset ontonotes --embedding_file data/glove.6B.100d.txt \ 34 | --num_lstm_layer 1 --dep_model none 35 | ``` 36 | Change `embedding_file` if you are using other languages, change `dataset` for other datasets, change `num_lstm_layer` for different `L = 0,1,2,3`. Use `--device cuda:0` if you are using gpu. 37 | 38 | **DGLSTM-CRF** 39 | ```bash 40 | python main.py --dataset ontonotes --embedding_file data/glove.6B.100d.txt \ 41 | --num_lstm_layer 1 --dep_model dglstm --inter_func mlp 42 | ``` 43 | Change the interaction function `inter_func = concatenation, addition, mlp` for other interactions. 44 | 45 | 46 | ### Usage for other datasets and other languages 47 | Remember to put the dataset under the data folder. The naming rule for `train/dev/test` is `train.sd.conllx`, `dev.sd.conllx` and `test.sd.conllx`. 48 | Then simply change the `--dataset` name and `--embedding_file`. 49 | 50 | Dataset | Embedding 51 | ------------ | ------------- 52 | OntoNotes English | glove.6B.100d.txt 53 | OntoNotes Chinese | cc.zh.300.vec (FastText) 54 | Catalan | cc.ca.300.vec (FastText) 55 | Spanish | cc.es.300.vec (FastText) 56 | 57 | 58 | 59 | ### Using ELMo 60 | In any case, once we have obtained the pretrained ELMo vector files ready. 61 | For example, download the `Catalan ELMo` vectors from [here](https://drive.google.com/open?id=1bGCRy4pYDWBcEae5sTSIcdu6PwWgz7Kn), decompressed all the files (`train.conllx.elmo.vec`,`dev.conllx.elmo.vec`, `test.conllx.elmo.vec`) into `data/catalan/`. 62 | We can then simply run the command below (we take the **DGLSTM-CRF** for example) 63 | ```bash 64 | python main.py --dataset ontonotes --embedding_file data/glove.6B.100d.txt \ 65 | --num_lstm_layer 1 --dep_model dglstm --inter_func mlp \ 66 | --context_emb elmo 67 | ``` 68 | ### Obtain ELMo vectors for other languages: 69 | We use the ELMo from AllenNLP for English, and use [ELMoForManyLangs](https://github.com/HIT-SCIR/ELMoForManyLangs) for other languages. 70 | * English, run the `preprocess/preelmo.py` code (remember to change the `dataset` name) 71 | ```bash 72 | python preprocess/preelmo.py 73 | ``` 74 | * Chinese, Catalan, and Spanish 75 | Download the ELMo models from [ELMoForManyLangs](https://github.com/HIT-SCIR/ELMoForManyLangs). NOTE: remember to follow the instruction to slighly modify some paths inside. 76 | Then you can run `preprocess/elmo_others.py`: (again remember to change `dataset` name and ELMo model path) 77 | ```bash 78 | python preprocess/elmo_others.py 79 | ``` 80 | 81 | 82 | ### Notes on Dataset Preprocessing (Two Options) 83 | 84 | #### OntoNotes Preprocessing 85 | Many people are asking for the OntoNotes 5.0 dataset. 86 | I understand that it is hard to get the correct split as in previous work (Chiu and Nichols, 2016; Li et al., 2017; Ghaddar and Langlais, 2018;). 87 | If you want to get the correct split, you can refere to a guide [here](https://github.com/allanj/pytorch_lstmcrf/blob/master/docs/benchmark.md) where 88 | I summarize how to preprocess the OntoNotes dataset. 89 | 90 | #### Download Our Preprocessed dataset 91 | We notice that the OntoNotes 5.0 dataset has been freely available on LDC. We will also release our link to our pre-processed OntoNotes here ([__English__](https://drive.google.com/file/d/1AAWnb5GlDiNMj3yNoaoQtoKHj7iSqNey/view?usp=sharing), [__Chinese__](https://drive.google.com/file/d/10t3XpZzsD67ji0a7sw9nHM7I5UhrJcdf/view?usp=sharing)). 92 | 93 | ### Citation 94 | ``` 95 | @InProceedings{jie2019dependency, 96 | author = "Jie, Zhanming and Lu, Wei", 97 | title = "Dependency-Guided LSTM-CRF for Named Entity Recognition", 98 | booktitle = "Proceedings of EMNLP", 99 | year = "2019", 100 | url = "https://www.aclweb.org/anthology/D19-1399", 101 | doi = "10.18653/v1/D19-1399", 102 | pages = "3860--3870" 103 | } 104 | ``` -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | ### NOTE: the code in this folder is only used to analyze the results and data statistics 3 | ## -------------------------------------------------------------------------------- /analysis/analyze_chinese.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | 6 | from tqdm import tqdm 7 | from common.sentence import Sentence 8 | from common.instance import Instance 9 | from typing import List 10 | from config.eval import evaluate, Span 11 | import random 12 | 13 | def get_spans(output): 14 | output_spans = set() 15 | start = -1 16 | for i in range(len(output)): 17 | if output[i].startswith("B-"): 18 | start = i 19 | if output[i].startswith("E-"): 20 | end = i 21 | output_spans.add(Span(start, end, output[i][2:])) 22 | if output[i].startswith("S-"): 23 | output_spans.add(Span(i, i, output[i][2:])) 24 | return output_spans 25 | 26 | def read_conll(res_file: str, number: int = -1) -> List[Instance]: 27 | print("Reading file: " + res_file) 28 | insts = [] 29 | # vocab = set() ## build the vocabulary 30 | with open(res_file, 'r', encoding='utf-8') as f: 31 | words = [] 32 | heads = [] 33 | deps = [] 34 | labels = [] 35 | tags = [] 36 | preds = [] 37 | for line in tqdm(f.readlines()): 38 | line = line.rstrip() 39 | if line == "": 40 | inst = Instance(Sentence(words, heads, deps, tags), labels) 41 | inst.prediction = preds 42 | insts.append(inst) 43 | words = [] 44 | heads = [] 45 | deps = [] 46 | labels = [] 47 | tags = [] 48 | preds = [] 49 | 50 | if len(insts) == number: 51 | break 52 | continue 53 | vals = line.split() 54 | word = vals[1] 55 | pos = vals[2] 56 | head = int(vals[3]) 57 | dep_label = vals[4] 58 | 59 | label = vals[5] 60 | pred_label = vals[6] 61 | 62 | words.append(word) 63 | heads.append(head) ## because of 0-indexed. 64 | deps.append(dep_label) 65 | tags.append(pos) 66 | labels.append(label) 67 | preds.append(pred_label) 68 | print("number of sentences: {}".format(len(insts))) 69 | return insts 70 | 71 | res1 = "../final_results/lstm_3_200_crf_ontonotes_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_200_lr_0.01.results" 72 | insts1 = read_conll(res1) 73 | 74 | res2 = "../final_results/lstm_2_200_crf_ontonotes_sd_-1_dep_feat_emb_elmo_none_sgd_gate_0_base_-1_epoch_150_lr_0.01.results" 75 | insts2 = read_conll(res2) 76 | 77 | print(evaluate(insts1)) 78 | print(evaluate(insts2)) 79 | num = 0 80 | total_entity = 0 81 | type2num = {} 82 | length2num = {} 83 | dep_label2num = {} 84 | gc2num = {} 85 | for i in range(len(insts1)): 86 | 87 | first = insts1[i] 88 | second = insts2[i] 89 | gold_spans = get_spans(first.output) 90 | 91 | pred_first = get_spans(first.prediction) 92 | pred_second = get_spans(second.prediction) 93 | 94 | 95 | # for span in pred_first: 96 | # if span in gold_spans and (span not in pred_second): 97 | for span in gold_spans: 98 | if span in pred_first and (span not in pred_second): 99 | num += 1 100 | print(span.to_str(first.input.words)) 101 | if span.type in type2num: 102 | type2num[span.type] +=1 103 | else: 104 | type2num[span.type] = 1 105 | length = span.right - span.left + 1 106 | if length in length2num: 107 | length2num[length] += 1 108 | else: 109 | length2num[length] = 1 110 | 111 | for k in range(span.left, span.right + 1): 112 | if first.input.heads[k] == -1 or (first.input.heads[k] > span.right or first.input.heads[k] < span.left): 113 | if first.input.dep_labels[k] in dep_label2num: 114 | dep_label2num[first.input.dep_labels[k]] +=1 115 | else: 116 | dep_label2num[first.input.dep_labels[k]] = 1 117 | 118 | if first.input.heads[k]!= -1 and first.input.heads[first.input.heads[k]] != -1: 119 | h = first.input.heads[first.input.heads[k]] 120 | if first.input.dep_labels[k] + "," + first.input.dep_labels[h] in gc2num: 121 | gc2num[first.input.dep_labels[k] + "," + first.input.dep_labels[h]] += 1 122 | else: 123 | gc2num[first.input.dep_labels[k] + "," + first.input.dep_labels[h]] = 1 124 | 125 | total_entity +=1 126 | 127 | print(num, total_entity) 128 | print(type2num) 129 | print("length 2 number: {}".format(length2num)) 130 | 131 | print() 132 | print("dependency label 2 num: {}".format(dep_label2num)) 133 | total_amount = sum([dep_label2num[key] for key in dep_label2num]) 134 | print("total number of dep 2 num: {}".format(total_amount)) 135 | print() 136 | 137 | counts = [(key, dep_label2num[key]) for key in dep_label2num] 138 | counts = sorted(counts, key=lambda vals: vals[1], reverse=True) 139 | print(counts) 140 | print() 141 | 142 | print(gc2num) 143 | counts = [(key, gc2num[key]) for key in gc2num] 144 | counts = sorted(counts, key=lambda vals: vals[1], reverse=True) 145 | print(counts) 146 | -------------------------------------------------------------------------------- /analysis/analyzer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from common.sentence import Sentence 3 | from common.instance import Instance 4 | from typing import List 5 | 6 | 7 | def read_conll(res_file: str, number: int = -1) -> List[Instance]: 8 | print("Reading file: " + res_file) 9 | insts = [] 10 | # vocab = set() ## build the vocabulary 11 | with open(res_file, 'r', encoding='utf-8') as f: 12 | words = [] 13 | heads = [] 14 | deps = [] 15 | labels = [] 16 | tags = [] 17 | preds = [] 18 | for line in tqdm(f.readlines()): 19 | line = line.rstrip() 20 | if line == "": 21 | inst = Instance(Sentence(words, heads, deps, tags), labels) 22 | inst.prediction = preds 23 | insts.append(inst) 24 | words = [] 25 | heads = [] 26 | deps = [] 27 | labels = [] 28 | tags = [] 29 | preds = [] 30 | 31 | if len(insts) == number: 32 | break 33 | continue 34 | vals = line.split() 35 | word = vals[1] 36 | pos = vals[2] 37 | head = int(vals[3]) 38 | dep_label = vals[4] 39 | 40 | label = vals[5] 41 | pred_label = vals[6] 42 | 43 | words.append(word) 44 | heads.append(head) ## because of 0-indexed. 45 | deps.append(dep_label) 46 | tags.append(pos) 47 | labels.append(label) 48 | preds.append(pred_label) 49 | print("number of sentences: {}".format(len(insts))) 50 | return insts 51 | 52 | res_file = "../results/lstm_200_crf_conll2003_-1_dep_none_elmo_1_sgd_gate_0.results" 53 | insts = read_conll(res_file) 54 | 55 | total = 0 56 | total_word = 0 57 | for inst in insts: 58 | gold = inst.output 59 | prediction = inst.prediction 60 | words = inst.input.words 61 | heads = inst.input.heads 62 | dep_labels = inst.input.dep_labels 63 | have_error= False 64 | for idx in range(len(gold)): 65 | if gold[idx] != 'O' and prediction[idx] == 'O': 66 | have_error = True 67 | total_word += 1 68 | print("{}\t{}\t{}\t{}\t{}\t{}\t".format(idx, words[idx], heads[idx]+1, dep_labels[idx], gold[idx], prediction[idx])) 69 | if have_error: 70 | print(words) 71 | print(gold) 72 | print(prediction) 73 | total +=1 74 | print() 75 | print("number of sentences have errors: {}".format(total)) 76 | print("number of words have errors: {}".format(total_word)) 77 | 78 | -------------------------------------------------------------------------------- /analysis/comparator.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | from tqdm import tqdm 5 | from common.sentence import Sentence 6 | from common.instance import Instance 7 | from typing import List 8 | 9 | 10 | 11 | def read_conll(res_file: str, number: int = -1) -> List[Instance]: 12 | print("Reading file: " + res_file) 13 | insts = [] 14 | # vocab = set() ## build the vocabulary 15 | with open(res_file, 'r', encoding='utf-8') as f: 16 | words = [] 17 | heads = [] 18 | deps = [] 19 | labels = [] 20 | tags = [] 21 | preds = [] 22 | for line in tqdm(f.readlines()): 23 | line = line.rstrip() 24 | if line == "": 25 | inst = Instance(Sentence(words, heads, deps, tags), labels) 26 | inst.prediction = preds 27 | insts.append(inst) 28 | words = [] 29 | heads = [] 30 | deps = [] 31 | labels = [] 32 | tags = [] 33 | preds = [] 34 | 35 | if len(insts) == number: 36 | break 37 | continue 38 | vals = line.split() 39 | word = vals[1] 40 | pos = vals[2] 41 | head = int(vals[3]) 42 | dep_label = vals[4] 43 | 44 | label = vals[5] 45 | pred_label = vals[6] 46 | 47 | words.append(word) 48 | heads.append(head) ## because of 0-indexed. 49 | deps.append(dep_label) 50 | tags.append(pos) 51 | labels.append(label) 52 | preds.append(pred_label) 53 | print("number of sentences: {}".format(len(insts))) 54 | return insts 55 | 56 | 57 | 58 | lgcn_file = "../final_results/lstm_200_crf_ontonotes_sd_-1_dep_lstm_lgcn_elmo_elmo_sgd_gate_0_epoch_100_lr_0.01.results" 59 | elmo_file = "../final_results/lstm_200_crf_ontonotes_.sd_-1_dep_none_elmo_elmo_sgd_gate_0_epoch_100_lr_0.01.results" 60 | lgcn_res = read_conll(lgcn_file) 61 | elmo_res = read_conll(elmo_file) 62 | 63 | 64 | 65 | 66 | total = 0 67 | total_word = 0 68 | for dep_inst, inst in zip(lgcn_res, elmo_res): 69 | gold = inst.output 70 | normal_pred = inst.prediction 71 | dep_pred = dep_inst.prediction 72 | words = inst.input.words 73 | heads = inst.input.heads 74 | dep_labels = inst.input.dep_labels 75 | have_error= False 76 | for idx in range(len(gold)): 77 | if normal_pred[idx] != dep_pred[idx]: 78 | if gold[idx] == dep_pred[idx]: 79 | print("{}\t{}\t{}\t{}\t{}\t{}\t".format(idx, words[idx], heads[idx] + 1, dep_labels[idx], gold[idx], normal_pred[idx])) 80 | print("") 81 | -------------------------------------------------------------------------------- /analysis/compare_heatmap.py: -------------------------------------------------------------------------------- 1 | from config.reader import Reader 2 | 3 | from common.sentence import Sentence 4 | from common.instance import Instance 5 | from typing import List 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | import seaborn as sns; sns.set(font_scale=0.8) 10 | import matplotlib.pyplot as plt 11 | import random 12 | 13 | 14 | def read_results(res_file: str, number: int = -1) -> List[Instance]: 15 | print("Reading file: " + res_file) 16 | insts = [] 17 | # vocab = set() ## build the vocabulary 18 | with open(res_file, 'r', encoding='utf-8') as f: 19 | words = [] 20 | heads = [] 21 | deps = [] 22 | labels = [] 23 | tags = [] 24 | preds = [] 25 | for line in tqdm(f.readlines()): 26 | line = line.rstrip() 27 | if line == "": 28 | inst = Instance(Sentence(words, heads, deps, tags), labels) 29 | inst.prediction = preds 30 | insts.append(inst) 31 | words = [] 32 | heads = [] 33 | deps = [] 34 | labels = [] 35 | tags = [] 36 | preds = [] 37 | 38 | if len(insts) == number: 39 | break 40 | continue 41 | vals = line.split() 42 | word = vals[1] 43 | pos = vals[2] 44 | head = int(vals[3]) 45 | dep_label = vals[4] 46 | 47 | label = vals[5] 48 | pred_label = vals[6] 49 | 50 | words.append(word) 51 | heads.append(head) ## because of 0-indexed. 52 | deps.append(dep_label) 53 | tags.append(pos) 54 | labels.append(label) 55 | preds.append(pred_label) 56 | print("number of sentences: {}".format(len(insts))) 57 | return insts 58 | 59 | 60 | # file = "data/ontonotes/test.sd.conllx" 61 | # digit2zero = False 62 | # reader = Reader(digit2zero) 63 | # 64 | # insts = reader.read_conll(file, -1, True) 65 | 66 | file = "final_results/lstm_2_200_crf_ontonotes_sd_-1_dep_feat_emb_elmo_none_sgd_gate_0_base_-1_epoch_150_lr_0.01.results" 67 | insts = read_results(file) ##change inst.output -> inst.prediction 68 | 69 | comp_file = "final_results/lstm_2_200_crf_ontonotes_sd_-1_dep_none_elmo_none_sgd_gate_0_base_-1_epoch_100_lr_0.01.results" 70 | comp_insts = read_results(comp_file) ##change inst.output -> inst.prediction 71 | 72 | entities = set([ label[2:] for inst in insts for label in inst.output if len(label)>1]) 73 | print(entities) 74 | dep_labels = set([ dep for inst in insts for label, dep in zip(inst.prediction, inst.input.dep_labels) if len(label)>1] ) 75 | print(len(dep_labels), dep_labels) 76 | 77 | ### add grandchild relation as well. 78 | for inst in insts: 79 | for head, dep in zip(inst.input.heads, inst.input.dep_labels): 80 | if head == -1: 81 | continue 82 | dep_labels.add(dep+", " + inst.input.dep_labels[head]) 83 | print(len(dep_labels), dep_labels) 84 | ### 85 | 86 | 87 | ent2idx = {} 88 | ents = list(entities) 89 | ents.sort() 90 | for i, label in enumerate(ents): 91 | ent2idx[label] = i 92 | 93 | 94 | dep2idx = {} 95 | deps = list(dep_labels) 96 | deps.sort() 97 | for i, label in enumerate(deps): 98 | dep2idx[label] = i 99 | 100 | ent_dep_mat = np.zeros((len(entities), len(dep_labels))) 101 | print(ent_dep_mat.shape) 102 | for inst, comp_inst in zip(insts,comp_insts): 103 | for label, dep, gold, comp_label, head in zip(inst.prediction, inst.input.dep_labels, inst.output, comp_inst.prediction, inst.input.heads): 104 | if gold == "O": 105 | continue 106 | if label == "O": 107 | continue 108 | if label == gold and label != comp_label: 109 | ent_dep_mat[ent2idx[label[2:]]][dep2idx[dep]] += 1 110 | if head != -1: 111 | if inst.output[head] == inst.prediction[head] and inst.output[head] != comp_inst.prediction[head]: 112 | ent_dep_mat[ent2idx[label[2:]]][dep2idx[dep+", " + inst.input.dep_labels[head]]] += 1 113 | 114 | sum_labels = [ sum(ent_dep_mat[i]) for i in range(ent_dep_mat.shape[0])] 115 | ent_dep_mat = np.stack([ (ent_dep_mat[i]/sum_labels[i]) * 100 for i in range(ent_dep_mat.shape[0])], axis=0) 116 | print(ent_dep_mat.shape) 117 | 118 | indexs = [i for i in range(ent_dep_mat.shape[1]) if len(ent_dep_mat[:,i][ ent_dep_mat[:,i] >5.0 ]) ] 119 | print(np.asarray(deps)[indexs]) 120 | 121 | xlabels = [deps[i] for i in indexs] 122 | # cmap = sns.light_palette("#2ecc71", as_cmap=True) 123 | # cmap = sns.light_palette("#8e44ad", as_cmap=True) 124 | cmap = sns.cubehelix_palette(8,as_cmap=True) 125 | ax = sns.heatmap(ent_dep_mat[:, indexs], annot=True, vmin=0, vmax=100, cmap=cmap,fmt='.0f', xticklabels=xlabels, yticklabels=ents, cbar=True) 126 | # ,annot_kws = {"size": 10}) 127 | # , cbar_kws={'label': 'percentage (%)'}) 128 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 129 | plt.show() -------------------------------------------------------------------------------- /analysis/evaluate_result.py: -------------------------------------------------------------------------------- 1 | 2 | from common.sentence import Sentence 3 | from common.instance import Instance 4 | from typing import List 5 | from config.eval import Span 6 | from tqdm import tqdm 7 | 8 | from config.eval import evaluate 9 | 10 | def read_conll(res_file: str, number: int = -1) -> List[Instance]: 11 | print("Reading file: " + res_file) 12 | insts = [] 13 | # vocab = set() ## build the vocabulary 14 | with open(res_file, 'r', encoding='utf-8') as f: 15 | words = [] 16 | heads = [] 17 | deps = [] 18 | labels = [] 19 | tags = [] 20 | preds = [] 21 | for line in tqdm(f.readlines()): 22 | line = line.rstrip() 23 | if line == "": 24 | inst = Instance(Sentence(words, heads, deps, tags), labels) 25 | inst.prediction = preds 26 | insts.append(inst) 27 | words = [] 28 | heads = [] 29 | deps = [] 30 | labels = [] 31 | tags = [] 32 | preds = [] 33 | 34 | if len(insts) == number: 35 | break 36 | continue 37 | vals = line.split() 38 | word = vals[1] 39 | pos = vals[2] 40 | head = int(vals[3]) 41 | dep_label = vals[4] 42 | 43 | label = vals[5] 44 | pred_label = vals[6] 45 | 46 | words.append(word) 47 | heads.append(head) ## because of 0-indexed. 48 | deps.append(dep_label) 49 | tags.append(pos) 50 | labels.append(label) 51 | preds.append(pred_label) 52 | print("number of sentences: {}".format(len(insts))) 53 | return insts 54 | 55 | 56 | 57 | res1 = "./final_results/lstm_2_200_crf_semes_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_300_lr_0.01_doubledep_0_comb_3.results" 58 | insts1 = read_conll(res1) 59 | 60 | 61 | print(evaluate(insts1)) -------------------------------------------------------------------------------- /analysis/heatmap.py: -------------------------------------------------------------------------------- 1 | from config.reader import Reader 2 | 3 | from common.sentence import Sentence 4 | from common.instance import Instance 5 | from typing import List 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | import seaborn as sns; sns.set(font_scale=0.8) 10 | import matplotlib.pyplot as plt 11 | import random 12 | 13 | 14 | def read_results(res_file: str, number: int = -1) -> List[Instance]: 15 | print("Reading file: " + res_file) 16 | insts = [] 17 | # vocab = set() ## build the vocabulary 18 | with open(res_file, 'r', encoding='utf-8') as f: 19 | words = [] 20 | heads = [] 21 | deps = [] 22 | labels = [] 23 | tags = [] 24 | preds = [] 25 | for line in tqdm(f.readlines()): 26 | line = line.rstrip() 27 | if line == "": 28 | inst = Instance(Sentence(words, heads, deps, tags), labels) 29 | inst.prediction = preds 30 | insts.append(inst) 31 | words = [] 32 | heads = [] 33 | deps = [] 34 | labels = [] 35 | tags = [] 36 | preds = [] 37 | 38 | if len(insts) == number: 39 | break 40 | continue 41 | vals = line.split() 42 | word = vals[1] 43 | pos = vals[2] 44 | head = int(vals[3]) 45 | dep_label = vals[4] 46 | 47 | label = vals[5] 48 | pred_label = vals[6] 49 | 50 | words.append(word) 51 | heads.append(head) ## because of 0-indexed. 52 | deps.append(dep_label) 53 | tags.append(pos) 54 | labels.append(label) 55 | preds.append(pred_label) 56 | print("number of sentences: {}".format(len(insts))) 57 | return insts 58 | 59 | 60 | # file = "data/ontonotes/test.sd.conllx" 61 | # digit2zero = False 62 | # reader = Reader(digit2zero) 63 | # 64 | # insts = reader.read_conll(file, -1, True) 65 | 66 | file = "final_results/lstm_2_200_crf_ontonotes_sd_-1_dep_feat_emb_elmo_none_sgd_gate_0_base_-1_epoch_150_lr_0.01.results" 67 | insts = read_results(file) ##change inst.output -> inst.prediction 68 | 69 | entities = set([ label[2:] for inst in insts for label in inst.output if len(label)>1]) 70 | print(entities) 71 | dep_labels = set([ dep for inst in insts for label, dep in zip(inst.prediction, inst.input.dep_labels) if len(label)>1] ) 72 | print(len(dep_labels), dep_labels) 73 | 74 | ent2idx = {} 75 | ents = list(entities) 76 | ents.sort() 77 | for i, label in enumerate(ents): 78 | ent2idx[label] = i 79 | 80 | 81 | dep2idx = {} 82 | deps = list(dep_labels) 83 | deps.sort() 84 | for i, label in enumerate(deps): 85 | dep2idx[label] = i 86 | 87 | ent_dep_mat = np.zeros((len(entities), len(dep_labels))) 88 | print(ent_dep_mat.shape) 89 | for inst in insts: 90 | for label, dep in zip(inst.prediction, inst.input.dep_labels): 91 | if label == "O": 92 | continue 93 | ent_dep_mat[ent2idx[label[2:]]] [dep2idx[dep]] += 1 94 | 95 | sum_labels = [ sum(ent_dep_mat[i]) for i in range(ent_dep_mat.shape[0])] 96 | ent_dep_mat = np.stack([ (ent_dep_mat[i]/sum_labels[i]) * 100 for i in range(ent_dep_mat.shape[0])], axis=0) 97 | print(ent_dep_mat.shape) 98 | 99 | indexs = [i for i in range(ent_dep_mat.shape[1]) if len(ent_dep_mat[:,i][ ent_dep_mat[:,i] >5.0 ]) ] 100 | print(np.asarray(deps)[indexs]) 101 | 102 | xlabels = [deps[i] for i in indexs] 103 | # cmap = sns.light_palette("#2ecc71", as_cmap=True) 104 | # cmap = sns.light_palette("#8e44ad", as_cmap=True) 105 | cmap = sns.cubehelix_palette(8,as_cmap=True) 106 | ax = sns.heatmap(ent_dep_mat[:, indexs], annot=True, vmin=0, vmax=100, cmap=cmap,fmt='.0f', xticklabels=xlabels, yticklabels=ents, cbar=True) 107 | # ,annot_kws = {"size": 10}) 108 | # , cbar_kws={'label': 'percentage (%)'}) 109 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 110 | plt.show() -------------------------------------------------------------------------------- /analysis/intro_examples.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from config.reader import Reader 4 | from config.eval import Span 5 | 6 | 7 | def get_spans(output): 8 | output_spans = set() 9 | start = -1 10 | for i in range(len(output)): 11 | if output[i].startswith("B-"): 12 | start = i 13 | if output[i].startswith("E-"): 14 | end = i 15 | output_spans.add(Span(start, end, output[i][2:])) 16 | if output[i].startswith("S-"): 17 | output_spans.add(Span(i, i, output[i][2:])) 18 | return output_spans 19 | 20 | def use_iobes(insts): 21 | for inst in insts: 22 | output = inst.output 23 | for pos in range(len(inst)): 24 | curr_entity = output[pos] 25 | if pos == len(inst) - 1: 26 | if curr_entity.startswith("B-"): 27 | output[pos] = curr_entity.replace("B-", "S-") 28 | elif curr_entity.startswith("I-"): 29 | output[pos] = curr_entity.replace("I-", "E-") 30 | else: 31 | next_entity = output[pos + 1] 32 | if curr_entity.startswith("B-"): 33 | if next_entity.startswith("O") or next_entity.startswith("B-"): 34 | output[pos] = curr_entity.replace("B-", "S-") 35 | elif curr_entity.startswith("I-"): 36 | if next_entity.startswith("O") or next_entity.startswith("B-"): 37 | output[pos] = curr_entity.replace("I-", "E-") 38 | 39 | 40 | file = "data/ontonotes/train.sd.conllx" 41 | digit2zero = False 42 | reader = Reader(digit2zero) 43 | 44 | insts = reader.read_conll(file, -1, True) 45 | use_iobes(insts) 46 | 47 | for i in range(len(insts)): 48 | 49 | inst = insts[i] 50 | gold_spans = get_spans(inst.output) 51 | 52 | 53 | for span in gold_spans: 54 | ent_words = ' '.join(inst.input.words[span.left:span.right+1]) 55 | # if ent_words.islower() and span.type != "DATE" and span.type != "ORDINAL" and span.type != "PERCENT"\ 56 | # and span.type != "CARDINAL" and span.type != "MONEY" and span.type != "QUANTITY" and span.type != "TIME" \ 57 | # and span.type != "NORP" and span.type != "PERSON": 58 | # print(ent_words + " " + span.type) 59 | # print(inst.input.words) 60 | # print() 61 | for k in range(span.left, span.right + 1): 62 | head_k = inst.input.heads[k] 63 | if abs (head_k - k) >= 4 and span.type != "DATE" and span.type != "ORDINAL" and span.type != "PERCENT" \ 64 | and ent_words.islower() and span.type != "MONEY" and span.type != "QUANTITY" and span.type != "TIME" and span.type != "CARDINAL" : 65 | print(ent_words + " " + span.type) 66 | print(inst.input.words) 67 | print() 68 | # if span.right - span.left >= 4 and span.type != "DATE" and span.type != "ORDINAL" and span.type != "PERCENT" \ 69 | # and ent_words.islower() and span.type != "MONEY" and span.type != "QUANTITY" and span.type != "TIME" and span.type != "CARDINAL" : 70 | # print(ent_words + " " + span.type) 71 | # print(inst.input.words) 72 | # print() 73 | ## book of the dead. -------------------------------------------------------------------------------- /analysis/length.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | 6 | from config.reader import Reader 7 | import numpy as np 8 | 9 | import matplotlib.pyplot as plt 10 | import random 11 | from config.eval import evaluate, Span 12 | 13 | from collections import defaultdict 14 | 15 | 16 | 17 | def use_iobes(insts): 18 | for inst in insts: 19 | output = inst.output 20 | for pos in range(len(inst)): 21 | curr_entity = output[pos] 22 | if pos == len(inst) - 1: 23 | if curr_entity.startswith("B-"): 24 | output[pos] = curr_entity.replace("B-", "S-") 25 | elif curr_entity.startswith("I-"): 26 | output[pos] = curr_entity.replace("I-", "E-") 27 | else: 28 | next_entity = output[pos + 1] 29 | if curr_entity.startswith("B-"): 30 | if next_entity.startswith("O") or next_entity.startswith("B-"): 31 | output[pos] = curr_entity.replace("B-", "S-") 32 | elif curr_entity.startswith("I-"): 33 | if next_entity.startswith("O") or next_entity.startswith("B-"): 34 | output[pos] = curr_entity.replace("I-", "E-") 35 | 36 | 37 | dataset = "ontonotes_chinese" 38 | train = "../data/"+dataset+"/train.sd.conllx" 39 | dev = "../data/"+dataset+"/dev.sd.conllx" 40 | test = "../data/"+dataset+"/test.sd.conllx" 41 | digit2zero = False 42 | reader = Reader(digit2zero) 43 | 44 | insts = reader.read_conll(train, -1, True) 45 | insts += reader.read_conll(dev, -1, False) 46 | insts += reader.read_conll(test, -1, False) 47 | use_iobes(insts) 48 | L = 3 49 | 50 | 51 | def get_spans(output): 52 | output_spans = set() 53 | start = -1 54 | for i in range(len(output)): 55 | if output[i].startswith("B-"): 56 | start = i 57 | if output[i].startswith("E-"): 58 | end = i 59 | output_spans.add(Span(start, end, output[i][2:])) 60 | if output[i].startswith("S-"): 61 | output_spans.add(Span(i, i, output[i][2:])) 62 | return output_spans 63 | 64 | count_all = 0 65 | count_have_sub = 0 66 | count_grand = 0 67 | length2num = defaultdict(int) 68 | for inst in insts: 69 | output = inst.output 70 | spans = get_spans(output) 71 | # print(spans) 72 | for span in spans: 73 | length2num[span.right - span.left + 1] += 1 74 | if span.right - span.left + 1 < L: 75 | continue 76 | count_dep = 0 77 | count_all += 1 78 | has_grand = False 79 | for i in range(span.left, span.right + 1): 80 | if inst.input.heads[i] >= span.left and inst.input.heads[i] <= span.right: 81 | count_dep += 1 82 | if inst.input.heads[i] >= span.left and inst.input.heads[i] <= span.right: 83 | head_i = inst.input.heads[i] 84 | if head_i != -1 and inst.input.heads[head_i] >= span.left and inst.input.heads[head_i] <= span.right: 85 | 86 | has_grand = True 87 | if has_grand: 88 | count_grand += 1 89 | if count_dep == (span.right - span.left): 90 | count_have_sub += 1 91 | else: 92 | pass 93 | # print(inst.input.words) 94 | 95 | 96 | print(count_have_sub, count_all, count_have_sub/count_all*100) 97 | print(count_grand, count_all, count_grand/count_all*100) 98 | print(length2num) 99 | 100 | -------------------------------------------------------------------------------- /analysis/length_analysis.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | 6 | from config.reader import Reader 7 | import numpy as np 8 | 9 | import matplotlib.pyplot as plt 10 | import random 11 | 12 | from common.sentence import Sentence 13 | from common.instance import Instance 14 | from typing import List 15 | from config.eval import Span 16 | from tqdm import tqdm 17 | 18 | def use_iobes(insts): 19 | for inst in insts: 20 | output = inst.output 21 | for pos in range(len(inst)): 22 | curr_entity = output[pos] 23 | if pos == len(inst) - 1: 24 | if curr_entity.startswith("B-"): 25 | output[pos] = curr_entity.replace("B-", "S-") 26 | elif curr_entity.startswith("I-"): 27 | output[pos] = curr_entity.replace("I-", "E-") 28 | else: 29 | next_entity = output[pos + 1] 30 | if curr_entity.startswith("B-"): 31 | if next_entity.startswith("O") or next_entity.startswith("B-"): 32 | output[pos] = curr_entity.replace("B-", "S-") 33 | elif curr_entity.startswith("I-"): 34 | if next_entity.startswith("O") or next_entity.startswith("B-"): 35 | output[pos] = curr_entity.replace("I-", "E-") 36 | 37 | 38 | 39 | def read_conll(res_file: str, number: int = -1) -> List[Instance]: 40 | print("Reading file: " + res_file) 41 | insts = [] 42 | # vocab = set() ## build the vocabulary 43 | with open(res_file, 'r', encoding='utf-8') as f: 44 | words = [] 45 | heads = [] 46 | deps = [] 47 | labels = [] 48 | tags = [] 49 | preds = [] 50 | for line in tqdm(f.readlines()): 51 | line = line.rstrip() 52 | if line == "": 53 | inst = Instance(Sentence(words, heads, deps, tags), labels) 54 | inst.prediction = preds 55 | insts.append(inst) 56 | words = [] 57 | heads = [] 58 | deps = [] 59 | labels = [] 60 | tags = [] 61 | preds = [] 62 | 63 | if len(insts) == number: 64 | break 65 | continue 66 | vals = line.split() 67 | word = vals[1] 68 | pos = vals[2] 69 | head = int(vals[3]) 70 | dep_label = vals[4] 71 | 72 | label = vals[5] 73 | pred_label = vals[6] 74 | 75 | words.append(word) 76 | heads.append(head) ## because of 0-indexed. 77 | deps.append(dep_label) 78 | tags.append(pos) 79 | labels.append(label) 80 | preds.append(pred_label) 81 | print("number of sentences: {}".format(len(insts))) 82 | return insts 83 | 84 | 85 | def get_spans(output): 86 | output_spans = set() 87 | start = -1 88 | for i in range(len(output)): 89 | if output[i].startswith("B-"): 90 | start = i 91 | if output[i].startswith("E-"): 92 | end = i 93 | output_spans.add(Span(start, end, output[i][2:])) 94 | if output[i].startswith("S-"): 95 | output_spans.add(Span(i, i, output[i][2:])) 96 | return output_spans 97 | 98 | def evaluate(insts, maximum_length = 4): 99 | 100 | p = {} 101 | total_entity = {} 102 | total_predict = {} 103 | 104 | for inst in insts: 105 | 106 | output = inst.output 107 | prediction = inst.prediction 108 | #convert to span 109 | output_spans = set() 110 | start = -1 111 | for i in range(len(output)): 112 | if output[i].startswith("B-"): 113 | start = i 114 | if output[i].startswith("E-"): 115 | end = i 116 | output_spans.add(Span(start, end, output[i][2:])) 117 | if output[i].startswith("S-"): 118 | output_spans.add(Span(i, i, output[i][2:])) 119 | predict_spans = set() 120 | for i in range(len(prediction)): 121 | if prediction[i].startswith("B-"): 122 | start = i 123 | if prediction[i].startswith("E-"): 124 | end = i 125 | predict_spans.add(Span(start, end, prediction[i][2:])) 126 | if prediction[i].startswith("S-"): 127 | predict_spans.add(Span(i, i, prediction[i][2:])) 128 | 129 | # total_entity += len(output_spans) 130 | # total_predict += len(predict_spans) 131 | # p += len(predict_spans.intersection(output_spans)) 132 | 133 | for span in output_spans: 134 | length = span.right - span.left + 1 135 | if length >= maximum_length: 136 | length = maximum_length 137 | if length in total_entity: 138 | total_entity[length] += 1 139 | else: 140 | total_entity[length] = 1 141 | 142 | for span in predict_spans: 143 | length = span.right - span.left + 1 144 | if length >= maximum_length: 145 | length = maximum_length 146 | if length in total_predict: 147 | total_predict[length] += 1 148 | else: 149 | total_predict[length] = 1 150 | 151 | for span in predict_spans.intersection(output_spans): 152 | length = span.right - span.left + 1 153 | if length >= maximum_length: 154 | length = maximum_length 155 | if length in p: 156 | p[length] += 1 157 | else: 158 | p[length] = 1 159 | 160 | max_len = max([key for key in p]) 161 | # precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0 162 | # recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0 163 | # fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0 164 | 165 | f = {} 166 | for length in range(1, max_len + 1): 167 | if length not in p: 168 | continue 169 | precision = p[length] * 1.0 / total_predict[length] * 100 if total_predict[length] != 0 else 0 170 | recall = p[length] * 1.0 / total_entity[length] * 100 if total_entity[length] != 0 else 0 171 | f[length] = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0 172 | 173 | return f 174 | 175 | 176 | def grand_child(insts1, insts2): 177 | num = 0 178 | gc_num = 0 179 | ld_num = 0 180 | for i in range(len(insts1)): 181 | 182 | first = insts1[i] 183 | second = insts2[i] 184 | inst = insts1[i] 185 | gold_spans = get_spans(first.output) 186 | 187 | pred_first = get_spans(first.prediction) 188 | pred_second = get_spans(second.prediction) 189 | 190 | # for span in pred_first: 191 | # if span in gold_spans and (span not in pred_second): 192 | for span in gold_spans: 193 | if span in pred_first and (span not in pred_second): 194 | if span.right - span.left < 2: 195 | continue 196 | num += 1 197 | # print(span.to_str(first.input.words)) 198 | has_grand = False 199 | has_ld = False 200 | for k in range(span.left, span.right + 1): 201 | if inst.input.heads[k] >= span.left and inst.input.heads[k] <= span.right: 202 | head_i = inst.input.heads[k] 203 | if abs(head_i - k) > 1: 204 | has_ld = True 205 | if head_i != -1 and inst.input.heads[head_i] >= span.left and inst.input.heads[ 206 | head_i] <= span.right: 207 | has_grand = True 208 | 209 | if has_grand: 210 | gc_num +=1 211 | if has_ld: 212 | ld_num += 1 213 | return gc_num, ld_num, num 214 | 215 | 216 | ## Chinese Comparison 217 | res1 = "./final_results/lstm_2_200_crf_ontonotes_chinese_sd_-1_dep_none_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results" 218 | insts1 = read_conll(res1) 219 | 220 | res2 = "./final_results/lstm_2_200_crf_ontonotes_chinese_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_100_lr_0.01_doubledep_0_comb_3.results" 221 | insts2 = read_conll(res2) 222 | 223 | # Catalan Comparison 224 | # res1 = "./final_results/lstm_2_200_crf_semca_sd_-1_dep_none_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results" 225 | # insts1 = read_conll(res1) 226 | # 227 | # res2 = "./final_results/lstm_2_200_crf_semca_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_300_lr_0.01_doubledep_0_comb_3.results" 228 | # insts2 = read_conll(res2) 229 | 230 | 231 | ## Spanish Comparison 232 | # res1 = "./final_results/lstm_2_200_crf_semes_sd_-1_dep_none_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results" 233 | # insts1 = read_conll(res1) 234 | # 235 | # res2 = "./final_results/lstm_2_200_crf_semes_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_300_lr_0.01_doubledep_0_comb_3.results" 236 | # insts2 = read_conll(res2) 237 | 238 | 239 | 240 | 241 | maximum_length = 6 242 | print(evaluate(insts1, maximum_length)) 243 | print(evaluate(insts2, maximum_length)) 244 | 245 | print(grand_child(insts1, insts2)) 246 | -------------------------------------------------------------------------------- /analysis/significant.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | 6 | from tqdm import tqdm 7 | from common.sentence import Sentence 8 | from common.instance import Instance 9 | from typing import List 10 | from config.eval import evaluate 11 | import random 12 | 13 | def read_conll(res_file: str, number: int = -1) -> List[Instance]: 14 | print("Reading file: " + res_file) 15 | insts = [] 16 | # vocab = set() ## build the vocabulary 17 | with open(res_file, 'r', encoding='utf-8') as f: 18 | words = [] 19 | heads = [] 20 | deps = [] 21 | labels = [] 22 | tags = [] 23 | preds = [] 24 | for line in tqdm(f.readlines()): 25 | line = line.rstrip() 26 | if line == "": 27 | inst = Instance(Sentence(words, heads, deps, tags), labels) 28 | inst.prediction = preds 29 | insts.append(inst) 30 | words = [] 31 | heads = [] 32 | deps = [] 33 | labels = [] 34 | tags = [] 35 | preds = [] 36 | 37 | if len(insts) == number: 38 | break 39 | continue 40 | vals = line.split() 41 | word = vals[1] 42 | pos = vals[2] 43 | head = int(vals[3]) 44 | dep_label = vals[4] 45 | 46 | label = vals[5] 47 | pred_label = vals[6] 48 | 49 | words.append(word) 50 | heads.append(head) ## because of 0-indexed. 51 | deps.append(dep_label) 52 | tags.append(pos) 53 | labels.append(label) 54 | preds.append(pred_label) 55 | print("number of sentences: {}".format(len(insts))) 56 | return insts 57 | 58 | res1 = "../final_results/lstm_2_200_crf_ontonotes_chinese_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results" 59 | insts1 = read_conll(res1) 60 | 61 | res2 = "../final_results/lstm_1_200_crf_ontonotes_chinese_sd_-1_dep_none_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results" 62 | insts2 = read_conll(res2) 63 | 64 | 65 | sample_num = 10000 66 | 67 | p = 0 68 | for i in range(sample_num): 69 | sinsts = [] 70 | sinsts_2 = [] 71 | for _ in range(len(insts1)): 72 | n = random.randint(0, len(insts1) - 1) 73 | sinsts.append(insts1[n]) 74 | sinsts_2.append(insts2[n]) 75 | 76 | f1 = evaluate(sinsts)[2] 77 | f2= evaluate(sinsts_2)[2] 78 | 79 | if f1 > f2: 80 | p += 1 81 | 82 | p_val = (i + 1 - p) / (i+1) 83 | print("current p value: {}".format(p_val)) 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /analysis/stator.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | from config.reader import Reader 6 | from collections import defaultdict 7 | 8 | file = "../data/ontonotes/dev.sd.conllx" 9 | digit2zero = False 10 | reader = Reader(digit2zero) 11 | 12 | insts = reader.read_conll(file, -1, True) 13 | # devs = reader.read_conll(conf.dev_file, conf.dev_num, False) 14 | # tests = reader.read_conll(conf.test_file, conf.test_num, False) 15 | 16 | out_dep_label2num = {} 17 | 18 | out_doubledep2num = {} 19 | 20 | out_word2num = {} 21 | 22 | label2idx = {} 23 | 24 | ent2num = defaultdict(int) 25 | 26 | def not_entity(label:str): 27 | if label.startswith("B-") or label.startswith("I-"): 28 | return False 29 | return True 30 | 31 | def is_entity(label:str): 32 | if label.startswith("B-") or label.startswith("I-"): 33 | return True 34 | return False 35 | 36 | for inst in insts: 37 | output = inst.output 38 | sent = inst.input 39 | 40 | for idx, (word, head_idx, ent, dep) in enumerate(zip(sent.words, sent.heads, output, sent.dep_labels)): 41 | if ent.startswith('B-'): 42 | ent2num[ent[2:]] += 1 43 | 44 | if dep not in label2idx: 45 | label2idx[dep] = len(label2idx) 46 | if is_entity(ent): 47 | if head_idx == -1 or not_entity(output[head_idx]): 48 | if dep in out_dep_label2num: 49 | out_dep_label2num[dep] +=1 50 | else: 51 | out_dep_label2num[dep] = 1 52 | head_word = "root" if head_idx == -1 else sent.words[head_idx] 53 | if head_word in out_word2num: 54 | out_word2num[head_word] += 1 55 | else: 56 | out_word2num[head_word] = 1 57 | 58 | if head_idx != -1: 59 | head_dep = sent.dep_labels[head_idx] 60 | if (head_dep, dep) in out_doubledep2num: 61 | out_doubledep2num[(head_dep, dep)] += 1 62 | else: 63 | out_doubledep2num[(head_dep, dep)] = 1 64 | 65 | 66 | counts = [(key, out_dep_label2num[key]) for key in out_dep_label2num] 67 | counts = sorted(counts, key=lambda vals: vals[1], reverse=True) 68 | total_ent_dep = sum([nums[1] for nums in counts]) 69 | print(counts) 70 | print("total is {}".format(total_ent_dep)) 71 | 72 | 73 | # counts = [(key, out_word2num[key]) for key in out_word2num] 74 | # counts = sorted(counts, key=lambda vals: vals[1], reverse=True) 75 | # total_ent_dep = sum([nums[1] for nums in counts]) 76 | # print(counts) 77 | # print("total is {}".format(total_ent_dep)) 78 | 79 | 80 | counts = [(key, out_doubledep2num[key]) for key in out_doubledep2num] 81 | counts = sorted(counts, key=lambda vals: vals[1], reverse=True) 82 | total_ent_dep = sum([nums[1] for nums in counts]) 83 | print(counts) 84 | print("total is {}".format(total_ent_dep)) 85 | 86 | 87 | 88 | print(f"entity2number: {ent2num}") 89 | -------------------------------------------------------------------------------- /common/instance.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | from common.sentence import Sentence 5 | class Instance: 6 | 7 | def __init__(self, input: Sentence, output): 8 | self.input = input 9 | self.output = output 10 | self.elmo_vec = None 11 | self.word_ids = None 12 | self.char_ids = None 13 | self.dep_label_ids = None 14 | self.dep_head_ids = None 15 | self.output_ids = None 16 | 17 | def __len__(self): 18 | return len(self.input) 19 | -------------------------------------------------------------------------------- /common/sentence.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | from typing import List 6 | 7 | class Sentence: 8 | 9 | def __init__(self, words: List[str], heads: List[int]=None , dep_labels: List[str]=None, pos_tags:List[str] = None): 10 | self.words = words 11 | self.heads = heads 12 | self.dep_labels = dep_labels 13 | self.pos_tags = pos_tags 14 | 15 | def __len__(self): 16 | return len(self.words) 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | # if __name__ == "__main__": 26 | # 27 | # words = ["a" ,"sdfsdf"] 28 | # sent = Sentence(words) 29 | # 30 | # print(len(sent)) -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from config.config import DepModelType, ContextEmb, InteractionFunction -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | from typing import List 8 | from common.instance import Instance 9 | from config.utils import PAD, START, STOP, ROOT, ROOT_DEP_LABEL, SELF_DEP_LABEL 10 | import torch 11 | from enum import Enum 12 | from termcolor import colored 13 | 14 | class DepModelType(Enum): 15 | none = 0 16 | dglstm = 1 17 | dggcn = 2 18 | 19 | 20 | class ContextEmb(Enum): 21 | none = 0 22 | elmo = 1 23 | bert = 2 24 | flair = 3 25 | 26 | 27 | class InteractionFunction(Enum): 28 | concatenation = 0 29 | addition = 1 30 | mlp = 2 31 | 32 | 33 | 34 | class Config: 35 | def __init__(self, args): 36 | 37 | self.PAD = PAD 38 | self.B = "B-" 39 | self.I = "I-" 40 | self.S = "S-" 41 | self.E = "E-" 42 | self.O = "O" 43 | self.START_TAG = START 44 | self.STOP_TAG = STOP 45 | self.ROOT = ROOT 46 | self.UNK = "" 47 | self.unk_id = -1 48 | self.root_dep_label = ROOT_DEP_LABEL 49 | self.self_label = SELF_DEP_LABEL 50 | 51 | print(colored("[Info] remember to chec the root dependency label if changing the data. current: {}".format(self.root_dep_label), "red" )) 52 | 53 | # self.device = torch.device("cuda" if args.gpu else "cpu") 54 | self.embedding_file = args.embedding_file 55 | self.embedding_dim = args.embedding_dim 56 | self.context_emb = ContextEmb[args.context_emb] 57 | self.context_emb_size = 0 58 | self.embedding, self.embedding_dim = self.read_pretrain_embedding() 59 | self.word_embedding = None 60 | self.seed = args.seed 61 | self.digit2zero = args.digit2zero 62 | 63 | self.dataset = args.dataset 64 | 65 | self.affix = args.affix 66 | train_affix = self.affix.replace("pred", "") if "pred" in self.affix else self.affix 67 | self.train_file = "data/" + self.dataset + "/train."+train_affix+".conllx" 68 | self.dev_file = "data/" + self.dataset + "/dev."+train_affix+".conllx" 69 | self.test_file = "data/" + self.dataset + "/test."+self.affix+".conllx" 70 | self.label2idx = {} 71 | self.idx2labels = [] 72 | self.char2idx = {} 73 | self.idx2char = [] 74 | self.num_char = 0 75 | 76 | 77 | self.optimizer = args.optimizer.lower() 78 | self.learning_rate = args.learning_rate 79 | self.momentum = args.momentum 80 | self.l2 = args.l2 81 | self.num_epochs = args.num_epochs 82 | # self.lr_decay = 0.05 83 | self.use_dev = True 84 | self.train_num = args.train_num 85 | self.dev_num = args.dev_num 86 | self.test_num = args.test_num 87 | self.batch_size = args.batch_size 88 | self.clip = 5 89 | self.lr_decay = args.lr_decay 90 | self.device = torch.device(args.device) 91 | 92 | self.hidden_dim = args.hidden_dim 93 | self.num_lstm_layer = args.num_lstm_layer 94 | self.use_brnn = True 95 | self.num_layers = 1 96 | self.dropout = args.dropout 97 | self.char_emb_size = 30 98 | self.charlstm_hidden_dim = 50 99 | self.use_char_rnn = args.use_char_rnn 100 | # self.use_head = args.use_head 101 | self.dep_model = DepModelType[args.dep_model] 102 | 103 | self.dep_hidden_dim = args.dep_hidden_dim 104 | self.num_gcn_layers = args.num_gcn_layers 105 | self.gcn_mlp_layers = args.gcn_mlp_layers 106 | self.gcn_dropout = args.gcn_dropout 107 | self.adj_directed = args.gcn_adj_directed 108 | self.adj_self_loop = args.gcn_adj_selfloop 109 | self.edge_gate = args.gcn_gate 110 | 111 | self.dep_emb_size = args.dep_emb_size 112 | self.deplabel2idx = {} 113 | self.deplabels = [] 114 | 115 | 116 | self.eval_epoch = args.eval_epoch 117 | 118 | 119 | self.interaction_func = InteractionFunction[args.inter_func] ## 0:concat, 1: addition, 2:gcn 120 | 121 | 122 | # def print(self): 123 | # print("") 124 | # print("\tuse gpu: " + ) 125 | 126 | ''' 127 | read all the pretrain embeddings 128 | ''' 129 | def read_pretrain_embedding(self): 130 | print("reading the pretraing embedding: %s" % (self.embedding_file)) 131 | if self.embedding_file is None: 132 | print("pretrain embedding in None, using random embedding") 133 | return None, self.embedding_dim 134 | embedding_dim = -1 135 | embedding = dict() 136 | with open(self.embedding_file, 'r', encoding='utf-8') as file: 137 | for line in tqdm(file.readlines()): 138 | line = line.strip() 139 | if len(line) == 0: 140 | continue 141 | tokens = line.split() 142 | if len(tokens) == 2: 143 | continue 144 | if embedding_dim < 0: 145 | embedding_dim = len(tokens) - 1 146 | else: 147 | # print(tokens) 148 | # print(embedding_dim) 149 | # assert (embedding_dim + 1 == len(tokens)) 150 | if (embedding_dim + 1) != len(tokens): 151 | continue 152 | pass 153 | embedd = np.empty([1, embedding_dim]) 154 | embedd[:] = tokens[1:] 155 | first_col = tokens[0] 156 | embedding[first_col] = embedd 157 | return embedding, embedding_dim 158 | 159 | 160 | def build_word_idx(self, train_insts, dev_insts, test_insts): 161 | self.word2idx = dict() 162 | self.idx2word = [] 163 | self.word2idx[self.PAD] = 0 164 | self.idx2word.append(self.PAD) 165 | self.word2idx[self.UNK] = 1 166 | self.unk_id = 1 167 | self.idx2word.append(self.UNK) 168 | 169 | self.word2idx[self.ROOT] = 2 170 | self.idx2word.append(self.ROOT) 171 | 172 | self.char2idx[self.PAD] = 0 173 | self.idx2char.append(self.PAD) 174 | self.char2idx[self.UNK] = 1 175 | self.idx2char.append(self.UNK) 176 | 177 | ##extract char on train, dev, test 178 | for inst in train_insts + dev_insts + test_insts: 179 | for word in inst.input.words: 180 | if word not in self.word2idx: 181 | self.word2idx[word] = len(self.word2idx) 182 | self.idx2word.append(word) 183 | ##extract char only on train 184 | for inst in train_insts: 185 | for word in inst.input.words: 186 | for c in word: 187 | if c not in self.char2idx: 188 | self.char2idx[c] = len(self.idx2char) 189 | self.idx2char.append(c) 190 | self.num_char = len(self.idx2char) 191 | # print(self.idx2word) 192 | # print(self.idx2char) 193 | # for idx, char in enumerate(self.idx2char): 194 | # print(idx, ":", char) 195 | # print("separator") 196 | # for idx, word in enumerate(self.idx2word): 197 | # print(idx, ":", word) 198 | ''' 199 | build the embedding table 200 | obtain the word2idx and idx2word as well. 201 | ''' 202 | def build_emb_table(self): 203 | print("Building the embedding table for vocabulary...") 204 | scale = np.sqrt(3.0 / self.embedding_dim) 205 | if self.embedding is not None: 206 | print("[Info] Use the pretrained word embedding to initialize: %d x %d" % (len(self.word2idx), self.embedding_dim)) 207 | word_found_in_emb_vocab = 0 208 | self.word_embedding = np.empty([len(self.word2idx), self.embedding_dim]) 209 | for word in self.word2idx: 210 | if word in self.embedding: 211 | self.word_embedding[self.word2idx[word], :] = self.embedding[word] 212 | word_found_in_emb_vocab += 1 213 | elif word.lower() in self.embedding: 214 | self.word_embedding[self.word2idx[word], :] = self.embedding[word.lower()] 215 | word_found_in_emb_vocab += 1 216 | else: 217 | # self.word_embedding[self.word2idx[word], :] = self.embedding[self.UNK] 218 | self.word_embedding[self.word2idx[word], :] = np.random.uniform(-scale, scale, [1, self.embedding_dim]) 219 | print(f"[Info] {word_found_in_emb_vocab} out of {len(self.word2idx)} found in the pretrained embedding.") 220 | self.embedding = None 221 | else: 222 | self.word_embedding = np.empty([len(self.word2idx), self.embedding_dim]) 223 | for word in self.word2idx: 224 | self.word_embedding[self.word2idx[word], :] = np.random.uniform(-scale, scale, [1, self.embedding_dim]) 225 | 226 | def build_deplabel_idx(self, insts): 227 | if self.self_label not in self.deplabel2idx: 228 | self.deplabels.append(self.self_label) 229 | self.deplabel2idx[self.self_label] = len(self.deplabel2idx) 230 | for inst in insts: 231 | for label in inst.input.dep_labels: 232 | if label not in self.deplabels: 233 | self.deplabels.append(label) 234 | self.deplabel2idx[label] = len(self.deplabel2idx) 235 | self.root_dep_label_id = self.deplabel2idx[self.root_dep_label] 236 | 237 | def build_label_idx(self, insts): 238 | self.label2idx[self.PAD] = len(self.label2idx) 239 | self.idx2labels.append(self.PAD) 240 | for inst in insts: 241 | for label in inst.output: 242 | if label not in self.label2idx: 243 | self.idx2labels.append(label) 244 | self.label2idx[label] = len(self.label2idx) 245 | 246 | self.label2idx[self.START_TAG] = len(self.label2idx) 247 | self.idx2labels.append(self.START_TAG) 248 | self.label2idx[self.STOP_TAG] = len(self.label2idx) 249 | self.idx2labels.append(self.STOP_TAG) 250 | self.label_size = len(self.label2idx) 251 | print("#labels: " + str(self.label_size)) 252 | print("label 2idx: " + str(self.label2idx)) 253 | 254 | def use_iobes(self, insts): 255 | for inst in insts: 256 | output = inst.output 257 | for pos in range(len(inst)): 258 | curr_entity = output[pos] 259 | if pos == len(inst) - 1: 260 | if curr_entity.startswith(self.B): 261 | output[pos] = curr_entity.replace(self.B, self.S) 262 | elif curr_entity.startswith(self.I): 263 | output[pos] = curr_entity.replace(self.I, self.E) 264 | else: 265 | next_entity = output[pos + 1] 266 | if curr_entity.startswith(self.B): 267 | if next_entity.startswith(self.O) or next_entity.startswith(self.B): 268 | output[pos] = curr_entity.replace(self.B, self.S) 269 | elif curr_entity.startswith(self.I): 270 | if next_entity.startswith(self.O) or next_entity.startswith(self.B): 271 | output[pos] = curr_entity.replace(self.I, self.E) 272 | 273 | def map_insts_ids(self, insts: List[Instance]): 274 | insts_ids = [] 275 | for inst in insts: 276 | words = inst.input.words 277 | inst.word_ids = [] 278 | inst.char_ids = [] 279 | inst.dep_label_ids = [] 280 | inst.dep_head_ids = [] 281 | inst.output_ids = [] 282 | for word in words: 283 | if word in self.word2idx: 284 | inst.word_ids.append(self.word2idx[word]) 285 | else: 286 | inst.word_ids.append(self.word2idx[self.UNK]) 287 | char_id = [] 288 | for c in word: 289 | if c in self.char2idx: 290 | char_id.append(self.char2idx[c]) 291 | else: 292 | char_id.append(self.char2idx[self.UNK]) 293 | inst.char_ids.append(char_id) 294 | for i, head in enumerate(inst.input.heads): 295 | if head == -1: 296 | inst.dep_head_ids.append(i) ## appended it self. 297 | else: 298 | inst.dep_head_ids.append(head) 299 | for label in inst.input.dep_labels: 300 | inst.dep_label_ids.append(self.deplabel2idx[label]) 301 | for label in inst.output: 302 | inst.output_ids.append(self.label2idx[label]) 303 | insts_ids.append([inst.word_ids, inst.char_ids, inst.output_ids]) 304 | return insts_ids 305 | -------------------------------------------------------------------------------- /config/eval.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from typing import Tuple 4 | 5 | from collections import defaultdict 6 | class Span: 7 | 8 | def __init__(self, left, right, type): 9 | self.left = left 10 | self.right = right 11 | self.type = type 12 | 13 | def __eq__(self, other): 14 | return self.left == other.left and self.right == other.right and self.type == other.type 15 | 16 | def __hash__(self): 17 | return hash((self.left, self.right, self.type)) 18 | 19 | def to_str(self, sent): 20 | return str(sent[self.left: (self.right+1)]) + ","+self.type 21 | 22 | ## the input to the evaluation should already have 23 | ## have the predictions which is the label. 24 | ## iobest tagging scheme 25 | ### NOTE: this function is used to evaluate the instances with prediction ready. 26 | def evaluate(insts): 27 | 28 | p = 0 29 | total_entity = 0 30 | total_predict = 0 31 | 32 | batch_p_dict = defaultdict(int) 33 | batch_total_entity_dict = defaultdict(int) 34 | batch_total_predict_dict = defaultdict(int) 35 | 36 | for inst in insts: 37 | 38 | output = inst.output 39 | prediction = inst.prediction 40 | #convert to span 41 | output_spans = set() 42 | start = -1 43 | for i in range(len(output)): 44 | if output[i].startswith("B-"): 45 | start = i 46 | if output[i].startswith("E-"): 47 | end = i 48 | output_spans.add(Span(start, end, output[i][2:])) 49 | batch_total_entity_dict[output[i][2:]] += 1 50 | if output[i].startswith("S-"): 51 | output_spans.add(Span(i, i, output[i][2:])) 52 | batch_total_entity_dict[output[i][2:]] += 1 53 | start = -1 54 | predict_spans = set() 55 | for i in range(len(prediction)): 56 | if prediction[i].startswith("B-"): 57 | start = i 58 | if prediction[i].startswith("E-"): 59 | end = i 60 | predict_spans.add(Span(start, end, prediction[i][2:])) 61 | batch_total_predict_dict[prediction[i][2:]] += 1 62 | if prediction[i].startswith("S-"): 63 | predict_spans.add(Span(i, i, prediction[i][2:])) 64 | batch_total_predict_dict[prediction[i][2:]] += 1 65 | 66 | total_entity += len(output_spans) 67 | total_predict += len(predict_spans) 68 | correct_spans = predict_spans.intersection(output_spans) 69 | p += len(correct_spans) 70 | for span in correct_spans: 71 | batch_p_dict[span.type] += 1 72 | 73 | for key in batch_total_entity_dict: 74 | precision_key, recall_key, fscore_key = get_metric(batch_p_dict[key], batch_total_entity_dict[key], batch_total_predict_dict[key]) 75 | print("[%s] Prec.: %.2f, Rec.: %.2f, F1: %.2f" % (key, precision_key, recall_key, fscore_key)) 76 | 77 | precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0 78 | recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0 79 | fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0 80 | 81 | return [precision, recall, fscore] 82 | 83 | def get_metric(p_num: int, total_num: int, total_predicted_num: int) -> Tuple[float, float, float]: 84 | """ 85 | Return the metrics of precision, recall and f-score, based on the number 86 | (We make this small piece of function in order to reduce the code effort and less possible to have typo error) 87 | :param p_num: 88 | :param total_num: 89 | :param total_predicted_num: 90 | :return: 91 | """ 92 | precision = p_num * 1.0 / total_predicted_num * 100 if total_predicted_num != 0 else 0 93 | recall = p_num * 1.0 / total_num * 100 if total_num != 0 else 0 94 | fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0 95 | return precision, recall, fscore 96 | 97 | 98 | 99 | def evaluate_num(batch_insts, batch_pred_ids, batch_gold_ids, word_seq_lens, idx2label): 100 | """ 101 | evaluate the batch of instances 102 | :param batch_insts: 103 | :param batch_pred_ids: 104 | :param batch_gold_ids: 105 | :param word_seq_lens: 106 | :param idx2label: 107 | :return: 108 | """ 109 | p = 0 110 | total_entity = 0 111 | total_predict = 0 112 | word_seq_lens = word_seq_lens.tolist() 113 | for idx in range(len(batch_pred_ids)): 114 | length = word_seq_lens[idx] 115 | output = batch_gold_ids[idx][:length].tolist() 116 | prediction = batch_pred_ids[idx][:length].tolist() 117 | prediction = prediction[::-1] 118 | output = [idx2label[l] for l in output] 119 | prediction =[idx2label[l] for l in prediction] 120 | batch_insts[idx].prediction = prediction 121 | #convert to span 122 | output_spans = set() 123 | start = -1 124 | for i in range(len(output)): 125 | if output[i].startswith("B-"): 126 | start = i 127 | if output[i].startswith("E-"): 128 | end = i 129 | output_spans.add(Span(start, end, output[i][2:])) 130 | if output[i].startswith("S-"): 131 | output_spans.add(Span(i, i, output[i][2:])) 132 | predict_spans = set() 133 | for i in range(len(prediction)): 134 | if prediction[i].startswith("B-"): 135 | start = i 136 | if prediction[i].startswith("E-"): 137 | end = i 138 | predict_spans.add(Span(start, end, prediction[i][2:])) 139 | if prediction[i].startswith("S-"): 140 | predict_spans.add(Span(i, i, prediction[i][2:])) 141 | 142 | total_entity += len(output_spans) 143 | total_predict += len(predict_spans) 144 | p += len(predict_spans.intersection(output_spans)) 145 | 146 | # precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0 147 | # recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0 148 | # fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0 149 | 150 | return np.asarray([p, total_predict, total_entity], dtype=int) -------------------------------------------------------------------------------- /config/reader.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | from tqdm import tqdm 6 | from common.sentence import Sentence 7 | from common.instance import Instance 8 | from typing import List 9 | import re 10 | import pickle 11 | 12 | class Reader: 13 | 14 | 15 | def __init__(self, digit2zero:bool=True): 16 | self.digit2zero = digit2zero 17 | self.vocab = set() 18 | 19 | def read_conll(self, file: str, number: int = -1, is_train: bool = True) -> List[Instance]: 20 | print("Reading file: " + file) 21 | insts = [] 22 | num_entity = 0 23 | # vocab = set() ## build the vocabulary 24 | find_root = False 25 | with open(file, 'r', encoding='utf-8') as f: 26 | words = [] 27 | heads = [] 28 | deps = [] 29 | labels = [] 30 | tags = [] 31 | for line in tqdm(f.readlines()): 32 | line = line.rstrip() 33 | if line == "": 34 | insts.append(Instance(Sentence(words, heads, deps, tags), labels)) 35 | words = [] 36 | heads = [] 37 | deps = [] 38 | labels = [] 39 | tags = [] 40 | find_root = False 41 | if len(insts) == number: 42 | break 43 | continue 44 | # if "conll2003" in file: 45 | # word, pos, head, dep_label, label = line.split() 46 | # else: 47 | vals = line.split() 48 | word = vals[1] 49 | head = int(vals[6]) 50 | dep_label = vals[7] 51 | pos = vals[3] 52 | label = vals[10] 53 | if self.digit2zero: 54 | word = re.sub('\d', '0', word) # replace digit with 0. 55 | words.append(word) 56 | if head == 0 and find_root: 57 | raise err("already have a root") 58 | heads.append(head - 1) ## because of 0-indexed. 59 | deps.append(dep_label) 60 | tags.append(pos) 61 | self.vocab.add(word) 62 | labels.append(label) 63 | if label.startswith("B-"): 64 | num_entity +=1 65 | print("number of sentences: {}, number of entities: {}".format(len(insts), num_entity)) 66 | return insts 67 | 68 | def read_txt(self, file: str, number: int = -1, is_train: bool = True) -> List[Instance]: 69 | print("Reading file: " + file) 70 | insts = [] 71 | # vocab = set() ## build the vocabulary 72 | with open(file, 'r', encoding='utf-8') as f: 73 | words = [] 74 | labels = [] 75 | tags = [] 76 | for line in tqdm(f.readlines()): 77 | line = line.rstrip() 78 | if line == "": 79 | insts.append(Instance(Sentence(words, None, None, tags), labels)) 80 | words = [] 81 | labels = [] 82 | tags = [] 83 | if len(insts) == number: 84 | break 85 | continue 86 | if "conll2003" in file: 87 | word, pos, label = line.split() 88 | else: 89 | vals = line.split() 90 | word = vals[1] 91 | pos = vals[3] 92 | label = vals[10] 93 | if self.digit2zero: 94 | word = re.sub('\d', '0', word) # replace digit with 0. 95 | words.append(word) 96 | tags.append(pos) 97 | self.vocab.add(word) 98 | labels.append(label) 99 | print("number of sentences: {}".format(len(insts))) 100 | return insts 101 | 102 | def load_elmo_vec(self, file, insts): 103 | f = open(file, 'rb') 104 | all_vecs = pickle.load(f) # variables come out in the order you put them in 105 | f.close() 106 | size = 0 107 | for vec, inst in zip(all_vecs, insts): 108 | inst.elmo_vec = vec 109 | size = vec.shape[1] 110 | # print(str(vec.shape[0]) + ","+ str(len(inst.input.words)) + ", " + str(inst.input.words)) 111 | assert(vec.shape[0] == len(inst.input.words)) 112 | return size 113 | 114 | 115 | -------------------------------------------------------------------------------- /config/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import List 4 | from common.instance import Instance 5 | from config.eval import Span 6 | 7 | START = "" 8 | STOP = "" 9 | PAD = "" 10 | ROOT = "" 11 | ROOT_DEP_LABEL = "root" 12 | SELF_DEP_LABEL = "self" 13 | 14 | 15 | def log_sum_exp_pytorch(vec): 16 | """ 17 | 18 | :param vec: [batchSize * from_label * to_label] 19 | :return: [batchSize * to_label] 20 | """ 21 | maxScores, idx = torch.max(vec, 1) 22 | maxScores[maxScores == -float("Inf")] = 0 23 | maxScoresExpanded = maxScores.view(vec.shape[0] ,1 , vec.shape[2]).expand(vec.shape[0], vec.shape[1], vec.shape[2]) 24 | return maxScores + torch.log(torch.sum(torch.exp(vec - maxScoresExpanded), 1)) 25 | 26 | 27 | 28 | def simple_batching(config, insts: List[Instance]): 29 | from config.config import DepModelType,ContextEmb 30 | """ 31 | 32 | :param config: 33 | :param insts: 34 | :return: 35 | word_seq_tensor, 36 | word_seq_len, 37 | char_seq_tensor, 38 | char_seq_len, 39 | label_seq_tensor 40 | """ 41 | batch_size = len(insts) 42 | batch_data = sorted(insts, key=lambda inst: len(inst.input.words), reverse=True) ##object-based not direct copy 43 | word_seq_len = torch.LongTensor(list(map(lambda inst: len(inst.input.words), batch_data))) 44 | max_seq_len = word_seq_len.max() 45 | ### NOTE: the 1 here might be used later?? We will make this as padding, because later we have to do a deduction. 46 | #### Use 1 here because the CharBiLSTM accepts 47 | char_seq_len = torch.LongTensor([list(map(len, inst.input.words)) + [1] * (int(max_seq_len) - len(inst.input.words)) for inst in batch_data]) 48 | max_char_seq_len = char_seq_len.max() 49 | 50 | word_emb_tensor = None 51 | if config.context_emb != ContextEmb.none: 52 | emb_size = insts[0].elmo_vec.shape[1] 53 | word_emb_tensor = torch.zeros((batch_size, max_seq_len, emb_size)) 54 | 55 | word_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long) 56 | label_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long) 57 | char_seq_tensor = torch.zeros((batch_size, max_seq_len, max_char_seq_len), dtype=torch.long) 58 | adjs = None 59 | adjs_in = None 60 | adjs_out = None 61 | dep_label_adj = None 62 | dep_label_tensor = None 63 | batch_dep_heads = None 64 | trees = None 65 | graphs = None 66 | if config.dep_model != DepModelType.none: 67 | if config.dep_model == DepModelType.dggcn: 68 | adjs = [ head_to_adj(max_seq_len, inst, config) for inst in batch_data] 69 | adjs = np.stack(adjs, axis=0) 70 | adjs = torch.from_numpy(adjs) 71 | dep_label_adj = [head_to_adj_label(max_seq_len, inst, config) for inst in batch_data] 72 | dep_label_adj = torch.from_numpy(np.stack(dep_label_adj, axis=0)).long() 73 | 74 | if config.dep_model == DepModelType.dglstm: 75 | batch_dep_heads = torch.zeros((batch_size, max_seq_len), dtype=torch.long) 76 | dep_label_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long) 77 | # trees = [inst.tree for inst in batch_data] 78 | for idx in range(batch_size): 79 | word_seq_tensor[idx, :word_seq_len[idx]] = torch.LongTensor(batch_data[idx].word_ids) 80 | label_seq_tensor[idx, :word_seq_len[idx]] = torch.LongTensor(batch_data[idx].output_ids) 81 | if config.context_emb != ContextEmb.none: 82 | word_emb_tensor[idx, :word_seq_len[idx], :] = torch.from_numpy(batch_data[idx].elmo_vec) 83 | 84 | if config.dep_model == DepModelType.dglstm: 85 | batch_dep_heads[idx, :word_seq_len[idx]] = torch.LongTensor(batch_data[idx].dep_head_ids) 86 | dep_label_tensor[idx, :word_seq_len[idx]] = torch.LongTensor(batch_data[idx].dep_label_ids) 87 | for word_idx in range(word_seq_len[idx]): 88 | char_seq_tensor[idx, word_idx, :char_seq_len[idx, word_idx]] = torch.LongTensor(batch_data[idx].char_ids[word_idx]) 89 | for wordIdx in range(word_seq_len[idx], max_seq_len): 90 | char_seq_tensor[idx, wordIdx, 0: 1] = torch.LongTensor([config.char2idx[PAD]]) ###because line 119 makes it 1, every single character should have a id. but actually 0 is enough 91 | 92 | ### NOTE: make this step during forward if you have limited GPU resource. 93 | word_seq_tensor = word_seq_tensor.to(config.device) 94 | label_seq_tensor = label_seq_tensor.to(config.device) 95 | char_seq_tensor = char_seq_tensor.to(config.device) 96 | word_seq_len = word_seq_len.to(config.device) 97 | char_seq_len = char_seq_len.to(config.device) 98 | if config.dep_model != DepModelType.none: 99 | if config.dep_model == DepModelType.dglstm: 100 | batch_dep_heads = batch_dep_heads.to(config.device) 101 | dep_label_tensor = dep_label_tensor.to(config.device) 102 | 103 | return word_seq_tensor, word_seq_len, word_emb_tensor, char_seq_tensor, char_seq_len, adjs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, trees, label_seq_tensor, dep_label_tensor 104 | 105 | 106 | 107 | def lr_decay(config, optimizer, epoch): 108 | lr = config.learning_rate / (1 + config.lr_decay * (epoch - 1)) 109 | for param_group in optimizer.param_groups: 110 | param_group['lr'] = lr 111 | print('learning rate is set to: ', lr) 112 | return optimizer 113 | 114 | 115 | 116 | def head_to_adj(max_len, inst, config): 117 | """ 118 | Convert a tree object to an (numpy) adjacency matrix. 119 | """ 120 | directed = config.adj_directed 121 | self_loop = False #config.adj_self_loop 122 | ret = np.zeros((max_len, max_len), dtype=np.float32) 123 | 124 | for i, head in enumerate(inst.input.heads): 125 | if head == -1: 126 | continue 127 | ret[head, i] = 1 128 | 129 | if not directed: 130 | ret = ret + ret.T 131 | 132 | if self_loop: 133 | for i in range(len(inst.input.words)): 134 | ret[i, i] = 1 135 | 136 | return ret 137 | 138 | 139 | def head_to_adj_label(max_len, inst, config): 140 | """ 141 | Convert a tree object to an (numpy) adjacency matrix. 142 | """ 143 | directed = config.adj_directed 144 | self_loop = config.adj_self_loop 145 | 146 | dep_label_ret = np.zeros((max_len, max_len), dtype=np.long) 147 | 148 | for i, head in enumerate(inst.input.heads): 149 | if head == -1: 150 | continue 151 | dep_label_ret[head, i] = inst.dep_label_ids[i] 152 | 153 | if not directed: 154 | dep_label_ret = dep_label_ret + dep_label_ret.T 155 | 156 | if self_loop: 157 | for i in range(len(inst.input.words)): 158 | dep_label_ret[i, i] = config.root_dep_label_id 159 | 160 | return dep_label_ret 161 | 162 | 163 | def get_spans(output): 164 | output_spans = set() 165 | start = -1 166 | for i in range(len(output)): 167 | if output[i].startswith("B-"): 168 | start = i 169 | if output[i].startswith("E-"): 170 | end = i 171 | output_spans.add(Span(start, end, output[i][2:])) 172 | if output[i].startswith("S-"): 173 | output_spans.add(Span(i, i, output[i][2:])) 174 | return output_spans 175 | 176 | def preprocess(conf, insts, file_type:str): 177 | print("[Preprocess Info]Doing preprocessing for the CoNLL-2003 dataset: {}.".format(file_type)) 178 | for inst in insts: 179 | output = inst.output 180 | spans = get_spans(output) 181 | for span in spans: 182 | if span.right - span.left + 1 < 2: 183 | continue 184 | count_dep = 0 185 | for i in range(span.left, span.right + 1): 186 | if inst.input.heads[i] >= span.left and inst.input.heads[i] <= span.right: 187 | count_dep += 1 188 | if count_dep != (span.right - span.left): 189 | 190 | for i in range(span.left, span.right + 1): 191 | if inst.input.heads[i] < span.left or inst.input.heads[i] > span.right: 192 | if i != span.right: 193 | inst.input.heads[i] = span.right 194 | inst.input.dep_labels[i] = "nn" if "sd" in conf.affix else "compound" -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | In terms of the OntoNotes English dataset: 2 | This is the standard train/dev/test split 3 | 4 | train dev are taken from the conll2012-processed 5 | test is taken from the pradhan-processed 6 | 7 | train: 1,088,503 tokens, 81,828 entities 8 | dev: 147,724 tokens, 11,066 entities 9 | test: 152,728 tokens, 11,257 entities 10 | 11 | The thing that we did this is because of fair comparison with previous work. 12 | 13 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import random 4 | import numpy as np 5 | from config.reader import Reader 6 | from config import eval 7 | from config.config import Config, ContextEmb, DepModelType 8 | import time 9 | from model.lstmcrf import NNCRF 10 | import torch 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | from config.utils import lr_decay, simple_batching, get_spans, preprocess 14 | from typing import List 15 | from common.instance import Instance 16 | from termcolor import colored 17 | import os 18 | 19 | 20 | def setSeed(opt, seed): 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | if opt.device.startswith("cuda"): 25 | print("using GPU...", torch.cuda.current_device()) 26 | torch.cuda.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | 30 | def parse_arguments(parser): 31 | ###Training Hyperparameters 32 | parser.add_argument('--mode', type=str, default='train') 33 | parser.add_argument('--device', type=str, default="cpu") 34 | parser.add_argument('--seed', type=int, default=42) 35 | parser.add_argument('--digit2zero', action="store_true", default=True) 36 | parser.add_argument('--dataset', type=str, default="ontonotes") 37 | parser.add_argument('--affix', type=str, default="sd") 38 | parser.add_argument('--embedding_file', type=str, default="data/glove.6B.100d.txt") 39 | # parser.add_argument('--embedding_file', type=str, default=None) 40 | parser.add_argument('--embedding_dim', type=int, default=100) 41 | parser.add_argument('--optimizer', type=str, default="sgd") 42 | parser.add_argument('--learning_rate', type=float, default=0.01) ##only for sgd now 43 | parser.add_argument('--momentum', type=float, default=0.0) 44 | parser.add_argument('--l2', type=float, default=1e-8) 45 | parser.add_argument('--lr_decay', type=float, default=0) 46 | parser.add_argument('--batch_size', type=int, default=10) 47 | parser.add_argument('--num_epochs', type=int, default=100) 48 | parser.add_argument('--train_num', type=int, default=-1) 49 | parser.add_argument('--dev_num', type=int, default=-1) 50 | parser.add_argument('--test_num', type=int, default=-1) 51 | parser.add_argument('--eval_freq', type=int, default=4000, help="evaluate frequency (iteration)") 52 | parser.add_argument('--eval_epoch', type=int, default=0, help="evaluate the dev set after this number of epoch") 53 | 54 | ## model hyperparameter 55 | parser.add_argument('--hidden_dim', type=int, default=200, help="hidden size of the LSTM") 56 | parser.add_argument('--num_lstm_layer', type=int, default=1, help="number of lstm layers") 57 | parser.add_argument('--dep_emb_size', type=int, default=50, help="embedding size of dependency") 58 | parser.add_argument('--dep_hidden_dim', type=int, default=200, help="hidden size of gcn, tree lstm") 59 | 60 | ### NOTE: GCN parameters, useless if we are not using GCN 61 | parser.add_argument('--num_gcn_layers', type=int, default=1, help="number of gcn layers") 62 | parser.add_argument('--gcn_mlp_layers', type=int, default=1, help="number of mlp layers after gcn") 63 | parser.add_argument('--gcn_dropout', type=float, default=0.5, help="GCN dropout") 64 | parser.add_argument('--gcn_adj_directed', type=int, default=0, choices=[0, 1], help="GCN ajacent matrix directed") 65 | parser.add_argument('--gcn_adj_selfloop', type=int, default=0, choices=[0, 1], help="GCN selfloop in adjacent matrix, now always false as add it in the model") 66 | parser.add_argument('--gcn_gate', type=int, default=0, choices=[0, 1], help="add edge_wise gating") 67 | 68 | ##NOTE: this dropout applies to many places 69 | parser.add_argument('--dropout', type=float, default=0.5, help="dropout for embedding") 70 | parser.add_argument('--use_char_rnn', type=int, default=1, choices=[0, 1], help="use character-level lstm, 0 or 1") 71 | # parser.add_argument('--use_head', type=int, default=0, choices=[0, 1], help="not use dependency") 72 | parser.add_argument('--dep_model', type=str, default="none", choices=["none", "dggcn", "dglstm"], help="dependency method") 73 | parser.add_argument('--inter_func', type=str, default="mlp", choices=["concatenation", "addition", "mlp"], help="combination method, 0 concat, 1 additon, 2 gcn, 3 more parameter gcn") 74 | parser.add_argument('--context_emb', type=str, default="none", choices=["none", "bert", "elmo", "flair"], help="contextual word embedding") 75 | 76 | 77 | 78 | 79 | args = parser.parse_args() 80 | for k in args.__dict__: 81 | print(k + ": " + str(args.__dict__[k])) 82 | return args 83 | 84 | 85 | def get_optimizer(config: Config, model: nn.Module): 86 | params = model.parameters() 87 | if config.optimizer.lower() == "sgd": 88 | print(colored("Using SGD: lr is: {}, L2 regularization is: {}".format(config.learning_rate, config.l2), 'yellow')) 89 | return optim.SGD(params, lr=config.learning_rate, weight_decay=float(config.l2)) 90 | elif config.optimizer.lower() == "adam": 91 | print(colored("Using Adam", 'yellow')) 92 | return optim.Adam(params) 93 | else: 94 | print("Illegal optimizer: {}".format(config.optimizer)) 95 | exit(1) 96 | 97 | def batching_list_instances(config: Config, insts:List[Instance]): 98 | train_num = len(insts) 99 | batch_size = config.batch_size 100 | total_batch = train_num // batch_size + 1 if train_num % batch_size != 0 else train_num // batch_size 101 | batched_data = [] 102 | for batch_id in range(total_batch): 103 | one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) * batch_size] 104 | batched_data.append(simple_batching(config, one_batch_insts)) 105 | 106 | return batched_data 107 | 108 | def learn_from_insts(config:Config, epoch: int, train_insts, dev_insts, test_insts): 109 | # train_insts: List[Instance], dev_insts: List[Instance], test_insts: List[Instance], batch_size: int = 1 110 | model = NNCRF(config) 111 | optimizer = get_optimizer(config, model) 112 | train_num = len(train_insts) 113 | print("number of instances: %d" % (train_num)) 114 | print(colored("[Shuffled] Shuffle the training instance ids", "red")) 115 | random.shuffle(train_insts) 116 | 117 | 118 | 119 | batched_data = batching_list_instances(config, train_insts) 120 | dev_batches = batching_list_instances(config, dev_insts) 121 | test_batches = batching_list_instances(config, test_insts) 122 | 123 | best_dev = [-1, 0] 124 | best_test = [-1, 0] 125 | 126 | dep_model_name = config.dep_model.name 127 | if config.dep_model == DepModelType.dggcn: 128 | dep_model_name += '(' + str(config.num_gcn_layers) + "," + str(config.gcn_dropout) + "," + str( 129 | config.gcn_mlp_layers) + ")" 130 | model_name = "model_files/lstm_{}_{}_crf_{}_{}_{}_dep_{}_elmo_{}_{}_gate_{}_epoch_{}_lr_{}_comb_{}.m".format(config.num_lstm_layer, config.hidden_dim, config.dataset, config.affix, config.train_num, dep_model_name, config.context_emb.name, config.optimizer.lower(), config.edge_gate, epoch, config.learning_rate, config.interaction_func) 131 | res_name = "results/lstm_{}_{}_crf_{}_{}_{}_dep_{}_elmo_{}_{}_gate_{}_epoch_{}_lr_{}_comb_{}.results".format(config.num_lstm_layer, config.hidden_dim, config.dataset, config.affix, config.train_num, dep_model_name, config.context_emb.name, config.optimizer.lower(), config.edge_gate, epoch, config.learning_rate, config.interaction_func) 132 | print("[Info] The model will be saved to: %s, please ensure models folder exist" % (model_name)) 133 | if not os.path.exists("model_files"): 134 | os.makedirs("model_files") 135 | if not os.path.exists("results"): 136 | os.makedirs("results") 137 | 138 | for i in range(1, epoch + 1): 139 | epoch_loss = 0 140 | start_time = time.time() 141 | model.zero_grad() 142 | if config.optimizer.lower() == "sgd": 143 | optimizer = lr_decay(config, optimizer, i) 144 | for index in np.random.permutation(len(batched_data)): 145 | # for index in range(len(batched_data)): 146 | model.train() 147 | batch_word, batch_wordlen, batch_context_emb, batch_char, batch_charlen, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, trees, batch_label, batch_dep_label = batched_data[index] 148 | loss = model.neg_log_obj(batch_word, batch_wordlen, batch_context_emb,batch_char, batch_charlen, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, batch_label, batch_dep_label, trees) 149 | epoch_loss += loss.item() 150 | loss.backward() 151 | if config.dep_model == DepModelType.dggcn: 152 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip) ##clipping the gradient 153 | optimizer.step() 154 | model.zero_grad() 155 | 156 | end_time = time.time() 157 | print("Epoch %d: %.5f, Time is %.2fs" % (i, epoch_loss, end_time - start_time), flush=True) 158 | 159 | if i + 1 >= config.eval_epoch: 160 | model.eval() 161 | dev_metrics = evaluate(config, model, dev_batches, "dev", dev_insts) 162 | test_metrics = evaluate(config, model, test_batches, "test", test_insts) 163 | if dev_metrics[2] > best_dev[0]: 164 | print("saving the best model...") 165 | best_dev[0] = dev_metrics[2] 166 | best_dev[1] = i 167 | best_test[0] = test_metrics[2] 168 | best_test[1] = i 169 | torch.save(model.state_dict(), model_name) 170 | write_results(res_name, test_insts) 171 | model.zero_grad() 172 | 173 | print("The best dev: %.2f" % (best_dev[0])) 174 | print("The corresponding test: %.2f" % (best_test[0])) 175 | print("Final testing.") 176 | model.load_state_dict(torch.load(model_name)) 177 | model.eval() 178 | evaluate(config, model, test_batches, "test", test_insts) 179 | write_results(res_name, test_insts) 180 | 181 | 182 | 183 | def evaluate(config:Config, model: NNCRF, batch_insts_ids, name:str, insts: List[Instance]): 184 | ## evaluation 185 | metrics = np.asarray([0, 0, 0], dtype=int) 186 | batch_id = 0 187 | batch_size = config.batch_size 188 | for batch in batch_insts_ids: 189 | one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) * batch_size] 190 | sorted_batch_insts = sorted(one_batch_insts, key=lambda inst: len(inst.input.words), reverse=True) 191 | batch_max_scores, batch_max_ids = model.decode(batch) 192 | metrics += eval.evaluate_num(sorted_batch_insts, batch_max_ids, batch[-2], batch[1], config.idx2labels) 193 | batch_id += 1 194 | p, total_predict, total_entity = metrics[0], metrics[1], metrics[2] 195 | precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0 196 | recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0 197 | fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0 198 | print("[%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" % (name, precision, recall,fscore), flush=True) 199 | return [precision, recall, fscore] 200 | 201 | 202 | def test_model(config: Config, test_insts): 203 | dep_model_name = config.dep_model.name 204 | if config.dep_model == DepModelType.dggcn: 205 | dep_model_name += '(' + str(config.num_gcn_layers) + ","+str(config.gcn_dropout)+ ","+str(config.gcn_mlp_layers)+")" 206 | model_name = "model_files/lstm_{}_{}_crf_{}_{}_{}_dep_{}_elmo_{}_{}_gate_{}_epoch_{}_lr_{}_comb_{}.m".format(config.num_lstm_layer, config.hidden_dim, 207 | config.dataset, config.affix, 208 | config.train_num, 209 | dep_model_name, 210 | config.context_emb.name, 211 | config.optimizer.lower(), 212 | config.edge_gate, 213 | config.num_epochs, 214 | config.learning_rate, config.interaction_func) 215 | res_name = "results/lstm_{}_{}_crf_{}_{}_{}_dep_{}_elmo_{}_{}_gate_{}_epoch_{}_lr_{}_comb_{}.results".format(config.num_lstm_layer, config.hidden_dim, 216 | config.dataset, config.affix, 217 | config.train_num, 218 | dep_model_name, 219 | config.context_emb.name, 220 | config.optimizer.lower(), 221 | config.edge_gate, 222 | config.num_epochs, 223 | config.learning_rate, config.interaction_func) 224 | model = NNCRF(config) 225 | model.load_state_dict(torch.load(model_name)) 226 | model.eval() 227 | test_batches = batching_list_instances(config, test_insts) 228 | evaluate(config, model, test_batches, "test", test_insts) 229 | write_results(res_name, test_insts) 230 | 231 | def write_results(filename:str, insts): 232 | f = open(filename, 'w', encoding='utf-8') 233 | for inst in insts: 234 | for i in range(len(inst.input)): 235 | words = inst.input.words 236 | tags = inst.input.pos_tags 237 | heads = inst.input.heads 238 | dep_labels = inst.input.dep_labels 239 | output = inst.output 240 | prediction = inst.prediction 241 | assert len(output) == len(prediction) 242 | f.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format(i, words[i], tags[i], heads[i], dep_labels[i], output[i], prediction[i])) 243 | f.write("\n") 244 | f.close() 245 | 246 | 247 | 248 | 249 | 250 | 251 | def main(): 252 | parser = argparse.ArgumentParser(description="Dependency-Guided LSTM CRF implementation") 253 | opt = parse_arguments(parser) 254 | conf = Config(opt) 255 | 256 | reader = Reader(conf.digit2zero) 257 | setSeed(opt, conf.seed) 258 | 259 | trains = reader.read_conll(conf.train_file, -1, True) 260 | devs = reader.read_conll(conf.dev_file, conf.dev_num, False) 261 | tests = reader.read_conll(conf.test_file, conf.test_num, False) 262 | 263 | if conf.context_emb != ContextEmb.none: 264 | print('Loading the {} vectors for all datasets.'.format(conf.context_emb.name)) 265 | conf.context_emb_size = reader.load_elmo_vec(conf.train_file.replace(".sd", "").replace(".ud", "").replace(".sud", "").replace(".predsd", "").replace(".predud", "").replace(".stud", "").replace(".ssd", "") + "."+conf.context_emb.name+".vec", trains) 266 | reader.load_elmo_vec(conf.dev_file.replace(".sd", "").replace(".ud", "").replace(".sud", "").replace(".predsd", "").replace(".predud", "").replace(".stud", "").replace(".ssd", "") + "."+conf.context_emb.name+".vec", devs) 267 | reader.load_elmo_vec(conf.test_file.replace(".sd", "").replace(".ud", "").replace(".sud", "").replace(".predsd", "").replace(".predud", "").replace(".stud", "").replace(".ssd", "") + "."+conf.context_emb.name+".vec", tests) 268 | 269 | conf.use_iobes(trains + devs + tests) 270 | conf.build_label_idx(trains) 271 | 272 | conf.build_deplabel_idx(trains + devs + tests) 273 | print("# deplabels: ", len(conf.deplabels)) 274 | print("dep label 2idx: ", conf.deplabel2idx) 275 | 276 | 277 | conf.build_word_idx(trains, devs, tests) 278 | conf.build_emb_table() 279 | conf.map_insts_ids(trains + devs + tests) 280 | 281 | 282 | print("num chars: " + str(conf.num_char)) 283 | # print(str(config.char2idx)) 284 | 285 | print("num words: " + str(len(conf.word2idx))) 286 | # print(config.word2idx) 287 | if opt.mode == "train": 288 | if conf.train_num != -1: 289 | random.shuffle(trains) 290 | trains = trains[:conf.train_num] 291 | learn_from_insts(conf, conf.num_epochs, trains, devs, tests) 292 | else: 293 | ## Load the trained model. 294 | test_model(conf, tests) 295 | # pass 296 | 297 | print(opt.mode) 298 | 299 | if __name__ == "__main__": 300 | main() -------------------------------------------------------------------------------- /model/charbilstm.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | 9 | 10 | class CharBiLSTM(nn.Module): 11 | 12 | def __init__(self, config): 13 | super(CharBiLSTM, self).__init__() 14 | print("[Info] Building character-level LSTM") 15 | self.char_emb_size = config.char_emb_size 16 | self.char2idx = config.char2idx 17 | self.chars = config.idx2char 18 | self.char_size = len(self.chars) 19 | self.device = config.device 20 | self.hidden = config.charlstm_hidden_dim 21 | self.dropout = nn.Dropout(config.dropout).to(self.device) 22 | self.char_embeddings = nn.Embedding(self.char_size, self.char_emb_size) 23 | # self.char_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(self.char_size, self.char_emb_size))) 24 | self.char_embeddings = self.char_embeddings.to(self.device) 25 | 26 | self.char_lstm = nn.LSTM(self.char_emb_size, self.hidden ,num_layers=1, batch_first=True, bidirectional=False).to(self.device) 27 | 28 | 29 | # def random_embedding(self, vocab_size, embedding_dim): 30 | # pretrain_emb = np.empty([vocab_size, embedding_dim]) 31 | # scale = np.sqrt(3.0 / embedding_dim) 32 | # for index in range(vocab_size): 33 | # pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim]) 34 | # return pretrain_emb 35 | 36 | def get_last_hiddens(self, char_seq_tensor, char_seq_len): 37 | """ 38 | input: 39 | char_seq_tensor: (batch_size, sent_len, word_length) 40 | char_seq_len: (batch_size, sent_len) 41 | output: 42 | Variable(batch_size, sent_len, char_hidden_dim ) 43 | """ 44 | batch_size = char_seq_tensor.size(0) 45 | sent_len = char_seq_tensor.size(1) 46 | char_seq_tensor = char_seq_tensor.view(batch_size * sent_len, -1) 47 | char_seq_len = char_seq_len.view(batch_size * sent_len) 48 | sorted_seq_len, permIdx = char_seq_len.sort(0, descending=True) 49 | _, recover_idx = permIdx.sort(0, descending=False) 50 | sorted_seq_tensor = char_seq_tensor[permIdx] 51 | 52 | char_embeds = self.dropout(self.char_embeddings(sorted_seq_tensor)) 53 | pack_input = pack_padded_sequence(char_embeds, sorted_seq_len, batch_first=True) 54 | 55 | _, char_hidden = self.char_lstm(pack_input, None) ### 56 | ## char_hidden = (h_t, c_t) 57 | # char_hidden[0] = h_t = (2, batch_size, lstm_dimension) 58 | # char_rnn_out, _ = pad_packed_sequence(char_rnn_out) 59 | ## transpose because the first dimension is num_direction x num-layer 60 | hidden = char_hidden[0].transpose(1,0).contiguous().view(batch_size * sent_len, 1, -1) ### before view, the size is ( batch_size * sent_len, 2, lstm_dimension) 2 means 2 direciton.. 61 | return hidden[recover_idx].view(batch_size, sent_len, -1) 62 | 63 | 64 | 65 | def forward(self, char_input, seq_lengths): 66 | return self.get_last_hiddens(char_input, seq_lengths) 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /model/deplabel_gcn.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | 11 | class DepLabeledGCN(nn.Module): 12 | def __init__(self, config, input_dim): 13 | super().__init__() 14 | 15 | self.gcn_hidden_dim = config.dep_hidden_dim 16 | self.num_gcn_layers = config.num_gcn_layers 17 | self.gcn_mlp_layers = config.gcn_mlp_layers 18 | self.edge_gate = config.edge_gate 19 | # gcn layer 20 | self.layers = self.num_gcn_layers 21 | self.device = config.device 22 | self.mem_dim = self.gcn_hidden_dim 23 | # self.in_dim = config.hidden_dim + config.dep_emb_size ## lstm hidden dim 24 | self.in_dim = input_dim ## lstm hidden dim 25 | self.self_dep_label_id = torch.tensor(config.deplabel2idx[config.self_label]).long().to(self.device) 26 | 27 | print("[Model Info] GCN Input Size: {}, # GCN Layers: {}, #MLP: {}".format(self.in_dim, self.num_gcn_layers, config.gcn_mlp_layers)) 28 | self.gcn_drop = nn.Dropout(config.gcn_dropout).to(self.device) 29 | 30 | # gcn layer 31 | self.W = nn.ModuleList() 32 | self.W_label = nn.ModuleList() 33 | 34 | if self.edge_gate: 35 | print("[Info] Labeled GCN model will be added edge-wise gating.") 36 | self.gates = nn.ModuleList() 37 | 38 | for layer in range(self.layers): 39 | input_dim = self.in_dim if layer == 0 else self.mem_dim 40 | self.W.append(nn.Linear(input_dim, self.mem_dim).to(self.device)) 41 | self.W_label.append(nn.Linear(input_dim, self.mem_dim).to(self.device)) 42 | if self.edge_gate: 43 | self.gates.append(nn.Linear(input_dim, self.mem_dim).to(self.device)) 44 | 45 | self.dep_emb = nn.Embedding(len(config.deplabels), 1).to(config.device) 46 | 47 | # output mlp layers 48 | in_dim = config.hidden_dim 49 | layers = [nn.Linear(in_dim, self.gcn_hidden_dim).to(self.device), nn.ReLU().to(self.device)] 50 | for _ in range(self.gcn_mlp_layers - 1): 51 | layers += [nn.Linear(self.gcn_hidden_dim, self.gcn_hidden_dim).to(self.device), nn.ReLU().to(self.device)] 52 | 53 | self.out_mlp = nn.Sequential(*layers).to(self.device) 54 | 55 | 56 | 57 | def forward(self, gcn_inputs, word_seq_len, adj_matrix, dep_label_matrix): 58 | 59 | """ 60 | 61 | :param gcn_inputs: 62 | :param word_seq_len: 63 | :param adj_matrix: should already contain the self loop 64 | :param dep_label_matrix: 65 | :return: 66 | """ 67 | adj_matrix = adj_matrix.to(self.device) 68 | dep_label_matrix = dep_label_matrix.to(self.device) 69 | batch_size, sent_len, input_dim = gcn_inputs.size() 70 | 71 | denom = adj_matrix.sum(2).unsqueeze(2) + 1 72 | 73 | ##dep_label_matrix: NxN 74 | ##dep_emb. 75 | dep_embs = self.dep_emb(dep_label_matrix) ## B x N x N x 1 76 | dep_embs = dep_embs.squeeze(3) * adj_matrix 77 | # 78 | self_val = self.dep_emb(self.self_dep_label_id) 79 | dep_denom = dep_embs.sum(2).unsqueeze(2) + self_val 80 | 81 | # gcn_biinput = gcn_inputs.view(batch_size, sent_len, 1, input_dim).expand(batch_size, sent_len, sent_len, input_dim) ## B x N x N x h 82 | # weighted_gcn_input = (dep_embs + gcn_biinput).sum(2) 83 | 84 | for l in range(self.layers): 85 | 86 | Ax = adj_matrix.bmm(gcn_inputs) ## N x N times N x h = Nxh 87 | AxW = self.W[l](Ax) ## N x m 88 | AxW = AxW + self.W[l](gcn_inputs) ## self loop N x h 89 | AxW = AxW / denom 90 | 91 | Bx = dep_embs.bmm(gcn_inputs) 92 | BxW = self.W_label[l](Bx) 93 | BxW = BxW + self.W_label[l](gcn_inputs * self_val) 94 | BxW = BxW / dep_denom 95 | 96 | if self.edge_gate: 97 | gx = adj_matrix.bmm(gcn_inputs) 98 | gxW = self.gates[l](gx) ## N x m 99 | gate_val = torch.sigmoid(gxW + self.gates[l](gcn_inputs)) ## self loop N x h 100 | gAxW = F.relu(gate_val * (AxW + BxW)) 101 | else: 102 | gAxW = F.relu(AxW + BxW) 103 | 104 | gcn_inputs = self.gcn_drop(gAxW) if l < self.layers - 1 else gAxW 105 | 106 | 107 | outputs = self.out_mlp(gcn_inputs) 108 | return outputs 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /model/lstmcrf.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from config.utils import START, STOP, PAD, log_sum_exp_pytorch 9 | from model.charbilstm import CharBiLSTM 10 | from model.deplabel_gcn import DepLabeledGCN 11 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 12 | from config.config import DepModelType, ContextEmb, InteractionFunction 13 | import torch.nn.functional as F 14 | 15 | class NNCRF(nn.Module): 16 | 17 | def __init__(self, config): 18 | super(NNCRF, self).__init__() 19 | 20 | self.label_size = config.label_size 21 | self.device = config.device 22 | self.use_char = config.use_char_rnn 23 | self.dep_model = config.dep_model 24 | self.context_emb = config.context_emb 25 | self.interaction_func = config.interaction_func 26 | 27 | 28 | self.label2idx = config.label2idx 29 | self.labels = config.idx2labels 30 | self.start_idx = self.label2idx[START] 31 | self.end_idx = self.label2idx[STOP] 32 | self.pad_idx = self.label2idx[PAD] 33 | 34 | 35 | 36 | self.input_size = config.embedding_dim 37 | 38 | if self.use_char: 39 | self.char_feature = CharBiLSTM(config) 40 | self.input_size += config.charlstm_hidden_dim 41 | 42 | 43 | vocab_size = len(config.word2idx) 44 | self.word_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(config.word_embedding), freeze=False).to(self.device) 45 | self.word_drop = nn.Dropout(config.dropout).to(self.device) 46 | 47 | if self.dep_model == DepModelType.dglstm and self.interaction_func == InteractionFunction.mlp: 48 | self.mlp_layers = nn.ModuleList() 49 | for i in range(config.num_lstm_layer - 1): 50 | self.mlp_layers.append(nn.Linear(config.hidden_dim, 2 * config.hidden_dim).to(self.device)) 51 | self.mlp_head_linears = nn.ModuleList() 52 | for i in range(config.num_lstm_layer - 1): 53 | self.mlp_head_linears.append(nn.Linear(config.hidden_dim, 2 * config.hidden_dim).to(self.device)) 54 | 55 | """ 56 | Input size to LSTM description 57 | """ 58 | self.charlstm_dim = config.charlstm_hidden_dim 59 | if self.dep_model == DepModelType.dglstm: 60 | self.input_size += config.embedding_dim + config.dep_emb_size 61 | if self.use_char: 62 | self.input_size += config.charlstm_hidden_dim 63 | 64 | if self.context_emb != ContextEmb.none: 65 | self.input_size += config.context_emb_size 66 | 67 | print("[Model Info] Input size to LSTM: {}".format(self.input_size)) 68 | print("[Model Info] LSTM Hidden Size: {}".format(config.hidden_dim)) 69 | 70 | 71 | num_layers = 1 72 | if config.num_lstm_layer > 1 and self.dep_model != DepModelType.dglstm: 73 | num_layers = config.num_lstm_layer 74 | if config.num_lstm_layer > 0: 75 | self.lstm = nn.LSTM(self.input_size, config.hidden_dim // 2, num_layers=num_layers, batch_first=True, bidirectional=True).to(self.device) 76 | 77 | self.num_lstm_layer = config.num_lstm_layer 78 | self.lstm_hidden_dim = config.hidden_dim 79 | self.embedding_dim = config.embedding_dim 80 | if config.num_lstm_layer > 1 and self.dep_model == DepModelType.dglstm: 81 | self.add_lstms = nn.ModuleList() 82 | if self.interaction_func == InteractionFunction.concatenation or \ 83 | self.interaction_func == InteractionFunction.mlp: 84 | hidden_size = 2 * config.hidden_dim 85 | elif self.interaction_func == InteractionFunction.addition: 86 | hidden_size = config.hidden_dim 87 | 88 | print("[Model Info] Building {} more LSTMs, with size: {} x {} (without dep label highway connection)".format(config.num_lstm_layer-1, hidden_size, config.hidden_dim)) 89 | for i in range(config.num_lstm_layer - 1): 90 | self.add_lstms.append(nn.LSTM(hidden_size, config.hidden_dim // 2, num_layers=1, batch_first=True, bidirectional=True).to(self.device)) 91 | 92 | self.drop_lstm = nn.Dropout(config.dropout).to(self.device) 93 | 94 | 95 | final_hidden_dim = config.hidden_dim if self.num_lstm_layer >0 else self.input_size 96 | """ 97 | Model description 98 | """ 99 | print("[Model Info] Dep Method: {}, hidden size: {}".format(self.dep_model.name, config.dep_hidden_dim)) 100 | if self.dep_model != DepModelType.none: 101 | print("Initializing the dependency label embedding") 102 | self.dep_label_embedding = nn.Embedding(len(config.deplabel2idx), config.dep_emb_size).to(self.device) 103 | if self.dep_model == DepModelType.dggcn: 104 | self.gcn = DepLabeledGCN(config, config.hidden_dim) ### lstm hidden size 105 | final_hidden_dim = config.dep_hidden_dim 106 | 107 | print("[Model Info] Final Hidden Size: {}".format(final_hidden_dim)) 108 | self.hidden2tag = nn.Linear(final_hidden_dim, self.label_size).to(self.device) 109 | 110 | init_transition = torch.randn(self.label_size, self.label_size).to(self.device) 111 | init_transition[:, self.start_idx] = -10000.0 112 | init_transition[self.end_idx, :] = -10000.0 113 | init_transition[:, self.pad_idx] = -10000.0 114 | init_transition[self.pad_idx, :] = -10000.0 115 | 116 | self.transition = nn.Parameter(init_transition) 117 | 118 | 119 | def neural_scoring(self, word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, dep_head_tensor, dep_label_tensor, trees=None): 120 | """ 121 | :param word_seq_tensor: (batch_size, sent_len) NOTE: The word seq actually is already ordered before come here. 122 | :param word_seq_lens: (batch_size, 1) 123 | :param chars: (batch_size * sent_len * word_length) 124 | :param char_seq_lens: numpy (batch_size * sent_len , 1) 125 | :param dep_label_tensor: (batch_size, max_sent_len) 126 | :return: emission scores (batch_size, sent_len, hidden_dim) 127 | """ 128 | batch_size = word_seq_tensor.size(0) 129 | sent_len = word_seq_tensor.size(1) 130 | 131 | word_emb = self.word_embedding(word_seq_tensor) 132 | if self.use_char: 133 | if self.dep_model == DepModelType.dglstm: 134 | char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lens) 135 | word_emb = torch.cat((word_emb, char_features), 2) 136 | if self.dep_model == DepModelType.dglstm: 137 | size = self.embedding_dim if not self.use_char else (self.embedding_dim + self.charlstm_dim) 138 | dep_head_emb = torch.gather(word_emb, 1, dep_head_tensor.view(batch_size, sent_len, 1).expand(batch_size, sent_len, size)) 139 | 140 | if self.context_emb != ContextEmb.none: 141 | word_emb = torch.cat((word_emb, batch_context_emb.to(self.device)), 2) 142 | 143 | if self.use_char: 144 | if self.dep_model != DepModelType.dglstm: 145 | char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lens) 146 | word_emb = torch.cat((word_emb, char_features), 2) 147 | 148 | """ 149 | Word Representation 150 | """ 151 | if self.dep_model == DepModelType.dglstm: 152 | dep_emb = self.dep_label_embedding(dep_label_tensor) 153 | word_emb = torch.cat((word_emb, dep_head_emb, dep_emb), 2) 154 | 155 | word_rep = self.word_drop(word_emb) 156 | 157 | sorted_seq_len, permIdx = word_seq_lens.sort(0, descending=True) 158 | _, recover_idx = permIdx.sort(0, descending=False) 159 | sorted_seq_tensor = word_rep[permIdx] 160 | 161 | 162 | if self.num_lstm_layer > 0: 163 | packed_words = pack_padded_sequence(sorted_seq_tensor, sorted_seq_len, True) 164 | lstm_out, _ = self.lstm(packed_words, None) 165 | lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True) ## CARE: make sure here is batch_first, otherwise need to transpose. 166 | feature_out = self.drop_lstm(lstm_out) 167 | else: 168 | feature_out = sorted_seq_tensor 169 | 170 | """ 171 | Higher order interactions 172 | """ 173 | if self.num_lstm_layer > 1 and (self.dep_model == DepModelType.dglstm): 174 | for l in range(self.num_lstm_layer-1): 175 | dep_head_emb = torch.gather(feature_out, 1, dep_head_tensor[permIdx].view(batch_size, sent_len, 1).expand(batch_size, sent_len, self.lstm_hidden_dim)) 176 | if self.interaction_func == InteractionFunction.concatenation: 177 | feature_out = torch.cat((feature_out, dep_head_emb), 2) 178 | elif self.interaction_func == InteractionFunction.addition: 179 | feature_out = feature_out + dep_head_emb 180 | elif self.interaction_func == InteractionFunction.mlp: 181 | feature_out = F.relu(self.mlp_layers[l](feature_out) + self.mlp_head_linears[l](dep_head_emb)) 182 | 183 | packed_words = pack_padded_sequence(feature_out, sorted_seq_len, True) 184 | lstm_out, _ = self.add_lstms[l](packed_words, None) 185 | lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True) ## CARE: make sure here is batch_first, otherwise need to transpose. 186 | feature_out = self.drop_lstm(lstm_out) 187 | 188 | """ 189 | Model forward if we have GCN 190 | """ 191 | if self.dep_model == DepModelType.dggcn: 192 | feature_out = self.gcn(feature_out, sorted_seq_len, adj_matrixs[permIdx], dep_label_adj[permIdx]) 193 | 194 | outputs = self.hidden2tag(feature_out) 195 | 196 | return outputs[recover_idx] 197 | 198 | def calculate_all_scores(self, features): 199 | batch_size = features.size(0) 200 | seq_len = features.size(1) 201 | scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \ 202 | features.view(batch_size, seq_len, 1, self.label_size).expand(batch_size,seq_len,self.label_size, self.label_size) 203 | return scores 204 | 205 | def forward_unlabeled(self, all_scores, word_seq_lens, masks): 206 | batch_size = all_scores.size(0) 207 | seq_len = all_scores.size(1) 208 | alpha = torch.zeros(batch_size, seq_len, self.label_size).to(self.device) 209 | 210 | alpha[:, 0, :] = all_scores[:, 0, self.start_idx, :] ## the first position of all labels = (the transition from start - > all labels) + current emission. 211 | 212 | for word_idx in range(1, seq_len): 213 | ## batch_size, self.label_size, self.label_size 214 | before_log_sum_exp = alpha[:, word_idx-1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + all_scores[:, word_idx, :, :] 215 | alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp) 216 | 217 | ### batch_size x label_size 218 | last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size)-1).view(batch_size, self.label_size) 219 | last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size) 220 | last_alpha = log_sum_exp_pytorch(last_alpha.view(batch_size, self.label_size, 1)).view(batch_size) 221 | 222 | return torch.sum(last_alpha) 223 | 224 | def forward_labeled(self, all_scores, word_seq_lens, tags, masks): 225 | ''' 226 | :param all_scores: (batch, seq_len, label_size, label_size) 227 | :param word_seq_lens: (batch, seq_len) 228 | :param tags: (batch, seq_len) 229 | :param masks: batch, seq_len 230 | :return: sum of score for the gold sequences 231 | ''' 232 | batchSize = all_scores.shape[0] 233 | sentLength = all_scores.shape[1] 234 | 235 | ## all the scores to current labels: batch, seq_len, all_from_label? 236 | currentTagScores = torch.gather(all_scores, 3, tags.view(batchSize, sentLength, 1, 1).expand(batchSize, sentLength, self.label_size, 1)).view(batchSize, -1, self.label_size) 237 | if sentLength != 1: 238 | tagTransScoresMiddle = torch.gather(currentTagScores[:, 1:, :], 2, tags[:, : sentLength - 1].view(batchSize, sentLength - 1, 1)).view(batchSize, -1) 239 | tagTransScoresBegin = currentTagScores[:, 0, self.start_idx] 240 | endTagIds = torch.gather(tags, 1, word_seq_lens.view(batchSize, 1) - 1) 241 | tagTransScoresEnd = torch.gather(self.transition[:, self.end_idx].view(1, self.label_size).expand(batchSize, self.label_size), 1, endTagIds).view(batchSize) 242 | score = torch.sum(tagTransScoresBegin) + torch.sum(tagTransScoresEnd) 243 | if sentLength != 1: 244 | score += torch.sum(tagTransScoresMiddle.masked_select(masks[:, 1:])) 245 | return score 246 | 247 | def neg_log_obj(self, words, word_seq_lens, batch_context_emb, chars, char_seq_lens, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, tags, batch_dep_label, trees=None): 248 | features = self.neural_scoring(words, word_seq_lens, batch_context_emb, chars, char_seq_lens, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, batch_dep_label, trees) 249 | 250 | all_scores = self.calculate_all_scores(features) 251 | 252 | batch_size = words.size(0) 253 | sent_len = words.size(1) 254 | 255 | maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(self.device) 256 | mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device) 257 | 258 | unlabed_score = self.forward_unlabeled(all_scores, word_seq_lens, mask) 259 | labeled_score = self.forward_labeled(all_scores, word_seq_lens, tags, mask) 260 | return unlabed_score - labeled_score 261 | 262 | 263 | def viterbiDecode(self, all_scores, word_seq_lens): 264 | batchSize = all_scores.shape[0] 265 | sentLength = all_scores.shape[1] 266 | # sent_len = 267 | scoresRecord = torch.zeros([batchSize, sentLength, self.label_size]).to(self.device) 268 | idxRecord = torch.zeros([batchSize, sentLength, self.label_size], dtype=torch.int64).to(self.device) 269 | mask = torch.ones_like(word_seq_lens, dtype=torch.int64).to(self.device) 270 | startIds = torch.full((batchSize, self.label_size), self.start_idx, dtype=torch.int64).to(self.device) 271 | decodeIdx = torch.LongTensor(batchSize, sentLength).to(self.device) 272 | 273 | scores = all_scores 274 | # scoresRecord[:, 0, :] = self.getInitAlphaWithBatchSize(batchSize).view(batchSize, self.label_size) 275 | scoresRecord[:, 0, :] = scores[:, 0, self.start_idx, :] ## represent the best current score from the start, is the best 276 | idxRecord[:, 0, :] = startIds 277 | for wordIdx in range(1, sentLength): 278 | ### scoresIdx: batch x from_label x to_label at current index. 279 | scoresIdx = scoresRecord[:, wordIdx - 1, :].view(batchSize, self.label_size, 1).expand(batchSize, self.label_size, 280 | self.label_size) + scores[:, wordIdx, :, :] 281 | idxRecord[:, wordIdx, :] = torch.argmax(scoresIdx, 1) ## the best previous label idx to crrent labels 282 | scoresRecord[:, wordIdx, :] = torch.gather(scoresIdx, 1, idxRecord[:, wordIdx, :].view(batchSize, 1, self.label_size)).view(batchSize, self.label_size) 283 | 284 | lastScores = torch.gather(scoresRecord, 1, word_seq_lens.view(batchSize, 1, 1).expand(batchSize, 1, self.label_size) - 1).view(batchSize, self.label_size) ##select position 285 | lastScores += self.transition[:, self.end_idx].view(1, self.label_size).expand(batchSize, self.label_size) 286 | decodeIdx[:, 0] = torch.argmax(lastScores, 1) 287 | bestScores = torch.gather(lastScores, 1, decodeIdx[:, 0].view(batchSize, 1)) 288 | 289 | for distance2Last in range(sentLength - 1): 290 | lastNIdxRecord = torch.gather(idxRecord, 1, torch.where(word_seq_lens - distance2Last - 1 > 0, word_seq_lens - distance2Last - 1, mask).view(batchSize, 1, 1).expand(batchSize, 1, self.label_size)).view(batchSize, self.label_size) 291 | decodeIdx[:, distance2Last + 1] = torch.gather(lastNIdxRecord, 1, decodeIdx[:, distance2Last].view(batchSize, 1)).view(batchSize) 292 | 293 | return bestScores, decodeIdx 294 | 295 | def decode(self, batchInput): 296 | wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, trees, tagSeqTensor, batch_dep_label = batchInput 297 | features = self.neural_scoring(wordSeqTensor, wordSeqLengths, batch_context_emb,charSeqTensor,charSeqLengths, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, batch_dep_label, trees) 298 | all_scores = self.calculate_all_scores(features) 299 | bestScores, decodeIdx = self.viterbiDecode(all_scores, wordSeqLengths) 300 | return bestScores, decodeIdx 301 | -------------------------------------------------------------------------------- /preprocess/convert_sem_eng.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | ### This file is used to convert the semeval English into our conllx format 6 | 7 | def process(filename:str, out:str): 8 | fres = open(out, 'w', encoding='utf-8') 9 | print(filename) 10 | with open(filename, 'r', encoding='utf-8') as f: 11 | words = [] 12 | heads = [] 13 | deps =[] 14 | labels = [] 15 | prev_label = "O" 16 | prev_raw_label = "" 17 | for line in f.readlines(): 18 | line = line.rstrip() 19 | # print(line) 20 | if line.startswith("#"): 21 | prev_label = "O" 22 | prev_raw_label = "" 23 | continue 24 | if line == "": 25 | idx = 1 26 | for w, h, dep, label in zip(words, heads, deps, labels): 27 | if dep == "sentence": 28 | dep = "root" 29 | fres.write("{}\t{}\t_\t_\t_\t_\t{}\t{}\t_\t_\t{}\n".format(idx, w, h, dep, label)) 30 | idx += 1 31 | fres.write('\n') 32 | words = [] 33 | heads = [] 34 | deps = [] 35 | labels = [] 36 | prev_label = "O" 37 | continue 38 | #1 West _ NNP NNP _ 5 compound _ _ B-MISC 39 | vals = line.split() 40 | idx = vals[0] 41 | word = vals[1] 42 | head = vals[8] 43 | dep_label = vals[10] 44 | label = vals[12] 45 | 46 | if label.startswith("("): 47 | if label.endswith(")"): 48 | label = "B-" + label[1:-1] 49 | else: 50 | label = "B-" + label[1:] 51 | elif label.startswith(")"): 52 | label = "I-" + label[:-1] 53 | else: 54 | if prev_label == "O": 55 | label = "O" 56 | else: 57 | if prev_raw_label.endswith(")"): 58 | label = "O" 59 | else: 60 | label = "I-" + prev_label[2:] 61 | 62 | words.append(word) 63 | heads.append(head) 64 | labels.append(label) 65 | deps.append(dep_label) 66 | prev_label = label 67 | prev_raw_label = vals[12] 68 | fres.close() 69 | 70 | 71 | 72 | 73 | 74 | 75 | # process("data/semeval10t1/en.train.txt", "data/semeval10t1/train.sd.conllx") 76 | # process("data/semeval10t1/en.devel.txt", "data/semeval10t1/dev.sd.conllx") 77 | # process("data/semeval10t1/en.test.txt", "data/semeval10t1/test.sd.conllx") 78 | 79 | lang = "it" 80 | folder="sem" + lang 81 | process("data/"+folder+"/"+lang+".train.txt", "data/"+folder+"/train.sd.conllx") 82 | process("data/"+folder+"/"+lang+".devel.txt", "data/"+folder+"/dev.sd.conllx") 83 | process("data/"+folder+"/"+lang+".test.txt", "data/"+folder+"/test.sd.conllx") -------------------------------------------------------------------------------- /preprocess/convert_sem_other.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | from typing import List 5 | 6 | type2num = {} 7 | 8 | def extract(words: List[str], labels:List[str], heads:List[str], deps:List[str]): 9 | entity_pool = [] ## type, left, right 10 | completed_pool = [] 11 | print(words) 12 | # print(labels) 13 | for i, label in enumerate(labels): 14 | if label == "_": 15 | continue 16 | if "|" in label: 17 | vals = label.split("|") 18 | for val in vals: 19 | if val.startswith("(") and val.endswith(")"): 20 | completed_pool.append((i, i, val[1:-1])) 21 | elif val.startswith("("): 22 | entity_pool.append((i, -1, val[1:])) 23 | elif val.endswith(")"): 24 | found = False 25 | for tup in entity_pool[::-1]: 26 | start, end, cur = tup 27 | if cur == val[:-1]: 28 | completed_pool.append((start, i, cur)) 29 | entity_pool.remove(tup) 30 | found = True 31 | break 32 | if not found: 33 | raise Exception("not found the entity:{}".format(val)) 34 | else: 35 | raise Exception("not val type".format(val)) 36 | else: 37 | if label.startswith("(") and label.endswith(")"): 38 | completed_pool.append((i, i, label[1:-1])) 39 | elif label.startswith("("): 40 | entity_pool.append((i, -1, label[1:])) 41 | elif label.endswith(")"): 42 | found = False 43 | for tup in entity_pool[::-1]: 44 | start, end, cur = tup 45 | if cur == label[:-1]: 46 | completed_pool.append((start, i, cur)) 47 | entity_pool.remove(tup) 48 | found = True 49 | break 50 | if not found: 51 | raise Exception("not found the entity:{}".format(label)) 52 | else: 53 | raise Exception("not val type {}".format(label)) 54 | assert (len(entity_pool) == 0) 55 | 56 | 57 | for i in range(len(words)): 58 | curr_pos = [] 59 | for span in completed_pool: 60 | start, end, label = span 61 | if i >= start and i <= end: 62 | curr_pos.append(span) 63 | curr_pos = sorted(curr_pos, key=lambda span: span[1] - span[0]) 64 | for span in curr_pos[1:]: 65 | completed_pool.remove(span) 66 | 67 | labels = ["O"] * len(words) 68 | visited = [False] * len(words) 69 | for span in completed_pool: 70 | start, end, label = span 71 | 72 | for check in visited[start:(end+1)]: 73 | if check: 74 | raise Exception("this position is checked.") 75 | 76 | if label not in ('person', 'loc', 'org'): 77 | label = 'misc' 78 | 79 | labels[start] = "B-"+label 80 | labels[(start+1):end] = ["I-" + label] * (end - start) 81 | visited[start: (end+1)] = [True] * (end-start + 1) 82 | 83 | if label in type2num: 84 | type2num[label] += 1 85 | else: 86 | type2num[label] = 1 87 | 88 | # print(labels) 89 | return labels 90 | 91 | 92 | def read_all_sents(filename:str, out:str, use_gold_dep: bool = True): 93 | print(filename) 94 | fres = open(out, 'w', encoding='utf-8') 95 | sents = [] 96 | with open(filename, 'r', encoding='utf-8') as f: 97 | words = [] 98 | heads = [] 99 | deps = [] 100 | labels = [] 101 | pos_tags = [] 102 | for line in f.readlines(): 103 | line = line.rstrip() 104 | # print(line) 105 | if line.startswith("#"): 106 | continue 107 | if line == "": 108 | idx = 1 109 | labels = extract(words, labels, heads, deps) 110 | idx = 1 111 | for w, h, dep, label, pos_tag in zip(words, heads, deps, labels, pos_tags): 112 | if dep == "sentence": 113 | dep = "root" 114 | fres.write("{}\t{}\t_\t{}\t{}\t_\t{}\t{}\t_\t_\t{}\n".format(idx, w, pos_tag, pos_tag, h, dep, label)) 115 | idx += 1 116 | fres.write('\n') 117 | 118 | words = [] 119 | heads = [] 120 | deps = [] 121 | labels = [] 122 | continue 123 | # 1 West _ NNP NNP _ 5 compound _ _ B-MISC 124 | vals = line.split() 125 | idx = vals[0] 126 | word = vals[1] 127 | pos_tag = vals[4] 128 | head = vals[8] if use_gold_dep else vals[9] 129 | dep_label = vals[10] if use_gold_dep else vals[11] 130 | label = vals[12] 131 | words.append(word) 132 | pos_tags.append(pos_tag) 133 | heads.append(head) 134 | labels.append(label) 135 | deps.append(dep_label) 136 | fres.close() 137 | 138 | def process(filename:str, out:str): 139 | fres = open(out, 'w', encoding='utf-8') 140 | print(filename) 141 | with open(filename, 'r', encoding='utf-8') as f: 142 | words = [] 143 | heads = [] 144 | deps =[] 145 | labels = [] 146 | prev_label = "O" 147 | prev_raw_label = "" 148 | for line in f.readlines(): 149 | line = line.rstrip() 150 | # print(line) 151 | if line.startswith("#"): 152 | prev_label = "O" 153 | prev_raw_label = "" 154 | continue 155 | if line == "": 156 | idx = 1 157 | for w, h, dep, label in zip(words, heads, deps, labels): 158 | if dep == "sentence": 159 | dep = "root" 160 | fres.write("{}\t{}\t_\t_\t_\t_\t{}\t{}\t_\t_\t{}\n".format(idx, w, h, dep, label)) 161 | idx += 1 162 | fres.write('\n') 163 | words = [] 164 | heads = [] 165 | deps = [] 166 | labels = [] 167 | prev_label = "O" 168 | continue 169 | #1 West _ NNP NNP _ 5 compound _ _ B-MISC 170 | vals = line.split() 171 | idx = vals[0] 172 | word = vals[1] 173 | head = vals[8] 174 | dep_label = vals[10] 175 | label = vals[12] 176 | 177 | if label.startswith("("): 178 | if label.endswith(")"): 179 | label = "B-" + label[1:-1] 180 | else: 181 | label = "B-" + label[1:] 182 | elif label.startswith(")"): 183 | label = "I-" + label[:-1] 184 | else: 185 | if prev_label == "O": 186 | label = "O" 187 | else: 188 | if prev_raw_label.endswith(")"): 189 | label = "O" 190 | else: 191 | label = "I-" + prev_label[2:] 192 | 193 | words.append(word) 194 | heads.append(head) 195 | labels.append(label) 196 | deps.append(dep_label) 197 | prev_label = label 198 | prev_raw_label = vals[12] 199 | fres.close() 200 | 201 | 202 | 203 | 204 | ### This file is used to convert the semeval Catalan and Spanish into our conllx format 205 | 206 | lang = "ca" 207 | folder="sem" + lang 208 | use_gold_dep = False 209 | affix = "sd" if use_gold_dep else "sud" 210 | read_all_sents("data/"+folder+"/"+lang+".train.txt", "data/"+folder+"/train."+affix+".conllx", use_gold_dep) 211 | print(type2num) 212 | type2num = {} 213 | read_all_sents("data/"+folder+"/"+lang+".devel.txt", "data/"+folder+"/dev."+affix+".conllx", use_gold_dep) 214 | 215 | print(type2num) 216 | type2num = {} 217 | read_all_sents("data/"+folder+"/"+lang+".test.txt", "data/"+folder+"/test."+affix+".conllx", use_gold_dep) 218 | 219 | print(type2num) -------------------------------------------------------------------------------- /preprocess/elmo_others.py: -------------------------------------------------------------------------------- 1 | from elmoformanylangs import Embedder 2 | import pickle 3 | 4 | """ 5 | This file should be deprecated since every time result is different. 6 | 7 | """ 8 | 9 | 10 | def read_conllx(filename:str): 11 | print(filename) 12 | sents = [] 13 | with open(filename, 'r', encoding='utf-8') as f: 14 | words = [] 15 | for line in f.readlines(): 16 | line = line.rstrip() 17 | if line == "": 18 | if len(words) == 0: 19 | print("len is 0") 20 | sents.append(words) 21 | words = [] 22 | continue 23 | vals = line.split() 24 | words.append(vals[1]) 25 | return sents 26 | 27 | 28 | def context_emb(emb, sents): 29 | ## 0, word encoder: 30 | ##1 for the first LSTM hidden layer 31 | ## 2 for the second LSTM hidden lyaer 32 | ## -1 for an average of 3 layers (default) 33 | ## -2 for all 3 layers 34 | return emb.sents2elmo(sents, -1) 35 | 36 | 37 | def read_parse_write(elmo, in_file, out_file): 38 | sents = read_conllx(in_file) 39 | print("number of sentences: {} in {}".format(len(sents), in_file)) 40 | f = open(out_file, 'wb') 41 | batch_size = 1 42 | all_vecs = [] 43 | for idx in range(0, len(sents), batch_size): 44 | start = idx*batch_size 45 | end = (idx+1)*batch_size if (idx+1)*batch_size < len(sents) else len(sents) 46 | batch_sents = sents[start: end] 47 | #print(batch_sents) 48 | embs = context_emb(elmo, batch_sents) 49 | for emb in embs: 50 | all_vecs.append(emb) 51 | pickle.dump(all_vecs, f) 52 | f.close() 53 | 54 | 55 | ## NOTE: Remember to download the model and change the path here 56 | elmo = Embedder('/data/allan/embeddings/Spanish_ELMo', batch_size=1) 57 | 58 | dataset = "spanish" 59 | read_parse_write(elmo, f"data/{dataset}/train.sd.conllx", f"data/{dataset}/train.conllx.elmo.vec") 60 | read_parse_write(elmo, f"data/{dataset}/dev.sd.conllx", f"data/{dataset}/dev.conllx.elmo.vec") 61 | read_parse_write(elmo, f"data/{dataset}/test.sd.conllx", f"data/{dataset}/test.conllx.elmo.vec") 62 | 63 | -------------------------------------------------------------------------------- /preprocess/prebert.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | from config.reader import Reader 6 | import numpy as np 7 | import pickle 8 | import torch 9 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 10 | 11 | # # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows 12 | # import logging 13 | # logging.basicConfig(level=logging.INFO) 14 | 15 | 16 | def parse_sentence(tokenizer, model, words, mode:str="average"): 17 | model.eval() 18 | indexed_tokens = tokenizer.convert_tokens_to_ids(words) 19 | segments_ids = [0] * len(indexed_tokens) 20 | tokens_tensor = torch.LongTensor([indexed_tokens]).to(device) 21 | segments_tensors = torch.LongTensor([segments_ids]).to(device) 22 | with torch.no_grad(): 23 | encoded_layers, _ = model(tokens_tensor, segments_tensors) 24 | return encoded_layers 25 | 26 | def read_parse_write(tokenizer, model, infile, outfile, mode): 27 | reader = Reader() 28 | insts = reader.read_conll(infile, -1, True) 29 | f = open(outfile, 'wb') 30 | all_vecs = [] 31 | for inst in insts: 32 | vec = parse_sentence(tokenizer, model, inst.input.words, mode=mode) 33 | all_vecs.append(vec) 34 | pickle.dump(all_vecs, f) 35 | f.close() 36 | 37 | 38 | def load_bert(): 39 | # Load pre-trained model tokenizer (vocabulary) 40 | tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 41 | model = BertModel.from_pretrained('bert-base-cased') 42 | model.eval() 43 | model.to(device) 44 | return tokenizer, model 45 | 46 | 47 | device = torch.device('cuda:0') 48 | tokenizer, bert_model = load_bert() 49 | mode= "average" 50 | dataset="conll2003" 51 | dep = "" 52 | file = "../data/"+dataset+"/train"+dep+".conllx" 53 | outfile = file + ".bert."+mode+".vec" 54 | read_parse_write(tokenizer, bert_model, file, outfile, mode) 55 | file = "../data/"+dataset+"/dev"+dep+".conllx" 56 | outfile = file + ".bert."+mode+".vec" 57 | read_parse_write(tokenizer, bert_model, file, outfile, mode) 58 | file = "../data/"+dataset+"/test"+dep+".conllx" 59 | outfile = file + ".bert."+mode+".vec" 60 | read_parse_write(tokenizer, bert_model, file, outfile, mode) -------------------------------------------------------------------------------- /preprocess/preelmo.py: -------------------------------------------------------------------------------- 1 | # 2 | # @author: Allan 3 | # 4 | 5 | from config.reader import Reader 6 | import numpy as np 7 | from allennlp.commands.elmo import ElmoEmbedder 8 | import pickle 9 | 10 | 11 | def parse_sentence(elmo, words, mode:str="average"): 12 | vectors = elmo.embed_sentence(words) 13 | if mode == "average": 14 | return np.average(vectors, 0) 15 | elif mode == 'weighted_average': 16 | return np.swapaxes(vectors, 0, 1) 17 | elif mode == 'last': 18 | return vectors[-1, :, :] 19 | elif mode == 'all': 20 | return vectors 21 | else: 22 | return vectors 23 | 24 | 25 | def load_elmo(): 26 | return ElmoEmbedder(cuda_device=0) 27 | 28 | 29 | 30 | def read_parse_write(elmo, infile, outfile, mode): 31 | reader = Reader() 32 | insts = reader.read_conll(infile, -1, True) 33 | f = open(outfile, 'wb') 34 | all_vecs = [] 35 | for inst in insts: 36 | vec = parse_sentence(elmo, inst.input.words, mode=mode) 37 | all_vecs.append(vec) 38 | pickle.dump(all_vecs, f) 39 | f.close() 40 | 41 | 42 | elmo = load_elmo() 43 | mode= "average" 44 | dataset="ontonotes" 45 | dep = "" 46 | file = "../data/"+dataset+"/train"+dep+".conllx" 47 | outfile = file + ".elmo."+mode+".vec" 48 | read_parse_write(elmo, file, outfile, mode) 49 | file = "../data/"+dataset+"/dev"+dep+".conllx" 50 | outfile = file + ".elmo."+mode+".vec" 51 | read_parse_write(elmo, file, outfile, mode) 52 | file = "../data/"+dataset+"/test"+dep+".conllx" 53 | outfile = file + ".elmo."+mode+".vec" 54 | read_parse_write(elmo, file, outfile, mode) 55 | 56 | -------------------------------------------------------------------------------- /preprocess/preflair.py: -------------------------------------------------------------------------------- 1 | from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings, BertEmbeddings, PooledFlairEmbeddings 2 | import pickle 3 | from config.reader import Reader 4 | import numpy as np 5 | from flair.data import Sentence 6 | 7 | def load_flair(mode = 'flair'): 8 | if mode == 'flair': 9 | stacked_embeddings = StackedEmbeddings([ 10 | WordEmbeddings('glove'), 11 | PooledFlairEmbeddings('news-forward', pooling='min'), 12 | PooledFlairEmbeddings('news-backward', pooling='min') 13 | ]) 14 | else:##bert 15 | stacked_embeddings = BertEmbeddings('bert-base-uncased') ##concat last 4 layers give the best 16 | return stacked_embeddings 17 | 18 | def embed_sent(embeder, sent): 19 | sent = Sentence(' '.join(sent)) 20 | embeder.embed(sent) 21 | return sent 22 | 23 | 24 | def read_parse_write(elmo, infile, outfile,): 25 | reader = Reader() 26 | insts = reader.read_conll(infile, -1, True) 27 | f = open(outfile, 'wb') 28 | all_vecs = [] 29 | for inst in insts: 30 | sent = embed_sent(elmo, inst.input.words) 31 | # np.empty((len(sent)),dtype=np.float32) 32 | arr = [] 33 | for token in sent: 34 | # print(token) 35 | # print(token.embedding) 36 | arr.append(np.expand_dims(token.embedding.numpy(), axis=0)) 37 | # all_vecs.append(vec) 38 | all_vecs.append(np.concatenate(arr)) 39 | pickle.dump(all_vecs, f) 40 | f.close() 41 | 42 | 43 | mode = 'flair' 44 | model = load_flair(mode=mode) 45 | # mode= "average" 46 | dataset="conll2003" 47 | dep = ".sd" 48 | file = "./data/"+dataset+"/train"+dep+".conllx" 49 | outfile = file.replace(".sd", "") + "."+mode+".vec" 50 | read_parse_write(model, file, outfile) 51 | file = "./data/"+dataset+"/dev"+dep+".conllx" 52 | outfile = file.replace(".sd", "") + "."+mode+".vec" 53 | read_parse_write(model, file, outfile) 54 | file = "./data/"+dataset+"/test"+dep+".conllx" 55 | outfile = file.replace(".sd", "") + "."+mode+".vec" 56 | read_parse_write(model, file, outfile) 57 | -------------------------------------------------------------------------------- /scripts/run.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | 6 | datasets=(ontonotes ontonotes_chinese catalan spanish) 7 | context_emb=elmo 8 | num_epochs_all=(100 100 300 300) 9 | devices=(cuda:0 cuda:1 cuda:2 cuda:3) ##cpu, cuda:0, cuda:1 10 | dep_model=dglstm ## none, dglstm, dggcn means do not use head features 11 | embs=(data/glove.6B.100d.txt data/cc.zh.300.vec data/cc.ca.300.vec data/cc.es.300.vec) 12 | num_lstm_layer=2 13 | inter_func=mlp 14 | 15 | for (( d=0; d<${#datasets[@]}; d++ )) do 16 | dataset=${datasets[$d]} 17 | emb=${embs[$d]} 18 | device=${devices[$d]} 19 | num_epochs=${num_epochs_all[$d]} 20 | first_part=logs/hidden_${num_lstm_layer}_${dataset}_${dep_model}_asfeat_${context_emb} 21 | logfile=${first_part}_epoch_${num_epochs}_if_${inter_func}.log 22 | python3.6 main.py --context_emb ${context_emb} \ 23 | --dataset ${dataset} --num_epochs ${num_epochs} --device ${device} --num_lstm_layer ${num_lstm_layer} \ 24 | --dep_model ${dep_model} \ 25 | --embedding_file ${emb} --inter_func ${inter_func} > ${logfile} 2>&1 26 | 27 | done 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /scripts/run_pytorch.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #autobatch=1 4 | #--dynet-autobatch 5 | #optimizer=adam 6 | #lr=1 7 | #batch=1 8 | #gpu=1 9 | 10 | #datasets=(ontonotes_chinese) 11 | datasets=(conll2003) 12 | #datasets=(bc bn mz nw tc wb) 13 | #heads=(1) ##1 means use GCN embedding. 14 | #datasets=(all) 15 | context_emb=elmo 16 | hidden=200 17 | optim=sgd 18 | batch=10 19 | num_epochs=100 20 | eval_freq=10000 21 | device=cuda:1 ##cpu, cuda:0, cuda:1 22 | gcn_layer=1 23 | gcn_dropout=0.5 24 | gcn_mlp_layers=1 25 | dep_model=dglstm ## none, dglstm, dggcn means do not use head features 26 | dep_hidden_dim=200 27 | affix=ssd 28 | gcn_adj_directed=0 ##bidirection 29 | gcn_adj_selfloop=0 ## keep to zero because we always add self loop in gcn 30 | embs=(data/glove.6B.100d.txt) 31 | #emb=data/cc.zh.300.vec 32 | lr=0.01 33 | gcn_gate=0 ##without gcn gate 34 | num_base=-1 ## number of bases in relational gcn 35 | num_lstm_layer=2 36 | dep_double_label=0 37 | inter_func=mlp 38 | 39 | for (( d=0; d<${#datasets[@]}; d++ )) do 40 | dataset=${datasets[$d]} 41 | emb=${embs[$d]} 42 | first_part=logs/hidden_${num_lstm_layer}_${hidden}_${dataset}_${affix}_head_${dep_model}_asfeat_${context_emb}_gcn_${gcn_layer}_${gcn_mlp_layers}_${gcn_dropout}_gate_${gcn_gate} 43 | logfile=${first_part}_dir_${gcn_adj_directed}_loop_${gcn_adj_selfloop}_base_${num_base}_epoch_${num_epochs}_lr_${lr}_dd_${dep_double_label}_if_${inter_func}.log 44 | python3.6 main.py --context_emb ${context_emb} --hidden_dim ${hidden} --optimizer ${optim} --gcn_adj_directed ${gcn_adj_directed} --gcn_adj_selfloop ${gcn_adj_selfloop} \ 45 | --dataset ${dataset} --eval_freq ${eval_freq} --num_epochs ${num_epochs} --device ${device} --dep_hidden_dim ${dep_hidden_dim} --num_lstm_layer ${num_lstm_layer} \ 46 | --batch_size ${batch} --num_gcn_layers ${gcn_layer} --gcn_mlp_layers ${gcn_mlp_layers} --dep_model ${dep_model} --gcn_gate ${gcn_gate} --dep_double_label ${dep_double_label} \ 47 | --gcn_dropout ${gcn_dropout} --affix ${affix} --lr_decay 0 --learning_rate ${lr} --embedding_file ${emb} --inter_func ${inter_func} \ 48 | --num_base ${num_base} > ${logfile} 2>&1 49 | 50 | done 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /scripts/run_pytorch_all.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #autobatch=1 4 | #--dynet-autobatch 5 | #optimizer=adam 6 | #lr=1 7 | #batch=1 8 | #gpu=1 9 | 10 | #datasets=(ontonotes_chinese) 11 | datasets=(ontonotes ontonotes_chinese catalan spanish) 12 | #datasets=(bc bn mz nw tc wb) 13 | #heads=(1) ##1 means use GCN embedding. 14 | #datasets=(all) 15 | context_emb=elmo 16 | hidden=200 17 | optim=sgd 18 | batch=10 19 | num_epochs_all=(100 100 300 300) 20 | devices=(cuda:0 cuda:1 cuda:2 cuda:3) ##cpu, cuda:0, cuda:1 21 | gcn_layer=1 22 | gcn_dropout=0.5 23 | gcn_mlp_layers=1 24 | dep_model=dglstm ## none, dglstm, dggcn means do not use head features 25 | dep_hidden_dim=200 26 | affix=sd 27 | gcn_adj_directed=0 ##bidirection 28 | gcn_adj_selfloop=0 ## keep to zero because we always add self loop in gcn 29 | embs=(data/glove.6B.100d.txt data/cc.zh.300.vec data/cc.ca.300.vec data/cc.es.300.vec) 30 | #emb=data/cc.zh.300.vec 31 | lr=0.01 32 | gcn_gate=0 ##without gcn gate 33 | num_base=-1 ## number of bases in relational gcn 34 | num_lstm_layer=2 35 | dep_double_label=0 36 | inter_func=mlp 37 | 38 | for (( d=0; d<${#datasets[@]}; d++ )) do 39 | dataset=${datasets[$d]} 40 | emb=${embs[$d]} 41 | device=${devices[$d]} 42 | num_epochs=${num_epochs_all[$d]} 43 | first_part=logs/hidden_${num_lstm_layer}_${hidden}_${dataset}_${affix}_head_${dep_model}_asfeat_${context_emb}_gcn_${gcn_layer}_${gcn_mlp_layers}_${gcn_dropout}_gate_${gcn_gate} 44 | logfile=${first_part}_dir_${gcn_adj_directed}_loop_${gcn_adj_selfloop}_base_${num_base}_epoch_${num_epochs}_lr_${lr}_dd_${dep_double_label}_if_${inter_func}.log 45 | python3.6 main.py --context_emb ${context_emb} --hidden_dim ${hidden} --optimizer ${optim} --gcn_adj_directed ${gcn_adj_directed} --gcn_adj_selfloop ${gcn_adj_selfloop} \ 46 | --dataset ${dataset} --num_epochs ${num_epochs} --device ${device} --dep_hidden_dim ${dep_hidden_dim} --num_lstm_layer ${num_lstm_layer} \ 47 | --batch_size ${batch} --num_gcn_layers ${gcn_layer} --gcn_mlp_layers ${gcn_mlp_layers} --dep_model ${dep_model} --gcn_gate ${gcn_gate} --dep_double_label ${dep_double_label} \ 48 | --gcn_dropout ${gcn_dropout} --affix ${affix} --lr_decay 0 --learning_rate ${lr} --embedding_file ${emb} --inter_func ${inter_func} \ 49 | --num_base ${num_base} > ${logfile} 2>&1 50 | 51 | done 52 | 53 | 54 | 55 | --------------------------------------------------------------------------------