├── LICENSE
├── README.md
├── analysis
├── __init__.py
├── analyze_chinese.py
├── analyzer.py
├── comparator.py
├── compare_heatmap.py
├── evaluate_result.py
├── heatmap.py
├── intro_examples.py
├── length.py
├── length_analysis.py
├── significant.py
└── stator.py
├── common
├── instance.py
└── sentence.py
├── config
├── __init__.py
├── config.py
├── eval.py
├── reader.py
└── utils.py
├── data
├── catalan
│ ├── dev.sd.conllx
│ ├── test.sd.conllx
│ └── train.sd.conllx
├── readme.txt
└── spanish
│ ├── dev.sd.conllx
│ ├── test.sd.conllx
│ └── train.sd.conllx
├── main.py
├── model
├── charbilstm.py
├── deplabel_gcn.py
└── lstmcrf.py
├── preprocess
├── convert_sem_eng.py
├── convert_sem_other.py
├── elmo_others.py
├── prebert.py
├── preelmo.py
└── preflair.py
└── scripts
├── run.bash
├── run_pytorch.bash
└── run_pytorch_all.bash
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Dependency-Guided LSTM-CRF Model for Named Entity Recognition
2 |
3 | Codebase for the upcoming paper "[Dependency-Guided LSTM-CRF for Named Entity Recognition](https://www.aclweb.org/anthology/D19-1399.pdf)" in EMNLP 2019.
4 | The usage code below make sure you can reproduce almost same results as shown in the paper.
5 |
6 | ### Requirements
7 | * PyTorch 1.1 (Also tested on PyTorch 1.3)
8 | * Python 3.6
9 |
10 | ### Dataset Format
11 |
12 | I have uploaded the preprocessed `Catalan` and `Spanish` datasets. (Please contact me with your license if you need the preprocessed OntoNotes dataset.)
13 | If you have a new dataset, please make sure we follow the CoNLL-X format and we put the entity label at the end.
14 | The sentence below is an example.
15 | Note that we only use the columns for *word*, *dependency head index*, *dependency relation label* and the last *entity label*.
16 | ```
17 | 1 Brasil _ n n _ 2 suj _ _ B-org
18 | 2 buscará _ v v _ 0 root _ _ O
19 | 3 a_partir_de _ s s _ 2 cc _ _ O
20 | 4 mañana _ n n _ 3 sn _ _ O
21 | 5 , _ f f _ 6 f _ _ B-misc
22 | 6 viernes _ w w _ 4 sn _ _ I-misc
23 | 7 , _ f f _ 6 f _ _ I-misc
24 | 8 el _ d d _ 9 spec _ _ O
25 | 9 pase _ n n _ 2 cd _ _ O
26 | ```
27 | Entity labels follow the `IOB` tagging scheme and will be converted to `IOBES` in this codebase.
28 |
29 | ### Usage
30 |
31 | Baseline **BiLSTM-CRF**:
32 | ```bash
33 | python main.py --dataset ontonotes --embedding_file data/glove.6B.100d.txt \
34 | --num_lstm_layer 1 --dep_model none
35 | ```
36 | Change `embedding_file` if you are using other languages, change `dataset` for other datasets, change `num_lstm_layer` for different `L = 0,1,2,3`. Use `--device cuda:0` if you are using gpu.
37 |
38 | **DGLSTM-CRF**
39 | ```bash
40 | python main.py --dataset ontonotes --embedding_file data/glove.6B.100d.txt \
41 | --num_lstm_layer 1 --dep_model dglstm --inter_func mlp
42 | ```
43 | Change the interaction function `inter_func = concatenation, addition, mlp` for other interactions.
44 |
45 |
46 | ### Usage for other datasets and other languages
47 | Remember to put the dataset under the data folder. The naming rule for `train/dev/test` is `train.sd.conllx`, `dev.sd.conllx` and `test.sd.conllx`.
48 | Then simply change the `--dataset` name and `--embedding_file`.
49 |
50 | Dataset | Embedding
51 | ------------ | -------------
52 | OntoNotes English | glove.6B.100d.txt
53 | OntoNotes Chinese | cc.zh.300.vec (FastText)
54 | Catalan | cc.ca.300.vec (FastText)
55 | Spanish | cc.es.300.vec (FastText)
56 |
57 |
58 |
59 | ### Using ELMo
60 | In any case, once we have obtained the pretrained ELMo vector files ready.
61 | For example, download the `Catalan ELMo` vectors from [here](https://drive.google.com/open?id=1bGCRy4pYDWBcEae5sTSIcdu6PwWgz7Kn), decompressed all the files (`train.conllx.elmo.vec`,`dev.conllx.elmo.vec`, `test.conllx.elmo.vec`) into `data/catalan/`.
62 | We can then simply run the command below (we take the **DGLSTM-CRF** for example)
63 | ```bash
64 | python main.py --dataset ontonotes --embedding_file data/glove.6B.100d.txt \
65 | --num_lstm_layer 1 --dep_model dglstm --inter_func mlp \
66 | --context_emb elmo
67 | ```
68 | ### Obtain ELMo vectors for other languages:
69 | We use the ELMo from AllenNLP for English, and use [ELMoForManyLangs](https://github.com/HIT-SCIR/ELMoForManyLangs) for other languages.
70 | * English, run the `preprocess/preelmo.py` code (remember to change the `dataset` name)
71 | ```bash
72 | python preprocess/preelmo.py
73 | ```
74 | * Chinese, Catalan, and Spanish
75 | Download the ELMo models from [ELMoForManyLangs](https://github.com/HIT-SCIR/ELMoForManyLangs). NOTE: remember to follow the instruction to slighly modify some paths inside.
76 | Then you can run `preprocess/elmo_others.py`: (again remember to change `dataset` name and ELMo model path)
77 | ```bash
78 | python preprocess/elmo_others.py
79 | ```
80 |
81 |
82 | ### Notes on Dataset Preprocessing (Two Options)
83 |
84 | #### OntoNotes Preprocessing
85 | Many people are asking for the OntoNotes 5.0 dataset.
86 | I understand that it is hard to get the correct split as in previous work (Chiu and Nichols, 2016; Li et al., 2017; Ghaddar and Langlais, 2018;).
87 | If you want to get the correct split, you can refere to a guide [here](https://github.com/allanj/pytorch_lstmcrf/blob/master/docs/benchmark.md) where
88 | I summarize how to preprocess the OntoNotes dataset.
89 |
90 | #### Download Our Preprocessed dataset
91 | We notice that the OntoNotes 5.0 dataset has been freely available on LDC. We will also release our link to our pre-processed OntoNotes here ([__English__](https://drive.google.com/file/d/1AAWnb5GlDiNMj3yNoaoQtoKHj7iSqNey/view?usp=sharing), [__Chinese__](https://drive.google.com/file/d/10t3XpZzsD67ji0a7sw9nHM7I5UhrJcdf/view?usp=sharing)).
92 |
93 | ### Citation
94 | ```
95 | @InProceedings{jie2019dependency,
96 | author = "Jie, Zhanming and Lu, Wei",
97 | title = "Dependency-Guided LSTM-CRF for Named Entity Recognition",
98 | booktitle = "Proceedings of EMNLP",
99 | year = "2019",
100 | url = "https://www.aclweb.org/anthology/D19-1399",
101 | doi = "10.18653/v1/D19-1399",
102 | pages = "3860--3870"
103 | }
104 | ```
--------------------------------------------------------------------------------
/analysis/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | ### NOTE: the code in this folder is only used to analyze the results and data statistics
3 | ##
--------------------------------------------------------------------------------
/analysis/analyze_chinese.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 |
6 | from tqdm import tqdm
7 | from common.sentence import Sentence
8 | from common.instance import Instance
9 | from typing import List
10 | from config.eval import evaluate, Span
11 | import random
12 |
13 | def get_spans(output):
14 | output_spans = set()
15 | start = -1
16 | for i in range(len(output)):
17 | if output[i].startswith("B-"):
18 | start = i
19 | if output[i].startswith("E-"):
20 | end = i
21 | output_spans.add(Span(start, end, output[i][2:]))
22 | if output[i].startswith("S-"):
23 | output_spans.add(Span(i, i, output[i][2:]))
24 | return output_spans
25 |
26 | def read_conll(res_file: str, number: int = -1) -> List[Instance]:
27 | print("Reading file: " + res_file)
28 | insts = []
29 | # vocab = set() ## build the vocabulary
30 | with open(res_file, 'r', encoding='utf-8') as f:
31 | words = []
32 | heads = []
33 | deps = []
34 | labels = []
35 | tags = []
36 | preds = []
37 | for line in tqdm(f.readlines()):
38 | line = line.rstrip()
39 | if line == "":
40 | inst = Instance(Sentence(words, heads, deps, tags), labels)
41 | inst.prediction = preds
42 | insts.append(inst)
43 | words = []
44 | heads = []
45 | deps = []
46 | labels = []
47 | tags = []
48 | preds = []
49 |
50 | if len(insts) == number:
51 | break
52 | continue
53 | vals = line.split()
54 | word = vals[1]
55 | pos = vals[2]
56 | head = int(vals[3])
57 | dep_label = vals[4]
58 |
59 | label = vals[5]
60 | pred_label = vals[6]
61 |
62 | words.append(word)
63 | heads.append(head) ## because of 0-indexed.
64 | deps.append(dep_label)
65 | tags.append(pos)
66 | labels.append(label)
67 | preds.append(pred_label)
68 | print("number of sentences: {}".format(len(insts)))
69 | return insts
70 |
71 | res1 = "../final_results/lstm_3_200_crf_ontonotes_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_200_lr_0.01.results"
72 | insts1 = read_conll(res1)
73 |
74 | res2 = "../final_results/lstm_2_200_crf_ontonotes_sd_-1_dep_feat_emb_elmo_none_sgd_gate_0_base_-1_epoch_150_lr_0.01.results"
75 | insts2 = read_conll(res2)
76 |
77 | print(evaluate(insts1))
78 | print(evaluate(insts2))
79 | num = 0
80 | total_entity = 0
81 | type2num = {}
82 | length2num = {}
83 | dep_label2num = {}
84 | gc2num = {}
85 | for i in range(len(insts1)):
86 |
87 | first = insts1[i]
88 | second = insts2[i]
89 | gold_spans = get_spans(first.output)
90 |
91 | pred_first = get_spans(first.prediction)
92 | pred_second = get_spans(second.prediction)
93 |
94 |
95 | # for span in pred_first:
96 | # if span in gold_spans and (span not in pred_second):
97 | for span in gold_spans:
98 | if span in pred_first and (span not in pred_second):
99 | num += 1
100 | print(span.to_str(first.input.words))
101 | if span.type in type2num:
102 | type2num[span.type] +=1
103 | else:
104 | type2num[span.type] = 1
105 | length = span.right - span.left + 1
106 | if length in length2num:
107 | length2num[length] += 1
108 | else:
109 | length2num[length] = 1
110 |
111 | for k in range(span.left, span.right + 1):
112 | if first.input.heads[k] == -1 or (first.input.heads[k] > span.right or first.input.heads[k] < span.left):
113 | if first.input.dep_labels[k] in dep_label2num:
114 | dep_label2num[first.input.dep_labels[k]] +=1
115 | else:
116 | dep_label2num[first.input.dep_labels[k]] = 1
117 |
118 | if first.input.heads[k]!= -1 and first.input.heads[first.input.heads[k]] != -1:
119 | h = first.input.heads[first.input.heads[k]]
120 | if first.input.dep_labels[k] + "," + first.input.dep_labels[h] in gc2num:
121 | gc2num[first.input.dep_labels[k] + "," + first.input.dep_labels[h]] += 1
122 | else:
123 | gc2num[first.input.dep_labels[k] + "," + first.input.dep_labels[h]] = 1
124 |
125 | total_entity +=1
126 |
127 | print(num, total_entity)
128 | print(type2num)
129 | print("length 2 number: {}".format(length2num))
130 |
131 | print()
132 | print("dependency label 2 num: {}".format(dep_label2num))
133 | total_amount = sum([dep_label2num[key] for key in dep_label2num])
134 | print("total number of dep 2 num: {}".format(total_amount))
135 | print()
136 |
137 | counts = [(key, dep_label2num[key]) for key in dep_label2num]
138 | counts = sorted(counts, key=lambda vals: vals[1], reverse=True)
139 | print(counts)
140 | print()
141 |
142 | print(gc2num)
143 | counts = [(key, gc2num[key]) for key in gc2num]
144 | counts = sorted(counts, key=lambda vals: vals[1], reverse=True)
145 | print(counts)
146 |
--------------------------------------------------------------------------------
/analysis/analyzer.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from common.sentence import Sentence
3 | from common.instance import Instance
4 | from typing import List
5 |
6 |
7 | def read_conll(res_file: str, number: int = -1) -> List[Instance]:
8 | print("Reading file: " + res_file)
9 | insts = []
10 | # vocab = set() ## build the vocabulary
11 | with open(res_file, 'r', encoding='utf-8') as f:
12 | words = []
13 | heads = []
14 | deps = []
15 | labels = []
16 | tags = []
17 | preds = []
18 | for line in tqdm(f.readlines()):
19 | line = line.rstrip()
20 | if line == "":
21 | inst = Instance(Sentence(words, heads, deps, tags), labels)
22 | inst.prediction = preds
23 | insts.append(inst)
24 | words = []
25 | heads = []
26 | deps = []
27 | labels = []
28 | tags = []
29 | preds = []
30 |
31 | if len(insts) == number:
32 | break
33 | continue
34 | vals = line.split()
35 | word = vals[1]
36 | pos = vals[2]
37 | head = int(vals[3])
38 | dep_label = vals[4]
39 |
40 | label = vals[5]
41 | pred_label = vals[6]
42 |
43 | words.append(word)
44 | heads.append(head) ## because of 0-indexed.
45 | deps.append(dep_label)
46 | tags.append(pos)
47 | labels.append(label)
48 | preds.append(pred_label)
49 | print("number of sentences: {}".format(len(insts)))
50 | return insts
51 |
52 | res_file = "../results/lstm_200_crf_conll2003_-1_dep_none_elmo_1_sgd_gate_0.results"
53 | insts = read_conll(res_file)
54 |
55 | total = 0
56 | total_word = 0
57 | for inst in insts:
58 | gold = inst.output
59 | prediction = inst.prediction
60 | words = inst.input.words
61 | heads = inst.input.heads
62 | dep_labels = inst.input.dep_labels
63 | have_error= False
64 | for idx in range(len(gold)):
65 | if gold[idx] != 'O' and prediction[idx] == 'O':
66 | have_error = True
67 | total_word += 1
68 | print("{}\t{}\t{}\t{}\t{}\t{}\t".format(idx, words[idx], heads[idx]+1, dep_labels[idx], gold[idx], prediction[idx]))
69 | if have_error:
70 | print(words)
71 | print(gold)
72 | print(prediction)
73 | total +=1
74 | print()
75 | print("number of sentences have errors: {}".format(total))
76 | print("number of words have errors: {}".format(total_word))
77 |
78 |
--------------------------------------------------------------------------------
/analysis/comparator.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 | from tqdm import tqdm
5 | from common.sentence import Sentence
6 | from common.instance import Instance
7 | from typing import List
8 |
9 |
10 |
11 | def read_conll(res_file: str, number: int = -1) -> List[Instance]:
12 | print("Reading file: " + res_file)
13 | insts = []
14 | # vocab = set() ## build the vocabulary
15 | with open(res_file, 'r', encoding='utf-8') as f:
16 | words = []
17 | heads = []
18 | deps = []
19 | labels = []
20 | tags = []
21 | preds = []
22 | for line in tqdm(f.readlines()):
23 | line = line.rstrip()
24 | if line == "":
25 | inst = Instance(Sentence(words, heads, deps, tags), labels)
26 | inst.prediction = preds
27 | insts.append(inst)
28 | words = []
29 | heads = []
30 | deps = []
31 | labels = []
32 | tags = []
33 | preds = []
34 |
35 | if len(insts) == number:
36 | break
37 | continue
38 | vals = line.split()
39 | word = vals[1]
40 | pos = vals[2]
41 | head = int(vals[3])
42 | dep_label = vals[4]
43 |
44 | label = vals[5]
45 | pred_label = vals[6]
46 |
47 | words.append(word)
48 | heads.append(head) ## because of 0-indexed.
49 | deps.append(dep_label)
50 | tags.append(pos)
51 | labels.append(label)
52 | preds.append(pred_label)
53 | print("number of sentences: {}".format(len(insts)))
54 | return insts
55 |
56 |
57 |
58 | lgcn_file = "../final_results/lstm_200_crf_ontonotes_sd_-1_dep_lstm_lgcn_elmo_elmo_sgd_gate_0_epoch_100_lr_0.01.results"
59 | elmo_file = "../final_results/lstm_200_crf_ontonotes_.sd_-1_dep_none_elmo_elmo_sgd_gate_0_epoch_100_lr_0.01.results"
60 | lgcn_res = read_conll(lgcn_file)
61 | elmo_res = read_conll(elmo_file)
62 |
63 |
64 |
65 |
66 | total = 0
67 | total_word = 0
68 | for dep_inst, inst in zip(lgcn_res, elmo_res):
69 | gold = inst.output
70 | normal_pred = inst.prediction
71 | dep_pred = dep_inst.prediction
72 | words = inst.input.words
73 | heads = inst.input.heads
74 | dep_labels = inst.input.dep_labels
75 | have_error= False
76 | for idx in range(len(gold)):
77 | if normal_pred[idx] != dep_pred[idx]:
78 | if gold[idx] == dep_pred[idx]:
79 | print("{}\t{}\t{}\t{}\t{}\t{}\t".format(idx, words[idx], heads[idx] + 1, dep_labels[idx], gold[idx], normal_pred[idx]))
80 | print("")
81 |
--------------------------------------------------------------------------------
/analysis/compare_heatmap.py:
--------------------------------------------------------------------------------
1 | from config.reader import Reader
2 |
3 | from common.sentence import Sentence
4 | from common.instance import Instance
5 | from typing import List
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | import seaborn as sns; sns.set(font_scale=0.8)
10 | import matplotlib.pyplot as plt
11 | import random
12 |
13 |
14 | def read_results(res_file: str, number: int = -1) -> List[Instance]:
15 | print("Reading file: " + res_file)
16 | insts = []
17 | # vocab = set() ## build the vocabulary
18 | with open(res_file, 'r', encoding='utf-8') as f:
19 | words = []
20 | heads = []
21 | deps = []
22 | labels = []
23 | tags = []
24 | preds = []
25 | for line in tqdm(f.readlines()):
26 | line = line.rstrip()
27 | if line == "":
28 | inst = Instance(Sentence(words, heads, deps, tags), labels)
29 | inst.prediction = preds
30 | insts.append(inst)
31 | words = []
32 | heads = []
33 | deps = []
34 | labels = []
35 | tags = []
36 | preds = []
37 |
38 | if len(insts) == number:
39 | break
40 | continue
41 | vals = line.split()
42 | word = vals[1]
43 | pos = vals[2]
44 | head = int(vals[3])
45 | dep_label = vals[4]
46 |
47 | label = vals[5]
48 | pred_label = vals[6]
49 |
50 | words.append(word)
51 | heads.append(head) ## because of 0-indexed.
52 | deps.append(dep_label)
53 | tags.append(pos)
54 | labels.append(label)
55 | preds.append(pred_label)
56 | print("number of sentences: {}".format(len(insts)))
57 | return insts
58 |
59 |
60 | # file = "data/ontonotes/test.sd.conllx"
61 | # digit2zero = False
62 | # reader = Reader(digit2zero)
63 | #
64 | # insts = reader.read_conll(file, -1, True)
65 |
66 | file = "final_results/lstm_2_200_crf_ontonotes_sd_-1_dep_feat_emb_elmo_none_sgd_gate_0_base_-1_epoch_150_lr_0.01.results"
67 | insts = read_results(file) ##change inst.output -> inst.prediction
68 |
69 | comp_file = "final_results/lstm_2_200_crf_ontonotes_sd_-1_dep_none_elmo_none_sgd_gate_0_base_-1_epoch_100_lr_0.01.results"
70 | comp_insts = read_results(comp_file) ##change inst.output -> inst.prediction
71 |
72 | entities = set([ label[2:] for inst in insts for label in inst.output if len(label)>1])
73 | print(entities)
74 | dep_labels = set([ dep for inst in insts for label, dep in zip(inst.prediction, inst.input.dep_labels) if len(label)>1] )
75 | print(len(dep_labels), dep_labels)
76 |
77 | ### add grandchild relation as well.
78 | for inst in insts:
79 | for head, dep in zip(inst.input.heads, inst.input.dep_labels):
80 | if head == -1:
81 | continue
82 | dep_labels.add(dep+", " + inst.input.dep_labels[head])
83 | print(len(dep_labels), dep_labels)
84 | ###
85 |
86 |
87 | ent2idx = {}
88 | ents = list(entities)
89 | ents.sort()
90 | for i, label in enumerate(ents):
91 | ent2idx[label] = i
92 |
93 |
94 | dep2idx = {}
95 | deps = list(dep_labels)
96 | deps.sort()
97 | for i, label in enumerate(deps):
98 | dep2idx[label] = i
99 |
100 | ent_dep_mat = np.zeros((len(entities), len(dep_labels)))
101 | print(ent_dep_mat.shape)
102 | for inst, comp_inst in zip(insts,comp_insts):
103 | for label, dep, gold, comp_label, head in zip(inst.prediction, inst.input.dep_labels, inst.output, comp_inst.prediction, inst.input.heads):
104 | if gold == "O":
105 | continue
106 | if label == "O":
107 | continue
108 | if label == gold and label != comp_label:
109 | ent_dep_mat[ent2idx[label[2:]]][dep2idx[dep]] += 1
110 | if head != -1:
111 | if inst.output[head] == inst.prediction[head] and inst.output[head] != comp_inst.prediction[head]:
112 | ent_dep_mat[ent2idx[label[2:]]][dep2idx[dep+", " + inst.input.dep_labels[head]]] += 1
113 |
114 | sum_labels = [ sum(ent_dep_mat[i]) for i in range(ent_dep_mat.shape[0])]
115 | ent_dep_mat = np.stack([ (ent_dep_mat[i]/sum_labels[i]) * 100 for i in range(ent_dep_mat.shape[0])], axis=0)
116 | print(ent_dep_mat.shape)
117 |
118 | indexs = [i for i in range(ent_dep_mat.shape[1]) if len(ent_dep_mat[:,i][ ent_dep_mat[:,i] >5.0 ]) ]
119 | print(np.asarray(deps)[indexs])
120 |
121 | xlabels = [deps[i] for i in indexs]
122 | # cmap = sns.light_palette("#2ecc71", as_cmap=True)
123 | # cmap = sns.light_palette("#8e44ad", as_cmap=True)
124 | cmap = sns.cubehelix_palette(8,as_cmap=True)
125 | ax = sns.heatmap(ent_dep_mat[:, indexs], annot=True, vmin=0, vmax=100, cmap=cmap,fmt='.0f', xticklabels=xlabels, yticklabels=ents, cbar=True)
126 | # ,annot_kws = {"size": 10})
127 | # , cbar_kws={'label': 'percentage (%)'})
128 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
129 | plt.show()
--------------------------------------------------------------------------------
/analysis/evaluate_result.py:
--------------------------------------------------------------------------------
1 |
2 | from common.sentence import Sentence
3 | from common.instance import Instance
4 | from typing import List
5 | from config.eval import Span
6 | from tqdm import tqdm
7 |
8 | from config.eval import evaluate
9 |
10 | def read_conll(res_file: str, number: int = -1) -> List[Instance]:
11 | print("Reading file: " + res_file)
12 | insts = []
13 | # vocab = set() ## build the vocabulary
14 | with open(res_file, 'r', encoding='utf-8') as f:
15 | words = []
16 | heads = []
17 | deps = []
18 | labels = []
19 | tags = []
20 | preds = []
21 | for line in tqdm(f.readlines()):
22 | line = line.rstrip()
23 | if line == "":
24 | inst = Instance(Sentence(words, heads, deps, tags), labels)
25 | inst.prediction = preds
26 | insts.append(inst)
27 | words = []
28 | heads = []
29 | deps = []
30 | labels = []
31 | tags = []
32 | preds = []
33 |
34 | if len(insts) == number:
35 | break
36 | continue
37 | vals = line.split()
38 | word = vals[1]
39 | pos = vals[2]
40 | head = int(vals[3])
41 | dep_label = vals[4]
42 |
43 | label = vals[5]
44 | pred_label = vals[6]
45 |
46 | words.append(word)
47 | heads.append(head) ## because of 0-indexed.
48 | deps.append(dep_label)
49 | tags.append(pos)
50 | labels.append(label)
51 | preds.append(pred_label)
52 | print("number of sentences: {}".format(len(insts)))
53 | return insts
54 |
55 |
56 |
57 | res1 = "./final_results/lstm_2_200_crf_semes_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_300_lr_0.01_doubledep_0_comb_3.results"
58 | insts1 = read_conll(res1)
59 |
60 |
61 | print(evaluate(insts1))
--------------------------------------------------------------------------------
/analysis/heatmap.py:
--------------------------------------------------------------------------------
1 | from config.reader import Reader
2 |
3 | from common.sentence import Sentence
4 | from common.instance import Instance
5 | from typing import List
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | import seaborn as sns; sns.set(font_scale=0.8)
10 | import matplotlib.pyplot as plt
11 | import random
12 |
13 |
14 | def read_results(res_file: str, number: int = -1) -> List[Instance]:
15 | print("Reading file: " + res_file)
16 | insts = []
17 | # vocab = set() ## build the vocabulary
18 | with open(res_file, 'r', encoding='utf-8') as f:
19 | words = []
20 | heads = []
21 | deps = []
22 | labels = []
23 | tags = []
24 | preds = []
25 | for line in tqdm(f.readlines()):
26 | line = line.rstrip()
27 | if line == "":
28 | inst = Instance(Sentence(words, heads, deps, tags), labels)
29 | inst.prediction = preds
30 | insts.append(inst)
31 | words = []
32 | heads = []
33 | deps = []
34 | labels = []
35 | tags = []
36 | preds = []
37 |
38 | if len(insts) == number:
39 | break
40 | continue
41 | vals = line.split()
42 | word = vals[1]
43 | pos = vals[2]
44 | head = int(vals[3])
45 | dep_label = vals[4]
46 |
47 | label = vals[5]
48 | pred_label = vals[6]
49 |
50 | words.append(word)
51 | heads.append(head) ## because of 0-indexed.
52 | deps.append(dep_label)
53 | tags.append(pos)
54 | labels.append(label)
55 | preds.append(pred_label)
56 | print("number of sentences: {}".format(len(insts)))
57 | return insts
58 |
59 |
60 | # file = "data/ontonotes/test.sd.conllx"
61 | # digit2zero = False
62 | # reader = Reader(digit2zero)
63 | #
64 | # insts = reader.read_conll(file, -1, True)
65 |
66 | file = "final_results/lstm_2_200_crf_ontonotes_sd_-1_dep_feat_emb_elmo_none_sgd_gate_0_base_-1_epoch_150_lr_0.01.results"
67 | insts = read_results(file) ##change inst.output -> inst.prediction
68 |
69 | entities = set([ label[2:] for inst in insts for label in inst.output if len(label)>1])
70 | print(entities)
71 | dep_labels = set([ dep for inst in insts for label, dep in zip(inst.prediction, inst.input.dep_labels) if len(label)>1] )
72 | print(len(dep_labels), dep_labels)
73 |
74 | ent2idx = {}
75 | ents = list(entities)
76 | ents.sort()
77 | for i, label in enumerate(ents):
78 | ent2idx[label] = i
79 |
80 |
81 | dep2idx = {}
82 | deps = list(dep_labels)
83 | deps.sort()
84 | for i, label in enumerate(deps):
85 | dep2idx[label] = i
86 |
87 | ent_dep_mat = np.zeros((len(entities), len(dep_labels)))
88 | print(ent_dep_mat.shape)
89 | for inst in insts:
90 | for label, dep in zip(inst.prediction, inst.input.dep_labels):
91 | if label == "O":
92 | continue
93 | ent_dep_mat[ent2idx[label[2:]]] [dep2idx[dep]] += 1
94 |
95 | sum_labels = [ sum(ent_dep_mat[i]) for i in range(ent_dep_mat.shape[0])]
96 | ent_dep_mat = np.stack([ (ent_dep_mat[i]/sum_labels[i]) * 100 for i in range(ent_dep_mat.shape[0])], axis=0)
97 | print(ent_dep_mat.shape)
98 |
99 | indexs = [i for i in range(ent_dep_mat.shape[1]) if len(ent_dep_mat[:,i][ ent_dep_mat[:,i] >5.0 ]) ]
100 | print(np.asarray(deps)[indexs])
101 |
102 | xlabels = [deps[i] for i in indexs]
103 | # cmap = sns.light_palette("#2ecc71", as_cmap=True)
104 | # cmap = sns.light_palette("#8e44ad", as_cmap=True)
105 | cmap = sns.cubehelix_palette(8,as_cmap=True)
106 | ax = sns.heatmap(ent_dep_mat[:, indexs], annot=True, vmin=0, vmax=100, cmap=cmap,fmt='.0f', xticklabels=xlabels, yticklabels=ents, cbar=True)
107 | # ,annot_kws = {"size": 10})
108 | # , cbar_kws={'label': 'percentage (%)'})
109 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
110 | plt.show()
--------------------------------------------------------------------------------
/analysis/intro_examples.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from config.reader import Reader
4 | from config.eval import Span
5 |
6 |
7 | def get_spans(output):
8 | output_spans = set()
9 | start = -1
10 | for i in range(len(output)):
11 | if output[i].startswith("B-"):
12 | start = i
13 | if output[i].startswith("E-"):
14 | end = i
15 | output_spans.add(Span(start, end, output[i][2:]))
16 | if output[i].startswith("S-"):
17 | output_spans.add(Span(i, i, output[i][2:]))
18 | return output_spans
19 |
20 | def use_iobes(insts):
21 | for inst in insts:
22 | output = inst.output
23 | for pos in range(len(inst)):
24 | curr_entity = output[pos]
25 | if pos == len(inst) - 1:
26 | if curr_entity.startswith("B-"):
27 | output[pos] = curr_entity.replace("B-", "S-")
28 | elif curr_entity.startswith("I-"):
29 | output[pos] = curr_entity.replace("I-", "E-")
30 | else:
31 | next_entity = output[pos + 1]
32 | if curr_entity.startswith("B-"):
33 | if next_entity.startswith("O") or next_entity.startswith("B-"):
34 | output[pos] = curr_entity.replace("B-", "S-")
35 | elif curr_entity.startswith("I-"):
36 | if next_entity.startswith("O") or next_entity.startswith("B-"):
37 | output[pos] = curr_entity.replace("I-", "E-")
38 |
39 |
40 | file = "data/ontonotes/train.sd.conllx"
41 | digit2zero = False
42 | reader = Reader(digit2zero)
43 |
44 | insts = reader.read_conll(file, -1, True)
45 | use_iobes(insts)
46 |
47 | for i in range(len(insts)):
48 |
49 | inst = insts[i]
50 | gold_spans = get_spans(inst.output)
51 |
52 |
53 | for span in gold_spans:
54 | ent_words = ' '.join(inst.input.words[span.left:span.right+1])
55 | # if ent_words.islower() and span.type != "DATE" and span.type != "ORDINAL" and span.type != "PERCENT"\
56 | # and span.type != "CARDINAL" and span.type != "MONEY" and span.type != "QUANTITY" and span.type != "TIME" \
57 | # and span.type != "NORP" and span.type != "PERSON":
58 | # print(ent_words + " " + span.type)
59 | # print(inst.input.words)
60 | # print()
61 | for k in range(span.left, span.right + 1):
62 | head_k = inst.input.heads[k]
63 | if abs (head_k - k) >= 4 and span.type != "DATE" and span.type != "ORDINAL" and span.type != "PERCENT" \
64 | and ent_words.islower() and span.type != "MONEY" and span.type != "QUANTITY" and span.type != "TIME" and span.type != "CARDINAL" :
65 | print(ent_words + " " + span.type)
66 | print(inst.input.words)
67 | print()
68 | # if span.right - span.left >= 4 and span.type != "DATE" and span.type != "ORDINAL" and span.type != "PERCENT" \
69 | # and ent_words.islower() and span.type != "MONEY" and span.type != "QUANTITY" and span.type != "TIME" and span.type != "CARDINAL" :
70 | # print(ent_words + " " + span.type)
71 | # print(inst.input.words)
72 | # print()
73 | ## book of the dead.
--------------------------------------------------------------------------------
/analysis/length.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 |
6 | from config.reader import Reader
7 | import numpy as np
8 |
9 | import matplotlib.pyplot as plt
10 | import random
11 | from config.eval import evaluate, Span
12 |
13 | from collections import defaultdict
14 |
15 |
16 |
17 | def use_iobes(insts):
18 | for inst in insts:
19 | output = inst.output
20 | for pos in range(len(inst)):
21 | curr_entity = output[pos]
22 | if pos == len(inst) - 1:
23 | if curr_entity.startswith("B-"):
24 | output[pos] = curr_entity.replace("B-", "S-")
25 | elif curr_entity.startswith("I-"):
26 | output[pos] = curr_entity.replace("I-", "E-")
27 | else:
28 | next_entity = output[pos + 1]
29 | if curr_entity.startswith("B-"):
30 | if next_entity.startswith("O") or next_entity.startswith("B-"):
31 | output[pos] = curr_entity.replace("B-", "S-")
32 | elif curr_entity.startswith("I-"):
33 | if next_entity.startswith("O") or next_entity.startswith("B-"):
34 | output[pos] = curr_entity.replace("I-", "E-")
35 |
36 |
37 | dataset = "ontonotes_chinese"
38 | train = "../data/"+dataset+"/train.sd.conllx"
39 | dev = "../data/"+dataset+"/dev.sd.conllx"
40 | test = "../data/"+dataset+"/test.sd.conllx"
41 | digit2zero = False
42 | reader = Reader(digit2zero)
43 |
44 | insts = reader.read_conll(train, -1, True)
45 | insts += reader.read_conll(dev, -1, False)
46 | insts += reader.read_conll(test, -1, False)
47 | use_iobes(insts)
48 | L = 3
49 |
50 |
51 | def get_spans(output):
52 | output_spans = set()
53 | start = -1
54 | for i in range(len(output)):
55 | if output[i].startswith("B-"):
56 | start = i
57 | if output[i].startswith("E-"):
58 | end = i
59 | output_spans.add(Span(start, end, output[i][2:]))
60 | if output[i].startswith("S-"):
61 | output_spans.add(Span(i, i, output[i][2:]))
62 | return output_spans
63 |
64 | count_all = 0
65 | count_have_sub = 0
66 | count_grand = 0
67 | length2num = defaultdict(int)
68 | for inst in insts:
69 | output = inst.output
70 | spans = get_spans(output)
71 | # print(spans)
72 | for span in spans:
73 | length2num[span.right - span.left + 1] += 1
74 | if span.right - span.left + 1 < L:
75 | continue
76 | count_dep = 0
77 | count_all += 1
78 | has_grand = False
79 | for i in range(span.left, span.right + 1):
80 | if inst.input.heads[i] >= span.left and inst.input.heads[i] <= span.right:
81 | count_dep += 1
82 | if inst.input.heads[i] >= span.left and inst.input.heads[i] <= span.right:
83 | head_i = inst.input.heads[i]
84 | if head_i != -1 and inst.input.heads[head_i] >= span.left and inst.input.heads[head_i] <= span.right:
85 |
86 | has_grand = True
87 | if has_grand:
88 | count_grand += 1
89 | if count_dep == (span.right - span.left):
90 | count_have_sub += 1
91 | else:
92 | pass
93 | # print(inst.input.words)
94 |
95 |
96 | print(count_have_sub, count_all, count_have_sub/count_all*100)
97 | print(count_grand, count_all, count_grand/count_all*100)
98 | print(length2num)
99 |
100 |
--------------------------------------------------------------------------------
/analysis/length_analysis.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 |
6 | from config.reader import Reader
7 | import numpy as np
8 |
9 | import matplotlib.pyplot as plt
10 | import random
11 |
12 | from common.sentence import Sentence
13 | from common.instance import Instance
14 | from typing import List
15 | from config.eval import Span
16 | from tqdm import tqdm
17 |
18 | def use_iobes(insts):
19 | for inst in insts:
20 | output = inst.output
21 | for pos in range(len(inst)):
22 | curr_entity = output[pos]
23 | if pos == len(inst) - 1:
24 | if curr_entity.startswith("B-"):
25 | output[pos] = curr_entity.replace("B-", "S-")
26 | elif curr_entity.startswith("I-"):
27 | output[pos] = curr_entity.replace("I-", "E-")
28 | else:
29 | next_entity = output[pos + 1]
30 | if curr_entity.startswith("B-"):
31 | if next_entity.startswith("O") or next_entity.startswith("B-"):
32 | output[pos] = curr_entity.replace("B-", "S-")
33 | elif curr_entity.startswith("I-"):
34 | if next_entity.startswith("O") or next_entity.startswith("B-"):
35 | output[pos] = curr_entity.replace("I-", "E-")
36 |
37 |
38 |
39 | def read_conll(res_file: str, number: int = -1) -> List[Instance]:
40 | print("Reading file: " + res_file)
41 | insts = []
42 | # vocab = set() ## build the vocabulary
43 | with open(res_file, 'r', encoding='utf-8') as f:
44 | words = []
45 | heads = []
46 | deps = []
47 | labels = []
48 | tags = []
49 | preds = []
50 | for line in tqdm(f.readlines()):
51 | line = line.rstrip()
52 | if line == "":
53 | inst = Instance(Sentence(words, heads, deps, tags), labels)
54 | inst.prediction = preds
55 | insts.append(inst)
56 | words = []
57 | heads = []
58 | deps = []
59 | labels = []
60 | tags = []
61 | preds = []
62 |
63 | if len(insts) == number:
64 | break
65 | continue
66 | vals = line.split()
67 | word = vals[1]
68 | pos = vals[2]
69 | head = int(vals[3])
70 | dep_label = vals[4]
71 |
72 | label = vals[5]
73 | pred_label = vals[6]
74 |
75 | words.append(word)
76 | heads.append(head) ## because of 0-indexed.
77 | deps.append(dep_label)
78 | tags.append(pos)
79 | labels.append(label)
80 | preds.append(pred_label)
81 | print("number of sentences: {}".format(len(insts)))
82 | return insts
83 |
84 |
85 | def get_spans(output):
86 | output_spans = set()
87 | start = -1
88 | for i in range(len(output)):
89 | if output[i].startswith("B-"):
90 | start = i
91 | if output[i].startswith("E-"):
92 | end = i
93 | output_spans.add(Span(start, end, output[i][2:]))
94 | if output[i].startswith("S-"):
95 | output_spans.add(Span(i, i, output[i][2:]))
96 | return output_spans
97 |
98 | def evaluate(insts, maximum_length = 4):
99 |
100 | p = {}
101 | total_entity = {}
102 | total_predict = {}
103 |
104 | for inst in insts:
105 |
106 | output = inst.output
107 | prediction = inst.prediction
108 | #convert to span
109 | output_spans = set()
110 | start = -1
111 | for i in range(len(output)):
112 | if output[i].startswith("B-"):
113 | start = i
114 | if output[i].startswith("E-"):
115 | end = i
116 | output_spans.add(Span(start, end, output[i][2:]))
117 | if output[i].startswith("S-"):
118 | output_spans.add(Span(i, i, output[i][2:]))
119 | predict_spans = set()
120 | for i in range(len(prediction)):
121 | if prediction[i].startswith("B-"):
122 | start = i
123 | if prediction[i].startswith("E-"):
124 | end = i
125 | predict_spans.add(Span(start, end, prediction[i][2:]))
126 | if prediction[i].startswith("S-"):
127 | predict_spans.add(Span(i, i, prediction[i][2:]))
128 |
129 | # total_entity += len(output_spans)
130 | # total_predict += len(predict_spans)
131 | # p += len(predict_spans.intersection(output_spans))
132 |
133 | for span in output_spans:
134 | length = span.right - span.left + 1
135 | if length >= maximum_length:
136 | length = maximum_length
137 | if length in total_entity:
138 | total_entity[length] += 1
139 | else:
140 | total_entity[length] = 1
141 |
142 | for span in predict_spans:
143 | length = span.right - span.left + 1
144 | if length >= maximum_length:
145 | length = maximum_length
146 | if length in total_predict:
147 | total_predict[length] += 1
148 | else:
149 | total_predict[length] = 1
150 |
151 | for span in predict_spans.intersection(output_spans):
152 | length = span.right - span.left + 1
153 | if length >= maximum_length:
154 | length = maximum_length
155 | if length in p:
156 | p[length] += 1
157 | else:
158 | p[length] = 1
159 |
160 | max_len = max([key for key in p])
161 | # precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0
162 | # recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0
163 | # fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0
164 |
165 | f = {}
166 | for length in range(1, max_len + 1):
167 | if length not in p:
168 | continue
169 | precision = p[length] * 1.0 / total_predict[length] * 100 if total_predict[length] != 0 else 0
170 | recall = p[length] * 1.0 / total_entity[length] * 100 if total_entity[length] != 0 else 0
171 | f[length] = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0
172 |
173 | return f
174 |
175 |
176 | def grand_child(insts1, insts2):
177 | num = 0
178 | gc_num = 0
179 | ld_num = 0
180 | for i in range(len(insts1)):
181 |
182 | first = insts1[i]
183 | second = insts2[i]
184 | inst = insts1[i]
185 | gold_spans = get_spans(first.output)
186 |
187 | pred_first = get_spans(first.prediction)
188 | pred_second = get_spans(second.prediction)
189 |
190 | # for span in pred_first:
191 | # if span in gold_spans and (span not in pred_second):
192 | for span in gold_spans:
193 | if span in pred_first and (span not in pred_second):
194 | if span.right - span.left < 2:
195 | continue
196 | num += 1
197 | # print(span.to_str(first.input.words))
198 | has_grand = False
199 | has_ld = False
200 | for k in range(span.left, span.right + 1):
201 | if inst.input.heads[k] >= span.left and inst.input.heads[k] <= span.right:
202 | head_i = inst.input.heads[k]
203 | if abs(head_i - k) > 1:
204 | has_ld = True
205 | if head_i != -1 and inst.input.heads[head_i] >= span.left and inst.input.heads[
206 | head_i] <= span.right:
207 | has_grand = True
208 |
209 | if has_grand:
210 | gc_num +=1
211 | if has_ld:
212 | ld_num += 1
213 | return gc_num, ld_num, num
214 |
215 |
216 | ## Chinese Comparison
217 | res1 = "./final_results/lstm_2_200_crf_ontonotes_chinese_sd_-1_dep_none_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results"
218 | insts1 = read_conll(res1)
219 |
220 | res2 = "./final_results/lstm_2_200_crf_ontonotes_chinese_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_100_lr_0.01_doubledep_0_comb_3.results"
221 | insts2 = read_conll(res2)
222 |
223 | # Catalan Comparison
224 | # res1 = "./final_results/lstm_2_200_crf_semca_sd_-1_dep_none_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results"
225 | # insts1 = read_conll(res1)
226 | #
227 | # res2 = "./final_results/lstm_2_200_crf_semca_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_300_lr_0.01_doubledep_0_comb_3.results"
228 | # insts2 = read_conll(res2)
229 |
230 |
231 | ## Spanish Comparison
232 | # res1 = "./final_results/lstm_2_200_crf_semes_sd_-1_dep_none_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results"
233 | # insts1 = read_conll(res1)
234 | #
235 | # res2 = "./final_results/lstm_2_200_crf_semes_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_300_lr_0.01_doubledep_0_comb_3.results"
236 | # insts2 = read_conll(res2)
237 |
238 |
239 |
240 |
241 | maximum_length = 6
242 | print(evaluate(insts1, maximum_length))
243 | print(evaluate(insts2, maximum_length))
244 |
245 | print(grand_child(insts1, insts2))
246 |
--------------------------------------------------------------------------------
/analysis/significant.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 |
6 | from tqdm import tqdm
7 | from common.sentence import Sentence
8 | from common.instance import Instance
9 | from typing import List
10 | from config.eval import evaluate
11 | import random
12 |
13 | def read_conll(res_file: str, number: int = -1) -> List[Instance]:
14 | print("Reading file: " + res_file)
15 | insts = []
16 | # vocab = set() ## build the vocabulary
17 | with open(res_file, 'r', encoding='utf-8') as f:
18 | words = []
19 | heads = []
20 | deps = []
21 | labels = []
22 | tags = []
23 | preds = []
24 | for line in tqdm(f.readlines()):
25 | line = line.rstrip()
26 | if line == "":
27 | inst = Instance(Sentence(words, heads, deps, tags), labels)
28 | inst.prediction = preds
29 | insts.append(inst)
30 | words = []
31 | heads = []
32 | deps = []
33 | labels = []
34 | tags = []
35 | preds = []
36 |
37 | if len(insts) == number:
38 | break
39 | continue
40 | vals = line.split()
41 | word = vals[1]
42 | pos = vals[2]
43 | head = int(vals[3])
44 | dep_label = vals[4]
45 |
46 | label = vals[5]
47 | pred_label = vals[6]
48 |
49 | words.append(word)
50 | heads.append(head) ## because of 0-indexed.
51 | deps.append(dep_label)
52 | tags.append(pos)
53 | labels.append(label)
54 | preds.append(pred_label)
55 | print("number of sentences: {}".format(len(insts)))
56 | return insts
57 |
58 | res1 = "../final_results/lstm_2_200_crf_ontonotes_chinese_sd_-1_dep_feat_emb_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results"
59 | insts1 = read_conll(res1)
60 |
61 | res2 = "../final_results/lstm_1_200_crf_ontonotes_chinese_sd_-1_dep_none_elmo_elmo_sgd_gate_0_base_-1_epoch_150_lr_0.01.results"
62 | insts2 = read_conll(res2)
63 |
64 |
65 | sample_num = 10000
66 |
67 | p = 0
68 | for i in range(sample_num):
69 | sinsts = []
70 | sinsts_2 = []
71 | for _ in range(len(insts1)):
72 | n = random.randint(0, len(insts1) - 1)
73 | sinsts.append(insts1[n])
74 | sinsts_2.append(insts2[n])
75 |
76 | f1 = evaluate(sinsts)[2]
77 | f2= evaluate(sinsts_2)[2]
78 |
79 | if f1 > f2:
80 | p += 1
81 |
82 | p_val = (i + 1 - p) / (i+1)
83 | print("current p value: {}".format(p_val))
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/analysis/stator.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | from config.reader import Reader
6 | from collections import defaultdict
7 |
8 | file = "../data/ontonotes/dev.sd.conllx"
9 | digit2zero = False
10 | reader = Reader(digit2zero)
11 |
12 | insts = reader.read_conll(file, -1, True)
13 | # devs = reader.read_conll(conf.dev_file, conf.dev_num, False)
14 | # tests = reader.read_conll(conf.test_file, conf.test_num, False)
15 |
16 | out_dep_label2num = {}
17 |
18 | out_doubledep2num = {}
19 |
20 | out_word2num = {}
21 |
22 | label2idx = {}
23 |
24 | ent2num = defaultdict(int)
25 |
26 | def not_entity(label:str):
27 | if label.startswith("B-") or label.startswith("I-"):
28 | return False
29 | return True
30 |
31 | def is_entity(label:str):
32 | if label.startswith("B-") or label.startswith("I-"):
33 | return True
34 | return False
35 |
36 | for inst in insts:
37 | output = inst.output
38 | sent = inst.input
39 |
40 | for idx, (word, head_idx, ent, dep) in enumerate(zip(sent.words, sent.heads, output, sent.dep_labels)):
41 | if ent.startswith('B-'):
42 | ent2num[ent[2:]] += 1
43 |
44 | if dep not in label2idx:
45 | label2idx[dep] = len(label2idx)
46 | if is_entity(ent):
47 | if head_idx == -1 or not_entity(output[head_idx]):
48 | if dep in out_dep_label2num:
49 | out_dep_label2num[dep] +=1
50 | else:
51 | out_dep_label2num[dep] = 1
52 | head_word = "root" if head_idx == -1 else sent.words[head_idx]
53 | if head_word in out_word2num:
54 | out_word2num[head_word] += 1
55 | else:
56 | out_word2num[head_word] = 1
57 |
58 | if head_idx != -1:
59 | head_dep = sent.dep_labels[head_idx]
60 | if (head_dep, dep) in out_doubledep2num:
61 | out_doubledep2num[(head_dep, dep)] += 1
62 | else:
63 | out_doubledep2num[(head_dep, dep)] = 1
64 |
65 |
66 | counts = [(key, out_dep_label2num[key]) for key in out_dep_label2num]
67 | counts = sorted(counts, key=lambda vals: vals[1], reverse=True)
68 | total_ent_dep = sum([nums[1] for nums in counts])
69 | print(counts)
70 | print("total is {}".format(total_ent_dep))
71 |
72 |
73 | # counts = [(key, out_word2num[key]) for key in out_word2num]
74 | # counts = sorted(counts, key=lambda vals: vals[1], reverse=True)
75 | # total_ent_dep = sum([nums[1] for nums in counts])
76 | # print(counts)
77 | # print("total is {}".format(total_ent_dep))
78 |
79 |
80 | counts = [(key, out_doubledep2num[key]) for key in out_doubledep2num]
81 | counts = sorted(counts, key=lambda vals: vals[1], reverse=True)
82 | total_ent_dep = sum([nums[1] for nums in counts])
83 | print(counts)
84 | print("total is {}".format(total_ent_dep))
85 |
86 |
87 |
88 | print(f"entity2number: {ent2num}")
89 |
--------------------------------------------------------------------------------
/common/instance.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 | from common.sentence import Sentence
5 | class Instance:
6 |
7 | def __init__(self, input: Sentence, output):
8 | self.input = input
9 | self.output = output
10 | self.elmo_vec = None
11 | self.word_ids = None
12 | self.char_ids = None
13 | self.dep_label_ids = None
14 | self.dep_head_ids = None
15 | self.output_ids = None
16 |
17 | def __len__(self):
18 | return len(self.input)
19 |
--------------------------------------------------------------------------------
/common/sentence.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | from typing import List
6 |
7 | class Sentence:
8 |
9 | def __init__(self, words: List[str], heads: List[int]=None , dep_labels: List[str]=None, pos_tags:List[str] = None):
10 | self.words = words
11 | self.heads = heads
12 | self.dep_labels = dep_labels
13 | self.pos_tags = pos_tags
14 |
15 | def __len__(self):
16 | return len(self.words)
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | # if __name__ == "__main__":
26 | #
27 | # words = ["a" ,"sdfsdf"]
28 | # sent = Sentence(words)
29 | #
30 | # print(len(sent))
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | from config.config import DepModelType, ContextEmb, InteractionFunction
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 | from typing import List
8 | from common.instance import Instance
9 | from config.utils import PAD, START, STOP, ROOT, ROOT_DEP_LABEL, SELF_DEP_LABEL
10 | import torch
11 | from enum import Enum
12 | from termcolor import colored
13 |
14 | class DepModelType(Enum):
15 | none = 0
16 | dglstm = 1
17 | dggcn = 2
18 |
19 |
20 | class ContextEmb(Enum):
21 | none = 0
22 | elmo = 1
23 | bert = 2
24 | flair = 3
25 |
26 |
27 | class InteractionFunction(Enum):
28 | concatenation = 0
29 | addition = 1
30 | mlp = 2
31 |
32 |
33 |
34 | class Config:
35 | def __init__(self, args):
36 |
37 | self.PAD = PAD
38 | self.B = "B-"
39 | self.I = "I-"
40 | self.S = "S-"
41 | self.E = "E-"
42 | self.O = "O"
43 | self.START_TAG = START
44 | self.STOP_TAG = STOP
45 | self.ROOT = ROOT
46 | self.UNK = ""
47 | self.unk_id = -1
48 | self.root_dep_label = ROOT_DEP_LABEL
49 | self.self_label = SELF_DEP_LABEL
50 |
51 | print(colored("[Info] remember to chec the root dependency label if changing the data. current: {}".format(self.root_dep_label), "red" ))
52 |
53 | # self.device = torch.device("cuda" if args.gpu else "cpu")
54 | self.embedding_file = args.embedding_file
55 | self.embedding_dim = args.embedding_dim
56 | self.context_emb = ContextEmb[args.context_emb]
57 | self.context_emb_size = 0
58 | self.embedding, self.embedding_dim = self.read_pretrain_embedding()
59 | self.word_embedding = None
60 | self.seed = args.seed
61 | self.digit2zero = args.digit2zero
62 |
63 | self.dataset = args.dataset
64 |
65 | self.affix = args.affix
66 | train_affix = self.affix.replace("pred", "") if "pred" in self.affix else self.affix
67 | self.train_file = "data/" + self.dataset + "/train."+train_affix+".conllx"
68 | self.dev_file = "data/" + self.dataset + "/dev."+train_affix+".conllx"
69 | self.test_file = "data/" + self.dataset + "/test."+self.affix+".conllx"
70 | self.label2idx = {}
71 | self.idx2labels = []
72 | self.char2idx = {}
73 | self.idx2char = []
74 | self.num_char = 0
75 |
76 |
77 | self.optimizer = args.optimizer.lower()
78 | self.learning_rate = args.learning_rate
79 | self.momentum = args.momentum
80 | self.l2 = args.l2
81 | self.num_epochs = args.num_epochs
82 | # self.lr_decay = 0.05
83 | self.use_dev = True
84 | self.train_num = args.train_num
85 | self.dev_num = args.dev_num
86 | self.test_num = args.test_num
87 | self.batch_size = args.batch_size
88 | self.clip = 5
89 | self.lr_decay = args.lr_decay
90 | self.device = torch.device(args.device)
91 |
92 | self.hidden_dim = args.hidden_dim
93 | self.num_lstm_layer = args.num_lstm_layer
94 | self.use_brnn = True
95 | self.num_layers = 1
96 | self.dropout = args.dropout
97 | self.char_emb_size = 30
98 | self.charlstm_hidden_dim = 50
99 | self.use_char_rnn = args.use_char_rnn
100 | # self.use_head = args.use_head
101 | self.dep_model = DepModelType[args.dep_model]
102 |
103 | self.dep_hidden_dim = args.dep_hidden_dim
104 | self.num_gcn_layers = args.num_gcn_layers
105 | self.gcn_mlp_layers = args.gcn_mlp_layers
106 | self.gcn_dropout = args.gcn_dropout
107 | self.adj_directed = args.gcn_adj_directed
108 | self.adj_self_loop = args.gcn_adj_selfloop
109 | self.edge_gate = args.gcn_gate
110 |
111 | self.dep_emb_size = args.dep_emb_size
112 | self.deplabel2idx = {}
113 | self.deplabels = []
114 |
115 |
116 | self.eval_epoch = args.eval_epoch
117 |
118 |
119 | self.interaction_func = InteractionFunction[args.inter_func] ## 0:concat, 1: addition, 2:gcn
120 |
121 |
122 | # def print(self):
123 | # print("")
124 | # print("\tuse gpu: " + )
125 |
126 | '''
127 | read all the pretrain embeddings
128 | '''
129 | def read_pretrain_embedding(self):
130 | print("reading the pretraing embedding: %s" % (self.embedding_file))
131 | if self.embedding_file is None:
132 | print("pretrain embedding in None, using random embedding")
133 | return None, self.embedding_dim
134 | embedding_dim = -1
135 | embedding = dict()
136 | with open(self.embedding_file, 'r', encoding='utf-8') as file:
137 | for line in tqdm(file.readlines()):
138 | line = line.strip()
139 | if len(line) == 0:
140 | continue
141 | tokens = line.split()
142 | if len(tokens) == 2:
143 | continue
144 | if embedding_dim < 0:
145 | embedding_dim = len(tokens) - 1
146 | else:
147 | # print(tokens)
148 | # print(embedding_dim)
149 | # assert (embedding_dim + 1 == len(tokens))
150 | if (embedding_dim + 1) != len(tokens):
151 | continue
152 | pass
153 | embedd = np.empty([1, embedding_dim])
154 | embedd[:] = tokens[1:]
155 | first_col = tokens[0]
156 | embedding[first_col] = embedd
157 | return embedding, embedding_dim
158 |
159 |
160 | def build_word_idx(self, train_insts, dev_insts, test_insts):
161 | self.word2idx = dict()
162 | self.idx2word = []
163 | self.word2idx[self.PAD] = 0
164 | self.idx2word.append(self.PAD)
165 | self.word2idx[self.UNK] = 1
166 | self.unk_id = 1
167 | self.idx2word.append(self.UNK)
168 |
169 | self.word2idx[self.ROOT] = 2
170 | self.idx2word.append(self.ROOT)
171 |
172 | self.char2idx[self.PAD] = 0
173 | self.idx2char.append(self.PAD)
174 | self.char2idx[self.UNK] = 1
175 | self.idx2char.append(self.UNK)
176 |
177 | ##extract char on train, dev, test
178 | for inst in train_insts + dev_insts + test_insts:
179 | for word in inst.input.words:
180 | if word not in self.word2idx:
181 | self.word2idx[word] = len(self.word2idx)
182 | self.idx2word.append(word)
183 | ##extract char only on train
184 | for inst in train_insts:
185 | for word in inst.input.words:
186 | for c in word:
187 | if c not in self.char2idx:
188 | self.char2idx[c] = len(self.idx2char)
189 | self.idx2char.append(c)
190 | self.num_char = len(self.idx2char)
191 | # print(self.idx2word)
192 | # print(self.idx2char)
193 | # for idx, char in enumerate(self.idx2char):
194 | # print(idx, ":", char)
195 | # print("separator")
196 | # for idx, word in enumerate(self.idx2word):
197 | # print(idx, ":", word)
198 | '''
199 | build the embedding table
200 | obtain the word2idx and idx2word as well.
201 | '''
202 | def build_emb_table(self):
203 | print("Building the embedding table for vocabulary...")
204 | scale = np.sqrt(3.0 / self.embedding_dim)
205 | if self.embedding is not None:
206 | print("[Info] Use the pretrained word embedding to initialize: %d x %d" % (len(self.word2idx), self.embedding_dim))
207 | word_found_in_emb_vocab = 0
208 | self.word_embedding = np.empty([len(self.word2idx), self.embedding_dim])
209 | for word in self.word2idx:
210 | if word in self.embedding:
211 | self.word_embedding[self.word2idx[word], :] = self.embedding[word]
212 | word_found_in_emb_vocab += 1
213 | elif word.lower() in self.embedding:
214 | self.word_embedding[self.word2idx[word], :] = self.embedding[word.lower()]
215 | word_found_in_emb_vocab += 1
216 | else:
217 | # self.word_embedding[self.word2idx[word], :] = self.embedding[self.UNK]
218 | self.word_embedding[self.word2idx[word], :] = np.random.uniform(-scale, scale, [1, self.embedding_dim])
219 | print(f"[Info] {word_found_in_emb_vocab} out of {len(self.word2idx)} found in the pretrained embedding.")
220 | self.embedding = None
221 | else:
222 | self.word_embedding = np.empty([len(self.word2idx), self.embedding_dim])
223 | for word in self.word2idx:
224 | self.word_embedding[self.word2idx[word], :] = np.random.uniform(-scale, scale, [1, self.embedding_dim])
225 |
226 | def build_deplabel_idx(self, insts):
227 | if self.self_label not in self.deplabel2idx:
228 | self.deplabels.append(self.self_label)
229 | self.deplabel2idx[self.self_label] = len(self.deplabel2idx)
230 | for inst in insts:
231 | for label in inst.input.dep_labels:
232 | if label not in self.deplabels:
233 | self.deplabels.append(label)
234 | self.deplabel2idx[label] = len(self.deplabel2idx)
235 | self.root_dep_label_id = self.deplabel2idx[self.root_dep_label]
236 |
237 | def build_label_idx(self, insts):
238 | self.label2idx[self.PAD] = len(self.label2idx)
239 | self.idx2labels.append(self.PAD)
240 | for inst in insts:
241 | for label in inst.output:
242 | if label not in self.label2idx:
243 | self.idx2labels.append(label)
244 | self.label2idx[label] = len(self.label2idx)
245 |
246 | self.label2idx[self.START_TAG] = len(self.label2idx)
247 | self.idx2labels.append(self.START_TAG)
248 | self.label2idx[self.STOP_TAG] = len(self.label2idx)
249 | self.idx2labels.append(self.STOP_TAG)
250 | self.label_size = len(self.label2idx)
251 | print("#labels: " + str(self.label_size))
252 | print("label 2idx: " + str(self.label2idx))
253 |
254 | def use_iobes(self, insts):
255 | for inst in insts:
256 | output = inst.output
257 | for pos in range(len(inst)):
258 | curr_entity = output[pos]
259 | if pos == len(inst) - 1:
260 | if curr_entity.startswith(self.B):
261 | output[pos] = curr_entity.replace(self.B, self.S)
262 | elif curr_entity.startswith(self.I):
263 | output[pos] = curr_entity.replace(self.I, self.E)
264 | else:
265 | next_entity = output[pos + 1]
266 | if curr_entity.startswith(self.B):
267 | if next_entity.startswith(self.O) or next_entity.startswith(self.B):
268 | output[pos] = curr_entity.replace(self.B, self.S)
269 | elif curr_entity.startswith(self.I):
270 | if next_entity.startswith(self.O) or next_entity.startswith(self.B):
271 | output[pos] = curr_entity.replace(self.I, self.E)
272 |
273 | def map_insts_ids(self, insts: List[Instance]):
274 | insts_ids = []
275 | for inst in insts:
276 | words = inst.input.words
277 | inst.word_ids = []
278 | inst.char_ids = []
279 | inst.dep_label_ids = []
280 | inst.dep_head_ids = []
281 | inst.output_ids = []
282 | for word in words:
283 | if word in self.word2idx:
284 | inst.word_ids.append(self.word2idx[word])
285 | else:
286 | inst.word_ids.append(self.word2idx[self.UNK])
287 | char_id = []
288 | for c in word:
289 | if c in self.char2idx:
290 | char_id.append(self.char2idx[c])
291 | else:
292 | char_id.append(self.char2idx[self.UNK])
293 | inst.char_ids.append(char_id)
294 | for i, head in enumerate(inst.input.heads):
295 | if head == -1:
296 | inst.dep_head_ids.append(i) ## appended it self.
297 | else:
298 | inst.dep_head_ids.append(head)
299 | for label in inst.input.dep_labels:
300 | inst.dep_label_ids.append(self.deplabel2idx[label])
301 | for label in inst.output:
302 | inst.output_ids.append(self.label2idx[label])
303 | insts_ids.append([inst.word_ids, inst.char_ids, inst.output_ids])
304 | return insts_ids
305 |
--------------------------------------------------------------------------------
/config/eval.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | from typing import Tuple
4 |
5 | from collections import defaultdict
6 | class Span:
7 |
8 | def __init__(self, left, right, type):
9 | self.left = left
10 | self.right = right
11 | self.type = type
12 |
13 | def __eq__(self, other):
14 | return self.left == other.left and self.right == other.right and self.type == other.type
15 |
16 | def __hash__(self):
17 | return hash((self.left, self.right, self.type))
18 |
19 | def to_str(self, sent):
20 | return str(sent[self.left: (self.right+1)]) + ","+self.type
21 |
22 | ## the input to the evaluation should already have
23 | ## have the predictions which is the label.
24 | ## iobest tagging scheme
25 | ### NOTE: this function is used to evaluate the instances with prediction ready.
26 | def evaluate(insts):
27 |
28 | p = 0
29 | total_entity = 0
30 | total_predict = 0
31 |
32 | batch_p_dict = defaultdict(int)
33 | batch_total_entity_dict = defaultdict(int)
34 | batch_total_predict_dict = defaultdict(int)
35 |
36 | for inst in insts:
37 |
38 | output = inst.output
39 | prediction = inst.prediction
40 | #convert to span
41 | output_spans = set()
42 | start = -1
43 | for i in range(len(output)):
44 | if output[i].startswith("B-"):
45 | start = i
46 | if output[i].startswith("E-"):
47 | end = i
48 | output_spans.add(Span(start, end, output[i][2:]))
49 | batch_total_entity_dict[output[i][2:]] += 1
50 | if output[i].startswith("S-"):
51 | output_spans.add(Span(i, i, output[i][2:]))
52 | batch_total_entity_dict[output[i][2:]] += 1
53 | start = -1
54 | predict_spans = set()
55 | for i in range(len(prediction)):
56 | if prediction[i].startswith("B-"):
57 | start = i
58 | if prediction[i].startswith("E-"):
59 | end = i
60 | predict_spans.add(Span(start, end, prediction[i][2:]))
61 | batch_total_predict_dict[prediction[i][2:]] += 1
62 | if prediction[i].startswith("S-"):
63 | predict_spans.add(Span(i, i, prediction[i][2:]))
64 | batch_total_predict_dict[prediction[i][2:]] += 1
65 |
66 | total_entity += len(output_spans)
67 | total_predict += len(predict_spans)
68 | correct_spans = predict_spans.intersection(output_spans)
69 | p += len(correct_spans)
70 | for span in correct_spans:
71 | batch_p_dict[span.type] += 1
72 |
73 | for key in batch_total_entity_dict:
74 | precision_key, recall_key, fscore_key = get_metric(batch_p_dict[key], batch_total_entity_dict[key], batch_total_predict_dict[key])
75 | print("[%s] Prec.: %.2f, Rec.: %.2f, F1: %.2f" % (key, precision_key, recall_key, fscore_key))
76 |
77 | precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0
78 | recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0
79 | fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0
80 |
81 | return [precision, recall, fscore]
82 |
83 | def get_metric(p_num: int, total_num: int, total_predicted_num: int) -> Tuple[float, float, float]:
84 | """
85 | Return the metrics of precision, recall and f-score, based on the number
86 | (We make this small piece of function in order to reduce the code effort and less possible to have typo error)
87 | :param p_num:
88 | :param total_num:
89 | :param total_predicted_num:
90 | :return:
91 | """
92 | precision = p_num * 1.0 / total_predicted_num * 100 if total_predicted_num != 0 else 0
93 | recall = p_num * 1.0 / total_num * 100 if total_num != 0 else 0
94 | fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0
95 | return precision, recall, fscore
96 |
97 |
98 |
99 | def evaluate_num(batch_insts, batch_pred_ids, batch_gold_ids, word_seq_lens, idx2label):
100 | """
101 | evaluate the batch of instances
102 | :param batch_insts:
103 | :param batch_pred_ids:
104 | :param batch_gold_ids:
105 | :param word_seq_lens:
106 | :param idx2label:
107 | :return:
108 | """
109 | p = 0
110 | total_entity = 0
111 | total_predict = 0
112 | word_seq_lens = word_seq_lens.tolist()
113 | for idx in range(len(batch_pred_ids)):
114 | length = word_seq_lens[idx]
115 | output = batch_gold_ids[idx][:length].tolist()
116 | prediction = batch_pred_ids[idx][:length].tolist()
117 | prediction = prediction[::-1]
118 | output = [idx2label[l] for l in output]
119 | prediction =[idx2label[l] for l in prediction]
120 | batch_insts[idx].prediction = prediction
121 | #convert to span
122 | output_spans = set()
123 | start = -1
124 | for i in range(len(output)):
125 | if output[i].startswith("B-"):
126 | start = i
127 | if output[i].startswith("E-"):
128 | end = i
129 | output_spans.add(Span(start, end, output[i][2:]))
130 | if output[i].startswith("S-"):
131 | output_spans.add(Span(i, i, output[i][2:]))
132 | predict_spans = set()
133 | for i in range(len(prediction)):
134 | if prediction[i].startswith("B-"):
135 | start = i
136 | if prediction[i].startswith("E-"):
137 | end = i
138 | predict_spans.add(Span(start, end, prediction[i][2:]))
139 | if prediction[i].startswith("S-"):
140 | predict_spans.add(Span(i, i, prediction[i][2:]))
141 |
142 | total_entity += len(output_spans)
143 | total_predict += len(predict_spans)
144 | p += len(predict_spans.intersection(output_spans))
145 |
146 | # precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0
147 | # recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0
148 | # fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0
149 |
150 | return np.asarray([p, total_predict, total_entity], dtype=int)
--------------------------------------------------------------------------------
/config/reader.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | from tqdm import tqdm
6 | from common.sentence import Sentence
7 | from common.instance import Instance
8 | from typing import List
9 | import re
10 | import pickle
11 |
12 | class Reader:
13 |
14 |
15 | def __init__(self, digit2zero:bool=True):
16 | self.digit2zero = digit2zero
17 | self.vocab = set()
18 |
19 | def read_conll(self, file: str, number: int = -1, is_train: bool = True) -> List[Instance]:
20 | print("Reading file: " + file)
21 | insts = []
22 | num_entity = 0
23 | # vocab = set() ## build the vocabulary
24 | find_root = False
25 | with open(file, 'r', encoding='utf-8') as f:
26 | words = []
27 | heads = []
28 | deps = []
29 | labels = []
30 | tags = []
31 | for line in tqdm(f.readlines()):
32 | line = line.rstrip()
33 | if line == "":
34 | insts.append(Instance(Sentence(words, heads, deps, tags), labels))
35 | words = []
36 | heads = []
37 | deps = []
38 | labels = []
39 | tags = []
40 | find_root = False
41 | if len(insts) == number:
42 | break
43 | continue
44 | # if "conll2003" in file:
45 | # word, pos, head, dep_label, label = line.split()
46 | # else:
47 | vals = line.split()
48 | word = vals[1]
49 | head = int(vals[6])
50 | dep_label = vals[7]
51 | pos = vals[3]
52 | label = vals[10]
53 | if self.digit2zero:
54 | word = re.sub('\d', '0', word) # replace digit with 0.
55 | words.append(word)
56 | if head == 0 and find_root:
57 | raise err("already have a root")
58 | heads.append(head - 1) ## because of 0-indexed.
59 | deps.append(dep_label)
60 | tags.append(pos)
61 | self.vocab.add(word)
62 | labels.append(label)
63 | if label.startswith("B-"):
64 | num_entity +=1
65 | print("number of sentences: {}, number of entities: {}".format(len(insts), num_entity))
66 | return insts
67 |
68 | def read_txt(self, file: str, number: int = -1, is_train: bool = True) -> List[Instance]:
69 | print("Reading file: " + file)
70 | insts = []
71 | # vocab = set() ## build the vocabulary
72 | with open(file, 'r', encoding='utf-8') as f:
73 | words = []
74 | labels = []
75 | tags = []
76 | for line in tqdm(f.readlines()):
77 | line = line.rstrip()
78 | if line == "":
79 | insts.append(Instance(Sentence(words, None, None, tags), labels))
80 | words = []
81 | labels = []
82 | tags = []
83 | if len(insts) == number:
84 | break
85 | continue
86 | if "conll2003" in file:
87 | word, pos, label = line.split()
88 | else:
89 | vals = line.split()
90 | word = vals[1]
91 | pos = vals[3]
92 | label = vals[10]
93 | if self.digit2zero:
94 | word = re.sub('\d', '0', word) # replace digit with 0.
95 | words.append(word)
96 | tags.append(pos)
97 | self.vocab.add(word)
98 | labels.append(label)
99 | print("number of sentences: {}".format(len(insts)))
100 | return insts
101 |
102 | def load_elmo_vec(self, file, insts):
103 | f = open(file, 'rb')
104 | all_vecs = pickle.load(f) # variables come out in the order you put them in
105 | f.close()
106 | size = 0
107 | for vec, inst in zip(all_vecs, insts):
108 | inst.elmo_vec = vec
109 | size = vec.shape[1]
110 | # print(str(vec.shape[0]) + ","+ str(len(inst.input.words)) + ", " + str(inst.input.words))
111 | assert(vec.shape[0] == len(inst.input.words))
112 | return size
113 |
114 |
115 |
--------------------------------------------------------------------------------
/config/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from typing import List
4 | from common.instance import Instance
5 | from config.eval import Span
6 |
7 | START = ""
8 | STOP = ""
9 | PAD = ""
10 | ROOT = ""
11 | ROOT_DEP_LABEL = "root"
12 | SELF_DEP_LABEL = "self"
13 |
14 |
15 | def log_sum_exp_pytorch(vec):
16 | """
17 |
18 | :param vec: [batchSize * from_label * to_label]
19 | :return: [batchSize * to_label]
20 | """
21 | maxScores, idx = torch.max(vec, 1)
22 | maxScores[maxScores == -float("Inf")] = 0
23 | maxScoresExpanded = maxScores.view(vec.shape[0] ,1 , vec.shape[2]).expand(vec.shape[0], vec.shape[1], vec.shape[2])
24 | return maxScores + torch.log(torch.sum(torch.exp(vec - maxScoresExpanded), 1))
25 |
26 |
27 |
28 | def simple_batching(config, insts: List[Instance]):
29 | from config.config import DepModelType,ContextEmb
30 | """
31 |
32 | :param config:
33 | :param insts:
34 | :return:
35 | word_seq_tensor,
36 | word_seq_len,
37 | char_seq_tensor,
38 | char_seq_len,
39 | label_seq_tensor
40 | """
41 | batch_size = len(insts)
42 | batch_data = sorted(insts, key=lambda inst: len(inst.input.words), reverse=True) ##object-based not direct copy
43 | word_seq_len = torch.LongTensor(list(map(lambda inst: len(inst.input.words), batch_data)))
44 | max_seq_len = word_seq_len.max()
45 | ### NOTE: the 1 here might be used later?? We will make this as padding, because later we have to do a deduction.
46 | #### Use 1 here because the CharBiLSTM accepts
47 | char_seq_len = torch.LongTensor([list(map(len, inst.input.words)) + [1] * (int(max_seq_len) - len(inst.input.words)) for inst in batch_data])
48 | max_char_seq_len = char_seq_len.max()
49 |
50 | word_emb_tensor = None
51 | if config.context_emb != ContextEmb.none:
52 | emb_size = insts[0].elmo_vec.shape[1]
53 | word_emb_tensor = torch.zeros((batch_size, max_seq_len, emb_size))
54 |
55 | word_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long)
56 | label_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long)
57 | char_seq_tensor = torch.zeros((batch_size, max_seq_len, max_char_seq_len), dtype=torch.long)
58 | adjs = None
59 | adjs_in = None
60 | adjs_out = None
61 | dep_label_adj = None
62 | dep_label_tensor = None
63 | batch_dep_heads = None
64 | trees = None
65 | graphs = None
66 | if config.dep_model != DepModelType.none:
67 | if config.dep_model == DepModelType.dggcn:
68 | adjs = [ head_to_adj(max_seq_len, inst, config) for inst in batch_data]
69 | adjs = np.stack(adjs, axis=0)
70 | adjs = torch.from_numpy(adjs)
71 | dep_label_adj = [head_to_adj_label(max_seq_len, inst, config) for inst in batch_data]
72 | dep_label_adj = torch.from_numpy(np.stack(dep_label_adj, axis=0)).long()
73 |
74 | if config.dep_model == DepModelType.dglstm:
75 | batch_dep_heads = torch.zeros((batch_size, max_seq_len), dtype=torch.long)
76 | dep_label_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long)
77 | # trees = [inst.tree for inst in batch_data]
78 | for idx in range(batch_size):
79 | word_seq_tensor[idx, :word_seq_len[idx]] = torch.LongTensor(batch_data[idx].word_ids)
80 | label_seq_tensor[idx, :word_seq_len[idx]] = torch.LongTensor(batch_data[idx].output_ids)
81 | if config.context_emb != ContextEmb.none:
82 | word_emb_tensor[idx, :word_seq_len[idx], :] = torch.from_numpy(batch_data[idx].elmo_vec)
83 |
84 | if config.dep_model == DepModelType.dglstm:
85 | batch_dep_heads[idx, :word_seq_len[idx]] = torch.LongTensor(batch_data[idx].dep_head_ids)
86 | dep_label_tensor[idx, :word_seq_len[idx]] = torch.LongTensor(batch_data[idx].dep_label_ids)
87 | for word_idx in range(word_seq_len[idx]):
88 | char_seq_tensor[idx, word_idx, :char_seq_len[idx, word_idx]] = torch.LongTensor(batch_data[idx].char_ids[word_idx])
89 | for wordIdx in range(word_seq_len[idx], max_seq_len):
90 | char_seq_tensor[idx, wordIdx, 0: 1] = torch.LongTensor([config.char2idx[PAD]]) ###because line 119 makes it 1, every single character should have a id. but actually 0 is enough
91 |
92 | ### NOTE: make this step during forward if you have limited GPU resource.
93 | word_seq_tensor = word_seq_tensor.to(config.device)
94 | label_seq_tensor = label_seq_tensor.to(config.device)
95 | char_seq_tensor = char_seq_tensor.to(config.device)
96 | word_seq_len = word_seq_len.to(config.device)
97 | char_seq_len = char_seq_len.to(config.device)
98 | if config.dep_model != DepModelType.none:
99 | if config.dep_model == DepModelType.dglstm:
100 | batch_dep_heads = batch_dep_heads.to(config.device)
101 | dep_label_tensor = dep_label_tensor.to(config.device)
102 |
103 | return word_seq_tensor, word_seq_len, word_emb_tensor, char_seq_tensor, char_seq_len, adjs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, trees, label_seq_tensor, dep_label_tensor
104 |
105 |
106 |
107 | def lr_decay(config, optimizer, epoch):
108 | lr = config.learning_rate / (1 + config.lr_decay * (epoch - 1))
109 | for param_group in optimizer.param_groups:
110 | param_group['lr'] = lr
111 | print('learning rate is set to: ', lr)
112 | return optimizer
113 |
114 |
115 |
116 | def head_to_adj(max_len, inst, config):
117 | """
118 | Convert a tree object to an (numpy) adjacency matrix.
119 | """
120 | directed = config.adj_directed
121 | self_loop = False #config.adj_self_loop
122 | ret = np.zeros((max_len, max_len), dtype=np.float32)
123 |
124 | for i, head in enumerate(inst.input.heads):
125 | if head == -1:
126 | continue
127 | ret[head, i] = 1
128 |
129 | if not directed:
130 | ret = ret + ret.T
131 |
132 | if self_loop:
133 | for i in range(len(inst.input.words)):
134 | ret[i, i] = 1
135 |
136 | return ret
137 |
138 |
139 | def head_to_adj_label(max_len, inst, config):
140 | """
141 | Convert a tree object to an (numpy) adjacency matrix.
142 | """
143 | directed = config.adj_directed
144 | self_loop = config.adj_self_loop
145 |
146 | dep_label_ret = np.zeros((max_len, max_len), dtype=np.long)
147 |
148 | for i, head in enumerate(inst.input.heads):
149 | if head == -1:
150 | continue
151 | dep_label_ret[head, i] = inst.dep_label_ids[i]
152 |
153 | if not directed:
154 | dep_label_ret = dep_label_ret + dep_label_ret.T
155 |
156 | if self_loop:
157 | for i in range(len(inst.input.words)):
158 | dep_label_ret[i, i] = config.root_dep_label_id
159 |
160 | return dep_label_ret
161 |
162 |
163 | def get_spans(output):
164 | output_spans = set()
165 | start = -1
166 | for i in range(len(output)):
167 | if output[i].startswith("B-"):
168 | start = i
169 | if output[i].startswith("E-"):
170 | end = i
171 | output_spans.add(Span(start, end, output[i][2:]))
172 | if output[i].startswith("S-"):
173 | output_spans.add(Span(i, i, output[i][2:]))
174 | return output_spans
175 |
176 | def preprocess(conf, insts, file_type:str):
177 | print("[Preprocess Info]Doing preprocessing for the CoNLL-2003 dataset: {}.".format(file_type))
178 | for inst in insts:
179 | output = inst.output
180 | spans = get_spans(output)
181 | for span in spans:
182 | if span.right - span.left + 1 < 2:
183 | continue
184 | count_dep = 0
185 | for i in range(span.left, span.right + 1):
186 | if inst.input.heads[i] >= span.left and inst.input.heads[i] <= span.right:
187 | count_dep += 1
188 | if count_dep != (span.right - span.left):
189 |
190 | for i in range(span.left, span.right + 1):
191 | if inst.input.heads[i] < span.left or inst.input.heads[i] > span.right:
192 | if i != span.right:
193 | inst.input.heads[i] = span.right
194 | inst.input.dep_labels[i] = "nn" if "sd" in conf.affix else "compound"
--------------------------------------------------------------------------------
/data/readme.txt:
--------------------------------------------------------------------------------
1 | In terms of the OntoNotes English dataset:
2 | This is the standard train/dev/test split
3 |
4 | train dev are taken from the conll2012-processed
5 | test is taken from the pradhan-processed
6 |
7 | train: 1,088,503 tokens, 81,828 entities
8 | dev: 147,724 tokens, 11,066 entities
9 | test: 152,728 tokens, 11,257 entities
10 |
11 | The thing that we did this is because of fair comparison with previous work.
12 |
13 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import random
4 | import numpy as np
5 | from config.reader import Reader
6 | from config import eval
7 | from config.config import Config, ContextEmb, DepModelType
8 | import time
9 | from model.lstmcrf import NNCRF
10 | import torch
11 | import torch.optim as optim
12 | import torch.nn as nn
13 | from config.utils import lr_decay, simple_batching, get_spans, preprocess
14 | from typing import List
15 | from common.instance import Instance
16 | from termcolor import colored
17 | import os
18 |
19 |
20 | def setSeed(opt, seed):
21 | random.seed(seed)
22 | np.random.seed(seed)
23 | torch.manual_seed(seed)
24 | if opt.device.startswith("cuda"):
25 | print("using GPU...", torch.cuda.current_device())
26 | torch.cuda.manual_seed(seed)
27 | torch.cuda.manual_seed_all(seed)
28 |
29 |
30 | def parse_arguments(parser):
31 | ###Training Hyperparameters
32 | parser.add_argument('--mode', type=str, default='train')
33 | parser.add_argument('--device', type=str, default="cpu")
34 | parser.add_argument('--seed', type=int, default=42)
35 | parser.add_argument('--digit2zero', action="store_true", default=True)
36 | parser.add_argument('--dataset', type=str, default="ontonotes")
37 | parser.add_argument('--affix', type=str, default="sd")
38 | parser.add_argument('--embedding_file', type=str, default="data/glove.6B.100d.txt")
39 | # parser.add_argument('--embedding_file', type=str, default=None)
40 | parser.add_argument('--embedding_dim', type=int, default=100)
41 | parser.add_argument('--optimizer', type=str, default="sgd")
42 | parser.add_argument('--learning_rate', type=float, default=0.01) ##only for sgd now
43 | parser.add_argument('--momentum', type=float, default=0.0)
44 | parser.add_argument('--l2', type=float, default=1e-8)
45 | parser.add_argument('--lr_decay', type=float, default=0)
46 | parser.add_argument('--batch_size', type=int, default=10)
47 | parser.add_argument('--num_epochs', type=int, default=100)
48 | parser.add_argument('--train_num', type=int, default=-1)
49 | parser.add_argument('--dev_num', type=int, default=-1)
50 | parser.add_argument('--test_num', type=int, default=-1)
51 | parser.add_argument('--eval_freq', type=int, default=4000, help="evaluate frequency (iteration)")
52 | parser.add_argument('--eval_epoch', type=int, default=0, help="evaluate the dev set after this number of epoch")
53 |
54 | ## model hyperparameter
55 | parser.add_argument('--hidden_dim', type=int, default=200, help="hidden size of the LSTM")
56 | parser.add_argument('--num_lstm_layer', type=int, default=1, help="number of lstm layers")
57 | parser.add_argument('--dep_emb_size', type=int, default=50, help="embedding size of dependency")
58 | parser.add_argument('--dep_hidden_dim', type=int, default=200, help="hidden size of gcn, tree lstm")
59 |
60 | ### NOTE: GCN parameters, useless if we are not using GCN
61 | parser.add_argument('--num_gcn_layers', type=int, default=1, help="number of gcn layers")
62 | parser.add_argument('--gcn_mlp_layers', type=int, default=1, help="number of mlp layers after gcn")
63 | parser.add_argument('--gcn_dropout', type=float, default=0.5, help="GCN dropout")
64 | parser.add_argument('--gcn_adj_directed', type=int, default=0, choices=[0, 1], help="GCN ajacent matrix directed")
65 | parser.add_argument('--gcn_adj_selfloop', type=int, default=0, choices=[0, 1], help="GCN selfloop in adjacent matrix, now always false as add it in the model")
66 | parser.add_argument('--gcn_gate', type=int, default=0, choices=[0, 1], help="add edge_wise gating")
67 |
68 | ##NOTE: this dropout applies to many places
69 | parser.add_argument('--dropout', type=float, default=0.5, help="dropout for embedding")
70 | parser.add_argument('--use_char_rnn', type=int, default=1, choices=[0, 1], help="use character-level lstm, 0 or 1")
71 | # parser.add_argument('--use_head', type=int, default=0, choices=[0, 1], help="not use dependency")
72 | parser.add_argument('--dep_model', type=str, default="none", choices=["none", "dggcn", "dglstm"], help="dependency method")
73 | parser.add_argument('--inter_func', type=str, default="mlp", choices=["concatenation", "addition", "mlp"], help="combination method, 0 concat, 1 additon, 2 gcn, 3 more parameter gcn")
74 | parser.add_argument('--context_emb', type=str, default="none", choices=["none", "bert", "elmo", "flair"], help="contextual word embedding")
75 |
76 |
77 |
78 |
79 | args = parser.parse_args()
80 | for k in args.__dict__:
81 | print(k + ": " + str(args.__dict__[k]))
82 | return args
83 |
84 |
85 | def get_optimizer(config: Config, model: nn.Module):
86 | params = model.parameters()
87 | if config.optimizer.lower() == "sgd":
88 | print(colored("Using SGD: lr is: {}, L2 regularization is: {}".format(config.learning_rate, config.l2), 'yellow'))
89 | return optim.SGD(params, lr=config.learning_rate, weight_decay=float(config.l2))
90 | elif config.optimizer.lower() == "adam":
91 | print(colored("Using Adam", 'yellow'))
92 | return optim.Adam(params)
93 | else:
94 | print("Illegal optimizer: {}".format(config.optimizer))
95 | exit(1)
96 |
97 | def batching_list_instances(config: Config, insts:List[Instance]):
98 | train_num = len(insts)
99 | batch_size = config.batch_size
100 | total_batch = train_num // batch_size + 1 if train_num % batch_size != 0 else train_num // batch_size
101 | batched_data = []
102 | for batch_id in range(total_batch):
103 | one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) * batch_size]
104 | batched_data.append(simple_batching(config, one_batch_insts))
105 |
106 | return batched_data
107 |
108 | def learn_from_insts(config:Config, epoch: int, train_insts, dev_insts, test_insts):
109 | # train_insts: List[Instance], dev_insts: List[Instance], test_insts: List[Instance], batch_size: int = 1
110 | model = NNCRF(config)
111 | optimizer = get_optimizer(config, model)
112 | train_num = len(train_insts)
113 | print("number of instances: %d" % (train_num))
114 | print(colored("[Shuffled] Shuffle the training instance ids", "red"))
115 | random.shuffle(train_insts)
116 |
117 |
118 |
119 | batched_data = batching_list_instances(config, train_insts)
120 | dev_batches = batching_list_instances(config, dev_insts)
121 | test_batches = batching_list_instances(config, test_insts)
122 |
123 | best_dev = [-1, 0]
124 | best_test = [-1, 0]
125 |
126 | dep_model_name = config.dep_model.name
127 | if config.dep_model == DepModelType.dggcn:
128 | dep_model_name += '(' + str(config.num_gcn_layers) + "," + str(config.gcn_dropout) + "," + str(
129 | config.gcn_mlp_layers) + ")"
130 | model_name = "model_files/lstm_{}_{}_crf_{}_{}_{}_dep_{}_elmo_{}_{}_gate_{}_epoch_{}_lr_{}_comb_{}.m".format(config.num_lstm_layer, config.hidden_dim, config.dataset, config.affix, config.train_num, dep_model_name, config.context_emb.name, config.optimizer.lower(), config.edge_gate, epoch, config.learning_rate, config.interaction_func)
131 | res_name = "results/lstm_{}_{}_crf_{}_{}_{}_dep_{}_elmo_{}_{}_gate_{}_epoch_{}_lr_{}_comb_{}.results".format(config.num_lstm_layer, config.hidden_dim, config.dataset, config.affix, config.train_num, dep_model_name, config.context_emb.name, config.optimizer.lower(), config.edge_gate, epoch, config.learning_rate, config.interaction_func)
132 | print("[Info] The model will be saved to: %s, please ensure models folder exist" % (model_name))
133 | if not os.path.exists("model_files"):
134 | os.makedirs("model_files")
135 | if not os.path.exists("results"):
136 | os.makedirs("results")
137 |
138 | for i in range(1, epoch + 1):
139 | epoch_loss = 0
140 | start_time = time.time()
141 | model.zero_grad()
142 | if config.optimizer.lower() == "sgd":
143 | optimizer = lr_decay(config, optimizer, i)
144 | for index in np.random.permutation(len(batched_data)):
145 | # for index in range(len(batched_data)):
146 | model.train()
147 | batch_word, batch_wordlen, batch_context_emb, batch_char, batch_charlen, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, trees, batch_label, batch_dep_label = batched_data[index]
148 | loss = model.neg_log_obj(batch_word, batch_wordlen, batch_context_emb,batch_char, batch_charlen, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, batch_label, batch_dep_label, trees)
149 | epoch_loss += loss.item()
150 | loss.backward()
151 | if config.dep_model == DepModelType.dggcn:
152 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip) ##clipping the gradient
153 | optimizer.step()
154 | model.zero_grad()
155 |
156 | end_time = time.time()
157 | print("Epoch %d: %.5f, Time is %.2fs" % (i, epoch_loss, end_time - start_time), flush=True)
158 |
159 | if i + 1 >= config.eval_epoch:
160 | model.eval()
161 | dev_metrics = evaluate(config, model, dev_batches, "dev", dev_insts)
162 | test_metrics = evaluate(config, model, test_batches, "test", test_insts)
163 | if dev_metrics[2] > best_dev[0]:
164 | print("saving the best model...")
165 | best_dev[0] = dev_metrics[2]
166 | best_dev[1] = i
167 | best_test[0] = test_metrics[2]
168 | best_test[1] = i
169 | torch.save(model.state_dict(), model_name)
170 | write_results(res_name, test_insts)
171 | model.zero_grad()
172 |
173 | print("The best dev: %.2f" % (best_dev[0]))
174 | print("The corresponding test: %.2f" % (best_test[0]))
175 | print("Final testing.")
176 | model.load_state_dict(torch.load(model_name))
177 | model.eval()
178 | evaluate(config, model, test_batches, "test", test_insts)
179 | write_results(res_name, test_insts)
180 |
181 |
182 |
183 | def evaluate(config:Config, model: NNCRF, batch_insts_ids, name:str, insts: List[Instance]):
184 | ## evaluation
185 | metrics = np.asarray([0, 0, 0], dtype=int)
186 | batch_id = 0
187 | batch_size = config.batch_size
188 | for batch in batch_insts_ids:
189 | one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) * batch_size]
190 | sorted_batch_insts = sorted(one_batch_insts, key=lambda inst: len(inst.input.words), reverse=True)
191 | batch_max_scores, batch_max_ids = model.decode(batch)
192 | metrics += eval.evaluate_num(sorted_batch_insts, batch_max_ids, batch[-2], batch[1], config.idx2labels)
193 | batch_id += 1
194 | p, total_predict, total_entity = metrics[0], metrics[1], metrics[2]
195 | precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0
196 | recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0
197 | fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0
198 | print("[%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" % (name, precision, recall,fscore), flush=True)
199 | return [precision, recall, fscore]
200 |
201 |
202 | def test_model(config: Config, test_insts):
203 | dep_model_name = config.dep_model.name
204 | if config.dep_model == DepModelType.dggcn:
205 | dep_model_name += '(' + str(config.num_gcn_layers) + ","+str(config.gcn_dropout)+ ","+str(config.gcn_mlp_layers)+")"
206 | model_name = "model_files/lstm_{}_{}_crf_{}_{}_{}_dep_{}_elmo_{}_{}_gate_{}_epoch_{}_lr_{}_comb_{}.m".format(config.num_lstm_layer, config.hidden_dim,
207 | config.dataset, config.affix,
208 | config.train_num,
209 | dep_model_name,
210 | config.context_emb.name,
211 | config.optimizer.lower(),
212 | config.edge_gate,
213 | config.num_epochs,
214 | config.learning_rate, config.interaction_func)
215 | res_name = "results/lstm_{}_{}_crf_{}_{}_{}_dep_{}_elmo_{}_{}_gate_{}_epoch_{}_lr_{}_comb_{}.results".format(config.num_lstm_layer, config.hidden_dim,
216 | config.dataset, config.affix,
217 | config.train_num,
218 | dep_model_name,
219 | config.context_emb.name,
220 | config.optimizer.lower(),
221 | config.edge_gate,
222 | config.num_epochs,
223 | config.learning_rate, config.interaction_func)
224 | model = NNCRF(config)
225 | model.load_state_dict(torch.load(model_name))
226 | model.eval()
227 | test_batches = batching_list_instances(config, test_insts)
228 | evaluate(config, model, test_batches, "test", test_insts)
229 | write_results(res_name, test_insts)
230 |
231 | def write_results(filename:str, insts):
232 | f = open(filename, 'w', encoding='utf-8')
233 | for inst in insts:
234 | for i in range(len(inst.input)):
235 | words = inst.input.words
236 | tags = inst.input.pos_tags
237 | heads = inst.input.heads
238 | dep_labels = inst.input.dep_labels
239 | output = inst.output
240 | prediction = inst.prediction
241 | assert len(output) == len(prediction)
242 | f.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format(i, words[i], tags[i], heads[i], dep_labels[i], output[i], prediction[i]))
243 | f.write("\n")
244 | f.close()
245 |
246 |
247 |
248 |
249 |
250 |
251 | def main():
252 | parser = argparse.ArgumentParser(description="Dependency-Guided LSTM CRF implementation")
253 | opt = parse_arguments(parser)
254 | conf = Config(opt)
255 |
256 | reader = Reader(conf.digit2zero)
257 | setSeed(opt, conf.seed)
258 |
259 | trains = reader.read_conll(conf.train_file, -1, True)
260 | devs = reader.read_conll(conf.dev_file, conf.dev_num, False)
261 | tests = reader.read_conll(conf.test_file, conf.test_num, False)
262 |
263 | if conf.context_emb != ContextEmb.none:
264 | print('Loading the {} vectors for all datasets.'.format(conf.context_emb.name))
265 | conf.context_emb_size = reader.load_elmo_vec(conf.train_file.replace(".sd", "").replace(".ud", "").replace(".sud", "").replace(".predsd", "").replace(".predud", "").replace(".stud", "").replace(".ssd", "") + "."+conf.context_emb.name+".vec", trains)
266 | reader.load_elmo_vec(conf.dev_file.replace(".sd", "").replace(".ud", "").replace(".sud", "").replace(".predsd", "").replace(".predud", "").replace(".stud", "").replace(".ssd", "") + "."+conf.context_emb.name+".vec", devs)
267 | reader.load_elmo_vec(conf.test_file.replace(".sd", "").replace(".ud", "").replace(".sud", "").replace(".predsd", "").replace(".predud", "").replace(".stud", "").replace(".ssd", "") + "."+conf.context_emb.name+".vec", tests)
268 |
269 | conf.use_iobes(trains + devs + tests)
270 | conf.build_label_idx(trains)
271 |
272 | conf.build_deplabel_idx(trains + devs + tests)
273 | print("# deplabels: ", len(conf.deplabels))
274 | print("dep label 2idx: ", conf.deplabel2idx)
275 |
276 |
277 | conf.build_word_idx(trains, devs, tests)
278 | conf.build_emb_table()
279 | conf.map_insts_ids(trains + devs + tests)
280 |
281 |
282 | print("num chars: " + str(conf.num_char))
283 | # print(str(config.char2idx))
284 |
285 | print("num words: " + str(len(conf.word2idx)))
286 | # print(config.word2idx)
287 | if opt.mode == "train":
288 | if conf.train_num != -1:
289 | random.shuffle(trains)
290 | trains = trains[:conf.train_num]
291 | learn_from_insts(conf, conf.num_epochs, trains, devs, tests)
292 | else:
293 | ## Load the trained model.
294 | test_model(conf, tests)
295 | # pass
296 |
297 | print(opt.mode)
298 |
299 | if __name__ == "__main__":
300 | main()
--------------------------------------------------------------------------------
/model/charbilstm.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8 |
9 |
10 | class CharBiLSTM(nn.Module):
11 |
12 | def __init__(self, config):
13 | super(CharBiLSTM, self).__init__()
14 | print("[Info] Building character-level LSTM")
15 | self.char_emb_size = config.char_emb_size
16 | self.char2idx = config.char2idx
17 | self.chars = config.idx2char
18 | self.char_size = len(self.chars)
19 | self.device = config.device
20 | self.hidden = config.charlstm_hidden_dim
21 | self.dropout = nn.Dropout(config.dropout).to(self.device)
22 | self.char_embeddings = nn.Embedding(self.char_size, self.char_emb_size)
23 | # self.char_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(self.char_size, self.char_emb_size)))
24 | self.char_embeddings = self.char_embeddings.to(self.device)
25 |
26 | self.char_lstm = nn.LSTM(self.char_emb_size, self.hidden ,num_layers=1, batch_first=True, bidirectional=False).to(self.device)
27 |
28 |
29 | # def random_embedding(self, vocab_size, embedding_dim):
30 | # pretrain_emb = np.empty([vocab_size, embedding_dim])
31 | # scale = np.sqrt(3.0 / embedding_dim)
32 | # for index in range(vocab_size):
33 | # pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
34 | # return pretrain_emb
35 |
36 | def get_last_hiddens(self, char_seq_tensor, char_seq_len):
37 | """
38 | input:
39 | char_seq_tensor: (batch_size, sent_len, word_length)
40 | char_seq_len: (batch_size, sent_len)
41 | output:
42 | Variable(batch_size, sent_len, char_hidden_dim )
43 | """
44 | batch_size = char_seq_tensor.size(0)
45 | sent_len = char_seq_tensor.size(1)
46 | char_seq_tensor = char_seq_tensor.view(batch_size * sent_len, -1)
47 | char_seq_len = char_seq_len.view(batch_size * sent_len)
48 | sorted_seq_len, permIdx = char_seq_len.sort(0, descending=True)
49 | _, recover_idx = permIdx.sort(0, descending=False)
50 | sorted_seq_tensor = char_seq_tensor[permIdx]
51 |
52 | char_embeds = self.dropout(self.char_embeddings(sorted_seq_tensor))
53 | pack_input = pack_padded_sequence(char_embeds, sorted_seq_len, batch_first=True)
54 |
55 | _, char_hidden = self.char_lstm(pack_input, None) ###
56 | ## char_hidden = (h_t, c_t)
57 | # char_hidden[0] = h_t = (2, batch_size, lstm_dimension)
58 | # char_rnn_out, _ = pad_packed_sequence(char_rnn_out)
59 | ## transpose because the first dimension is num_direction x num-layer
60 | hidden = char_hidden[0].transpose(1,0).contiguous().view(batch_size * sent_len, 1, -1) ### before view, the size is ( batch_size * sent_len, 2, lstm_dimension) 2 means 2 direciton..
61 | return hidden[recover_idx].view(batch_size, sent_len, -1)
62 |
63 |
64 |
65 | def forward(self, char_input, seq_lengths):
66 | return self.get_last_hiddens(char_input, seq_lengths)
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/model/deplabel_gcn.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 |
11 | class DepLabeledGCN(nn.Module):
12 | def __init__(self, config, input_dim):
13 | super().__init__()
14 |
15 | self.gcn_hidden_dim = config.dep_hidden_dim
16 | self.num_gcn_layers = config.num_gcn_layers
17 | self.gcn_mlp_layers = config.gcn_mlp_layers
18 | self.edge_gate = config.edge_gate
19 | # gcn layer
20 | self.layers = self.num_gcn_layers
21 | self.device = config.device
22 | self.mem_dim = self.gcn_hidden_dim
23 | # self.in_dim = config.hidden_dim + config.dep_emb_size ## lstm hidden dim
24 | self.in_dim = input_dim ## lstm hidden dim
25 | self.self_dep_label_id = torch.tensor(config.deplabel2idx[config.self_label]).long().to(self.device)
26 |
27 | print("[Model Info] GCN Input Size: {}, # GCN Layers: {}, #MLP: {}".format(self.in_dim, self.num_gcn_layers, config.gcn_mlp_layers))
28 | self.gcn_drop = nn.Dropout(config.gcn_dropout).to(self.device)
29 |
30 | # gcn layer
31 | self.W = nn.ModuleList()
32 | self.W_label = nn.ModuleList()
33 |
34 | if self.edge_gate:
35 | print("[Info] Labeled GCN model will be added edge-wise gating.")
36 | self.gates = nn.ModuleList()
37 |
38 | for layer in range(self.layers):
39 | input_dim = self.in_dim if layer == 0 else self.mem_dim
40 | self.W.append(nn.Linear(input_dim, self.mem_dim).to(self.device))
41 | self.W_label.append(nn.Linear(input_dim, self.mem_dim).to(self.device))
42 | if self.edge_gate:
43 | self.gates.append(nn.Linear(input_dim, self.mem_dim).to(self.device))
44 |
45 | self.dep_emb = nn.Embedding(len(config.deplabels), 1).to(config.device)
46 |
47 | # output mlp layers
48 | in_dim = config.hidden_dim
49 | layers = [nn.Linear(in_dim, self.gcn_hidden_dim).to(self.device), nn.ReLU().to(self.device)]
50 | for _ in range(self.gcn_mlp_layers - 1):
51 | layers += [nn.Linear(self.gcn_hidden_dim, self.gcn_hidden_dim).to(self.device), nn.ReLU().to(self.device)]
52 |
53 | self.out_mlp = nn.Sequential(*layers).to(self.device)
54 |
55 |
56 |
57 | def forward(self, gcn_inputs, word_seq_len, adj_matrix, dep_label_matrix):
58 |
59 | """
60 |
61 | :param gcn_inputs:
62 | :param word_seq_len:
63 | :param adj_matrix: should already contain the self loop
64 | :param dep_label_matrix:
65 | :return:
66 | """
67 | adj_matrix = adj_matrix.to(self.device)
68 | dep_label_matrix = dep_label_matrix.to(self.device)
69 | batch_size, sent_len, input_dim = gcn_inputs.size()
70 |
71 | denom = adj_matrix.sum(2).unsqueeze(2) + 1
72 |
73 | ##dep_label_matrix: NxN
74 | ##dep_emb.
75 | dep_embs = self.dep_emb(dep_label_matrix) ## B x N x N x 1
76 | dep_embs = dep_embs.squeeze(3) * adj_matrix
77 | #
78 | self_val = self.dep_emb(self.self_dep_label_id)
79 | dep_denom = dep_embs.sum(2).unsqueeze(2) + self_val
80 |
81 | # gcn_biinput = gcn_inputs.view(batch_size, sent_len, 1, input_dim).expand(batch_size, sent_len, sent_len, input_dim) ## B x N x N x h
82 | # weighted_gcn_input = (dep_embs + gcn_biinput).sum(2)
83 |
84 | for l in range(self.layers):
85 |
86 | Ax = adj_matrix.bmm(gcn_inputs) ## N x N times N x h = Nxh
87 | AxW = self.W[l](Ax) ## N x m
88 | AxW = AxW + self.W[l](gcn_inputs) ## self loop N x h
89 | AxW = AxW / denom
90 |
91 | Bx = dep_embs.bmm(gcn_inputs)
92 | BxW = self.W_label[l](Bx)
93 | BxW = BxW + self.W_label[l](gcn_inputs * self_val)
94 | BxW = BxW / dep_denom
95 |
96 | if self.edge_gate:
97 | gx = adj_matrix.bmm(gcn_inputs)
98 | gxW = self.gates[l](gx) ## N x m
99 | gate_val = torch.sigmoid(gxW + self.gates[l](gcn_inputs)) ## self loop N x h
100 | gAxW = F.relu(gate_val * (AxW + BxW))
101 | else:
102 | gAxW = F.relu(AxW + BxW)
103 |
104 | gcn_inputs = self.gcn_drop(gAxW) if l < self.layers - 1 else gAxW
105 |
106 |
107 | outputs = self.out_mlp(gcn_inputs)
108 | return outputs
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------
/model/lstmcrf.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from config.utils import START, STOP, PAD, log_sum_exp_pytorch
9 | from model.charbilstm import CharBiLSTM
10 | from model.deplabel_gcn import DepLabeledGCN
11 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
12 | from config.config import DepModelType, ContextEmb, InteractionFunction
13 | import torch.nn.functional as F
14 |
15 | class NNCRF(nn.Module):
16 |
17 | def __init__(self, config):
18 | super(NNCRF, self).__init__()
19 |
20 | self.label_size = config.label_size
21 | self.device = config.device
22 | self.use_char = config.use_char_rnn
23 | self.dep_model = config.dep_model
24 | self.context_emb = config.context_emb
25 | self.interaction_func = config.interaction_func
26 |
27 |
28 | self.label2idx = config.label2idx
29 | self.labels = config.idx2labels
30 | self.start_idx = self.label2idx[START]
31 | self.end_idx = self.label2idx[STOP]
32 | self.pad_idx = self.label2idx[PAD]
33 |
34 |
35 |
36 | self.input_size = config.embedding_dim
37 |
38 | if self.use_char:
39 | self.char_feature = CharBiLSTM(config)
40 | self.input_size += config.charlstm_hidden_dim
41 |
42 |
43 | vocab_size = len(config.word2idx)
44 | self.word_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(config.word_embedding), freeze=False).to(self.device)
45 | self.word_drop = nn.Dropout(config.dropout).to(self.device)
46 |
47 | if self.dep_model == DepModelType.dglstm and self.interaction_func == InteractionFunction.mlp:
48 | self.mlp_layers = nn.ModuleList()
49 | for i in range(config.num_lstm_layer - 1):
50 | self.mlp_layers.append(nn.Linear(config.hidden_dim, 2 * config.hidden_dim).to(self.device))
51 | self.mlp_head_linears = nn.ModuleList()
52 | for i in range(config.num_lstm_layer - 1):
53 | self.mlp_head_linears.append(nn.Linear(config.hidden_dim, 2 * config.hidden_dim).to(self.device))
54 |
55 | """
56 | Input size to LSTM description
57 | """
58 | self.charlstm_dim = config.charlstm_hidden_dim
59 | if self.dep_model == DepModelType.dglstm:
60 | self.input_size += config.embedding_dim + config.dep_emb_size
61 | if self.use_char:
62 | self.input_size += config.charlstm_hidden_dim
63 |
64 | if self.context_emb != ContextEmb.none:
65 | self.input_size += config.context_emb_size
66 |
67 | print("[Model Info] Input size to LSTM: {}".format(self.input_size))
68 | print("[Model Info] LSTM Hidden Size: {}".format(config.hidden_dim))
69 |
70 |
71 | num_layers = 1
72 | if config.num_lstm_layer > 1 and self.dep_model != DepModelType.dglstm:
73 | num_layers = config.num_lstm_layer
74 | if config.num_lstm_layer > 0:
75 | self.lstm = nn.LSTM(self.input_size, config.hidden_dim // 2, num_layers=num_layers, batch_first=True, bidirectional=True).to(self.device)
76 |
77 | self.num_lstm_layer = config.num_lstm_layer
78 | self.lstm_hidden_dim = config.hidden_dim
79 | self.embedding_dim = config.embedding_dim
80 | if config.num_lstm_layer > 1 and self.dep_model == DepModelType.dglstm:
81 | self.add_lstms = nn.ModuleList()
82 | if self.interaction_func == InteractionFunction.concatenation or \
83 | self.interaction_func == InteractionFunction.mlp:
84 | hidden_size = 2 * config.hidden_dim
85 | elif self.interaction_func == InteractionFunction.addition:
86 | hidden_size = config.hidden_dim
87 |
88 | print("[Model Info] Building {} more LSTMs, with size: {} x {} (without dep label highway connection)".format(config.num_lstm_layer-1, hidden_size, config.hidden_dim))
89 | for i in range(config.num_lstm_layer - 1):
90 | self.add_lstms.append(nn.LSTM(hidden_size, config.hidden_dim // 2, num_layers=1, batch_first=True, bidirectional=True).to(self.device))
91 |
92 | self.drop_lstm = nn.Dropout(config.dropout).to(self.device)
93 |
94 |
95 | final_hidden_dim = config.hidden_dim if self.num_lstm_layer >0 else self.input_size
96 | """
97 | Model description
98 | """
99 | print("[Model Info] Dep Method: {}, hidden size: {}".format(self.dep_model.name, config.dep_hidden_dim))
100 | if self.dep_model != DepModelType.none:
101 | print("Initializing the dependency label embedding")
102 | self.dep_label_embedding = nn.Embedding(len(config.deplabel2idx), config.dep_emb_size).to(self.device)
103 | if self.dep_model == DepModelType.dggcn:
104 | self.gcn = DepLabeledGCN(config, config.hidden_dim) ### lstm hidden size
105 | final_hidden_dim = config.dep_hidden_dim
106 |
107 | print("[Model Info] Final Hidden Size: {}".format(final_hidden_dim))
108 | self.hidden2tag = nn.Linear(final_hidden_dim, self.label_size).to(self.device)
109 |
110 | init_transition = torch.randn(self.label_size, self.label_size).to(self.device)
111 | init_transition[:, self.start_idx] = -10000.0
112 | init_transition[self.end_idx, :] = -10000.0
113 | init_transition[:, self.pad_idx] = -10000.0
114 | init_transition[self.pad_idx, :] = -10000.0
115 |
116 | self.transition = nn.Parameter(init_transition)
117 |
118 |
119 | def neural_scoring(self, word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, dep_head_tensor, dep_label_tensor, trees=None):
120 | """
121 | :param word_seq_tensor: (batch_size, sent_len) NOTE: The word seq actually is already ordered before come here.
122 | :param word_seq_lens: (batch_size, 1)
123 | :param chars: (batch_size * sent_len * word_length)
124 | :param char_seq_lens: numpy (batch_size * sent_len , 1)
125 | :param dep_label_tensor: (batch_size, max_sent_len)
126 | :return: emission scores (batch_size, sent_len, hidden_dim)
127 | """
128 | batch_size = word_seq_tensor.size(0)
129 | sent_len = word_seq_tensor.size(1)
130 |
131 | word_emb = self.word_embedding(word_seq_tensor)
132 | if self.use_char:
133 | if self.dep_model == DepModelType.dglstm:
134 | char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lens)
135 | word_emb = torch.cat((word_emb, char_features), 2)
136 | if self.dep_model == DepModelType.dglstm:
137 | size = self.embedding_dim if not self.use_char else (self.embedding_dim + self.charlstm_dim)
138 | dep_head_emb = torch.gather(word_emb, 1, dep_head_tensor.view(batch_size, sent_len, 1).expand(batch_size, sent_len, size))
139 |
140 | if self.context_emb != ContextEmb.none:
141 | word_emb = torch.cat((word_emb, batch_context_emb.to(self.device)), 2)
142 |
143 | if self.use_char:
144 | if self.dep_model != DepModelType.dglstm:
145 | char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lens)
146 | word_emb = torch.cat((word_emb, char_features), 2)
147 |
148 | """
149 | Word Representation
150 | """
151 | if self.dep_model == DepModelType.dglstm:
152 | dep_emb = self.dep_label_embedding(dep_label_tensor)
153 | word_emb = torch.cat((word_emb, dep_head_emb, dep_emb), 2)
154 |
155 | word_rep = self.word_drop(word_emb)
156 |
157 | sorted_seq_len, permIdx = word_seq_lens.sort(0, descending=True)
158 | _, recover_idx = permIdx.sort(0, descending=False)
159 | sorted_seq_tensor = word_rep[permIdx]
160 |
161 |
162 | if self.num_lstm_layer > 0:
163 | packed_words = pack_padded_sequence(sorted_seq_tensor, sorted_seq_len, True)
164 | lstm_out, _ = self.lstm(packed_words, None)
165 | lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True) ## CARE: make sure here is batch_first, otherwise need to transpose.
166 | feature_out = self.drop_lstm(lstm_out)
167 | else:
168 | feature_out = sorted_seq_tensor
169 |
170 | """
171 | Higher order interactions
172 | """
173 | if self.num_lstm_layer > 1 and (self.dep_model == DepModelType.dglstm):
174 | for l in range(self.num_lstm_layer-1):
175 | dep_head_emb = torch.gather(feature_out, 1, dep_head_tensor[permIdx].view(batch_size, sent_len, 1).expand(batch_size, sent_len, self.lstm_hidden_dim))
176 | if self.interaction_func == InteractionFunction.concatenation:
177 | feature_out = torch.cat((feature_out, dep_head_emb), 2)
178 | elif self.interaction_func == InteractionFunction.addition:
179 | feature_out = feature_out + dep_head_emb
180 | elif self.interaction_func == InteractionFunction.mlp:
181 | feature_out = F.relu(self.mlp_layers[l](feature_out) + self.mlp_head_linears[l](dep_head_emb))
182 |
183 | packed_words = pack_padded_sequence(feature_out, sorted_seq_len, True)
184 | lstm_out, _ = self.add_lstms[l](packed_words, None)
185 | lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True) ## CARE: make sure here is batch_first, otherwise need to transpose.
186 | feature_out = self.drop_lstm(lstm_out)
187 |
188 | """
189 | Model forward if we have GCN
190 | """
191 | if self.dep_model == DepModelType.dggcn:
192 | feature_out = self.gcn(feature_out, sorted_seq_len, adj_matrixs[permIdx], dep_label_adj[permIdx])
193 |
194 | outputs = self.hidden2tag(feature_out)
195 |
196 | return outputs[recover_idx]
197 |
198 | def calculate_all_scores(self, features):
199 | batch_size = features.size(0)
200 | seq_len = features.size(1)
201 | scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
202 | features.view(batch_size, seq_len, 1, self.label_size).expand(batch_size,seq_len,self.label_size, self.label_size)
203 | return scores
204 |
205 | def forward_unlabeled(self, all_scores, word_seq_lens, masks):
206 | batch_size = all_scores.size(0)
207 | seq_len = all_scores.size(1)
208 | alpha = torch.zeros(batch_size, seq_len, self.label_size).to(self.device)
209 |
210 | alpha[:, 0, :] = all_scores[:, 0, self.start_idx, :] ## the first position of all labels = (the transition from start - > all labels) + current emission.
211 |
212 | for word_idx in range(1, seq_len):
213 | ## batch_size, self.label_size, self.label_size
214 | before_log_sum_exp = alpha[:, word_idx-1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + all_scores[:, word_idx, :, :]
215 | alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)
216 |
217 | ### batch_size x label_size
218 | last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size)-1).view(batch_size, self.label_size)
219 | last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size)
220 | last_alpha = log_sum_exp_pytorch(last_alpha.view(batch_size, self.label_size, 1)).view(batch_size)
221 |
222 | return torch.sum(last_alpha)
223 |
224 | def forward_labeled(self, all_scores, word_seq_lens, tags, masks):
225 | '''
226 | :param all_scores: (batch, seq_len, label_size, label_size)
227 | :param word_seq_lens: (batch, seq_len)
228 | :param tags: (batch, seq_len)
229 | :param masks: batch, seq_len
230 | :return: sum of score for the gold sequences
231 | '''
232 | batchSize = all_scores.shape[0]
233 | sentLength = all_scores.shape[1]
234 |
235 | ## all the scores to current labels: batch, seq_len, all_from_label?
236 | currentTagScores = torch.gather(all_scores, 3, tags.view(batchSize, sentLength, 1, 1).expand(batchSize, sentLength, self.label_size, 1)).view(batchSize, -1, self.label_size)
237 | if sentLength != 1:
238 | tagTransScoresMiddle = torch.gather(currentTagScores[:, 1:, :], 2, tags[:, : sentLength - 1].view(batchSize, sentLength - 1, 1)).view(batchSize, -1)
239 | tagTransScoresBegin = currentTagScores[:, 0, self.start_idx]
240 | endTagIds = torch.gather(tags, 1, word_seq_lens.view(batchSize, 1) - 1)
241 | tagTransScoresEnd = torch.gather(self.transition[:, self.end_idx].view(1, self.label_size).expand(batchSize, self.label_size), 1, endTagIds).view(batchSize)
242 | score = torch.sum(tagTransScoresBegin) + torch.sum(tagTransScoresEnd)
243 | if sentLength != 1:
244 | score += torch.sum(tagTransScoresMiddle.masked_select(masks[:, 1:]))
245 | return score
246 |
247 | def neg_log_obj(self, words, word_seq_lens, batch_context_emb, chars, char_seq_lens, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, tags, batch_dep_label, trees=None):
248 | features = self.neural_scoring(words, word_seq_lens, batch_context_emb, chars, char_seq_lens, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, batch_dep_label, trees)
249 |
250 | all_scores = self.calculate_all_scores(features)
251 |
252 | batch_size = words.size(0)
253 | sent_len = words.size(1)
254 |
255 | maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(self.device)
256 | mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device)
257 |
258 | unlabed_score = self.forward_unlabeled(all_scores, word_seq_lens, mask)
259 | labeled_score = self.forward_labeled(all_scores, word_seq_lens, tags, mask)
260 | return unlabed_score - labeled_score
261 |
262 |
263 | def viterbiDecode(self, all_scores, word_seq_lens):
264 | batchSize = all_scores.shape[0]
265 | sentLength = all_scores.shape[1]
266 | # sent_len =
267 | scoresRecord = torch.zeros([batchSize, sentLength, self.label_size]).to(self.device)
268 | idxRecord = torch.zeros([batchSize, sentLength, self.label_size], dtype=torch.int64).to(self.device)
269 | mask = torch.ones_like(word_seq_lens, dtype=torch.int64).to(self.device)
270 | startIds = torch.full((batchSize, self.label_size), self.start_idx, dtype=torch.int64).to(self.device)
271 | decodeIdx = torch.LongTensor(batchSize, sentLength).to(self.device)
272 |
273 | scores = all_scores
274 | # scoresRecord[:, 0, :] = self.getInitAlphaWithBatchSize(batchSize).view(batchSize, self.label_size)
275 | scoresRecord[:, 0, :] = scores[:, 0, self.start_idx, :] ## represent the best current score from the start, is the best
276 | idxRecord[:, 0, :] = startIds
277 | for wordIdx in range(1, sentLength):
278 | ### scoresIdx: batch x from_label x to_label at current index.
279 | scoresIdx = scoresRecord[:, wordIdx - 1, :].view(batchSize, self.label_size, 1).expand(batchSize, self.label_size,
280 | self.label_size) + scores[:, wordIdx, :, :]
281 | idxRecord[:, wordIdx, :] = torch.argmax(scoresIdx, 1) ## the best previous label idx to crrent labels
282 | scoresRecord[:, wordIdx, :] = torch.gather(scoresIdx, 1, idxRecord[:, wordIdx, :].view(batchSize, 1, self.label_size)).view(batchSize, self.label_size)
283 |
284 | lastScores = torch.gather(scoresRecord, 1, word_seq_lens.view(batchSize, 1, 1).expand(batchSize, 1, self.label_size) - 1).view(batchSize, self.label_size) ##select position
285 | lastScores += self.transition[:, self.end_idx].view(1, self.label_size).expand(batchSize, self.label_size)
286 | decodeIdx[:, 0] = torch.argmax(lastScores, 1)
287 | bestScores = torch.gather(lastScores, 1, decodeIdx[:, 0].view(batchSize, 1))
288 |
289 | for distance2Last in range(sentLength - 1):
290 | lastNIdxRecord = torch.gather(idxRecord, 1, torch.where(word_seq_lens - distance2Last - 1 > 0, word_seq_lens - distance2Last - 1, mask).view(batchSize, 1, 1).expand(batchSize, 1, self.label_size)).view(batchSize, self.label_size)
291 | decodeIdx[:, distance2Last + 1] = torch.gather(lastNIdxRecord, 1, decodeIdx[:, distance2Last].view(batchSize, 1)).view(batchSize)
292 |
293 | return bestScores, decodeIdx
294 |
295 | def decode(self, batchInput):
296 | wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, trees, tagSeqTensor, batch_dep_label = batchInput
297 | features = self.neural_scoring(wordSeqTensor, wordSeqLengths, batch_context_emb,charSeqTensor,charSeqLengths, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, batch_dep_label, trees)
298 | all_scores = self.calculate_all_scores(features)
299 | bestScores, decodeIdx = self.viterbiDecode(all_scores, wordSeqLengths)
300 | return bestScores, decodeIdx
301 |
--------------------------------------------------------------------------------
/preprocess/convert_sem_eng.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | ### This file is used to convert the semeval English into our conllx format
6 |
7 | def process(filename:str, out:str):
8 | fres = open(out, 'w', encoding='utf-8')
9 | print(filename)
10 | with open(filename, 'r', encoding='utf-8') as f:
11 | words = []
12 | heads = []
13 | deps =[]
14 | labels = []
15 | prev_label = "O"
16 | prev_raw_label = ""
17 | for line in f.readlines():
18 | line = line.rstrip()
19 | # print(line)
20 | if line.startswith("#"):
21 | prev_label = "O"
22 | prev_raw_label = ""
23 | continue
24 | if line == "":
25 | idx = 1
26 | for w, h, dep, label in zip(words, heads, deps, labels):
27 | if dep == "sentence":
28 | dep = "root"
29 | fres.write("{}\t{}\t_\t_\t_\t_\t{}\t{}\t_\t_\t{}\n".format(idx, w, h, dep, label))
30 | idx += 1
31 | fres.write('\n')
32 | words = []
33 | heads = []
34 | deps = []
35 | labels = []
36 | prev_label = "O"
37 | continue
38 | #1 West _ NNP NNP _ 5 compound _ _ B-MISC
39 | vals = line.split()
40 | idx = vals[0]
41 | word = vals[1]
42 | head = vals[8]
43 | dep_label = vals[10]
44 | label = vals[12]
45 |
46 | if label.startswith("("):
47 | if label.endswith(")"):
48 | label = "B-" + label[1:-1]
49 | else:
50 | label = "B-" + label[1:]
51 | elif label.startswith(")"):
52 | label = "I-" + label[:-1]
53 | else:
54 | if prev_label == "O":
55 | label = "O"
56 | else:
57 | if prev_raw_label.endswith(")"):
58 | label = "O"
59 | else:
60 | label = "I-" + prev_label[2:]
61 |
62 | words.append(word)
63 | heads.append(head)
64 | labels.append(label)
65 | deps.append(dep_label)
66 | prev_label = label
67 | prev_raw_label = vals[12]
68 | fres.close()
69 |
70 |
71 |
72 |
73 |
74 |
75 | # process("data/semeval10t1/en.train.txt", "data/semeval10t1/train.sd.conllx")
76 | # process("data/semeval10t1/en.devel.txt", "data/semeval10t1/dev.sd.conllx")
77 | # process("data/semeval10t1/en.test.txt", "data/semeval10t1/test.sd.conllx")
78 |
79 | lang = "it"
80 | folder="sem" + lang
81 | process("data/"+folder+"/"+lang+".train.txt", "data/"+folder+"/train.sd.conllx")
82 | process("data/"+folder+"/"+lang+".devel.txt", "data/"+folder+"/dev.sd.conllx")
83 | process("data/"+folder+"/"+lang+".test.txt", "data/"+folder+"/test.sd.conllx")
--------------------------------------------------------------------------------
/preprocess/convert_sem_other.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 | from typing import List
5 |
6 | type2num = {}
7 |
8 | def extract(words: List[str], labels:List[str], heads:List[str], deps:List[str]):
9 | entity_pool = [] ## type, left, right
10 | completed_pool = []
11 | print(words)
12 | # print(labels)
13 | for i, label in enumerate(labels):
14 | if label == "_":
15 | continue
16 | if "|" in label:
17 | vals = label.split("|")
18 | for val in vals:
19 | if val.startswith("(") and val.endswith(")"):
20 | completed_pool.append((i, i, val[1:-1]))
21 | elif val.startswith("("):
22 | entity_pool.append((i, -1, val[1:]))
23 | elif val.endswith(")"):
24 | found = False
25 | for tup in entity_pool[::-1]:
26 | start, end, cur = tup
27 | if cur == val[:-1]:
28 | completed_pool.append((start, i, cur))
29 | entity_pool.remove(tup)
30 | found = True
31 | break
32 | if not found:
33 | raise Exception("not found the entity:{}".format(val))
34 | else:
35 | raise Exception("not val type".format(val))
36 | else:
37 | if label.startswith("(") and label.endswith(")"):
38 | completed_pool.append((i, i, label[1:-1]))
39 | elif label.startswith("("):
40 | entity_pool.append((i, -1, label[1:]))
41 | elif label.endswith(")"):
42 | found = False
43 | for tup in entity_pool[::-1]:
44 | start, end, cur = tup
45 | if cur == label[:-1]:
46 | completed_pool.append((start, i, cur))
47 | entity_pool.remove(tup)
48 | found = True
49 | break
50 | if not found:
51 | raise Exception("not found the entity:{}".format(label))
52 | else:
53 | raise Exception("not val type {}".format(label))
54 | assert (len(entity_pool) == 0)
55 |
56 |
57 | for i in range(len(words)):
58 | curr_pos = []
59 | for span in completed_pool:
60 | start, end, label = span
61 | if i >= start and i <= end:
62 | curr_pos.append(span)
63 | curr_pos = sorted(curr_pos, key=lambda span: span[1] - span[0])
64 | for span in curr_pos[1:]:
65 | completed_pool.remove(span)
66 |
67 | labels = ["O"] * len(words)
68 | visited = [False] * len(words)
69 | for span in completed_pool:
70 | start, end, label = span
71 |
72 | for check in visited[start:(end+1)]:
73 | if check:
74 | raise Exception("this position is checked.")
75 |
76 | if label not in ('person', 'loc', 'org'):
77 | label = 'misc'
78 |
79 | labels[start] = "B-"+label
80 | labels[(start+1):end] = ["I-" + label] * (end - start)
81 | visited[start: (end+1)] = [True] * (end-start + 1)
82 |
83 | if label in type2num:
84 | type2num[label] += 1
85 | else:
86 | type2num[label] = 1
87 |
88 | # print(labels)
89 | return labels
90 |
91 |
92 | def read_all_sents(filename:str, out:str, use_gold_dep: bool = True):
93 | print(filename)
94 | fres = open(out, 'w', encoding='utf-8')
95 | sents = []
96 | with open(filename, 'r', encoding='utf-8') as f:
97 | words = []
98 | heads = []
99 | deps = []
100 | labels = []
101 | pos_tags = []
102 | for line in f.readlines():
103 | line = line.rstrip()
104 | # print(line)
105 | if line.startswith("#"):
106 | continue
107 | if line == "":
108 | idx = 1
109 | labels = extract(words, labels, heads, deps)
110 | idx = 1
111 | for w, h, dep, label, pos_tag in zip(words, heads, deps, labels, pos_tags):
112 | if dep == "sentence":
113 | dep = "root"
114 | fres.write("{}\t{}\t_\t{}\t{}\t_\t{}\t{}\t_\t_\t{}\n".format(idx, w, pos_tag, pos_tag, h, dep, label))
115 | idx += 1
116 | fres.write('\n')
117 |
118 | words = []
119 | heads = []
120 | deps = []
121 | labels = []
122 | continue
123 | # 1 West _ NNP NNP _ 5 compound _ _ B-MISC
124 | vals = line.split()
125 | idx = vals[0]
126 | word = vals[1]
127 | pos_tag = vals[4]
128 | head = vals[8] if use_gold_dep else vals[9]
129 | dep_label = vals[10] if use_gold_dep else vals[11]
130 | label = vals[12]
131 | words.append(word)
132 | pos_tags.append(pos_tag)
133 | heads.append(head)
134 | labels.append(label)
135 | deps.append(dep_label)
136 | fres.close()
137 |
138 | def process(filename:str, out:str):
139 | fres = open(out, 'w', encoding='utf-8')
140 | print(filename)
141 | with open(filename, 'r', encoding='utf-8') as f:
142 | words = []
143 | heads = []
144 | deps =[]
145 | labels = []
146 | prev_label = "O"
147 | prev_raw_label = ""
148 | for line in f.readlines():
149 | line = line.rstrip()
150 | # print(line)
151 | if line.startswith("#"):
152 | prev_label = "O"
153 | prev_raw_label = ""
154 | continue
155 | if line == "":
156 | idx = 1
157 | for w, h, dep, label in zip(words, heads, deps, labels):
158 | if dep == "sentence":
159 | dep = "root"
160 | fres.write("{}\t{}\t_\t_\t_\t_\t{}\t{}\t_\t_\t{}\n".format(idx, w, h, dep, label))
161 | idx += 1
162 | fres.write('\n')
163 | words = []
164 | heads = []
165 | deps = []
166 | labels = []
167 | prev_label = "O"
168 | continue
169 | #1 West _ NNP NNP _ 5 compound _ _ B-MISC
170 | vals = line.split()
171 | idx = vals[0]
172 | word = vals[1]
173 | head = vals[8]
174 | dep_label = vals[10]
175 | label = vals[12]
176 |
177 | if label.startswith("("):
178 | if label.endswith(")"):
179 | label = "B-" + label[1:-1]
180 | else:
181 | label = "B-" + label[1:]
182 | elif label.startswith(")"):
183 | label = "I-" + label[:-1]
184 | else:
185 | if prev_label == "O":
186 | label = "O"
187 | else:
188 | if prev_raw_label.endswith(")"):
189 | label = "O"
190 | else:
191 | label = "I-" + prev_label[2:]
192 |
193 | words.append(word)
194 | heads.append(head)
195 | labels.append(label)
196 | deps.append(dep_label)
197 | prev_label = label
198 | prev_raw_label = vals[12]
199 | fres.close()
200 |
201 |
202 |
203 |
204 | ### This file is used to convert the semeval Catalan and Spanish into our conllx format
205 |
206 | lang = "ca"
207 | folder="sem" + lang
208 | use_gold_dep = False
209 | affix = "sd" if use_gold_dep else "sud"
210 | read_all_sents("data/"+folder+"/"+lang+".train.txt", "data/"+folder+"/train."+affix+".conllx", use_gold_dep)
211 | print(type2num)
212 | type2num = {}
213 | read_all_sents("data/"+folder+"/"+lang+".devel.txt", "data/"+folder+"/dev."+affix+".conllx", use_gold_dep)
214 |
215 | print(type2num)
216 | type2num = {}
217 | read_all_sents("data/"+folder+"/"+lang+".test.txt", "data/"+folder+"/test."+affix+".conllx", use_gold_dep)
218 |
219 | print(type2num)
--------------------------------------------------------------------------------
/preprocess/elmo_others.py:
--------------------------------------------------------------------------------
1 | from elmoformanylangs import Embedder
2 | import pickle
3 |
4 | """
5 | This file should be deprecated since every time result is different.
6 |
7 | """
8 |
9 |
10 | def read_conllx(filename:str):
11 | print(filename)
12 | sents = []
13 | with open(filename, 'r', encoding='utf-8') as f:
14 | words = []
15 | for line in f.readlines():
16 | line = line.rstrip()
17 | if line == "":
18 | if len(words) == 0:
19 | print("len is 0")
20 | sents.append(words)
21 | words = []
22 | continue
23 | vals = line.split()
24 | words.append(vals[1])
25 | return sents
26 |
27 |
28 | def context_emb(emb, sents):
29 | ## 0, word encoder:
30 | ##1 for the first LSTM hidden layer
31 | ## 2 for the second LSTM hidden lyaer
32 | ## -1 for an average of 3 layers (default)
33 | ## -2 for all 3 layers
34 | return emb.sents2elmo(sents, -1)
35 |
36 |
37 | def read_parse_write(elmo, in_file, out_file):
38 | sents = read_conllx(in_file)
39 | print("number of sentences: {} in {}".format(len(sents), in_file))
40 | f = open(out_file, 'wb')
41 | batch_size = 1
42 | all_vecs = []
43 | for idx in range(0, len(sents), batch_size):
44 | start = idx*batch_size
45 | end = (idx+1)*batch_size if (idx+1)*batch_size < len(sents) else len(sents)
46 | batch_sents = sents[start: end]
47 | #print(batch_sents)
48 | embs = context_emb(elmo, batch_sents)
49 | for emb in embs:
50 | all_vecs.append(emb)
51 | pickle.dump(all_vecs, f)
52 | f.close()
53 |
54 |
55 | ## NOTE: Remember to download the model and change the path here
56 | elmo = Embedder('/data/allan/embeddings/Spanish_ELMo', batch_size=1)
57 |
58 | dataset = "spanish"
59 | read_parse_write(elmo, f"data/{dataset}/train.sd.conllx", f"data/{dataset}/train.conllx.elmo.vec")
60 | read_parse_write(elmo, f"data/{dataset}/dev.sd.conllx", f"data/{dataset}/dev.conllx.elmo.vec")
61 | read_parse_write(elmo, f"data/{dataset}/test.sd.conllx", f"data/{dataset}/test.conllx.elmo.vec")
62 |
63 |
--------------------------------------------------------------------------------
/preprocess/prebert.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | from config.reader import Reader
6 | import numpy as np
7 | import pickle
8 | import torch
9 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
10 |
11 | # # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
12 | # import logging
13 | # logging.basicConfig(level=logging.INFO)
14 |
15 |
16 | def parse_sentence(tokenizer, model, words, mode:str="average"):
17 | model.eval()
18 | indexed_tokens = tokenizer.convert_tokens_to_ids(words)
19 | segments_ids = [0] * len(indexed_tokens)
20 | tokens_tensor = torch.LongTensor([indexed_tokens]).to(device)
21 | segments_tensors = torch.LongTensor([segments_ids]).to(device)
22 | with torch.no_grad():
23 | encoded_layers, _ = model(tokens_tensor, segments_tensors)
24 | return encoded_layers
25 |
26 | def read_parse_write(tokenizer, model, infile, outfile, mode):
27 | reader = Reader()
28 | insts = reader.read_conll(infile, -1, True)
29 | f = open(outfile, 'wb')
30 | all_vecs = []
31 | for inst in insts:
32 | vec = parse_sentence(tokenizer, model, inst.input.words, mode=mode)
33 | all_vecs.append(vec)
34 | pickle.dump(all_vecs, f)
35 | f.close()
36 |
37 |
38 | def load_bert():
39 | # Load pre-trained model tokenizer (vocabulary)
40 | tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
41 | model = BertModel.from_pretrained('bert-base-cased')
42 | model.eval()
43 | model.to(device)
44 | return tokenizer, model
45 |
46 |
47 | device = torch.device('cuda:0')
48 | tokenizer, bert_model = load_bert()
49 | mode= "average"
50 | dataset="conll2003"
51 | dep = ""
52 | file = "../data/"+dataset+"/train"+dep+".conllx"
53 | outfile = file + ".bert."+mode+".vec"
54 | read_parse_write(tokenizer, bert_model, file, outfile, mode)
55 | file = "../data/"+dataset+"/dev"+dep+".conllx"
56 | outfile = file + ".bert."+mode+".vec"
57 | read_parse_write(tokenizer, bert_model, file, outfile, mode)
58 | file = "../data/"+dataset+"/test"+dep+".conllx"
59 | outfile = file + ".bert."+mode+".vec"
60 | read_parse_write(tokenizer, bert_model, file, outfile, mode)
--------------------------------------------------------------------------------
/preprocess/preelmo.py:
--------------------------------------------------------------------------------
1 | #
2 | # @author: Allan
3 | #
4 |
5 | from config.reader import Reader
6 | import numpy as np
7 | from allennlp.commands.elmo import ElmoEmbedder
8 | import pickle
9 |
10 |
11 | def parse_sentence(elmo, words, mode:str="average"):
12 | vectors = elmo.embed_sentence(words)
13 | if mode == "average":
14 | return np.average(vectors, 0)
15 | elif mode == 'weighted_average':
16 | return np.swapaxes(vectors, 0, 1)
17 | elif mode == 'last':
18 | return vectors[-1, :, :]
19 | elif mode == 'all':
20 | return vectors
21 | else:
22 | return vectors
23 |
24 |
25 | def load_elmo():
26 | return ElmoEmbedder(cuda_device=0)
27 |
28 |
29 |
30 | def read_parse_write(elmo, infile, outfile, mode):
31 | reader = Reader()
32 | insts = reader.read_conll(infile, -1, True)
33 | f = open(outfile, 'wb')
34 | all_vecs = []
35 | for inst in insts:
36 | vec = parse_sentence(elmo, inst.input.words, mode=mode)
37 | all_vecs.append(vec)
38 | pickle.dump(all_vecs, f)
39 | f.close()
40 |
41 |
42 | elmo = load_elmo()
43 | mode= "average"
44 | dataset="ontonotes"
45 | dep = ""
46 | file = "../data/"+dataset+"/train"+dep+".conllx"
47 | outfile = file + ".elmo."+mode+".vec"
48 | read_parse_write(elmo, file, outfile, mode)
49 | file = "../data/"+dataset+"/dev"+dep+".conllx"
50 | outfile = file + ".elmo."+mode+".vec"
51 | read_parse_write(elmo, file, outfile, mode)
52 | file = "../data/"+dataset+"/test"+dep+".conllx"
53 | outfile = file + ".elmo."+mode+".vec"
54 | read_parse_write(elmo, file, outfile, mode)
55 |
56 |
--------------------------------------------------------------------------------
/preprocess/preflair.py:
--------------------------------------------------------------------------------
1 | from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings, BertEmbeddings, PooledFlairEmbeddings
2 | import pickle
3 | from config.reader import Reader
4 | import numpy as np
5 | from flair.data import Sentence
6 |
7 | def load_flair(mode = 'flair'):
8 | if mode == 'flair':
9 | stacked_embeddings = StackedEmbeddings([
10 | WordEmbeddings('glove'),
11 | PooledFlairEmbeddings('news-forward', pooling='min'),
12 | PooledFlairEmbeddings('news-backward', pooling='min')
13 | ])
14 | else:##bert
15 | stacked_embeddings = BertEmbeddings('bert-base-uncased') ##concat last 4 layers give the best
16 | return stacked_embeddings
17 |
18 | def embed_sent(embeder, sent):
19 | sent = Sentence(' '.join(sent))
20 | embeder.embed(sent)
21 | return sent
22 |
23 |
24 | def read_parse_write(elmo, infile, outfile,):
25 | reader = Reader()
26 | insts = reader.read_conll(infile, -1, True)
27 | f = open(outfile, 'wb')
28 | all_vecs = []
29 | for inst in insts:
30 | sent = embed_sent(elmo, inst.input.words)
31 | # np.empty((len(sent)),dtype=np.float32)
32 | arr = []
33 | for token in sent:
34 | # print(token)
35 | # print(token.embedding)
36 | arr.append(np.expand_dims(token.embedding.numpy(), axis=0))
37 | # all_vecs.append(vec)
38 | all_vecs.append(np.concatenate(arr))
39 | pickle.dump(all_vecs, f)
40 | f.close()
41 |
42 |
43 | mode = 'flair'
44 | model = load_flair(mode=mode)
45 | # mode= "average"
46 | dataset="conll2003"
47 | dep = ".sd"
48 | file = "./data/"+dataset+"/train"+dep+".conllx"
49 | outfile = file.replace(".sd", "") + "."+mode+".vec"
50 | read_parse_write(model, file, outfile)
51 | file = "./data/"+dataset+"/dev"+dep+".conllx"
52 | outfile = file.replace(".sd", "") + "."+mode+".vec"
53 | read_parse_write(model, file, outfile)
54 | file = "./data/"+dataset+"/test"+dep+".conllx"
55 | outfile = file.replace(".sd", "") + "."+mode+".vec"
56 | read_parse_write(model, file, outfile)
57 |
--------------------------------------------------------------------------------
/scripts/run.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 |
6 | datasets=(ontonotes ontonotes_chinese catalan spanish)
7 | context_emb=elmo
8 | num_epochs_all=(100 100 300 300)
9 | devices=(cuda:0 cuda:1 cuda:2 cuda:3) ##cpu, cuda:0, cuda:1
10 | dep_model=dglstm ## none, dglstm, dggcn means do not use head features
11 | embs=(data/glove.6B.100d.txt data/cc.zh.300.vec data/cc.ca.300.vec data/cc.es.300.vec)
12 | num_lstm_layer=2
13 | inter_func=mlp
14 |
15 | for (( d=0; d<${#datasets[@]}; d++ )) do
16 | dataset=${datasets[$d]}
17 | emb=${embs[$d]}
18 | device=${devices[$d]}
19 | num_epochs=${num_epochs_all[$d]}
20 | first_part=logs/hidden_${num_lstm_layer}_${dataset}_${dep_model}_asfeat_${context_emb}
21 | logfile=${first_part}_epoch_${num_epochs}_if_${inter_func}.log
22 | python3.6 main.py --context_emb ${context_emb} \
23 | --dataset ${dataset} --num_epochs ${num_epochs} --device ${device} --num_lstm_layer ${num_lstm_layer} \
24 | --dep_model ${dep_model} \
25 | --embedding_file ${emb} --inter_func ${inter_func} > ${logfile} 2>&1
26 |
27 | done
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/scripts/run_pytorch.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #autobatch=1
4 | #--dynet-autobatch
5 | #optimizer=adam
6 | #lr=1
7 | #batch=1
8 | #gpu=1
9 |
10 | #datasets=(ontonotes_chinese)
11 | datasets=(conll2003)
12 | #datasets=(bc bn mz nw tc wb)
13 | #heads=(1) ##1 means use GCN embedding.
14 | #datasets=(all)
15 | context_emb=elmo
16 | hidden=200
17 | optim=sgd
18 | batch=10
19 | num_epochs=100
20 | eval_freq=10000
21 | device=cuda:1 ##cpu, cuda:0, cuda:1
22 | gcn_layer=1
23 | gcn_dropout=0.5
24 | gcn_mlp_layers=1
25 | dep_model=dglstm ## none, dglstm, dggcn means do not use head features
26 | dep_hidden_dim=200
27 | affix=ssd
28 | gcn_adj_directed=0 ##bidirection
29 | gcn_adj_selfloop=0 ## keep to zero because we always add self loop in gcn
30 | embs=(data/glove.6B.100d.txt)
31 | #emb=data/cc.zh.300.vec
32 | lr=0.01
33 | gcn_gate=0 ##without gcn gate
34 | num_base=-1 ## number of bases in relational gcn
35 | num_lstm_layer=2
36 | dep_double_label=0
37 | inter_func=mlp
38 |
39 | for (( d=0; d<${#datasets[@]}; d++ )) do
40 | dataset=${datasets[$d]}
41 | emb=${embs[$d]}
42 | first_part=logs/hidden_${num_lstm_layer}_${hidden}_${dataset}_${affix}_head_${dep_model}_asfeat_${context_emb}_gcn_${gcn_layer}_${gcn_mlp_layers}_${gcn_dropout}_gate_${gcn_gate}
43 | logfile=${first_part}_dir_${gcn_adj_directed}_loop_${gcn_adj_selfloop}_base_${num_base}_epoch_${num_epochs}_lr_${lr}_dd_${dep_double_label}_if_${inter_func}.log
44 | python3.6 main.py --context_emb ${context_emb} --hidden_dim ${hidden} --optimizer ${optim} --gcn_adj_directed ${gcn_adj_directed} --gcn_adj_selfloop ${gcn_adj_selfloop} \
45 | --dataset ${dataset} --eval_freq ${eval_freq} --num_epochs ${num_epochs} --device ${device} --dep_hidden_dim ${dep_hidden_dim} --num_lstm_layer ${num_lstm_layer} \
46 | --batch_size ${batch} --num_gcn_layers ${gcn_layer} --gcn_mlp_layers ${gcn_mlp_layers} --dep_model ${dep_model} --gcn_gate ${gcn_gate} --dep_double_label ${dep_double_label} \
47 | --gcn_dropout ${gcn_dropout} --affix ${affix} --lr_decay 0 --learning_rate ${lr} --embedding_file ${emb} --inter_func ${inter_func} \
48 | --num_base ${num_base} > ${logfile} 2>&1
49 |
50 | done
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/scripts/run_pytorch_all.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #autobatch=1
4 | #--dynet-autobatch
5 | #optimizer=adam
6 | #lr=1
7 | #batch=1
8 | #gpu=1
9 |
10 | #datasets=(ontonotes_chinese)
11 | datasets=(ontonotes ontonotes_chinese catalan spanish)
12 | #datasets=(bc bn mz nw tc wb)
13 | #heads=(1) ##1 means use GCN embedding.
14 | #datasets=(all)
15 | context_emb=elmo
16 | hidden=200
17 | optim=sgd
18 | batch=10
19 | num_epochs_all=(100 100 300 300)
20 | devices=(cuda:0 cuda:1 cuda:2 cuda:3) ##cpu, cuda:0, cuda:1
21 | gcn_layer=1
22 | gcn_dropout=0.5
23 | gcn_mlp_layers=1
24 | dep_model=dglstm ## none, dglstm, dggcn means do not use head features
25 | dep_hidden_dim=200
26 | affix=sd
27 | gcn_adj_directed=0 ##bidirection
28 | gcn_adj_selfloop=0 ## keep to zero because we always add self loop in gcn
29 | embs=(data/glove.6B.100d.txt data/cc.zh.300.vec data/cc.ca.300.vec data/cc.es.300.vec)
30 | #emb=data/cc.zh.300.vec
31 | lr=0.01
32 | gcn_gate=0 ##without gcn gate
33 | num_base=-1 ## number of bases in relational gcn
34 | num_lstm_layer=2
35 | dep_double_label=0
36 | inter_func=mlp
37 |
38 | for (( d=0; d<${#datasets[@]}; d++ )) do
39 | dataset=${datasets[$d]}
40 | emb=${embs[$d]}
41 | device=${devices[$d]}
42 | num_epochs=${num_epochs_all[$d]}
43 | first_part=logs/hidden_${num_lstm_layer}_${hidden}_${dataset}_${affix}_head_${dep_model}_asfeat_${context_emb}_gcn_${gcn_layer}_${gcn_mlp_layers}_${gcn_dropout}_gate_${gcn_gate}
44 | logfile=${first_part}_dir_${gcn_adj_directed}_loop_${gcn_adj_selfloop}_base_${num_base}_epoch_${num_epochs}_lr_${lr}_dd_${dep_double_label}_if_${inter_func}.log
45 | python3.6 main.py --context_emb ${context_emb} --hidden_dim ${hidden} --optimizer ${optim} --gcn_adj_directed ${gcn_adj_directed} --gcn_adj_selfloop ${gcn_adj_selfloop} \
46 | --dataset ${dataset} --num_epochs ${num_epochs} --device ${device} --dep_hidden_dim ${dep_hidden_dim} --num_lstm_layer ${num_lstm_layer} \
47 | --batch_size ${batch} --num_gcn_layers ${gcn_layer} --gcn_mlp_layers ${gcn_mlp_layers} --dep_model ${dep_model} --gcn_gate ${gcn_gate} --dep_double_label ${dep_double_label} \
48 | --gcn_dropout ${gcn_dropout} --affix ${affix} --lr_decay 0 --learning_rate ${lr} --embedding_file ${emb} --inter_func ${inter_func} \
49 | --num_base ${num_base} > ${logfile} 2>&1
50 |
51 | done
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------