├── LICENSE
├── README.md
└── 决赛提交
├── Dockerfile
├── README.md
└── sohu_matching
├── data
├── dummy_bert
│ ├── config.json
│ └── vocab.txt
├── dummy_ernie
│ ├── config.json
│ └── vocab.txt
└── dummy_nezha
│ ├── config.json
│ └── vocab.txt
├── results
└── rematch
│ └── merge_final.csv
└── src
├── NEZHA
├── __pycache__
│ ├── model_nezha.cpython-36.pyc
│ └── nezha_utils.cpython-36.pyc
├── model_nezha.py
└── nezha_utils.py
├── config.py
├── data.py
├── infer.py
├── infer_final.py
├── merge_result.py
├── model.py
├── search_better_merge.py
├── train.py
├── train_old.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # sohu_text_matching
2 | 2021搜狐校园文本匹配算法大赛Top2:分比我们低的都是帅哥队
3 |
4 | 本repo包含了本次大赛决赛环节提交的代码文件,提交的模型文件可在百度网盘获取(链接:https://pan.baidu.com/s/1T9FtwiGFZhuC8qqwXKZSNA ,提取码:2333 )。
5 |
6 | 最终提交的5个模型(限制大小在2G内)在复赛测试集上的f1指标为0.78921,在决赛测试集上的f1指标为0.78123,在十组队伍中位列第二,最终取得亚军成绩。
7 |
8 | 复现复赛测试集结果,可将模型下载后放至`checkpoints/rematch`内,将测试集合并为决赛格式后,进入`src`文件夹运行`infer_final.py`并指定输入文件及输出位置即可。依赖项可参考dockerfile。
9 |
--------------------------------------------------------------------------------
/决赛提交/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel
2 |
3 | # prepare your environment here
4 | ENV LANG "en_US.UTF-8"
5 | COPY . /app
6 | WORKDIR /app/sohu_matching/src
7 |
8 | # RUN pip install ...
9 | RUN pip install transformers && pip install pandas && pip install scikit-learn
10 |
11 | ENTRYPOINT ["python","infer_final.py"]
--------------------------------------------------------------------------------
/决赛提交/README.md:
--------------------------------------------------------------------------------
1 | ### sohu_matching
2 |
3 | #### 小组:**分比我们低的都是帅哥**
4 |
5 | #### 决赛Docker运行说明
6 |
7 | 本项目的Docker构建过程符合提交指南要求,运行官方给出的测试命令即可进行推断:
8 |
9 | ```bash
10 | docker run --rm -it --gpus all \
11 | -v ${TestInputDir}:/data/input \
12 | -v ${TestOutputDir}:/data/output \
13 | ${MyImageName} \
14 | --input /data/input/test.txt \
15 | --output /data/output/pred.csv
16 | ```
17 |
18 | 基本镜像为`pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel`,构建时通过`pip install transformers && pip install pandas && pip install scikit-learn`安装依赖包即可运行。容器的WORKDIR设定为`/app/sohu_matching/src`(由于代码采取相对路径,不在此目录运行会报错;如果测试命令运行出错,可进入该目录直接`python infer_final.py`并指定输入与输出文件位置)。镜像大小约10G,其中`sohu_matching/checkpoints/rematch`存放我们的模型,总大小在2G以内,符合比赛要求。
19 |
20 | #### 简介
21 |
22 | 本项目包含了我们在2021搜狐校园文本匹配**复赛环节**的PyTorch版本代码,在复赛Public排行榜上排名第三,线上测评的F1分数为0.791075301658579,其中A类任务0.8548399419359769,B类任务0.727310661381181。
23 |
24 | 我们采用了联合训练的方式,在A、B两个任务上采用一个共同的基于预训练语言模型的encoder,而后分别为各个任务采用多组简单的全连接结构作为classifier。我们使用了不同的预训练模型(如NEZHA、MacBert、ROBERTA、ERNIE等),设计了选择了两种文本匹配的技术路线(通过[SEP]拼接source与target作为输入、类似SBERT的句子向量编码进行比较),并尝试了多种上分策略(如在给定语料上继续mlm预训练、focal loss损失函数、不同的pooling策略、加入TextCNN、fgm对抗训练、数据增强等)。我们选取了多组差异较大的模型的输出,通过投票的方式进行集成,得到最好成绩。
25 |
26 | #### 项目结构
27 |
28 | ```bash
29 | │ README.md # README
30 | │ test.yaml # conda环境配置
31 | │ # 基本上安装pytorch>=1.6和transformer即可复现
32 | ├─checkpoints # 用于保存模型
33 | ├─data
34 | │ └─dummy_bert # 包含BERT\ERNIE\NEZHA的分词词表及config.json
35 | │ └─dummy_ernie # 用于模型推断时从config文件定义模型,不加载原预训练权重
36 | │ └─dummy_nezha
37 | │ └─sohu2021_open_data # 包含初赛及复赛的训练、评估和测试数据
38 | │ ├─短短匹配A类 # 包括train.txt, train_r2.txt, train_r3.txt, train_rematch.txt
39 | │ ├─短短匹配B类 # valid.txt, valid_rematch.txt, test_with_id_rematch.txt
40 | │ ├─短长匹配A类
41 | │ ├─短长匹配B类
42 | │ ├─长长匹配A类
43 | │ └─长长匹配B类
44 | ├─logs # 用于保存日志,例:python train.py > log_dir
45 | ├─results # 用于保存测试集推理结果
46 | ├─valid_output # 记录模型在valid上的输出,并计算各类f1
47 | └─src # 主要代码文件夹
48 | │ config.py # 模型与训练等参数统一通过config.py设置
49 | │ data.py # 数据读取,DataLoader等
50 | │ infer.py # 测试集推理代码
51 | │ merge_result.py # 用于投票集成
52 | │ model.py # 模型定义
53 | │ search_better_merge.py # 在验证集输出上寻找最优投票组合
54 | │ train.py # 训练代码,支持多任务形式(更改model中的num_task)
55 | │ train_old.py # 训练代码,仅支持A\B两任务,复赛中主要使用该方式训练模型
56 | │ utils.py # 其他函数等
57 | │
58 | ├─new_runs # tensorboard事件目录,用于可视化损失函数等指标
59 | ├─NEZHA # nezha相关的模型结构定义等
60 | │ │ model_nezha.py
61 | │ │ nezha_utils.py
62 | └─__pycache__
63 | ```
64 |
65 | #### 运行示例
66 | (备注:决赛提交中针对A\B类测试样本在同一个文件中的情况略微修改了`data.py`,直接运行`train_old.py`可能会有错误)
67 | 补充训练数据后,在`config.py`文件中设置训练相关参数,进入到src文件夹下,运行`train_old.py`进行训练(在复赛中,我们尝试了为6个子任务分别设置分类网络的形式,统一在`train.py`中,但对于A\B两任务的情况,初赛训练代码方式效果似乎更加,因此我们在`train_old.py`中保留了原方式,并作为主要训练代码;默认多卡训练,在`train_old.py`调整设备卡数),可通过重定向将输出保存为日志。训练结束后,在`config.py`中设置推理相关参数,进入到src文件夹下,运行`infer.py`进行推理(默认多卡推理,在`infer.py`调整设备卡数)。
68 |
69 | ```bash
70 | python train_old.py > ../logs/0523/0523_roberta_80k.log # 训练并保存输出日志
71 | python infer.py # 推理
72 | ```
73 |
74 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/data/dummy_bert/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "directionality": "bidi",
7 | "hidden_act": "gelu",
8 | "hidden_dropout_prob": 0.1,
9 | "hidden_size": 768,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 3072,
12 | "layer_norm_eps": 1e-12,
13 | "max_position_embeddings": 512,
14 | "model_type": "bert",
15 | "num_attention_heads": 12,
16 | "num_hidden_layers": 12,
17 | "output_past": true,
18 | "pad_token_id": 0,
19 | "pooler_fc_size": 768,
20 | "pooler_num_attention_heads": 12,
21 | "pooler_num_fc_layers": 3,
22 | "pooler_size_per_head": 128,
23 | "pooler_type": "first_token_transform",
24 | "type_vocab_size": 2,
25 | "vocab_size": 21128
26 | }
27 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/data/dummy_ernie/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "gradient_checkpointing": false,
4 | "hidden_act": "relu",
5 | "hidden_dropout_prob": 0.1,
6 | "hidden_size": 768,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "layer_norm_eps": 1e-05,
10 | "max_position_embeddings": 513,
11 | "model_type": "bert",
12 | "num_attention_heads": 12,
13 | "num_hidden_layers": 12,
14 | "pad_token_id": 0,
15 | "type_vocab_size": 2,
16 | "vocab_size": 18000
17 | }
18 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/data/dummy_nezha/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 768,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 3072,
8 | "max_position_embeddings": 512,
9 | "num_attention_heads": 12,
10 | "num_hidden_layers": 12,
11 | "type_vocab_size": 2,
12 | "vocab_size": 21128,
13 | "use_relative_position": true,
14 | "model_type": "bert"
15 | }
16 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/NEZHA/__pycache__/model_nezha.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Decem-Y/sohu_text_matching_Rank2/4d87c85b6de65fda777b15f1d7e37af74a3033b9/决赛提交/sohu_matching/src/NEZHA/__pycache__/model_nezha.cpython-36.pyc
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/NEZHA/__pycache__/nezha_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Decem-Y/sohu_text_matching_Rank2/4d87c85b6de65fda777b15f1d7e37af74a3033b9/决赛提交/sohu_matching/src/NEZHA/__pycache__/nezha_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/NEZHA/model_nezha.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team and Huawei Noah's Ark Lab.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import copy
20 | import json
21 | import logging
22 | import math
23 | import sys
24 | from io import open
25 |
26 | import numpy as np
27 |
28 | import torch
29 | from torch import nn
30 | from torch.nn import CrossEntropyLoss
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | def gelu(x):
36 | """Implementation of the gelu activation function.
37 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
38 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
39 | Also see https://arxiv.org/abs/1606.08415
40 | """
41 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
42 |
43 |
44 | def swish(x):
45 | return x * torch.sigmoid(x)
46 |
47 |
48 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
49 |
50 |
51 | class NezhaConfig(object):
52 | """Configuration class to store the configuration of a `BertModel`.
53 | """
54 |
55 | def __init__(self,
56 | vocab_size_or_config_json_file,
57 | hidden_size=768,
58 | num_hidden_layers=12,
59 | num_attention_heads=12,
60 | intermediate_size=3072,
61 | hidden_act="gelu",
62 | hidden_dropout_prob=0.1,
63 | attention_probs_dropout_prob=0.1,
64 | max_position_embeddings=512,
65 | max_relative_position=64,
66 | type_vocab_size=2,
67 | initializer_range=0.02,
68 | layer_norm_eps=1e-12):
69 | """Constructs NezhaConfig.
70 |
71 | Args:
72 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
73 | hidden_size: Size of the encoder layers and the pooler layer.
74 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
75 | num_attention_heads: Number of attention heads for each attention layer in
76 | the Transformer encoder.
77 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
78 | layer in the Transformer encoder.
79 | hidden_act: The non-linear activation function (function or string) in the
80 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
81 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
82 | layers in the embeddings, encoder, and pooler.
83 | attention_probs_dropout_prob: The dropout ratio for the attention
84 | probabilities.
85 | max_position_embeddings: The maximum sequence length that this model might
86 | ever be used with. Typically set this to something large just in case
87 | (e.g., 512 or 1024 or 2048).
88 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
89 | `BertModel`.
90 | initializer_range: The sttdev of the truncated_normal_initializer for
91 | initializing all weight matrices.
92 | layer_norm_eps: The epsilon used by LayerNorm.
93 | """
94 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
95 | and isinstance(vocab_size_or_config_json_file, unicode)):
96 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
97 | json_config = json.loads(reader.read())
98 | for key, value in json_config.items():
99 | self.__dict__[key] = value
100 | elif isinstance(vocab_size_or_config_json_file, int):
101 | self.vocab_size = vocab_size_or_config_json_file
102 | self.hidden_size = hidden_size
103 | self.num_hidden_layers = num_hidden_layers
104 | self.num_attention_heads = num_attention_heads
105 | self.hidden_act = hidden_act
106 | self.intermediate_size = intermediate_size
107 | self.hidden_dropout_prob = hidden_dropout_prob
108 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
109 | self.max_position_embeddings = max_position_embeddings
110 | self.max_relative_position = max_relative_position
111 | self.type_vocab_size = type_vocab_size
112 | self.initializer_range = initializer_range
113 | self.layer_norm_eps = layer_norm_eps
114 | else:
115 | raise ValueError("First argument must be either a vocabulary size (int)"
116 | "or the path to a pretrained model config file (str)")
117 |
118 | @classmethod
119 | def from_dict(cls, json_object):
120 | """Constructs a `NezhaConfig` from a Python dictionary of parameters."""
121 | config = NezhaConfig(vocab_size_or_config_json_file=-1)
122 | for key, value in json_object.items():
123 | config.__dict__[key] = value
124 | return config
125 |
126 | @classmethod
127 | def from_json_file(cls, json_file):
128 | """Constructs a `NezhaConfig` from a json file of parameters."""
129 | with open(json_file, "r", encoding='utf-8') as reader:
130 | text = reader.read()
131 | return cls.from_dict(json.loads(text))
132 |
133 | def __repr__(self):
134 | return str(self.to_json_string())
135 |
136 | def to_dict(self):
137 | """Serializes this instance to a Python dictionary."""
138 | output = copy.deepcopy(self.__dict__)
139 | return output
140 |
141 | def to_json_string(self):
142 | """Serializes this instance to a JSON string."""
143 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
144 |
145 | def to_json_file(self, json_file_path):
146 | """ Save this instance to a json file."""
147 | with open(json_file_path, "w", encoding='utf-8') as writer:
148 | writer.write(self.to_json_string())
149 |
150 |
151 | try:
152 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
153 | except ImportError:
154 | logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
155 |
156 |
157 | class BertLayerNorm(nn.Module):
158 | def __init__(self, hidden_size, eps=1e-12):
159 | """Construct a layernorm module in the TF style (epsilon inside the square root).
160 | """
161 | super(BertLayerNorm, self).__init__()
162 | self.weight = nn.Parameter(torch.ones(hidden_size))
163 | self.bias = nn.Parameter(torch.zeros(hidden_size))
164 | self.variance_epsilon = eps
165 |
166 | def forward(self, x):
167 | u = x.mean(-1, keepdim=True)
168 | s = (x - u).pow(2).mean(-1, keepdim=True)
169 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
170 | return self.weight * x + self.bias
171 |
172 |
173 | class BertEmbeddings(nn.Module):
174 | """Construct the embeddings from word, position and token_type embeddings.
175 | """
176 |
177 | def __init__(self, config):
178 | super(BertEmbeddings, self).__init__()
179 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
180 | try:
181 | self.use_relative_position = config.use_relative_position
182 | except:
183 | self.use_relative_position = False
184 | if not self.use_relative_position:
185 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
186 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
187 |
188 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
189 | # any TensorFlow checkpoint file
190 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
191 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
192 |
193 | def forward(self, input_ids, token_type_ids=None):
194 | seq_length = input_ids.size(1)
195 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
196 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
197 | if token_type_ids is None:
198 | token_type_ids = torch.zeros_like(input_ids)
199 |
200 | words_embeddings = self.word_embeddings(input_ids)
201 | embeddings = words_embeddings
202 | if not self.use_relative_position:
203 | position_embeddings = self.position_embeddings(position_ids)
204 | embeddings += position_embeddings
205 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
206 | embeddings += token_type_embeddings
207 | embeddings = self.LayerNorm(embeddings)
208 | embeddings = self.dropout(embeddings)
209 | return embeddings
210 |
211 |
212 | class BertSelfAttention(nn.Module):
213 | def __init__(self, config):
214 | super(BertSelfAttention, self).__init__()
215 | if config.hidden_size % config.num_attention_heads != 0:
216 | raise ValueError(
217 | "The hidden size (%d) is not a multiple of the number of attention "
218 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
219 | self.num_attention_heads = config.num_attention_heads
220 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
221 | self.all_head_size = self.num_attention_heads * self.attention_head_size
222 |
223 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
224 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
225 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
226 |
227 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
228 |
229 | def transpose_for_scores(self, x):
230 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
231 | x = x.view(*new_x_shape)
232 | return x.permute(0, 2, 1, 3)
233 |
234 | def forward(self, hidden_states, attention_mask):
235 | mixed_query_layer = self.query(hidden_states)
236 | mixed_key_layer = self.key(hidden_states)
237 | mixed_value_layer = self.value(hidden_states)
238 |
239 | query_layer = self.transpose_for_scores(mixed_query_layer)
240 | key_layer = self.transpose_for_scores(mixed_key_layer)
241 | value_layer = self.transpose_for_scores(mixed_value_layer)
242 |
243 | # Take the dot product between "query" and "key" to get the raw attention scores.
244 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
245 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
247 | attention_scores = attention_scores + attention_mask
248 |
249 | # Normalize the attention scores to probabilities.
250 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
251 |
252 | # This is actually dropping out entire tokens to attend to, which might
253 | # seem a bit unusual, but is taken from the original Transformer paper.
254 | attention_probs = self.dropout(attention_probs)
255 |
256 | context_layer = torch.matmul(attention_probs, value_layer)
257 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
258 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
259 | context_layer = context_layer.view(*new_context_layer_shape)
260 | return context_layer, attention_scores
261 |
262 |
263 | def _generate_relative_positions_matrix(length, max_relative_position,
264 | cache=False):
265 | """Generates matrix of relative positions between inputs."""
266 | if not cache:
267 | range_vec = torch.arange(length)
268 | range_mat = range_vec.repeat(length).view(length, length)
269 | distance_mat = range_mat - torch.t(range_mat)
270 | else:
271 | distance_mat = torch.arange(-length + 1, 1, 1).unsqueeze(0)
272 |
273 | distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position)
274 | final_mat = distance_mat_clipped + max_relative_position
275 |
276 | return final_mat
277 |
278 |
279 | def _generate_relative_positions_embeddings(length, depth, max_relative_position=127):
280 | vocab_size = max_relative_position * 2 + 1
281 | range_vec = torch.arange(length)
282 | range_mat = range_vec.repeat(length).view(length, length)
283 | distance_mat = range_mat - torch.t(range_mat)
284 | distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position)
285 | final_mat = distance_mat_clipped + max_relative_position
286 | embeddings_table = np.zeros([vocab_size, depth])
287 | for pos in range(vocab_size):
288 | for i in range(depth // 2):
289 | embeddings_table[pos, 2 * i] = np.sin(pos / np.power(10000, 2 * i / depth))
290 | embeddings_table[pos, 2 * i + 1] = np.cos(pos / np.power(10000, 2 * i / depth))
291 |
292 | embeddings_table_tensor = torch.tensor(embeddings_table).float()
293 | flat_relative_positions_matrix = final_mat.view(-1)
294 | one_hot_relative_positions_matrix = torch.nn.functional.one_hot(flat_relative_positions_matrix,
295 | num_classes=vocab_size).float()
296 | embeddings = torch.matmul(one_hot_relative_positions_matrix, embeddings_table_tensor)
297 | my_shape = list(final_mat.size())
298 | my_shape.append(depth)
299 | embeddings = embeddings.view(my_shape)
300 | return embeddings
301 |
302 |
303 | class NeZhaSelfAttention(nn.Module):
304 | def __init__(self, config):
305 | super(NeZhaSelfAttention, self).__init__()
306 | if config.hidden_size % config.num_attention_heads != 0:
307 | raise ValueError(
308 | "The hidden size (%d) is not a multiple of the number of attention "
309 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
310 | self.num_attention_heads = config.num_attention_heads
311 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
312 | self.all_head_size = self.num_attention_heads * self.attention_head_size
313 |
314 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
315 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
316 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
317 | # self.relative_positions_embeddings = _generate_relative_positions_embeddings(
318 | # length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position).to(
319 | # self.query.weight.device)
320 | self.relative_positions_embeddings = _generate_relative_positions_embeddings(
321 | length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position).cuda()
322 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
323 |
324 | def transpose_for_scores(self, x):
325 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
326 | x = x.view(*new_x_shape)
327 | return x.permute(0, 2, 1, 3)
328 |
329 | def forward(self, hidden_states, attention_mask):
330 | device = 'cpu'
331 | if hidden_states.is_cuda:
332 | device = hidden_states.get_device()
333 | mixed_query_layer = self.query(hidden_states)
334 | mixed_key_layer = self.key(hidden_states)
335 | mixed_value_layer = self.value(hidden_states)
336 |
337 | query_layer = self.transpose_for_scores(mixed_query_layer)
338 | key_layer = self.transpose_for_scores(mixed_key_layer)
339 | value_layer = self.transpose_for_scores(mixed_value_layer)
340 |
341 | # Take the dot product between "query" and "key" to get the raw attention scores.
342 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
343 | batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size()
344 |
345 | relations_keys = self.relative_positions_embeddings.detach().clone()[:to_seq_length, :to_seq_length, :].to(
346 | device)
347 | # relations_keys = embeddings.clone().detach().to(device)
348 | query_layer_t = query_layer.permute(2, 0, 1, 3)
349 | query_layer_r = query_layer_t.contiguous().view(from_seq_length, batch_size * num_attention_heads,
350 | self.attention_head_size)
351 | key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1))
352 | key_position_scores_r = key_position_scores.view(from_seq_length, batch_size,
353 | num_attention_heads, from_seq_length)
354 | key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3)
355 | attention_scores = attention_scores + key_position_scores_r_t
356 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
357 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
358 | attention_scores = attention_scores + attention_mask
359 |
360 | # Normalize the attention scores to probabilities.
361 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
362 |
363 | # This is actually dropping out entire tokens to attend to, which might
364 | # seem a bit unusual, but is taken from the original Transformer paper.
365 | attention_probs = self.dropout(attention_probs)
366 |
367 | context_layer = torch.matmul(attention_probs, value_layer)
368 |
369 | relations_values = self.relative_positions_embeddings.clone()[:to_seq_length, :to_seq_length, :].to(
370 | device)
371 | attention_probs_t = attention_probs.permute(2, 0, 1, 3)
372 | attentions_probs_r = attention_probs_t.contiguous().view(from_seq_length, batch_size * num_attention_heads,
373 | to_seq_length)
374 | value_position_scores = torch.matmul(attentions_probs_r, relations_values)
375 | value_position_scores_r = value_position_scores.view(from_seq_length, batch_size,
376 | num_attention_heads, self.attention_head_size)
377 | value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3)
378 | context_layer = context_layer + value_position_scores_r_t
379 |
380 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
381 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
382 | context_layer = context_layer.view(*new_context_layer_shape)
383 | return context_layer, attention_scores
384 |
385 |
386 | class BertSelfOutput(nn.Module):
387 | def __init__(self, config):
388 | super(BertSelfOutput, self).__init__()
389 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
390 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
391 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
392 |
393 | def forward(self, hidden_states, input_tensor):
394 | hidden_states = self.dense(hidden_states)
395 | hidden_states = self.dropout(hidden_states)
396 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
397 | return hidden_states
398 |
399 |
400 | class BertAttention(nn.Module):
401 | def __init__(self, config):
402 | super(BertAttention, self).__init__()
403 | try:
404 | self.use_relative_position = config.use_relative_position
405 | except:
406 | self.use_relative_position = False
407 | if self.use_relative_position:
408 | self.self = NeZhaSelfAttention(config)
409 | else:
410 | self.self = BertSelfAttention(config)
411 |
412 | self.output = BertSelfOutput(config)
413 |
414 | def forward(self, input_tensor, attention_mask):
415 | self_output = self.self(input_tensor, attention_mask)
416 | self_output, layer_att = self_output
417 | attention_output = self.output(self_output, input_tensor)
418 | return attention_output, layer_att
419 |
420 |
421 | class BertIntermediate(nn.Module):
422 | def __init__(self, config):
423 | super(BertIntermediate, self).__init__()
424 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
425 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
426 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
427 | else:
428 | self.intermediate_act_fn = config.hidden_act
429 |
430 | def forward(self, hidden_states):
431 | hidden_states = self.dense(hidden_states)
432 | hidden_states = self.intermediate_act_fn(hidden_states)
433 | return hidden_states
434 |
435 |
436 | class BertOutput(nn.Module):
437 | def __init__(self, config):
438 | super(BertOutput, self).__init__()
439 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
440 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
441 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
442 |
443 | def forward(self, hidden_states, input_tensor):
444 | hidden_states = self.dense(hidden_states)
445 | hidden_states = self.dropout(hidden_states)
446 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
447 | return hidden_states
448 |
449 |
450 | class BertLayer(nn.Module):
451 | def __init__(self, config):
452 | super(BertLayer, self).__init__()
453 | self.attention = BertAttention(config)
454 | self.intermediate = BertIntermediate(config)
455 | self.output = BertOutput(config)
456 |
457 | def forward(self, hidden_states, attention_mask):
458 | attention_output = self.attention(hidden_states, attention_mask)
459 | attention_output, layer_att = attention_output
460 | intermediate_output = self.intermediate(attention_output)
461 | layer_output = self.output(intermediate_output, attention_output)
462 | return layer_output, layer_att
463 |
464 |
465 | class BertEncoder(nn.Module):
466 | def __init__(self, config):
467 | super(BertEncoder, self).__init__()
468 | layer = BertLayer(config)
469 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
470 |
471 | def forward(self, hidden_states, attention_mask):
472 | all_encoder_layers = []
473 | all_encoder_att = []
474 | for i, layer_module in enumerate(self.layer):
475 | all_encoder_layers.append(hidden_states)
476 | hidden_states = layer_module(all_encoder_layers[i], attention_mask)
477 | hidden_states, layer_att = hidden_states
478 | all_encoder_att.append(layer_att)
479 | all_encoder_layers.append(hidden_states)
480 | return all_encoder_layers, all_encoder_att
481 |
482 |
483 | class BertPooler(nn.Module):
484 | def __init__(self, config):
485 | super(BertPooler, self).__init__()
486 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
487 | self.activation = nn.Tanh()
488 |
489 | def forward(self, hidden_states):
490 | # We "pool" the model by simply taking the hidden state corresponding
491 | # to the first token.
492 | first_token_tensor = hidden_states[:, 0]
493 | pooled_output = self.dense(first_token_tensor)
494 | pooled_output = self.activation(pooled_output)
495 | return pooled_output
496 |
497 |
498 | class BertPreTrainedModel(nn.Module):
499 | """ An abstract class to handle weights initialization and
500 | a simple interface for dowloading and loading pretrained models.
501 | """
502 |
503 | def __init__(self, config, *inputs, **kwargs):
504 | super(BertPreTrainedModel, self).__init__()
505 | if not isinstance(config, NezhaConfig):
506 | raise ValueError(
507 | "Parameter config in `{}(config)` should be an instance of class `NezhaConfig`. "
508 | "To create a model from a Google pretrained model use "
509 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
510 | self.__class__.__name__, self.__class__.__name__
511 | ))
512 | self.config = config
513 |
514 | def init_bert_weights(self, module):
515 | """ Initialize the weights.
516 | """
517 | if isinstance(module, (nn.Linear, nn.Embedding)):
518 | # Slightly different from the TF version which uses truncated_normal for initialization
519 | # cf https://github.com/pytorch/pytorch/pull/5617
520 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
521 | elif isinstance(module, BertLayerNorm):
522 | module.bias.data.zero_()
523 | module.weight.data.fill_(1.0)
524 | if isinstance(module, nn.Linear) and module.bias is not None:
525 | module.bias.data.zero_()
526 |
527 |
528 | class NEZHAModel(BertPreTrainedModel):
529 | def __init__(self, config):
530 | super(NEZHAModel, self).__init__(config)
531 | self.embeddings = BertEmbeddings(config)
532 | self.encoder = BertEncoder(config)
533 | self.pooler = BertPooler(config)
534 |
535 | self.apply(self.init_bert_weights)
536 |
537 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_attention_mask=False,
538 | model_distillation=False, output_all_encoded_layers=False):
539 | if attention_mask is None:
540 | attention_mask = torch.ones_like(input_ids)
541 | if token_type_ids is None:
542 | token_type_ids = torch.zeros_like(input_ids)
543 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
544 | # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
545 | extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) # fp16 compatibility
546 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
547 |
548 | embedding_output = self.embeddings(input_ids, token_type_ids)
549 | encoded_layers = self.encoder(embedding_output,
550 | extended_attention_mask)
551 | encoded_layers, attention_layers = encoded_layers
552 | sequence_output = encoded_layers[-1]
553 | pooled_output = self.pooler(sequence_output)
554 | if output_attention_mask:
555 | return encoded_layers, attention_layers, pooled_output, extended_attention_mask
556 | if model_distillation:
557 | return encoded_layers, attention_layers
558 | if not output_all_encoded_layers:
559 | encoded_layers = encoded_layers[-1]
560 | return encoded_layers, pooled_output
561 |
562 |
563 | class BertPredictionHeadTransform(nn.Module):
564 | def __init__(self, config):
565 | super(BertPredictionHeadTransform, self).__init__()
566 | # Need to unty it when we separate the dimensions of hidden and emb
567 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
568 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
569 | self.transform_act_fn = ACT2FN[config.hidden_act]
570 | else:
571 | self.transform_act_fn = config.hidden_act
572 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
573 |
574 | def forward(self, hidden_states):
575 | hidden_states = self.dense(hidden_states)
576 | hidden_states = self.transform_act_fn(hidden_states)
577 | hidden_states = self.LayerNorm(hidden_states)
578 | return hidden_states
579 |
580 |
581 | class BertLMPredictionHead(nn.Module):
582 | def __init__(self, config, bert_model_embedding_weights):
583 | super(BertLMPredictionHead, self).__init__()
584 | self.transform = BertPredictionHeadTransform(config)
585 |
586 | # The output weights are the same as the input embeddings, but there is
587 | # an output-only bias for each token.
588 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
589 | bert_model_embedding_weights.size(0),
590 | bias=False)
591 | self.decoder.weight = bert_model_embedding_weights
592 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
593 |
594 | def forward(self, hidden_states):
595 | hidden_states = self.transform(hidden_states)
596 | hidden_states = self.decoder(hidden_states) + self.bias
597 | return hidden_states
598 |
599 |
600 | class BertOnlyMLMHead(nn.Module):
601 | def __init__(self, config, bert_model_embedding_weights):
602 | super(BertOnlyMLMHead, self).__init__()
603 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
604 |
605 | def forward(self, sequence_output):
606 | prediction_scores = self.predictions(sequence_output)
607 | return prediction_scores
608 |
609 |
610 | class BertOnlyNSPHead(nn.Module):
611 | def __init__(self, config):
612 | super(BertOnlyNSPHead, self).__init__()
613 | self.seq_relationship = nn.Linear(config.hidden_size, 2)
614 |
615 | def forward(self, pooled_output):
616 | seq_relationship_score = self.seq_relationship(pooled_output)
617 | return seq_relationship_score
618 |
619 |
620 | class BertPreTrainingHeads(nn.Module):
621 | def __init__(self, config, bert_model_embedding_weights):
622 | super(BertPreTrainingHeads, self).__init__()
623 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
624 | self.seq_relationship = nn.Linear(config.hidden_size, 2)
625 |
626 | def forward(self, sequence_output, pooled_output):
627 | prediction_scores = self.predictions(sequence_output)
628 | seq_relationship_score = self.seq_relationship(pooled_output)
629 | return prediction_scores, seq_relationship_score
630 |
631 |
632 | class BertForPreTraining(BertPreTrainedModel):
633 | """BERT model with pre-training heads.
634 | This module comprises the BERT model followed by the two pre-training heads:
635 | - the masked language modeling head, and
636 | - the next sentence classification head.
637 |
638 | Params:
639 | config: a NezhaConfig class instance with the configuration to build a new model.
640 |
641 | Inputs:
642 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
643 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
644 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
645 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
646 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
647 | a `sentence B` token (see BERT paper for more details).
648 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
649 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
650 | input sequence length in the current batch. It's the mask that we typically use for attention when
651 | a batch has varying length sentences.
652 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
653 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
654 | is only computed for the labels set in [0, ..., vocab_size]
655 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
656 | with indices selected in [0, 1].
657 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
658 |
659 | Outputs:
660 | if `masked_lm_labels` and `next_sentence_label` are not `None`:
661 | Outputs the total_loss which is the sum of the masked language modeling loss and the next
662 | sentence classification loss.
663 | if `masked_lm_labels` or `next_sentence_label` is `None`:
664 | Outputs a tuple comprising
665 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
666 | - the next sentence classification logits of shape [batch_size, 2].
667 |
668 | Example usage:
669 | ```python
670 | # Already been converted into WordPiece token ids
671 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
672 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
673 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
674 |
675 | config = NezhaConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
676 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
677 |
678 | model = BertForPreTraining(config)
679 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
680 | ```
681 | """
682 |
683 | def __init__(self, config):
684 | super(BertForPreTraining, self).__init__(config)
685 | self.bert = NEZHAModel(config)
686 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
687 | self.apply(self.init_bert_weights)
688 |
689 | def forward(self, input_ids, token_type_ids=None, attention_mask=None,
690 | masked_lm_labels=None, next_sentence_label=None):
691 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
692 | output_all_encoded_layers=False)
693 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
694 |
695 | if masked_lm_labels is not None and next_sentence_label is not None:
696 | loss_fct = CrossEntropyLoss(ignore_index=-1)
697 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
698 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
699 | total_loss = masked_lm_loss + next_sentence_loss
700 | return total_loss
701 | elif masked_lm_labels is not None:
702 | loss_fct = CrossEntropyLoss(ignore_index=-1)
703 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
704 | total_loss = masked_lm_loss
705 | return total_loss
706 | else:
707 | return prediction_scores, seq_relationship_score
708 |
709 |
710 | class BertForMaskedLM(BertPreTrainedModel):
711 | """BERT model with the masked language modeling head.
712 | This module comprises the BERT model followed by the masked language modeling head.
713 |
714 | Params:
715 | config: a NezhaConfig class instance with the configuration to build a new model.
716 |
717 | Inputs:
718 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
719 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
720 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
721 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
722 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
723 | a `sentence B` token (see BERT paper for more details).
724 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
725 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
726 | input sequence length in the current batch. It's the mask that we typically use for attention when
727 | a batch has varying length sentences.
728 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
729 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
730 | is only computed for the labels set in [0, ..., vocab_size]
731 |
732 | Outputs:
733 | if `masked_lm_labels` is not `None`:
734 | Outputs the masked language modeling loss.
735 | if `masked_lm_labels` is `None`:
736 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
737 |
738 | Example usage:
739 | ```python
740 | # Already been converted into WordPiece token ids
741 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
742 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
743 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
744 |
745 | config = NezhaConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
746 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
747 |
748 | model = BertForMaskedLM(config)
749 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
750 | ```
751 | """
752 |
753 | def __init__(self, config):
754 | super(BertForMaskedLM, self).__init__(config)
755 | self.bert = NEZHAModel(config)
756 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
757 | self.apply(self.init_bert_weights)
758 |
759 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
760 | output_att=False, infer=False):
761 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
762 | output_all_encoded_layers=True, output_att=output_att)
763 |
764 | if output_att:
765 | sequence_output, att_output = sequence_output
766 | prediction_scores = self.cls(sequence_output[-1])
767 |
768 | if masked_lm_labels is not None:
769 | loss_fct = CrossEntropyLoss(ignore_index=-1)
770 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
771 | if not output_att:
772 | return masked_lm_loss
773 | else:
774 | return masked_lm_loss, att_output
775 | else:
776 | if not output_att:
777 | return prediction_scores
778 | else:
779 | return prediction_scores, att_output
780 |
781 |
782 | class BertForSequenceClassification(BertPreTrainedModel):
783 | """BERT model for classification.
784 | This module is composed of the BERT model with a linear layer on top of
785 | the pooled output.
786 |
787 | Params:
788 | `config`: a NezhaConfig class instance with the configuration to build a new model.
789 | `num_labels`: the number of classes for the classifier. Default = 2.
790 |
791 | Inputs:
792 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
793 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
794 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
795 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
796 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
797 | a `sentence B` token (see BERT paper for more details).
798 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
799 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
800 | input sequence length in the current batch. It's the mask that we typically use for attention when
801 | a batch has varying length sentences.
802 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
803 | with indices selected in [0, ..., num_labels].
804 |
805 | Outputs:
806 | if `labels` is not `None`:
807 | Outputs the CrossEntropy classification loss of the output with the labels.
808 | if `labels` is `None`:
809 | Outputs the classification logits of shape [batch_size, num_labels].
810 |
811 | Example usage:
812 | ```python
813 | # Already been converted into WordPiece token ids
814 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
815 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
816 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
817 |
818 | config = NezhaConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
819 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
820 |
821 | num_labels = 2
822 |
823 | model = BertForSequenceClassification(config, num_labels)
824 | logits = model(input_ids, token_type_ids, input_mask)
825 | ```
826 | """
827 |
828 | def __init__(self, config, num_labels):
829 | super(BertForSequenceClassification, self).__init__(config)
830 | self.num_labels = num_labels
831 | self.bert = NEZHAModel(config)
832 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
833 | self.classifier = nn.Linear(config.hidden_size, num_labels)
834 | self.apply(self.init_bert_weights)
835 |
836 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
837 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
838 | output_all_encoded_layers=False)
839 | task_output = self.dropout(pooled_output)
840 | logits = self.classifier(task_output)
841 | if labels is not None:
842 | loss_fct = CrossEntropyLoss()
843 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
844 | return loss
845 | else:
846 | return logits
847 |
848 |
849 | class NeZhaForMultipleChoice(BertPreTrainedModel):
850 | def __init__(self, config, num_choices=2):
851 | super(NeZhaForMultipleChoice, self).__init__(config)
852 | self.num_choices = num_choices
853 | self.bert = NEZHAModel(config)
854 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
855 | self.classifier = nn.Linear(config.hidden_size, 1)
856 | self.apply(self.init_bert_weights)
857 |
858 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, return_logits=False):
859 | # input_ids: [bs,num_choice,seq_l]
860 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) # flat_input_ids: [bs*num_choice,seq_l]
861 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
862 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
863 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask,
864 | output_all_encoded_layers=False)
865 | pooled_output = self.dropout(pooled_output)
866 | logits = self.classifier(pooled_output) # logits: (bs*num_choice,1)
867 | reshaped_logits = logits.view(-1, self.num_choices) # logits: (bs, num_choice)
868 |
869 | if labels is not None:
870 | loss_fct = CrossEntropyLoss()
871 | loss = loss_fct(reshaped_logits, labels)
872 | if return_logits:
873 | return loss, reshaped_logits
874 | else:
875 | return loss
876 | else:
877 | return reshaped_logits
878 |
879 |
880 | class NeZhaForQuestionAnswering(BertPreTrainedModel):
881 | def __init__(self, config):
882 | super(NeZhaForQuestionAnswering, self).__init__(config)
883 | self.bert = NEZHAModel(config)
884 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
885 | # self.dropout = nn.Dropout(config.hidden_dropout_prob)
886 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
887 | self.apply(self.init_bert_weights)
888 |
889 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
890 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
891 | logits = self.qa_outputs(sequence_output)
892 | start_logits, end_logits = logits.split(1, dim=-1)
893 | start_logits = start_logits.squeeze(-1)
894 | end_logits = end_logits.squeeze(-1)
895 |
896 | if start_positions is not None and end_positions is not None:
897 | # If we are on multi-GPU, split add a dimension
898 | if len(start_positions.size()) > 1:
899 | start_positions = start_positions.squeeze(-1)
900 | if len(end_positions.size()) > 1:
901 | end_positions = end_positions.squeeze(-1)
902 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
903 | ignored_index = start_logits.size(1)
904 | start_positions.clamp_(0, ignored_index)
905 | end_positions.clamp_(0, ignored_index)
906 |
907 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
908 | start_loss = loss_fct(start_logits, start_positions)
909 | end_loss = loss_fct(end_logits, end_positions)
910 | total_loss = (start_loss + end_loss) / 2
911 | return total_loss
912 | else:
913 | return start_logits, end_logits
914 |
915 |
916 | class BertForJointLSTM(BertPreTrainedModel):
917 | def __init__(self, config, num_intent_labels, num_slot_labels):
918 | super(BertForJointLSTM, self).__init__(config)
919 | self.num_intent_labels = num_intent_labels
920 | self.num_slot_labels = num_slot_labels
921 | self.bert = NEZHAModel(config)
922 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
923 | self.intent_classifier = nn.Linear(config.hidden_size, num_intent_labels)
924 | self.lstm = nn.LSTM(
925 | input_size=config.hidden_size,
926 | hidden_size=300,
927 | batch_first=True,
928 | bidirectional=True
929 |
930 | )
931 | self.slot_classifier = nn.Linear(300 * 2, num_slot_labels)
932 | self.apply(self.init_bert_weights)
933 |
934 | def forward(self, input_ids, token_type_ids=None,
935 | attention_mask=None, intent_labels=None, slot_labels=None):
936 | encoded_layers, attention_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
937 | intent_logits = self.intent_classifier(self.dropout(pooled_output))
938 |
939 | last_encoded_layer = encoded_layers[-1]
940 | slot_logits, _ = self.lstm(last_encoded_layer)
941 | slot_logits = self.slot_classifier(slot_logits)
942 | tmp = []
943 | if intent_labels is not None and slot_labels is not None:
944 | loss_fct = CrossEntropyLoss()
945 | intent_loss = loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_labels.view(-1))
946 | if attention_mask is not None:
947 | active_slot_loss = attention_mask.view(-1) == 1
948 | active_slot_logits = slot_logits.view(-1, self.num_slot_labels)[active_slot_loss]
949 | active_slot_labels = slot_labels.view(-1)[active_slot_loss]
950 | slot_loss = loss_fct(active_slot_logits, active_slot_labels)
951 | else:
952 | slot_loss = loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels.view(-1))
953 |
954 | return intent_loss, slot_loss
955 | else:
956 | return intent_logits, slot_logits
957 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/NEZHA/nezha_utils.py:
--------------------------------------------------------------------------------
1 | # /usr/bin/env python
2 | # coding=utf-8
3 | import os
4 | from glob import glob
5 |
6 | import torch
7 |
8 |
9 | def check_args(args):
10 | args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file)
11 | args.log_file = os.path.join(args.checkpoint_dir, args.log_file)
12 | os.makedirs(args.checkpoint_dir, exist_ok=True)
13 | with open(args.setting_file, 'wt') as opt_file:
14 | opt_file.write('------------ Options -------------\n')
15 | print('------------ Options -------------')
16 | for k in args.__dict__:
17 | v = args.__dict__[k]
18 | opt_file.write('%s: %s\n' % (str(k), str(v)))
19 | print('%s: %s' % (str(k), str(v)))
20 | opt_file.write('-------------- End ----------------\n')
21 | print('------------ End -------------')
22 |
23 | return args
24 |
25 |
26 | def torch_show_all_params(model, rank=0):
27 | params = list(model.parameters())
28 | k = 0
29 | for i in params:
30 | l = 1
31 | for j in i.size():
32 | l *= j
33 | k = k + l
34 | if rank == 0:
35 | print("Total param num:" + str(k))
36 |
37 |
38 | def torch_init_model(model, init_checkpoint, delete_module=False):
39 | state_dict = torch.load(init_checkpoint, map_location='cpu')
40 | state_dict_new = {}
41 | # delete module.
42 | if delete_module:
43 | for key in state_dict.keys():
44 | v = state_dict[key]
45 | state_dict_new[key.replace('module.', '')] = v
46 | state_dict = state_dict_new
47 | missing_keys = []
48 | unexpected_keys = []
49 | error_msgs = []
50 | # copy state_dict so _load_from_state_dict can modify it
51 | metadata = getattr(state_dict, '_metadata', None)
52 | state_dict = state_dict.copy()
53 | if metadata is not None:
54 | state_dict._metadata = metadata
55 |
56 | def load(module, prefix=''):
57 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
58 |
59 | module._load_from_state_dict(
60 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
61 | for name, child in module._modules.items():
62 | if child is not None:
63 | load(child, prefix + name + '.')
64 |
65 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
66 |
67 | print("missing keys:{}".format(missing_keys))
68 | print('unexpected keys:{}'.format(unexpected_keys))
69 | print('error msgs:{}'.format(error_msgs))
70 |
71 |
72 | def torch_save_model(model, output_dir, scores, max_save_num=1):
73 | # Save model checkpoint
74 | if not os.path.exists(output_dir):
75 | os.makedirs(output_dir)
76 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
77 | saved_pths = glob(os.path.join(output_dir, '*.pth'))
78 | saved_pths.sort()
79 | while len(saved_pths) >= max_save_num:
80 | if os.path.exists(saved_pths[0].replace('//', '/')):
81 | os.remove(saved_pths[0].replace('//', '/'))
82 | del saved_pths[0]
83 |
84 | save_prex = "checkpoint_score"
85 | for k in scores:
86 | save_prex += ('_' + k + '-' + str(scores[k])[:6])
87 | save_prex += '.pth'
88 |
89 | torch.save(model_to_save.state_dict(),
90 | os.path.join(output_dir, save_prex))
91 | print("Saving model checkpoint to %s", output_dir)
92 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/config.py:
--------------------------------------------------------------------------------
1 | class Config():
2 | def __init__(self):
3 | self.device= 'cuda'
4 | self.model_type = '0523_roberta_80k_6tasks'
5 | self.task_type = 'ab'
6 |
7 | self.save_dir = '/data1/wangchenyue/sohu_matching/checkpoints/rematch/'
8 | self.data_dir = '/data1/wangchenyue/sohu_matching/data/sohu2021_open_data/'
9 | self.load_toy_dataset = False
10 |
11 | # self.pretrained = '/data1/wangchenyue/Downloads/chinese-macbert-base/'
12 | # self.pretrained = '/data1/wangchenyue/Downloads/nezha-base-wwm/'
13 | # self.pretrained = '/data1/wangchenyue/Downloads/chinese-roberta-wwm-ext/'
14 | # self.pretrained = '/data1/wangchenyue/Downloads/roberta-base-finetuned-chinanews-chinese/'
15 | # self.pretrained = '/data1/wangchenyue/Downloads/chinese-bert-wwm-ext/'
16 | # self.pretrained = '/data1/wangchenyue/Downloads/roberta-base-word-chinese'
17 | # self.pretrained = '/data1/wangchenyue/Downloads/ernie-1.0/'
18 | self.pretrained = '/data1/wangchenyue/Downloads/DSP/roberta-wwm-rematch/checkpoint-80000/'
19 |
20 | self.epochs = 3
21 | self.lr = 2e-5
22 | self.classifier_lr = 1e-3
23 | self.use_scheduler = True
24 | self.weight_decay = 1e-3
25 | self.num_warmup_steps = 2000
26 |
27 | # for larger models, e.g.roberta-large hidden_size = 1024, otherwise 768
28 | self.hidden_size = 768
29 | # for sbert, train_bs = 16, eval_bs = 32, otherwise 32/64
30 | self.train_bs = 32
31 | self.eval_bs = 64
32 | self.criterion = 'CE'
33 | self.print_every = 50
34 | self.eval_every = 500
35 |
36 | # whether to shffle the order in training data as augmentation
37 | self.shuffle_order = False
38 | self.aug_data = False
39 | # how to clip the long sequences, 'head': using the first sentences, 'tail': using the last sentences
40 | # 'head' is reportedly better than 'tail'
41 | self.clip_method = 'head'
42 |
43 | # whether to use fgm for adversial attack in training
44 | self.use_fgm = False
45 |
46 | # settings for inference
47 | # self.infer_model_dir = '../checkpoints/0502/'
48 | self.infer_model_dir = '/data1/wangchenyue/sohu_matching/checkpoints/rematch/'
49 | self.infer_model_name = '0525_roberta_6tasks_epoch_1_ab_loss'
50 | # fake pretrained model dir containing config.json and vocab.txt, for tokenzier and model initialization
51 | self.dummy_pretrained = '../data/dummy_bert/'
52 | # self.dummy_pretrained = '../data/dummy_ernie/'
53 | # self.dummy_pretrained = '../data/dummy_nezha/'
54 | # infer_task_type should match the last letter in infer_model_name
55 | self.infer_task_type = self.infer_model_name.split('_')[-2]
56 | self.infer_output_dir = '/data1/wangchenyue/sohu_matching/results/rematch/'
57 | self.infer_output_filename = '{}.csv'.format(self.infer_model_name)
58 | self.infer_clip_method = 'head'
59 | # for NEZHA, infer_bs=64, otherwise 256
60 | self.infer_bs = 256
61 | self.infer_fixed_thres_a = 0.45
62 | self.infer_fixed_thres_b = 0.35
63 | self.infer_search_thres = True
64 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | from transformers import BertTokenizer, AutoTokenizer
4 | from utils import pad_to_maxlen, augment_data
5 | import pandas as pd
6 |
7 | from tqdm import tqdm
8 | import json
9 |
10 | # the main difference between the two datasets is
11 | # the length limit (512 for one sentence in SBERT
12 | # but for the two concated sentences in BERT setting)
13 | class SentencePairDatasetForSBERT(Dataset):
14 | def __init__(self, file_dir, is_train, tokenizer_config, shuffle_order=False, aug_data=False, len_limit=512, clip='head'):
15 | self.is_train = is_train
16 | self.shuffle_order = shuffle_order
17 | self.aug_data = aug_data
18 | self.total_source_input_ids = []
19 | # token_types are no longer neccessary if not concat into one text
20 | # self.total_source_input_types = []
21 | self.total_target_input_ids = []
22 | # self.total_target_input_types = []
23 | self.sample_types = []
24 |
25 | # use AutoTokenzier instead of BertTokenizer to support speice.model (AlbertTokenizer-like)
26 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_config)
27 | lines = []
28 | for single_file_dir in file_dir:
29 | with open(single_file_dir, 'r', encoding='utf-8') as f_in:
30 | content = f_in.readlines()
31 | for item in content:
32 | line = json.loads(item.strip())
33 | if not is_train:
34 | line['type'] = 0 if 'a' in line['id'] else 1
35 | lines.append(line)
36 |
37 | content = pd.DataFrame(lines)
38 | content.columns = ['source', 'target', 'label', 'type']
39 |
40 | # utilize labelB=1-->A positive, labelA=0-->B negative
41 | if self.is_train and self.aug_data:
42 | print("augmenting data...")
43 | content = augment_data(content)
44 |
45 | sources = content['source'].values.tolist()
46 | targets = content['target'].values.tolist()
47 |
48 | self.sample_types = content['type'].values.tolist()
49 | if self.is_train:
50 | self.labels = content['label'].values.tolist()
51 | else:
52 | self.ids = content['label'].values.tolist()
53 |
54 | # shuffle_order is only allowed for training mode
55 | if self.shuffle_order and self.is_train:
56 | sources += content['target'].values.tolist()
57 | targets += content['source'].values.tolist()
58 | self.labels += self.labels
59 | self.sample_types += self.sample_types
60 |
61 | for source, target in tqdm(zip(sources, targets), total=len(sources)):
62 | # tokenize before clipping
63 | source = tokenizer.encode(source)[1:-1]
64 | target = tokenizer.encode(target)[1:-1]
65 |
66 | # clip the sentences if too long
67 | # TODO: different strategies to clip long sequences
68 | if clip == 'head':
69 | if len(source)+2 > len_limit:
70 | source = source[0: len_limit-2]
71 | if len(target)+2 > len_limit:
72 | target = target[0: len_limit-2]
73 |
74 | if clip == 'tail':
75 | if len(source)+2 > len_limit:
76 | source = source[-len_limit+2:]
77 | if len(target)+2 > len_limit:
78 | target = target[-len_limit+2:]
79 |
80 | # check if the length is within the limit
81 | assert len(source)+2 <= len_limit and len(target)+2 <= len_limit
82 |
83 | # [CLS]:101, [SEP]:102
84 | source_input_ids = [101] + source + [102]
85 | target_input_ids = [101] + target + [102]
86 |
87 | assert len(source_input_ids) <= len_limit and len(target_input_ids) <= len_limit
88 |
89 | self.total_source_input_ids.append(source_input_ids)
90 | self.total_target_input_ids.append(target_input_ids)
91 |
92 | self.max_source_input_len = max([len(s) for s in self.total_source_input_ids])
93 | self.max_target_input_len = max([len(s) for s in self.total_target_input_ids])
94 | print("max source length: ", self.max_source_input_len)
95 | print("max target length: ", self.max_target_input_len)
96 |
97 | def __len__(self):
98 | return len(self.total_target_input_ids)
99 |
100 | def __getitem__(self, idx):
101 | source_input_ids = pad_to_maxlen(self.total_source_input_ids[idx], self.max_source_input_len)
102 | target_input_ids = pad_to_maxlen(self.total_target_input_ids[idx], self.max_target_input_len)
103 | sample_type = int(self.sample_types[idx])
104 |
105 | if self.is_train:
106 | label = int(self.labels[idx])
107 | return torch.LongTensor(source_input_ids), torch.LongTensor(target_input_ids), torch.LongTensor([label]), sample_type
108 |
109 | else:
110 | index = self.ids[idx]
111 | return torch.LongTensor(source_input_ids), torch.LongTensor(target_input_ids), index, sample_type
112 |
113 | class SentencePairDatasetWithType(Dataset):
114 | def __init__(self, file_dir, is_train, tokenizer_config, shuffle_order=False, aug_data=False, len_limit=512, clip='head'):
115 | self.is_train = is_train
116 | self.shuffle_order = shuffle_order
117 | self.aug_data = aug_data
118 | self.total_input_ids = []
119 | self.total_input_types = []
120 | self.sample_types = []
121 |
122 | # use AutoTokenzier instead of BertTokenizer to support speice.model (AlbertTokenizer-like)
123 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_config)
124 |
125 | # read json lines and convert to dict / df
126 | lines = []
127 | for single_file_dir in file_dir:
128 | with open(single_file_dir, 'r', encoding='utf-8') as f_in:
129 | content = f_in.readlines()
130 | for item in content:
131 | line = json.loads(item.strip())
132 |
133 | # for final stage, task a and b are included in the same file
134 | if not is_train:
135 | line['type'] = 0 if 'a' in line['id'] else 1
136 | lines.append(line)
137 | print(single_file_dir, len(lines))
138 | content = pd.DataFrame(lines)
139 | # print(content.head())
140 | content.columns = ['source', 'target', 'label', 'type']
141 |
142 | # utilize labelB=1-->A positive, labelA=0-->B negative
143 | if self.is_train and self.aug_data:
144 | print("augmenting data...")
145 | content = augment_data(content)
146 |
147 | sources = content['source'].values.tolist()
148 | targets = content['target'].values.tolist()
149 |
150 | self.sample_types = content['type'].values.tolist()
151 | if self.is_train:
152 | self.labels = content['label'].values.tolist()
153 | else:
154 | self.ids = content['label'].values.tolist()
155 |
156 | # shuffle_order is only allowed for training mode
157 | if self.shuffle_order and self.is_train:
158 | sources += content['target'].values.tolist()
159 | targets += content['source'].values.tolist()
160 | self.labels += self.labels
161 | self.sample_types += self.sample_types
162 |
163 | len_limit_s = (len_limit-3)//2
164 | len_limit_t = (len_limit-3)-len_limit_s
165 | # print('len_limit_s: ', len_limit_s)
166 | # print('len_limit_t: ', len_limit_t)
167 | for source, target in tqdm(zip(sources, targets), total=len(sources)):
168 | # tokenize before clipping
169 | source = tokenizer.encode(source)[1:-1]
170 | target = tokenizer.encode(target)[1:-1]
171 |
172 | # clip the sentences if too long
173 | # TODO: different strategies to clip long sequences
174 | if clip == 'head' and len(source)+len(target)+3 > len_limit:
175 | if len(source)>len_limit_s and len(target)>len_limit_t:
176 | source = source[0:len_limit_s]
177 | target = target[0:len_limit_t]
178 | elif len(source)>len_limit_s:
179 | source = source[0:len_limit-3-len(target)]
180 | elif len(target)>len_limit_t:
181 | target = target[0:len_limit-3-len(source)]
182 |
183 | if clip == 'tail' and len(source)+len(target)+3 > len_limit:
184 | if len(source)>len_limit_s and len(target)>len_limit_t:
185 | source = source[-len_limit_s:]
186 | target = target[-len_limit_t:]
187 | elif len(source)>len_limit_s:
188 | source = source[-(len_limit-3-len(target)):]
189 | elif len(target)>len_limit_t:
190 | target = target[-(len_limit-3-len(source)):]
191 |
192 | # check if the total length is within the limit
193 | assert len(source)+len(target)+3 <= len_limit
194 |
195 | # [CLS]:101, [SEP]:102
196 | input_ids = [101] + source + [102] + target + [102]
197 | input_types = [0]*(len(source)+2) + [1]*(len(target)+1)
198 |
199 | assert len(input_ids) <= len_limit and len(input_types) <= len_limit
200 | self.total_input_ids.append(input_ids)
201 | self.total_input_types.append(input_types)
202 |
203 | self.max_input_len = max([len(s) for s in self.total_input_ids])
204 | print("max length: ", self.max_input_len)
205 |
206 | def __len__(self):
207 | return len(self.total_input_ids)
208 |
209 | def __getitem__(self, idx):
210 | if self.is_train:
211 | input_ids = pad_to_maxlen(self.total_input_ids[idx], self.max_input_len)
212 | input_types = pad_to_maxlen(self.total_input_types[idx], self.max_input_len)
213 | label = int(self.labels[idx])
214 | sample_type = int(self.sample_types[idx])
215 | # print(len(input_ids), len(input_types), label)
216 | return torch.LongTensor(input_ids), torch.LongTensor(input_types), torch.LongTensor([label]), sample_type
217 |
218 | else:
219 | input_ids = pad_to_maxlen(self.total_input_ids[idx], self.max_input_len)
220 | input_types = pad_to_maxlen(self.total_input_types[idx], self.max_input_len)
221 | index = self.ids[idx]
222 | sample_type = int(self.sample_types[idx])
223 | return torch.LongTensor(input_ids), torch.LongTensor(input_types), index, sample_type
224 |
225 | # NOT CURRENTLY IN USE
226 | # template for the dataset of multiple task types
227 | # compatible with training code by changing task_num
228 | class SentencePairDatasetWithMultiType(Dataset):
229 | def __init__(self, file_dir, is_train, tokenizer_config, shuffle_order=False, aug_data=False, len_limit=512, clip='head'):
230 | self.is_train = is_train
231 | self.shuffle_order = shuffle_order
232 | self.aug_data = aug_data
233 | self.total_input_ids = []
234 | self.total_input_types = []
235 | self.sample_types = []
236 |
237 | # use AutoTokenzier instead of BertTokenizer to support speice.model (AlbertTokenizer-like)
238 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_config)
239 |
240 | # read json lines and convert to dict / df
241 | lines = []
242 | for single_file_dir in file_dir:
243 | with open(single_file_dir, 'r', encoding='utf-8') as f_in:
244 | content = f_in.readlines()
245 | for item in content:
246 | line = json.loads(item.strip())
247 | # BUG FIXED, order MATTERS!
248 | # mannually add key 'type' to distinguish the origin of samples
249 | # 0 for A, 1 for B
250 | if 'A' in single_file_dir:
251 | if self.is_train:
252 | line['label'] = line.pop('labelA')
253 | # assign type according to task names
254 | if '短短' in single_file_dir:
255 | line['type'] = 0
256 | elif '短长' in single_file_dir:
257 | line['type'] = 2
258 | else:
259 | line['type'] = 4
260 | else:
261 | if self.is_train:
262 | line['label'] = line.pop('labelB')
263 | # assign type according to task names
264 | if '短短' in single_file_dir:
265 | line['type'] = 1
266 | elif '短长' in single_file_dir:
267 | line['type'] = 3
268 | else:
269 | line['type'] = 5
270 | lines.append(line)
271 | print(single_file_dir, len(lines))
272 | content = pd.DataFrame(lines)
273 | # print(content.head())
274 | content.columns = ['source', 'target', 'label', 'type']
275 |
276 | # utilize labelB=1-->A positive, labelA=0-->B negative
277 | if self.is_train and self.aug_data:
278 | print("augmenting data...")
279 | content = augment_data(content)
280 |
281 | sources = content['source'].values.tolist()
282 | targets = content['target'].values.tolist()
283 |
284 | self.sample_types = content['type'].values.tolist()
285 | if self.is_train:
286 | self.labels = content['label'].values.tolist()
287 | else:
288 | self.ids = content['label'].values.tolist()
289 |
290 | # shuffle_order is only allowed for training mode
291 | if self.shuffle_order and self.is_train:
292 | sources += content['target'].values.tolist()
293 | targets += content['source'].values.tolist()
294 | self.labels += self.labels
295 | self.sample_types += self.sample_types
296 |
297 | len_limit_s = (len_limit-3)//2
298 | len_limit_t = (len_limit-3)-len_limit_s
299 | # print('len_limit_s: ', len_limit_s)
300 | # print('len_limit_t: ', len_limit_t)
301 | for source, target in tqdm(zip(sources, targets), total=len(sources)):
302 | # tokenize before clipping
303 | source = tokenizer.encode(source)[1:-1]
304 | target = tokenizer.encode(target)[1:-1]
305 |
306 | # clip the sentences if too long
307 | # TODO: different strategies to clip long sequences
308 | if clip == 'head' and len(source)+len(target)+3 > len_limit:
309 | if len(source)>len_limit_s and len(target)>len_limit_t:
310 | source = source[0:len_limit_s]
311 | target = target[0:len_limit_t]
312 | elif len(source)>len_limit_s:
313 | source = source[0:len_limit-3-len(target)]
314 | elif len(target)>len_limit_t:
315 | target = target[0:len_limit-3-len(source)]
316 |
317 | if clip == 'tail' and len(source)+len(target)+3 > len_limit:
318 | if len(source)>len_limit_s and len(target)>len_limit_t:
319 | source = source[-len_limit_s:]
320 | target = target[-len_limit_t:]
321 | elif len(source)>len_limit_s:
322 | source = source[-(len_limit-3-len(target)):]
323 | elif len(target)>len_limit_t:
324 | target = target[-(len_limit-3-len(source)):]
325 |
326 | # check if the total length is within the limit
327 | assert len(source)+len(target)+3 <= len_limit
328 |
329 | # [CLS]:101, [SEP]:102
330 | input_ids = [101] + source + [102] + target + [102]
331 | input_types = [0]*(len(source)+2) + [1]*(len(target)+1)
332 |
333 | assert len(input_ids) <= len_limit and len(input_types) <= len_limit
334 | self.total_input_ids.append(input_ids)
335 | self.total_input_types.append(input_types)
336 |
337 | self.max_input_len = max([len(s) for s in self.total_input_ids])
338 | print("max length: ", self.max_input_len)
339 |
340 | def __len__(self):
341 | return len(self.total_input_ids)
342 |
343 | def __getitem__(self, idx):
344 | if self.is_train:
345 | input_ids = pad_to_maxlen(self.total_input_ids[idx], self.max_input_len)
346 | input_types = pad_to_maxlen(self.total_input_types[idx], self.max_input_len)
347 | label = int(self.labels[idx])
348 | sample_type = int(self.sample_types[idx])
349 | # print(len(input_ids), len(input_types), label)
350 | return torch.LongTensor(input_ids), torch.LongTensor(input_types), torch.LongTensor([label]), sample_type
351 |
352 | else:
353 | input_ids = pad_to_maxlen(self.total_input_ids[idx], self.max_input_len)
354 | input_types = pad_to_maxlen(self.total_input_types[idx], self.max_input_len)
355 | index = self.ids[idx]
356 | sample_type = int(self.sample_types[idx])
357 | return torch.LongTensor(input_ids), torch.LongTensor(input_types), index, sample_type
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/infer.py:
--------------------------------------------------------------------------------
1 | from model import BertClassifierSingleModel, NezhaClassifierSingleModel, SBERTSingleModel, SNEZHASingleModel, BertClassifierTextCNNSingleModel
2 |
3 | from data import SentencePairDatasetWithType, SentencePairDatasetForSBERT
4 | from config import Config
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.utils.data import DataLoader
9 |
10 | import numpy as np
11 | from sklearn import metrics
12 | from tqdm import tqdm
13 |
14 | import os
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'
16 |
17 | def infer(model, device, dev_dataloader, test_dataloader, search_thres=True, threshold_fixed_a=0.5, threshold_fixed_b=0.5, save_valid=True):
18 | print("Inferring")
19 | model.eval()
20 |
21 | if torch.cuda.device_count() > 1:
22 | model = torch.nn.DataParallel(model)
23 |
24 | total_gt_a, total_preds_a, total_probs_a = [], [], []
25 | total_gt_b, total_preds_b, total_probs_b = [], [], []
26 |
27 | print("Model running on dev set...")
28 | for idx, batch in enumerate(tqdm(dev_dataloader)):
29 | input_ids, input_types, labels, types = batch
30 | input_ids = input_ids.to(device)
31 | input_types = input_types.to(device)
32 | # labels should be flattened
33 | labels = labels.to(device).view(-1)
34 |
35 | with torch.no_grad():
36 | all_probs = model(input_ids, input_types)
37 | num_tasks = len(all_probs)
38 |
39 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)]
40 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)]
41 | all_labels = [labels[all_masks[task_id]] for task_id in range(num_tasks)]
42 |
43 | all_gt = [all_labels[task_id].cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)]
44 | all_preds = [all_output[task_id].argmax(axis=1).cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)]
45 |
46 | gt_a, preds_a, probs_a = [], [], []
47 | for task_id in range(0, num_tasks, 2):
48 | gt_a += all_gt[task_id]
49 | preds_a += all_preds[task_id]
50 | probs_a += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()]
51 |
52 | gt_b, preds_b, probs_b = [], [], []
53 | for task_id in range(1, num_tasks, 2):
54 | gt_b += all_gt[task_id]
55 | preds_b += all_preds[task_id]
56 | probs_b += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()]
57 |
58 | total_gt_a += gt_a
59 | total_preds_a += preds_a
60 | total_probs_a += probs_a
61 |
62 | total_gt_b += gt_b
63 | total_preds_b += preds_b
64 | total_probs_b += probs_b
65 |
66 | if search_thres:
67 | # search for the optimal threshold
68 | print("Searching for the best threshold on valid dataset...")
69 | thresholds = np.arange(0.2, 0.9, 0.01)
70 | fscore_a = np.zeros(shape=(len(thresholds)))
71 | fscore_b = np.zeros(shape=(len(thresholds)))
72 | print('Length of sequence: {}'.format(len(thresholds)))
73 |
74 | print("Original F1 Score for Task A: {}".format(str(metrics.f1_score(total_gt_a, total_preds_a, zero_division=0))))
75 | if len(total_gt_a) != 0:
76 | print("\tClassification Report\n")
77 | print(metrics.classification_report(total_gt_a, total_preds_a))
78 |
79 | print("Original F1 Score for Task B: {}".format(str(metrics.f1_score(total_gt_b, total_preds_b, zero_division=0))))
80 | if len(total_gt_b) != 0:
81 | print("\tClassification Report\n")
82 | print(metrics.classification_report(total_gt_b, total_preds_b))
83 |
84 | for index, thres in enumerate(tqdm(thresholds)):
85 | y_pred_prob_a = (np.array(total_probs_a) > thres).astype('int')
86 | fscore_a[index] = metrics.f1_score(total_gt_a, y_pred_prob_a.tolist(), zero_division=0)
87 |
88 | y_pred_prob_b = (np.array(total_probs_b) > thres).astype('int')
89 | fscore_b[index] = metrics.f1_score(total_gt_b, y_pred_prob_b.tolist(), zero_division=0)
90 |
91 | # record the optimal threshold for task A
92 | # print(fscore_a)
93 | index_a = np.argmax(fscore_a)
94 | threshold_opt_a = round(thresholds[index_a], ndigits=4)
95 | f1_score_opt_a = round(fscore_a[index_a], ndigits=6)
96 | print('Best Threshold for Task A: {} with F-Score: {}'.format(threshold_opt_a, f1_score_opt_a))
97 | # print("\nThreshold Classification Report\n")
98 | # print(metrics.classification_report(total_gt_a, (np.array(total_probs_a) > threshold_opt_a).astype('int').tolist()))
99 |
100 | # record the optimal threshold for task B
101 | index_b = np.argmax(fscore_b)
102 | threshold_opt_b = round(thresholds[index_b], ndigits=4)
103 | f1_score_opt_b = round(fscore_b[index_b], ndigits=6)
104 | print('Best Threshold for Task B: {} with F-Score: {}'.format(threshold_opt_b, f1_score_opt_b))
105 | # print("\nThreshold Classification Report\n")
106 | # print(metrics.classification_report(total_gt_b, (np.array(total_probs_b) > threshold_opt_b).astype('int').tolist()))
107 |
108 | if save_valid:
109 | y_pred_prob_a = (np.array(total_probs_a) > threshold_opt_a).astype('int')
110 | y_pred_prob_b = (np.array(total_probs_b) > threshold_opt_b).astype('int')
111 | # index of valid and valid_rematch
112 | # ssa, sla, lla = y_pred_prob_a[0:3395], y_pred_prob_a[3395:7681], y_pred_prob_a[7681:]
113 | # gt_ssa, gt_sla, gt_lla = total_gt_a[0:3395], total_gt_a[3395:7681], total_gt_a[7681:]
114 | # ssb, slb, llb = y_pred_prob_b[0:3393], y_pred_prob_b[3393:7684], y_pred_prob_b[7684:]
115 | # gt_ssb, gt_slb, gt_llb = total_gt_b[0:3393], total_gt_b[3393:7684], total_gt_b[7684:]
116 |
117 | # valid_rematch only
118 | ssa, sla, lla = y_pred_prob_a[0:1750], y_pred_prob_a[1750:4380], y_pred_prob_a[4380:]
119 | gt_ssa, gt_sla, gt_lla = total_gt_a[0:1750], total_gt_a[1750:4380], total_gt_a[4380:]
120 | ssb, slb, llb = y_pred_prob_b[0:1750], y_pred_prob_b[1750:4385], y_pred_prob_b[4385:]
121 | gt_ssb, gt_slb, gt_llb = total_gt_b[0:1750], total_gt_b[1750:4385], total_gt_b[4385:]
122 | print("f1 on ssa: ", metrics.f1_score(gt_ssa, ssa))
123 | print("f1 on sla: ", metrics.f1_score(gt_sla, sla))
124 | print("f1 on lla: ", metrics.f1_score(gt_lla, lla))
125 | print("f1 on ssb: ", metrics.f1_score(gt_ssb, ssb))
126 | print("f1 on slb: ", metrics.f1_score(gt_slb, slb))
127 | print("f1 on llb: ", metrics.f1_score(gt_llb, llb))
128 |
129 | np.save('../valid_output/{}_pred_a.npy'.format(model_type), y_pred_prob_a)
130 | np.save('../valid_output/{}_pred_b.npy'.format(model_type), y_pred_prob_b)
131 | np.save('../valid_output/gt_a.npy', np.array(total_gt_a))
132 | np.save('../valid_output/gt_b.npy', np.array(total_gt_b))
133 |
134 | total_ids_a, total_probs_a = [], []
135 | total_ids_b, total_probs_b = [], []
136 | for idx, batch in enumerate(tqdm(test_dataloader)):
137 | input_ids, input_types, ids, types = batch
138 | input_ids = input_ids.to(device)
139 | input_types = input_types.to(device)
140 |
141 | # the probs given by the model, without grads
142 | with torch.no_grad():
143 | # probs_a, probs_b = model(input_ids, input_types)
144 | # mask_a, mask_b = (types==0).numpy(), (types==1).numpy()
145 |
146 | all_probs = model(input_ids, input_types)
147 | num_tasks = len(all_probs)
148 | # mask_a, mask_b = (types==0).numpy(), (types==1).numpy()
149 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)]
150 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)]
151 |
152 | total_ids_a += [id for id in ids if id.endswith('a')]
153 | total_ids_b += [id for id in ids if id.endswith('b')]
154 |
155 | gt_a, preds_a, probs_a = [], [], []
156 | for task_id in range(0, num_tasks, 2):
157 | probs_a += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()]
158 |
159 | gt_b, preds_b, probs_b = [], [], []
160 | for task_id in range(1, num_tasks, 2):
161 | probs_b += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()]
162 |
163 | total_probs_a += probs_a
164 | total_probs_b += probs_b
165 |
166 | # positive if the prob passes the original threshold of 0.5
167 | total_fixed_preds_a = (np.array(total_probs_a) > threshold_fixed_a).astype('int').tolist()
168 | total_fixed_preds_b = (np.array(total_probs_b) > threshold_fixed_b).astype('int').tolist()
169 |
170 | if search_thres:
171 | # positive if the prob passes the optimal threshold
172 | total_preds_a = (np.array(total_probs_a) > threshold_opt_a).astype('int').tolist()
173 | total_preds_b = (np.array(total_probs_b) > threshold_opt_b).astype('int').tolist()
174 | else:
175 | total_preds_a = None
176 | total_preds_b = None
177 |
178 | return total_ids_a, total_preds_a, total_fixed_preds_a, \
179 | total_ids_b, total_preds_b, total_fixed_preds_b
180 |
181 | if __name__=='__main__':
182 | config = Config()
183 | device = config.device
184 | dummy_pretrained = config.dummy_pretrained
185 | model_type = config.infer_model_name
186 |
187 | save_dir = config.infer_model_dir
188 | model_name = config.infer_model_name
189 | hidden_size = config.hidden_size
190 | output_dir= config.infer_output_dir
191 | output_filename = config.infer_output_filename
192 | data_dir = config.data_dir
193 | task_a = ['短短匹配A类', '短长匹配A类', '长长匹配A类']
194 | task_b = ['短短匹配B类', '短长匹配B类', '长长匹配B类']
195 | task_type = config.infer_task_type
196 |
197 | infer_bs = config.infer_bs
198 | search_thres = config.infer_search_thres
199 | threshold_fixed_a = config.infer_fixed_thres_a
200 | threshold_fixed_b = config.infer_fixed_thres_b
201 | # method for clipping long seqeunces, 'head' or 'tail'
202 | clip_method = config.infer_clip_method
203 |
204 | dev_data_dir, test_data_dir = [], []
205 | if 'a' in task_type:
206 | for task in task_a:
207 | # dev_data_dir.append(data_dir + task + '/valid.txt')
208 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt')
209 | test_data_dir.append(data_dir + task + '/test_with_id_rematch.txt')
210 | if 'b' in task_type:
211 | for task in task_b:
212 | # dev_data_dir.append(data_dir + task + '/valid.txt')
213 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt')
214 | test_data_dir.append(data_dir + task + '/test_with_id_rematch.txt')
215 |
216 | print("Loading Bert Model from {}...".format(save_dir + model_name))
217 | # distinguish model architectures or pretrained models according to model_type
218 | if 'sbert' in model_type.lower():
219 | print("Using SentenceBERT model and dataset")
220 | if 'nezha' in model_type.lower():
221 | model = SNEZHASingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size)
222 | else:
223 | model = SBERTSingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size)
224 |
225 | model_dict = torch.load(save_dir + model_name)
226 | # model_dict = torch.load(save_dir + model_name)
227 | # weights will be saved in module when DataParallel
228 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()})
229 | model.to(device)
230 |
231 | print("Loading Dev Data...")
232 | dev_dataset = SentencePairDatasetForSBERT(dev_data_dir, True, dummy_pretrained, clip=clip_method)
233 | dev_dataloader = DataLoader(dev_dataset, batch_size=infer_bs, shuffle=False)
234 |
235 | print("Loading Test Data...")
236 | # for test dataset, is_train should be set to False, thus get ids instead of labels
237 | test_dataset = SentencePairDatasetForSBERT(test_data_dir, False, dummy_pretrained, clip=clip_method)
238 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False)
239 |
240 | else:
241 | print("Using BERT model and dataset")
242 | if 'nezha' in model_type.lower():
243 | print("Using NEZHA pretrained model")
244 | model = NezhaClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size)
245 | elif 'cnn' in model_type.lower():
246 | print("Adding TextCNN after BERT output")
247 | model = BertClassifierTextCNNSingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size)
248 | else:
249 | print("Using conventional BERT model with linears")
250 | model = BertClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size)
251 |
252 | model_dict = torch.load(save_dir + model_name)
253 | # weights will be saved in module when DataParallel
254 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()})
255 | model.to(device)
256 |
257 | # model_dict = torch.load(save_dir + model_name).module.state_dict()
258 | # model.load_state_dict(model_dict)
259 | # model.to(device)
260 |
261 | print("Loading Dev Data...")
262 | dev_dataset = SentencePairDatasetWithType(dev_data_dir, True, dummy_pretrained, clip=clip_method)
263 | dev_dataloader = DataLoader(dev_dataset, batch_size=infer_bs, shuffle=False)
264 |
265 | print("Loading Test Data...")
266 | # for test dataset, is_train should be set to False, thus get ids instead of labels
267 | test_dataset = SentencePairDatasetWithType(test_data_dir, False, dummy_pretrained, clip=clip_method)
268 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False)
269 |
270 | total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer(model, device, dev_dataloader, test_dataloader, search_thres, threshold_fixed_a, threshold_fixed_b)
271 |
272 | with open(output_dir + 'fixed_' + output_filename, 'w') as f_out:
273 | for id, pred in zip(total_ids_a, total_fixed_preds_a):
274 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
275 | for id, pred in zip(total_ids_b, total_fixed_preds_b):
276 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
277 |
278 | if total_preds_a is not None:
279 | with open(output_dir + output_filename, 'w') as f_out:
280 | for id, pred in zip(total_ids_a, total_preds_a):
281 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
282 | for id, pred in zip(total_ids_b, total_preds_b):
283 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/infer_final.py:
--------------------------------------------------------------------------------
1 | from model import BertClassifierSingleModel, NezhaClassifierSingleModel, SBERTSingleModel, SNEZHASingleModel, BertClassifierTextCNNSingleModel
2 | from data import SentencePairDatasetWithType, SentencePairDatasetForSBERT
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.data import DataLoader
7 | import numpy as np
8 | from tqdm import tqdm
9 | from argparse import ArgumentParser
10 |
11 | import os
12 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
13 | import time
14 |
15 | def infer_final(model, device, test_dataloader, threshold_fixed_a=0.5, threshold_fixed_b=0.5):
16 | print("Inferring for final stage")
17 | model.eval()
18 |
19 | # as only one GPU is available for the final stage
20 | # if torch.cuda.device_count() > 1:
21 | # model = torch.nn.DataParallel(model)
22 |
23 | total_ids_a, total_probs_a = [], []
24 | total_ids_b, total_probs_b = [], []
25 | for idx, batch in enumerate(tqdm(test_dataloader)):
26 | input_ids, input_types, ids, types = batch
27 | input_ids = input_ids.to(device)
28 | input_types = input_types.to(device)
29 |
30 | # the probs given by the model, without grads
31 | with torch.no_grad():
32 | all_probs = model(input_ids, input_types)
33 | num_tasks = len(all_probs)
34 |
35 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)]
36 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)]
37 |
38 | total_ids_a += [id for id in ids if id.endswith('a')]
39 | total_ids_b += [id for id in ids if id.endswith('b')]
40 |
41 | probs_a, probs_b = [], []
42 | for task_id in range(0, num_tasks, 2):
43 | probs_a += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()]
44 |
45 | for task_id in range(1, num_tasks, 2):
46 | probs_b += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()]
47 |
48 | total_probs_a += probs_a
49 | total_probs_b += probs_b
50 |
51 | # positive if the prob passes the original threshold of 0.5
52 | total_fixed_preds_a = (np.array(total_probs_a) > threshold_fixed_a).astype('int').tolist()
53 | total_fixed_preds_b = (np.array(total_probs_b) > threshold_fixed_b).astype('int').tolist()
54 |
55 | total_preds_a = None
56 | total_preds_b = None
57 |
58 | return total_ids_a, total_preds_a, total_fixed_preds_a, \
59 | total_ids_b, total_preds_b, total_fixed_preds_b
60 |
61 | if __name__=='__main__':
62 | s_time = time.time()
63 |
64 | parser = ArgumentParser()
65 | parser.add_argument("-i","--input", type=str, required=True, help="输入文件")
66 | parser.add_argument("-o","--output", type=str, required=True, help="输出文件")
67 | args = parser.parse_args()
68 | input_dir = args.input
69 | output_dir= args.output
70 |
71 | device = 'cuda'
72 | data_dir = '../data/sohu2021_open_data/'
73 | save_dir = '../checkpoints/rematch/'
74 | result_dir = '../results/final/'
75 | bert_tokenizer_config = '../data/dummy_bert/' # as NEZHA, MACBERT and ROBERTA share the same tokenizer vocabulary
76 | ernie_tokenizer_config = '../data/dummy_ernie/' # unfortunately, ERNIE has its unique vocabulary, should load dataset again
77 |
78 | # only use test dataloader for final stage
79 | # the test file will be in one file
80 | test_data_dir = [input_dir]
81 | bert_model_configs = [
82 | # model_name, dummy_pretrained, threshold_fixed_a, threshold_fixed_b, infer_bs
83 | ('0520_roberta_80k_same_lr_zy_epoch_1_ab_loss', '../data/dummy_bert/', 0.4, 0.3, 128),
84 | ('0518_macbert_same_lr_epoch_1_ab_loss', '../data/dummy_bert/', 0.37, 0.39, 128),
85 | ('0523_roberta_dataaug_epoch_0_ab_loss', '../data/dummy_bert/', 0.41, 0.48, 128)
86 | ]
87 |
88 | ernie_model_configs = [
89 | ('0523_ernie_epoch_1_ab_loss', '../data/dummy_ernie/', 0.42, 0.39, 128),
90 | ]
91 |
92 | sbert_model_configs = [
93 | ('0520_roberta_sbert_same_lr_epoch_1_ab_loss', '../data/dummy_bert/', 0.4, 0.36, 128)
94 | ]
95 |
96 | # We will first infer for the bert-style models
97 | if len(bert_model_configs) != 0:
98 | print("Loading Test Data for BERT models...")
99 | test_dataset = SentencePairDatasetWithType(test_data_dir, False, bert_tokenizer_config)
100 |
101 | for model_config in bert_model_configs:
102 | model_name, dummy_pretrained, threshold_fixed_a, threshold_fixed_b, infer_bs = model_config
103 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False)
104 | print("Loading Bert Model from {}...".format(save_dir + model_name))
105 | # distinguish model architectures or pretrained models according to model_type
106 | if 'nezha' in model_name.lower():
107 | print("Using NEZHA pretrained model")
108 | model = NezhaClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False)
109 | elif 'cnn' in model_name.lower():
110 | print("Adding TextCNN after BERT output")
111 | model = BertClassifierTextCNNSingleModel(bert_dir=dummy_pretrained, from_pretrained=False)
112 | else:
113 | print("Using conventional BERT model with linears")
114 | model = BertClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False)
115 |
116 | model_dict = torch.load(save_dir + model_name)
117 | # weights will be saved in module when DataParallel
118 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()})
119 | model.to(device)
120 |
121 | # total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer(model, device, dev_dataloader, test_dataloader, search_thres, threshold_fixed_a, threshold_fixed_b)
122 | total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer_final(model, device, test_dataloader, threshold_fixed_a, threshold_fixed_b)
123 |
124 | with open(result_dir + 'final_' + '{}.csv'.format(model_name), 'w') as f_out:
125 | for id, pred in zip(total_ids_a, total_fixed_preds_a):
126 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
127 | for id, pred in zip(total_ids_b, total_fixed_preds_b):
128 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
129 |
130 | # infer for the ernie models, dataset should be reloaded for ernie's vocabulary
131 | if len(ernie_model_configs) != 0:
132 | print("Loading Test Data for ERNIE models...")
133 | test_dataset = SentencePairDatasetWithType(test_data_dir, False, ernie_tokenizer_config)
134 |
135 | for model_config in ernie_model_configs:
136 | model_name, dummy_pretrained, threshold_fixed_a, threshold_fixed_b, infer_bs = model_config
137 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False)
138 | print("Loading Bert Model from {}...".format(save_dir + model_name))
139 | # distinguish model architectures or pretrained models according to model_type
140 | if 'nezha' in model_name.lower():
141 | print("Using NEZHA pretrained model")
142 | model = NezhaClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False)
143 | elif 'cnn' in model_name.lower():
144 | print("Adding TextCNN after BERT output")
145 | model = BertClassifierTextCNNSingleModel(bert_dir=dummy_pretrained, from_pretrained=False)
146 | else:
147 | print("Using conventional BERT model with linears")
148 | model = BertClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False)
149 |
150 | model_dict = torch.load(save_dir + model_name)
151 | # weights will be saved in module when DataParallel
152 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()})
153 | model.to(device)
154 |
155 | # total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer(model, device, dev_dataloader, test_dataloader, search_thres, threshold_fixed_a, threshold_fixed_b)
156 | total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer_final(model, device, test_dataloader, threshold_fixed_a, threshold_fixed_b)
157 |
158 | with open(result_dir + 'final_' + '{}.csv'.format(model_name), 'w') as f_out:
159 | for id, pred in zip(total_ids_a, total_fixed_preds_a):
160 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
161 | for id, pred in zip(total_ids_b, total_fixed_preds_b):
162 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
163 |
164 | # infer for SBERT models
165 | if len(sbert_model_configs) != 0:
166 | print("Loading Test Data for SBERT models...")
167 | # for test dataset, is_train should be set to False, thus get ids instead of labels
168 | test_dataset = SentencePairDatasetForSBERT(test_data_dir, False, bert_tokenizer_config)
169 |
170 | for model_config in sbert_model_configs:
171 | model_name, dummy_pretrained, threshold_fixed_a, threshold_fixed_b, infer_bs = model_config
172 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False)
173 | print("Loading SentenceBert Model from {}...".format(save_dir + model_name))
174 | # distinguish model architectures or pretrained models according to model_type
175 | if 'nezha' in model_name.lower():
176 | model = SNEZHASingleModel(bert_dir=dummy_pretrained, from_pretrained=False)
177 | else:
178 | model = SBERTSingleModel(bert_dir=dummy_pretrained, from_pretrained=False)
179 |
180 | model_dict = torch.load(save_dir + model_name)
181 | # weights will be saved in module when training on multiple GPUs with DataParallel
182 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()})
183 | model.to(device)
184 |
185 | # total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer(model, device, dev_dataloader, test_dataloader, search_thres, threshold_fixed_a, threshold_fixed_b)
186 | total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer_final(model, device, test_dataloader, threshold_fixed_a, threshold_fixed_b)
187 |
188 | with open(result_dir + 'final_' + '{}.csv'.format(model_name), 'w') as f_out:
189 | for id, pred in zip(total_ids_a, total_fixed_preds_a):
190 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
191 | for id, pred in zip(total_ids_b, total_fixed_preds_b):
192 | f_out.writelines(str(id) + ',' + str(pred) + '\n')
193 |
194 | # finally, merge all the output files in output_dir
195 | print("Merging the model outputs...")
196 | result_list = [filename for filename in os.listdir(result_dir) if filename.endswith('.csv')]
197 | result_dict = {}
198 | for name in result_list:
199 | with open(result_dir + name, "r", encoding="utf-8") as fr:
200 | for line in fr:
201 | words = line.strip().split(",")
202 | if words[0] == "id":
203 | continue
204 | if words[0] not in result_dict:
205 | result_dict[words[0]] = [words[1]]
206 | else:
207 | result_dict[words[0]].append(words[1])
208 |
209 | # merging the outputs into final csv file
210 | with open(output_dir, "w", encoding="utf-8") as fw:
211 | fw.write("id,label"+"\n")
212 | for k, v in result_dict.items():
213 | tmp = {}
214 | for ele in v:
215 | if ele in tmp:
216 | tmp[ele] += 1
217 | else:
218 | tmp[ele] = 1
219 | tmp = sorted(tmp.items(), key=lambda d: d[1], reverse=True)
220 | fw.write(",".join([k, tmp[0][0]]) + "\n")
221 |
222 | e_time = time.time()
223 | print("Time taken: ", e_time - s_time)
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/merge_result.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | if __name__ == '__main__':
4 | result_dir = '../results/rematch/'
5 |
6 | target_files = [
7 | '0520_roberta_80k_same_lr_zy_epoch_1_ab_loss',
8 | '0518_macbert_same_lr_epoch_1_ab_loss',
9 | '0520_roberta_sbert_same_lr_epoch_1_ab_loss',
10 | '0523_roberta_dataaug_epoch_0_ab_loss',
11 | '0523_ernie_epoch_1_ab_loss'
12 | ]
13 | # 0.7931380664848722
14 |
15 | # target_files = [
16 | # '0518_roberta_same_lr_epoch_1_ab_loss',
17 | # '0519_nezha_same_lr_epoch_1_ab_f1',
18 | # '0518_nezha_diff_lr_zy_epoch_1_ab_loss',
19 | # '0518_macbert_same_lr_epoch_1_ab_los',
20 | # '0523_roberta_dataaug_epoch_0_ab_loss'
21 | # ]
22 | # # 0.7930518678397445
23 |
24 | result_list = [file_name+'.csv' for file_name in target_files]
25 | result_dict = {}
26 | for name in result_list:
27 | with open(result_dir + name, "r", encoding="utf-8") as fr:
28 | for line in fr:
29 | words = line.strip().split(",")
30 | if words[0] == "id":
31 | continue
32 | if words[0] not in result_dict:
33 | result_dict[words[0]] = [words[1]]
34 | else:
35 | result_dict[words[0]].append(words[1])
36 |
37 | with open(result_dir+"merge.csv", "w", encoding="utf-8") as fw:
38 | fw.write("id,label"+"\n")
39 | for k, v in result_dict.items():
40 | tmp = {}
41 | for ele in v:
42 | if ele in tmp:
43 | tmp[ele] += 1
44 | else:
45 | tmp[ele] = 1
46 | tmp = sorted(tmp.items(), key=lambda d: d[1], reverse=True)
47 | # print(tmp)
48 | fw.write(",".join([k, tmp[0][0]]) + "\n")
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/model.py:
--------------------------------------------------------------------------------
1 | from transformers import BertModel, BertConfig
2 | import torch
3 | import torch.nn as nn
4 | from torch.utils.data import DataLoader
5 | import math
6 | from data import SentencePairDatasetWithType
7 |
8 | # import files for NEZHA models
9 | from NEZHA.model_nezha import NezhaConfig, NEZHAModel
10 | from NEZHA import nezha_utils
11 |
12 | import os
13 | os.environ['CUDA_VISIBLE_DEVICES'] = '3'
14 |
15 | # basic BERT-like models
16 | class BertClassifierSingleModel(nn.Module):
17 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size = 768, mid_size=512, freeze = False):
18 | super(BertClassifierSingleModel, self).__init__()
19 | self.hidden_size = hidden_size
20 | # could extended to multiple tasks setting, e.g. 6 classifiers for 6 subtasks
21 | self.task_num = task_num
22 |
23 | if from_pretrained:
24 | print("Initialize BERT from pretrained weights")
25 | self.bert = BertModel.from_pretrained(bert_dir)
26 | else:
27 | print("Initialize BERT from config.json, weight NOT loaded")
28 | self.bert_config = BertConfig.from_json_file(bert_dir+'config.json')
29 | self.bert = BertModel(self.bert_config)
30 | self.dropout = nn.Dropout(0.5)
31 |
32 | self.all_classifier = nn.ModuleList([
33 | nn.Sequential(
34 | nn.Linear(hidden_size, mid_size),
35 | nn.BatchNorm1d(mid_size),
36 | nn.ReLU(),
37 | nn.Dropout(0.5),
38 | nn.Linear(mid_size, 2)
39 | )
40 | for _ in range(self.task_num)
41 | ])
42 |
43 | def forward(self, input_ids, input_types):
44 | # get shared BERT model output
45 | mask = torch.ne(input_ids, 0)
46 | bert_output = self.bert(input_ids, token_type_ids=input_types, attention_mask=mask)
47 | cls_embed = bert_output[1]
48 | output = self.dropout(cls_embed)
49 |
50 | # get probs for two tasks A and B
51 | all_probs = [classifier(output) for classifier in self.all_classifier]
52 | return all_probs
53 |
54 | class NezhaClassifierSingleModel(nn.Module):
55 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size = 768, mid_size=512, freeze = False):
56 | super(NezhaClassifierSingleModel, self).__init__()
57 | self.hidden_size = hidden_size
58 | self.task_num = task_num
59 |
60 | self.bert_config = NezhaConfig.from_json_file(bert_dir+'config.json')
61 | self.bert = NEZHAModel(config=self.bert_config)
62 | if from_pretrained:
63 | print("Initialize NEZHA from config.json, weight NOT loaded")
64 | nezha_utils.torch_init_model(self.bert, bert_dir+'pytorch_model.bin')
65 |
66 | self.dropout = nn.Dropout(0.5)
67 | self.all_classifier = nn.ModuleList([
68 | nn.Sequential(
69 | nn.Linear(hidden_size, mid_size),
70 | nn.BatchNorm1d(mid_size),
71 | nn.ReLU(),
72 | nn.Dropout(0.5),
73 | nn.Linear(mid_size, 2)
74 | )
75 | for _ in range(self.task_num)
76 | ])
77 |
78 | def forward(self, input_ids, input_types):
79 | # get shared BERT model output
80 | mask = torch.ne(input_ids, 0)
81 | bert_output = self.bert(input_ids, token_type_ids=input_types, attention_mask=mask)
82 | cls_embed = bert_output[1]
83 | output = self.dropout(cls_embed)
84 |
85 | # get probs for two tasks A and B
86 | all_probs = [classifier(output) for classifier in self.all_classifier]
87 | return all_probs
88 |
89 | class SBERTSingleModel(nn.Module):
90 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size=768, mid_size=512, freeze = False):
91 | super(SBERTSingleModel, self).__init__()
92 | self.hidden_size = hidden_size
93 | self.task_num = task_num
94 |
95 | if from_pretrained:
96 | print("Initialize BERT from pretrained weights")
97 | self.bert = BertModel.from_pretrained(bert_dir)
98 | else:
99 | print("Initialize BERT from config.json, weight NOT loaded")
100 | self.bert_config = BertConfig.from_json_file(bert_dir+'config.json')
101 | self.bert = BertModel(self.bert_config)
102 |
103 | self.dropout = nn.Dropout(0.5)
104 | self.all_classifier = nn.ModuleList([
105 | nn.Sequential(
106 | nn.Linear(hidden_size*3, mid_size),
107 | nn.BatchNorm1d(mid_size),
108 | nn.ReLU(),
109 | nn.Dropout(0.5),
110 | nn.Linear(mid_size, 2)
111 | )
112 | for _ in range(self.task_num)
113 | ])
114 |
115 | def forward(self, source_input_ids, target_input_ids):
116 | # 0 for [PAD], mask out the padded values
117 | source_attention_mask = torch.ne(source_input_ids, 0)
118 | target_attention_mask = torch.ne(target_input_ids, 0)
119 |
120 | # get bert output
121 | source_embedding = self.bert(source_input_ids, attention_mask=source_attention_mask)
122 | target_embedding = self.bert(target_input_ids, attention_mask=target_attention_mask)
123 |
124 | # simply take out the [CLS] represention
125 | # TODO: try different pooling strategies
126 | source_embedding = source_embedding[1]
127 | target_embedding = target_embedding[1]
128 |
129 | # concat the source embedding, target embedding and abs embedding as in the original SBERT paper
130 | abs_embedding = torch.abs(source_embedding-target_embedding)
131 | context_embedding = torch.cat([source_embedding, target_embedding, abs_embedding], -1)
132 | context_embedding = self.dropout(context_embedding)
133 |
134 | # get probs for two tasks A and B
135 | all_probs = [classifier(context_embedding) for classifier in self.all_classifier]
136 | return all_probs
137 |
138 | class SNEZHASingleModel(nn.Module):
139 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size=768, mid_size=512, freeze = False):
140 | super(SNEZHASingleModel, self).__init__()
141 | self.hidden_size = hidden_size
142 | self.task_num = task_num
143 |
144 | self.bert_config = NezhaConfig.from_json_file(bert_dir+'config.json')
145 | self.bert = NEZHAModel(config=self.bert_config)
146 | if from_pretrained:
147 | print("Initialize NEZHA from config.json, weight NOT loaded")
148 | nezha_utils.torch_init_model(self.bert, bert_dir+'pytorch_model.bin')
149 |
150 | self.dropout = nn.Dropout(0.5)
151 | self.all_classifier = nn.ModuleList([
152 | nn.Sequential(
153 | nn.Linear(hidden_size*3, mid_size),
154 | nn.BatchNorm1d(mid_size),
155 | nn.ReLU(),
156 | nn.Dropout(0.5),
157 | nn.Linear(mid_size, 2)
158 | )
159 | for _ in range(self.task_num)
160 | ])
161 |
162 | def forward(self, source_input_ids, target_input_ids):
163 | # 0 for [PAD], mask out the padded values
164 | source_attention_mask = torch.ne(source_input_ids, 0)
165 | target_attention_mask = torch.ne(target_input_ids, 0)
166 |
167 | # get bert output
168 | source_embedding = self.bert(source_input_ids, attention_mask=source_attention_mask)
169 | target_embedding = self.bert(target_input_ids, attention_mask=target_attention_mask)
170 |
171 | # simply take out the [CLS] represention
172 | # TODO: try different pooling strategies
173 | source_embedding = source_embedding[1]
174 | target_embedding = target_embedding[1]
175 |
176 | # concat the source embedding, target embedding and abs embedding as in the original SBERT paper
177 | abs_embedding = torch.abs(source_embedding-target_embedding)
178 | context_embedding = torch.cat([source_embedding, target_embedding, abs_embedding], -1)
179 | context_embedding = self.dropout(context_embedding)
180 |
181 | # get probs for two tasks A and B
182 | all_probs = [classifier(context_embedding) for classifier in self.all_classifier]
183 | return all_probs
184 |
185 | class BertClassifierTextCNNSingleModel(nn.Module):
186 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size = 768, mid_size=512, freeze = False):
187 | super(BertClassifierTextCNNSingleModel, self).__init__()
188 | self.hidden_size = hidden_size
189 | self.task_num = task_num
190 |
191 | if from_pretrained:
192 | print("Initialize BERT from pretrained weights")
193 | self.bert = BertModel.from_pretrained(bert_dir)
194 | else:
195 | print("Initialize BERT from config.json, weight NOT loaded")
196 | self.bert_config = BertConfig.from_json_file(bert_dir+'config.json')
197 | self.bert = BertModel(self.bert_config)
198 |
199 | self.dropout = nn.Dropout(0.5)
200 |
201 | # for TextCNN
202 | filter_num = 128
203 | filter_sizes = [2,3,4]
204 | self.convs = nn.ModuleList(
205 | [nn.Conv2d(1, filter_num, (size, hidden_size)) for size in filter_sizes])
206 |
207 | self.all_classifier = nn.ModuleList([
208 | nn.Sequential(
209 | nn.Linear(len(filter_sizes) * filter_num, mid_size),
210 | nn.BatchNorm1d(mid_size),
211 | nn.ReLU(),
212 | nn.Dropout(0.5),
213 | nn.Linear(mid_size, 2)
214 | )
215 | for _ in range(self.task_num)
216 | ])
217 |
218 | def forward(self, input_ids, input_types):
219 | # get shared BERT model output
220 | mask = torch.ne(input_ids, 0)
221 | bert_output = self.bert(input_ids, token_type_ids=input_types, attention_mask=mask)
222 | bert_hidden = bert_output[0]
223 | output = self.dropout(bert_hidden)
224 |
225 | tcnn_input = output.unsqueeze(1)
226 | tcnn_output = [nn.functional.relu(conv(tcnn_input)).squeeze(3) for conv in self.convs]
227 | # max pooling in TextCNN
228 | # TODO: support avg pooling
229 | tcnn_output = [nn.functional.max_pool1d(item, item.size(2)).squeeze(2) for item in tcnn_output]
230 | tcnn_output = torch.cat(tcnn_output, 1)
231 | tcnn_output = self.dropout(tcnn_output)
232 |
233 | # get probs for two tasks A and B
234 | all_probs = [classifier(tcnn_output) for classifier in self.all_classifier]
235 | return all_probs
236 |
237 | class BertCoAttention(nn.Module):
238 | def __init__(self, config):
239 | super(BertCoAttention, self).__init__()
240 | if config.hidden_size % config.num_attention_heads != 0:
241 | raise ValueError(
242 | "The hidden size (%d) is not a multiple of the number of attention "
243 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
244 | self.output_attentions = config.output_attentions
245 |
246 | self.num_attention_heads = config.num_attention_heads
247 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
248 | self.all_head_size = self.num_attention_heads * self.attention_head_size
249 |
250 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
251 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
252 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
253 |
254 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
255 |
256 | def transpose_for_scores(self, x):
257 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
258 | x = x.view(*new_x_shape)
259 | return x.permute(0, 2, 1, 3)
260 |
261 | def forward(self, context_states, query_states, attention_mask=None, head_mask=None, encoder_hidden_states=None,
262 | encoder_attention_mask=None):
263 | mixed_query_layer = self.query(query_states)
264 |
265 | extended_attention_mask = attention_mask[:, None, None, :]
266 | extended_attention_mask = extended_attention_mask.float() # fp16 compatibility
267 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
268 | attention_mask = extended_attention_mask
269 |
270 | # If this is instantiated as a cross-attention module, the keys
271 | # and values come from an encoder; the attention mask needs to be
272 | # such that the encoder's padding tokens are not attended to.
273 | if encoder_hidden_states is not None:
274 | mixed_key_layer = self.key(encoder_hidden_states)
275 | mixed_value_layer = self.value(encoder_hidden_states)
276 | attention_mask = encoder_attention_mask
277 | else:
278 | mixed_key_layer = self.key(context_states)
279 | mixed_value_layer = self.value(context_states)
280 |
281 | query_layer = self.transpose_for_scores(mixed_query_layer)
282 | key_layer = self.transpose_for_scores(mixed_key_layer)
283 | value_layer = self.transpose_for_scores(mixed_value_layer)
284 |
285 | # Take the dot product between "query" and "key" to get the raw attention scores.
286 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
287 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
288 | if attention_mask is not None:
289 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
290 | attention_scores = attention_scores + attention_mask
291 |
292 | # Normalize the attention scores to probabilities.
293 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
294 |
295 | # This is actually dropping out entire tokens to attend to, which might
296 | # seem a bit unusual, but is taken from the original Transformer paper.
297 | attention_probs = self.dropout(attention_probs)
298 |
299 | # Mask heads if we want to
300 | if head_mask is not None:
301 | attention_probs = attention_probs * head_mask
302 |
303 | context_layer = torch.matmul(attention_probs, value_layer)
304 |
305 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
306 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
307 | context_layer = context_layer.view(*new_context_layer_shape)
308 |
309 | # outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
310 | outputs = context_layer
311 | return outputs
312 |
313 | class SBERTCoAttentionModel(nn.Module):
314 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size=768, mid_size=512, freeze = False):
315 | super(SBERTCoAttentionModel, self).__init__()
316 | self.hidden_size = hidden_size
317 | self.task_num = task_num
318 |
319 | if from_pretrained:
320 | print("Initialize BERT from pretrained weights")
321 | self.bert = BertModel.from_pretrained(bert_dir)
322 | else:
323 | print("Initialize BERT from config.json, weight NOT loaded")
324 | self.bert_config = BertConfig.from_json_file(bert_dir+'config.json')
325 | self.bert = BertModel(self.bert_config)
326 |
327 | self.dropout = nn.Dropout(0.5)
328 | self.co_attention = BertCoAttention(hidden_size=hidden_size)
329 | self.all_classifier = nn.ModuleList([
330 | nn.Sequential(
331 | nn.Linear(hidden_size * 3, mid_size),
332 | nn.BatchNorm1d(mid_size),
333 | nn.ReLU(),
334 | nn.Dropout(0.5),
335 | nn.Linear(mid_size, 2)
336 | )
337 | for _ in range(self.task_num)
338 | ])
339 |
340 | def forward(self, source_input_ids, target_input_ids):
341 | # 0 for [PAD], mask out the padded values
342 | source_attention_mask = torch.ne(source_input_ids, 0)
343 | target_attention_mask = torch.ne(target_input_ids, 0)
344 |
345 | # get bert output
346 | source_embedding = self.bert(source_input_ids, attention_mask=source_attention_mask)
347 | target_embedding = self.bert(target_input_ids, attention_mask=target_attention_mask)
348 |
349 | source_coattention_outputs = self.co_attention(target_embedding[0], source_embedding[0], source_attention_mask)
350 | target_coattention_outputs = self.co_attention(source_embedding[0], target_embedding[0], target_attention_mask)
351 | source_coattention_embedding = source_coattention_outputs[:, 0, :]
352 | target_coattention_embedding = target_coattention_outputs[:, 0, :]
353 |
354 | # simply take out the [CLS] represention
355 | # TODO: try different pooling strategies
356 | # source_embedding = source_embedding[1]
357 | # target_embedding = target_embedding[1]
358 |
359 | # concat the source embedding, target embedding and abs embedding as in the original SBERT paper
360 | # we also add a coattention embedding as the forth embedding
361 | abs_embedding = torch.abs(source_coattention_embedding - target_coattention_embedding)
362 | context_embedding = torch.cat([source_coattention_embedding, target_coattention_embedding, abs_embedding], -1)
363 | context_embedding = self.dropout(context_embedding)
364 |
365 | # get probs for two tasks A and B
366 | all_probs = [classifier(context_embedding) for classifier in self.all_classifier]
367 | return all_probs
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/search_better_merge.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import f1_score
3 | from itertools import combinations
4 | from tqdm import tqdm
5 | # from collections import defaultdict
6 |
7 | def merge_on_valid(model_names, verbose=False):
8 | # len of valid_a: 4971, len of valid_b: 4969
9 | # total_preds_a, total_preds_b = [0]*4971, [0]*4969
10 | total_preds_a, total_preds_b = [0]*6911, [0]*6914
11 | # positive if the vote exceeds the threshold (>)
12 | threshold = len(model_names)/2
13 | for model in model_names:
14 | # print("processing model {}".format(model))
15 | preds_a, preds_b = np.load('{}_pred_a.npy'.format(model)), np.load('{}_pred_b.npy'.format(model))
16 | preds_a, preds_b = preds_a.tolist(), preds_b.tolist()
17 | assert len(total_preds_a)==len(preds_a) and len(total_preds_b)==len(preds_b)
18 | for idx, pred_a in enumerate(preds_a):
19 | total_preds_a[idx] += pred_a
20 | for idx, pred_b in enumerate(preds_b):
21 | total_preds_b[idx] += pred_b
22 | # print(len(preds_a), len(preds_b))
23 | # print(type(preds_b))
24 |
25 | total_preds_a, total_preds_b = np.array(total_preds_a), np.array(total_preds_b)
26 | vote_a, vote_b = (total_preds_a>threshold).astype('int'), (total_preds_b>threshold).astype('int')
27 | gt_a, gt_b = np.load(valid_dir + 'gt_a.npy'), np.load(valid_dir + 'gt_b.npy')
28 | # print(len(vote_a), len(vote_b))
29 | # print(len(gt_a), len(gt_b))
30 |
31 | f1a, f1b = f1_score(gt_a, vote_a), f1_score(gt_b, vote_b)
32 | ssa, ssb = f1_score(gt_a[:1750], vote_a[:1750]), f1_score(gt_b[:1750], vote_b[:1750])
33 | sla, slb = f1_score(gt_a[1750:4380], vote_a[1750:4380]), f1_score(gt_b[1750:4385], vote_b[1750:4385])
34 | lla, llb = f1_score(gt_a[4380:], vote_a[4380:]), f1_score(gt_b[4385:], vote_b[4385:])
35 |
36 | if verbose:
37 | print("f1a: {}, f1b: {}".format(f1a, f1b))
38 | print("ssa: {}, ssb: {}".format(ssa, ssb))
39 | print("sla: {}, slb: {}".format(sla, slb))
40 | print("lla: {}, llb: {}".format(lla, llb))
41 |
42 | return f1a, f1b, ssa, ssb, sla, slb, lla, llb
43 |
44 | if __name__ == '__main__':
45 | valid_dir = '../valid_output/'
46 | total_model_names = [
47 | '0518_roberta_same_lr_epoch_1_ab_loss',
48 | '0520_roberta_diff_lr_epoch_1_ab_loss',
49 | '0520_roberta_tcnn_diff_lr_epoch_1_ab_loss',
50 | '0520_roberta_80k_same_lr_zy_epoch_1_ab_loss',
51 | '0522_roberta_80k_fl_epoch_1_ab_loss',
52 | '0519_nezha_same_lr_epoch_1_ab_f1',
53 | '0519_nezha_same_lr_epoch_0_ab_loss',
54 | '0518_nezha_diff_lr_zy_epoch_1_ab_loss',
55 | '0518_macbert_same_lr_epoch_1_ab_loss',
56 | '0520_macbert_sbert_same_lr_epoch_1_ab_loss',
57 | '0520_roberta_sbert_same_lr_epoch_1_ab_loss',
58 | '0522_roberta_80k_tcnn_epoch_1_ab_loss',
59 | '0523_roberta_dataaug_epoch_0_ab_loss',
60 | '0523_ernie_epoch_1_ab_loss'
61 | ]
62 | total_model_dir = [valid_dir + model_name for model_name in total_model_names]
63 | f1a, f1b, *_ = merge_on_valid(total_model_dir)
64 | print("total merge: f1 {}, f1a {}, f1b {}".format(((f1a+f1b)/2), f1a, f1b))
65 | print()
66 |
67 | for size in [3,5,7,9,11]:
68 | print("searching the best merge of {} models".format(size))
69 | records = []
70 | combs = combinations(total_model_dir, size)
71 | best_f1 = 0
72 | best_comb = None
73 | for comb in tqdm(combs):
74 | f1a, f1b, *_ = merge_on_valid(list(comb))
75 | if (f1a + f1b)/2 > best_f1:
76 | best_f1 = (f1a + f1b)/2
77 | best_comb = comb
78 | records.append((list(comb), (f1a+f1b)/2))
79 | print("best f1 and model list:")
80 | print(best_f1, best_comb)
81 | merge_on_valid(list(best_comb), True)
82 |
83 | print("top5 candidates list:")
84 | records.sort(key=lambda x:x[-1], reverse=True)
85 | for i in range(5):
86 | print(records[i])
87 | print()
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/train.py:
--------------------------------------------------------------------------------
1 | from model import BertClassifierSingleModel, NezhaClassifierSingleModel, SBERTSingleModel, SNEZHASingleModel, BertClassifierTextCNNSingleModel
2 | from data import SentencePairDatasetWithType, SentencePairDatasetForSBERT, SentencePairDatasetWithMultiType
3 | from utils import focal_loss, FGM
4 | from transformers import AdamW, get_linear_schedule_with_warmup
5 | from config import Config
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 |
11 | import numpy as np
12 | from sklearn import metrics
13 | from tensorboardX import SummaryWriter
14 |
15 | import os
16 | # os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6,7' # recommended for NEZHA
17 | # os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7'
18 | os.environ['CUDA_VISIBLE_DEVICES'] = '5,6,7'
19 |
20 | def train(model, device, epoch, train_dataloader, test_dataloader, save_dir, optimizer, scheduler=None, criterion_type='CE', model_type='bert', print_every=100, eval_every=500, writer=None, use_fgm=False):
21 | print("Training at epoch {}".format(epoch))
22 | if use_fgm:
23 | print("Using fgm for adversial attack")
24 |
25 | est_batch = len(train_dataloader.dataset) / (train_dataloader.batch_size)
26 | model.train()
27 |
28 | # for multiple GPU support
29 | model = torch.nn.DataParallel(model)
30 |
31 | assert criterion_type == 'CE' or criterion_type == 'FL'
32 | if criterion_type == 'CE':
33 | criterion = nn.CrossEntropyLoss()
34 | elif criterion_type == 'FL':
35 | criterion = focal_loss()
36 |
37 | if use_fgm:
38 | fgm = FGM(model)
39 |
40 | total_loss = []
41 | total_gt_a, total_preds_a = [], []
42 | total_gt_b, total_preds_b = [], []
43 | for idx, batch in enumerate(train_dataloader):
44 | # for SentencePairDatasetWithType, types would be returned
45 | input_ids, input_types, labels, types = batch
46 | input_ids = input_ids.to(device)
47 | input_types = input_types.to(device)
48 | # labels should be flattened
49 | labels = labels.to(device).view(-1)
50 |
51 | optimizer.zero_grad()
52 |
53 | # the probs given by the model
54 | all_probs = model(input_ids, input_types)
55 | num_tasks = len(all_probs)
56 |
57 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)]
58 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)]
59 | all_labels = [labels[all_masks[task_id]] for task_id in range(num_tasks)]
60 |
61 | # calculate the loss and BP
62 | # TODO: different weights for each task?
63 | all_loss = None
64 | for task_id in range(num_tasks):
65 | if all_masks[task_id].sum() != 0:
66 | if all_loss is None:
67 | all_loss = criterion(all_output[task_id], all_labels[task_id])
68 | else:
69 | all_loss += criterion(all_output[task_id], all_labels[task_id])
70 | all_loss.backward()
71 |
72 | # code for fgm adversial training
73 | if use_fgm:
74 | fgm.attack()
75 | # adv_probs_a, adv_probs_b = model(input_ids, input_types)
76 | adv_all_probs = model(input_ids, input_types)
77 | adv_all_output = [adv_all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)]
78 | # calculate the loss and BP
79 | adv_all_loss = None
80 | for task_id in range(num_tasks):
81 | if all_masks[task_id].sum() != 0:
82 | if adv_all_loss is None:
83 | adv_all_loss = criterion(adv_all_output[task_id], all_labels[task_id])
84 | else:
85 | adv_all_loss += criterion(adv_all_output[task_id], all_labels[task_id])
86 | adv_all_loss.backward()
87 | fgm.restore()
88 |
89 | optimizer.step()
90 | if scheduler is not None:
91 | scheduler.step()
92 |
93 |
94 | all_gt = [all_labels[task_id].cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)]
95 | all_preds = [all_output[task_id].argmax(axis=1).cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)]
96 |
97 | gt_a, preds_a = [], []
98 | for task_id in range(0, num_tasks, 2):
99 | gt_a += all_gt[task_id]
100 | preds_a += all_preds[task_id]
101 |
102 | gt_b, preds_b = [], []
103 | for task_id in range(1, num_tasks, 2):
104 | gt_b += all_gt[task_id]
105 | preds_b += all_preds[task_id]
106 |
107 | total_preds_a += preds_a
108 | total_gt_a += gt_a
109 | total_preds_b += preds_b
110 | total_gt_b += gt_b
111 | total_loss.append(all_loss.item())
112 | # print('a', preds_a, gt_a)
113 | # print('b', preds_b, gt_b)
114 |
115 | acc_a = metrics.accuracy_score(gt_a, preds_a) if len(gt_a)!=0 else 0
116 | f1_a = metrics.f1_score(gt_a, preds_a, zero_division=0)
117 | acc_b = metrics.accuracy_score(gt_b, preds_b) if len(gt_b)!=0 else 0
118 | f1_b = metrics.f1_score(gt_b, preds_b, zero_division=0)
119 |
120 | # learning rate for bert is the second (the last) parameter group
121 | writer.add_scalar('train/learning_rate', optimizer.param_groups[-1]['lr'], global_step=epoch*est_batch+idx)
122 | writer.add_scalar('train/loss', all_loss.item(), global_step=epoch*est_batch+idx)
123 | writer.add_scalar('train/acc_a', acc_a, global_step=epoch*est_batch+idx)
124 | writer.add_scalar('train/acc_b', acc_b, global_step=epoch*est_batch+idx)
125 | writer.add_scalar('train/f1_a', f1_a, global_step=epoch*est_batch+idx)
126 | writer.add_scalar('train/f1_b', f1_b, global_step=epoch*est_batch+idx)
127 |
128 | # print the loss and accuracy score if reach print_every
129 | if (idx+1) % print_every == 0:
130 | print("\tBatch: {} / {:.0f}, Loss: {:.6f}".format(idx, est_batch, all_loss.item()))
131 | print("\t\t Task A\tAcc: {:.6f}, F1: {:.6f}".format(acc_a, f1_a))
132 | print("\t\t Task B\tAcc: {:.6f}, F1: {:.6f}".format(acc_b, f1_b))
133 |
134 | # evaluate the model if reach eval_every, instead of evaluate after the whole epoch
135 | global best_dev_loss, best_dev_f1
136 | if (idx+1) % eval_every == 0:
137 | dev_loss, dev_acc_a, dev_acc_b, dev_f1_a, dev_f1_b = eval(model, device, test_dataloader, criterion_type)
138 | dev_f1 = (dev_f1_a + dev_f1_b) / 2
139 | writer.add_scalar('eval/loss', dev_loss, global_step=epoch*est_batch+idx)
140 | writer.add_scalar('eval/acc_a', dev_acc_a, global_step=epoch*est_batch+idx)
141 | writer.add_scalar('eval/acc_b', dev_acc_b, global_step=epoch*est_batch+idx)
142 | writer.add_scalar('eval/f1_a', dev_f1_a, global_step=epoch*est_batch+idx)
143 | writer.add_scalar('eval/f1_b', dev_f1_b, global_step=epoch*est_batch+idx)
144 | # in practice, better loss is preferred instead of better f1 score,
145 | # which could be resulted from random overfitting on the valid set
146 | # 0517: save the model's state_dict instead of the whole model (mainly for NEZHA's sake)
147 | if (dev_loss < best_dev_loss or dev_f1 > best_dev_f1):
148 | if dev_loss < best_dev_loss:
149 | best_dev_loss = dev_loss
150 | torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'loss')
151 | print("----------BETTER LOSS, MODEL SAVED-----------")
152 | if dev_f1 > best_dev_f1:
153 | best_dev_f1 = dev_f1
154 | torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'f1')
155 | print("----------BETTER F1, MODEL SAVED-----------")
156 |
157 | loss = np.array(total_loss).mean()
158 | # Setting average=None to return class-specific scores
159 | # 0502 BUG FIXED: do not use 'macro', DO NOT require class-specific metrics!
160 | # macro_f1 = metrics.f1_score(total_gt, total_preds, average='macro')
161 | f1_a = metrics.f1_score(total_gt_a, total_preds_a, zero_division=0)
162 | f1_b = metrics.f1_score(total_gt_b, total_preds_b, zero_division=0)
163 | f1 = (f1_a + f1_b) / 2
164 | print("Average f1 on training set: {:.6f}, f1_a: {:.6f}, f1_b: {:.6f}".format(f1, f1_a, f1_b))
165 |
166 | return loss, f1, f1_a, f1_b
167 |
168 |
169 | def eval(model, device, test_dataloader, criterion_type='CE'):
170 | print("Evaluating")
171 | model.eval()
172 | # if called while training, then model parallel is already done
173 | # model = torch.nn.DataParallel(model)
174 |
175 | assert criterion_type == 'CE' or criterion_type == 'FL'
176 | if criterion_type == 'CE':
177 | criterion = nn.CrossEntropyLoss()
178 | elif criterion_type == 'FL':
179 | criterion = focal_loss()
180 |
181 | total_loss = []
182 | total_gt_a, total_preds_a = [], []
183 | total_gt_b, total_preds_b = [], []
184 |
185 | for idx, batch in enumerate(test_dataloader):
186 | input_ids, input_types, labels, types = batch
187 | input_ids = input_ids.to(device)
188 | input_types = input_types.to(device)
189 | # labels should be flattened
190 | labels = labels.to(device).view(-1)
191 |
192 | # the probs given by the model, without grads
193 | with torch.no_grad():
194 | # the probs given by the model
195 | # probs_a, probs_b = model(input_ids, input_types)
196 | # mask_a, mask_b = (types==0).numpy(), (types==1).numpy()
197 | # output_a, labels_a = probs_a[mask_a], labels[mask_a]
198 | # output_b, labels_b = probs_b[mask_b], labels[mask_b]
199 |
200 | all_probs = model(input_ids, input_types)
201 | num_tasks = len(all_probs)
202 |
203 | # mask_a, mask_b = (types==0).numpy(), (types==1).numpy()
204 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)]
205 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)]
206 | all_labels = [labels[all_masks[task_id]] for task_id in range(num_tasks)]
207 |
208 | all_loss = None
209 | for task_id in range(num_tasks):
210 | # print(task_id, all_masks[task_id])
211 | if all_masks[task_id].sum() != 0:
212 | if all_loss is None:
213 | all_loss = criterion(all_output[task_id], all_labels[task_id])
214 | else:
215 | all_loss += criterion(all_output[task_id], all_labels[task_id])
216 |
217 | all_gt = [all_labels[task_id].cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)]
218 | all_preds = [all_output[task_id].argmax(axis=1).cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)]
219 |
220 | gt_a, preds_a = [], []
221 | for task_id in range(0, num_tasks, 2):
222 | gt_a += all_gt[task_id]
223 | preds_a += all_preds[task_id]
224 |
225 | gt_b, preds_b = [], []
226 | for task_id in range(1, num_tasks, 2):
227 | gt_b += all_gt[task_id]
228 | preds_b += all_preds[task_id]
229 |
230 | total_preds_a += preds_a
231 | total_gt_a += gt_a
232 | total_preds_b += preds_b
233 | total_gt_b += gt_b
234 | total_loss.append(all_loss.item())
235 |
236 | loss = np.array(total_loss).mean()
237 | acc_a = metrics.accuracy_score(total_gt_a, total_preds_a) if len(total_gt_a)!=0 else 0
238 | f1_a = metrics.f1_score(total_gt_a, total_preds_a, zero_division=0)
239 | if (f1_a == 0):
240 | print("F1_a = 0, checking precision, recall, fscore and support...")
241 | print(metrics.precision_recall_fscore_support(total_gt_a, total_preds_a, zero_division=0))
242 |
243 | acc_b = metrics.accuracy_score(total_gt_b, total_preds_b) if len(total_gt_b)!=0 else 0
244 | f1_b = metrics.f1_score(total_gt_b, total_preds_b, zero_division=0)
245 | if (f1_b == 0):
246 | print("F1_b = 0, checking precision, recall, fscore and support...")
247 | print(metrics.precision_recall_fscore_support(total_gt_b, total_preds_b, zero_division=0))
248 |
249 | # Setting average=None to return class-specific scores
250 | # macro_f1 = metrics.f1_score(total_gt, total_preds, average='macro')
251 | # f1 = metrics.f1_score(total_gt, total_preds)
252 |
253 | # print loss and classification report
254 | print("Loss on dev set: ", loss)
255 | print("F1 on dev set: {:.6f}, f1_a: {:.6f}, f1_b: {:.6f}".format((f1_a+f1_b)/2, f1_a, f1_b))
256 |
257 | # return loss, acc, macro_f1
258 | return loss, acc_a, acc_b, f1_a, f1_b
259 |
260 |
261 | if __name__ == '__main__':
262 | config = Config()
263 | device = config.device
264 | pretrained = config.pretrained
265 | model_type = config.model_type
266 | use_fgm = config.use_fgm
267 |
268 | save_dir = config.save_dir
269 | data_dir = config.data_dir
270 | # whether to shuffle the pos of source and target to augment data
271 | shuffle_order = config.shuffle_order
272 | # whether to use the positive case in task b for task a (positives)
273 | # and to use the negativate case in task a for task b (negatives)
274 | aug_data = config.aug_data
275 | # method for clipping long seqeunces, 'head' or 'tail'
276 | clip_method = config.clip_method
277 |
278 | task_type = config.task_type
279 | task_a = ['短短匹配A类', '短长匹配A类', '长长匹配A类']
280 | task_b = ['短短匹配B类', '短长匹配B类', '长长匹配B类']
281 |
282 | # hypter parameters here
283 | epochs = config.epochs
284 | lr = config.lr
285 | classifer_lr = config.classifier_lr
286 | weight_decay = config.weight_decay
287 | hidden_size = config.hidden_size
288 | train_bs = config.train_bs
289 | eval_bs = config.eval_bs
290 |
291 | print_every = config.print_every
292 | eval_every = config.eval_every
293 |
294 | train_data_dir, dev_data_dir = [], []
295 | # integrate the two tasks into one dataset using task_type = 'ab'
296 | if 'a' in task_type:
297 | for task in task_a:
298 | train_data_dir.append(data_dir + task + '/train.txt')
299 | train_data_dir.append(data_dir + task + '/train_r2.txt')
300 | train_data_dir.append(data_dir + task + '/train_r3.txt')
301 | train_data_dir.append(data_dir + task + '/train_rematch.txt') # training file in rematch
302 | dev_data_dir.append(data_dir + task + '/valid.txt')
303 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') # valid file in rematch
304 |
305 | if 'b' in task_type:
306 | for task in task_b:
307 | train_data_dir.append(data_dir + task + '/train.txt')
308 | train_data_dir.append(data_dir + task + '/train_r2.txt')
309 | train_data_dir.append(data_dir + task + '/train_r3.txt')
310 | train_data_dir.append(data_dir + task + '/train_rematch.txt') # training file in rematch
311 | dev_data_dir.append(data_dir + task + '/valid.txt')
312 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') # valid file in rematch
313 |
314 | # toy dataset for testing
315 | if config.load_toy_dataset:
316 | # train_data_dir = ['../data/sohu2021_open_data/短短匹配A类/train.txt',
317 | # '../data/sohu2021_open_data/短短匹配B类/train.txt']
318 | train_data_dir = [
319 | '../data/sohu2021_open_data/短短匹配A类/valid.txt',
320 | '../data/sohu2021_open_data/短短匹配B类/valid.txt',
321 | '../data/sohu2021_open_data/短长匹配A类/valid.txt',
322 | '../data/sohu2021_open_data/短长匹配B类/valid.txt',
323 | '../data/sohu2021_open_data/长长匹配A类/valid.txt',
324 | '../data/sohu2021_open_data/长长匹配B类/valid.txt']
325 | dev_data_dir = ['../data/sohu2021_open_data/短短匹配A类/valid.txt',
326 | '../data/sohu2021_open_data/短短匹配B类/valid.txt',]
327 | dev_data_dir = train_data_dir
328 |
329 | # if config.load_toy_dataset:
330 | # train_data_dir = ['../data/sohu2021_open_data/长长匹配A类/train.txt']
331 | # dev_data_dir = ['../data/sohu2021_open_data/长长匹配A类/valid.txt']
332 |
333 | print("Loading pretrained Model from {}...".format(pretrained))
334 | # integrating SBERT model into a unified training framework
335 | if 'sbert' in model_type.lower():
336 | print("Using SentenceBERT model and dataset")
337 | if 'nezha' in model_type.lower():
338 | model = SNEZHASingleModel(bert_dir=pretrained, hidden_size=hidden_size)
339 | else:
340 | model = SBERTSingleModel(bert_dir=pretrained, hidden_size=hidden_size)
341 | model.to(device)
342 | print("Loading Training Data...")
343 | print(train_data_dir)
344 | # augment the data with shuffle_order=True (changing order of source and target)
345 | # or with aug_data=True (neg cases in A -> B, pos cases in B -> A)
346 | train_dataset = SentencePairDatasetForSBERT(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method)
347 | train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True)
348 |
349 | print("Loading Dev Data...")
350 | test_dataset = SentencePairDatasetForSBERT(dev_data_dir, True, pretrained, clip=clip_method)
351 | test_dataloader = DataLoader(test_dataset, batch_size=eval_bs, shuffle=False)
352 |
353 | # 0517: for training, load weights from pretrained with from_pretrained=True (by default)
354 | # for larger model, adjust the hidden_size according to its config
355 | # distinguish model architectures or pretrained models according to model_type
356 | else:
357 | print("Using BERT model and dataset")
358 | if 'nezha' in model_type.lower():
359 | print("Using NEZHA pretrained model")
360 | model = NezhaClassifierSingleModel(bert_dir=pretrained, hidden_size=hidden_size)
361 | elif 'cnn' in model_type.lower():
362 | print("Adding TextCNN after BERT output")
363 | model = BertClassifierTextCNNSingleModel(bert_dir=pretrained, hidden_size=hidden_size)
364 | else:
365 | print("Using conventional BERT model with linears")
366 | # model = BertClassifierSingleModel(bert_dir=pretrained, hidden_size=hidden_size)
367 | model = BertClassifierSingleModel(bert_dir=pretrained, task_num=6, hidden_size=hidden_size)
368 | model.to(device)
369 |
370 | print("Loading Training Data...")
371 | print(train_data_dir)
372 | # augment the data with shuffle_order=True (changing order of source and target)
373 | # or with aug_data=True (neg cases in A -> B, pos cases in B -> A)
374 | # train_dataset = SentencePairDatasetWithType(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method)
375 | train_dataset = SentencePairDatasetWithMultiType(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method)
376 | train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True)
377 |
378 | print("Loading Dev Data...")
379 | # test_dataset = SentencePairDatasetWithType(dev_data_dir, True, pretrained, clip=clip_method)
380 | test_dataset = SentencePairDatasetWithMultiType(dev_data_dir, True, pretrained, clip=clip_method)
381 | test_dataloader = DataLoader(test_dataset, batch_size=eval_bs, shuffle=False)
382 |
383 | optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, correct_bias=False)
384 | # 0514 setting different lr for bert encoder and classifier
385 | # TODO: verify the large lr works for classifiers
386 | # optimizer = AdamW([
387 | # {"params": model.all_classifier.parameters(), "lr": classifer_lr},
388 | # {"params": model.bert.parameters()}],
389 | # lr=lr)
390 |
391 | # for p in optimizer.param_groups:
392 | # outputs = ''
393 | # for k, v in p.items():
394 | # if k is 'params':
395 | # outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ')
396 | # else:
397 | # outputs += (k + ': ' + str(v).ljust(10) + ' ')
398 | # print(outputs)
399 |
400 | total_steps = len(train_dataloader) * epochs
401 |
402 | # TODO: using ReduceLROnPlateau instead of linear scheduler
403 | if config.use_scheduler:
404 | scheduler = get_linear_schedule_with_warmup(
405 | optimizer,
406 | num_training_steps = total_steps,
407 | num_warmup_steps = config.num_warmup_steps,
408 | )
409 | else:
410 | scheduler = None
411 |
412 | print("Training on Task {}...".format(task_type))
413 | writer = SummaryWriter('runs/{}'.format(model_type + '_' + task_type))
414 |
415 | best_dev_loss = 999
416 | best_dev_f1 = 0
417 | for epoch in range(epochs):
418 | train_loss, train_f1, train_f1_a, train_f1_b = train(model, device, epoch, train_dataloader, test_dataloader, \
419 | save_dir, optimizer, scheduler=scheduler, model_type=model_type, \
420 | print_every=print_every, eval_every=eval_every, writer=writer, use_fgm=use_fgm)
421 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/train_old.py:
--------------------------------------------------------------------------------
1 | from model import BertClassifierSingleModel, NezhaClassifierSingleModel, SBERTSingleModel, SNEZHASingleModel, BertClassifierTextCNNSingleModel
2 | from data import SentencePairDatasetWithType, SentencePairDatasetForSBERT
3 | from utils import focal_loss, FGM
4 | from transformers import AdamW, get_linear_schedule_with_warmup
5 | from config import Config
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 |
11 | import numpy as np
12 | from sklearn import metrics
13 | from tensorboardX import SummaryWriter
14 |
15 | import os
16 | # os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6,7' # recommended for NEZHA
17 | # os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7'
18 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
19 | os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4'
20 |
21 | def train(model, device, epoch, train_dataloader, test_dataloader, save_dir, optimizer, scheduler=None, criterion_type='CE', model_type='bert', print_every=100, eval_every=500, writer=None, use_fgm=False):
22 | print("Training at epoch {}".format(epoch))
23 | if use_fgm:
24 | print("Using fgm for adversial attack")
25 |
26 | est_batch = len(train_dataloader.dataset) / (train_dataloader.batch_size)
27 | model.train()
28 |
29 | # for multiple GPU support
30 | model = torch.nn.DataParallel(model)
31 |
32 | assert criterion_type == 'CE' or criterion_type == 'FL'
33 | if criterion_type == 'CE':
34 | criterion = nn.CrossEntropyLoss()
35 | elif criterion_type == 'FL':
36 | criterion = focal_loss()
37 |
38 | if use_fgm:
39 | fgm = FGM(model)
40 |
41 | total_loss = []
42 | total_gt_a, total_preds_a = [], []
43 | total_gt_b, total_preds_b = [], []
44 |
45 | # the following commented code is compatitable with multitasks (e.g. 6 subtasks with designated dataset)
46 | # however, the model's performance seems to be influenced by 1 precent in task b
47 | # for idx, batch in enumerate(train_dataloader):
48 | # # for SentencePairDatasetWithType, types would be returned
49 | # input_ids, input_types, labels, types = batch
50 | # input_ids = input_ids.to(device)
51 | # input_types = input_types.to(device)
52 | # # labels should be flattened
53 | # labels = labels.to(device).view(-1)
54 |
55 | # optimizer.zero_grad()
56 |
57 | # # the probs given by the model
58 | # all_probs = model(input_ids, input_types)
59 | # num_tasks = len(all_probs)
60 |
61 | # all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)]
62 | # all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)]
63 | # all_labels = [labels[all_masks[task_id]] for task_id in range(num_tasks)]
64 |
65 | # # calculate the loss and BP
66 | # # TODO: different weights for each task?
67 | # all_loss = None
68 | # for task_id in range(num_tasks):
69 | # if all_masks[task_id].sum() != 0:
70 | # if all_loss is None:
71 | # all_loss = criterion(all_output[task_id], all_labels[task_id])
72 | # else:
73 | # all_loss += criterion(all_output[task_id], all_labels[task_id])
74 | # all_loss.backward()
75 |
76 | # # code for fgm adversial training
77 | # if use_fgm:
78 | # fgm.attack()
79 | # # adv_probs_a, adv_probs_b = model(input_ids, input_types)
80 | # adv_all_probs = model(input_ids, input_types)
81 | # adv_all_output = [adv_all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)]
82 | # # calculate the loss and BP
83 | # adv_all_loss = None
84 | # for task_id in range(num_tasks):
85 | # if all_masks[task_id].sum() != 0:
86 | # if adv_all_loss is None:
87 | # adv_all_loss = criterion(adv_all_output[task_id], all_labels[task_id])
88 | # else:
89 | # adv_all_loss += criterion(adv_all_output[task_id], all_labels[task_id])
90 | # adv_all_loss.backward()
91 | # fgm.restore()
92 |
93 | # optimizer.step()
94 | # if scheduler is not None:
95 | # scheduler.step()
96 |
97 |
98 | # all_gt = [all_labels[task_id].cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)]
99 | # all_preds = [all_output[task_id].argmax(axis=1).cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)]
100 |
101 | # gt_a, preds_a = [], []
102 | # for task_id in range(0, num_tasks, 2):
103 | # gt_a += all_gt[task_id]
104 | # preds_a += all_preds[task_id]
105 |
106 | # gt_b, preds_b = [], []
107 | # for task_id in range(1, num_tasks, 2):
108 | # gt_b += all_gt[task_id]
109 | # preds_b += all_preds[task_id]
110 |
111 | # total_preds_a += preds_a
112 | # total_gt_a += gt_a
113 | # total_preds_b += preds_b
114 | # total_gt_b += gt_b
115 | # total_loss.append(all_loss.item())
116 | # # print('a', preds_a, gt_a)
117 | # # print('b', preds_b, gt_b)
118 |
119 | # acc_a = metrics.accuracy_score(gt_a, preds_a) if len(gt_a)!=0 else 0
120 | # f1_a = metrics.f1_score(gt_a, preds_a, zero_division=0)
121 | # acc_b = metrics.accuracy_score(gt_b, preds_b) if len(gt_b)!=0 else 0
122 | # f1_b = metrics.f1_score(gt_b, preds_b, zero_division=0)
123 |
124 | # # learning rate for bert is the second (the last) parameter group
125 | # writer.add_scalar('train/learning_rate', optimizer.param_groups[-1]['lr'], global_step=epoch*est_batch+idx)
126 | # writer.add_scalar('train/loss', all_loss.item(), global_step=epoch*est_batch+idx)
127 | # writer.add_scalar('train/acc_a', acc_a, global_step=epoch*est_batch+idx)
128 | # writer.add_scalar('train/acc_b', acc_b, global_step=epoch*est_batch+idx)
129 | # writer.add_scalar('train/f1_a', f1_a, global_step=epoch*est_batch+idx)
130 | # writer.add_scalar('train/f1_b', f1_b, global_step=epoch*est_batch+idx)
131 |
132 | # # print the loss and accuracy score if reach print_every
133 | # if (idx+1) % print_every == 0:
134 | # print("\tBatch: {} / {:.0f}, Loss: {:.6f}".format(idx, est_batch, all_loss.item()))
135 | # print("\t\t Task A\tAcc: {:.6f}, F1: {:.6f}".format(acc_a, f1_a))
136 | # print("\t\t Task B\tAcc: {:.6f}, F1: {:.6f}".format(acc_b, f1_b))
137 |
138 | # # evaluate the model if reach eval_every, instead of evaluate after the whole epoch
139 | # global best_dev_loss, best_dev_f1
140 | # if (idx+1) % eval_every == 0:
141 | # dev_loss, dev_acc_a, dev_acc_b, dev_f1_a, dev_f1_b = eval(model, device, test_dataloader, criterion_type)
142 | # dev_f1 = (dev_f1_a + dev_f1_b) / 2
143 | # writer.add_scalar('eval/loss', dev_loss, global_step=epoch*est_batch+idx)
144 | # writer.add_scalar('eval/acc_a', dev_acc_a, global_step=epoch*est_batch+idx)
145 | # writer.add_scalar('eval/acc_b', dev_acc_b, global_step=epoch*est_batch+idx)
146 | # writer.add_scalar('eval/f1_a', dev_f1_a, global_step=epoch*est_batch+idx)
147 | # writer.add_scalar('eval/f1_b', dev_f1_b, global_step=epoch*est_batch+idx)
148 | # # in practice, better loss is preferred instead of better f1 score,
149 | # # which could be resulted from random overfitting on the valid set
150 | # # 0517: save the model's state_dict instead of the whole model (mainly for NEZHA's sake)
151 | # if (dev_loss < best_dev_loss or dev_f1 > best_dev_f1):
152 | # if dev_loss < best_dev_loss:
153 | # best_dev_loss = dev_loss
154 | # torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'loss')
155 | # print("----------BETTER LOSS, MODEL SAVED-----------")
156 | # if dev_f1 > best_dev_f1:
157 | # best_dev_f1 = dev_f1
158 | # torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'f1')
159 | # print("----------BETTER F1, MODEL SAVED-----------")
160 |
161 | for idx, batch in enumerate(train_dataloader):
162 | # for SentencePairDatasetWithType, types would be returned
163 | input_ids, input_types, labels, types = batch
164 | input_ids = input_ids.to(device)
165 | input_types = input_types.to(device)
166 | # labels should be flattened
167 | labels = labels.to(device).view(-1)
168 |
169 | optimizer.zero_grad()
170 |
171 | # the probs given by the model
172 | probs_a, probs_b = model(input_ids, input_types)
173 |
174 | mask_a, mask_b = (types==0).numpy(), (types==1).numpy()
175 | output_a, labels_a = probs_a[mask_a], labels[mask_a]
176 | output_b, labels_b = probs_b[mask_b], labels[mask_b]
177 |
178 | # calculate the loss and BP
179 | # loss_a = criterion(output_a, labels_a) if mask_a.sum()!=0 else None
180 | # loss_b = criterion(output_b, labels_b) if mask_b.sum()!=0 else None
181 | # so-called multi-task training
182 | # TODO: different weights for each task?
183 | if mask_a.sum()==0:
184 | loss = criterion(output_b, labels_b)
185 | elif mask_b.sum()==0:
186 | loss = criterion(output_a, labels_a)
187 | else:
188 | loss = criterion(output_a, labels_a) + criterion(output_b, labels_b)
189 | # print(loss.item())
190 | loss.backward()
191 |
192 | # code for fgm adversial training
193 | if use_fgm:
194 | fgm.attack()
195 | adv_probs_a, adv_probs_b = model(input_ids, input_types)
196 | # calculate the loss and BP
197 | adv_output_a, adv_output_b = adv_probs_a[mask_a], adv_probs_b[mask_b]
198 | if mask_a.sum()==0:
199 | adv_loss = criterion(adv_output_b, labels_b)
200 | elif mask_b.sum()==0:
201 | adv_loss = criterion(adv_output_a, labels_a)
202 | else:
203 | adv_loss = criterion(adv_output_a, labels_a) + criterion(adv_output_b, labels_b)
204 | adv_loss.backward()
205 | fgm.restore()
206 |
207 | optimizer.step()
208 | if scheduler is not None:
209 | scheduler.step()
210 |
211 | gt_a = labels_a.cpu().numpy().tolist()
212 | preds_a = output_a.argmax(axis=1).cpu().numpy().tolist() if len(gt_a)!=0 else []
213 |
214 | gt_b = labels_b.cpu().numpy().tolist()
215 | preds_b = output_b.argmax(axis=1).cpu().numpy().tolist() if len(gt_b)!=0 else []
216 |
217 | total_preds_a += preds_a
218 | total_gt_a += gt_a
219 | total_preds_b += preds_b
220 | total_gt_b += gt_b
221 | total_loss.append(loss.item())
222 | # print('a', preds_a, gt_a)
223 | # print('b', preds_b, gt_b)
224 |
225 | acc_a = metrics.accuracy_score(gt_a, preds_a) if len(gt_a)!=0 else 0
226 | f1_a = metrics.f1_score(gt_a, preds_a, zero_division=0)
227 | acc_b = metrics.accuracy_score(gt_b, preds_b) if len(gt_b)!=0 else 0
228 | f1_b = metrics.f1_score(gt_b, preds_b, zero_division=0)
229 |
230 | writer.add_scalar('train/learning_rate', optimizer.param_groups[-1]['lr'], global_step=epoch*est_batch+idx)
231 | writer.add_scalar('train/loss', loss.item(), global_step=epoch*est_batch+idx)
232 | writer.add_scalar('train/acc_a', acc_a, global_step=epoch*est_batch+idx)
233 | writer.add_scalar('train/acc_b', acc_b, global_step=epoch*est_batch+idx)
234 | writer.add_scalar('train/f1_a', f1_a, global_step=epoch*est_batch+idx)
235 | writer.add_scalar('train/f1_b', f1_b, global_step=epoch*est_batch+idx)
236 |
237 | # print the loss and accuracy score if reach print_every
238 | if (idx+1) % print_every == 0:
239 | print("\tBatch: {} / {:.0f}, Loss: {:.6f}".format(idx, est_batch, loss.item()))
240 | print("\t\t Task A\tAcc: {:.6f}, F1: {:.6f}".format(acc_a, f1_a))
241 | # if (f1_a == 0):
242 | # print(metrics.precision_recall_fscore_support(gt_a, preds_a, zero_division=0))
243 |
244 | print("\t\t Task B\tAcc: {:.6f}, F1: {:.6f}".format(acc_b, f1_b))
245 | # if (f1_b == 0):
246 | # print(metrics.precision_recall_fscore_support(gt_b, preds_b, zero_division=0))
247 | # evaluate the model if reach eval_every, instead of evaluate after the whole epoch
248 | global best_dev_loss, best_dev_f1
249 | if (idx+1) % eval_every == 0:
250 | dev_loss, dev_acc_a, dev_acc_b, dev_f1_a, dev_f1_b = eval(model, device, test_dataloader, criterion_type)
251 | dev_f1 = (dev_f1_a + dev_f1_b) / 2
252 | writer.add_scalar('eval/loss', dev_loss, global_step=epoch*est_batch+idx)
253 | writer.add_scalar('eval/acc_a', dev_acc_a, global_step=epoch*est_batch+idx)
254 | writer.add_scalar('eval/acc_b', dev_acc_b, global_step=epoch*est_batch+idx)
255 | writer.add_scalar('eval/f1_a', dev_f1_a, global_step=epoch*est_batch+idx)
256 | writer.add_scalar('eval/f1_b', dev_f1_b, global_step=epoch*est_batch+idx)
257 | if (dev_loss < best_dev_loss or dev_f1 > best_dev_f1):
258 | if dev_loss < best_dev_loss:
259 | best_dev_loss = dev_loss
260 | # torch.save(model, save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'loss')
261 | torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'loss')
262 | print("----------BETTER LOSS, MODEL SAVED-----------")
263 | if dev_f1 > best_dev_f1:
264 | best_dev_f1 = dev_f1
265 | # torch.save(model, save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'f1')
266 | torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'f1')
267 | print("----------BETTER F1, MODEL SAVED-----------")
268 |
269 | loss = np.array(total_loss).mean()
270 | # Setting average=None to return class-specific scores
271 | # 0502 BUG FIXED: do not use 'macro', DO NOT require class-specific metrics!
272 | # macro_f1 = metrics.f1_score(total_gt, total_preds, average='macro')
273 | f1_a = metrics.f1_score(total_gt_a, total_preds_a, zero_division=0)
274 | f1_b = metrics.f1_score(total_gt_b, total_preds_b, zero_division=0)
275 | f1 = (f1_a + f1_b) / 2
276 | print("Average f1 on training set: {:.6f}, f1_a: {:.6f}, f1_b: {:.6f}".format(f1, f1_a, f1_b))
277 |
278 | return loss, f1, f1_a, f1_b
279 |
280 |
281 | def eval(model, device, test_dataloader, criterion_type='CE'):
282 | print("Evaluating")
283 | model.eval()
284 | # if called while training, then model parallel is already done
285 | # model = torch.nn.DataParallel(model)
286 |
287 | assert criterion_type == 'CE' or criterion_type == 'FL'
288 | if criterion_type == 'CE':
289 | criterion = nn.CrossEntropyLoss()
290 | elif criterion_type == 'FL':
291 | criterion = focal_loss()
292 |
293 | total_loss = []
294 | total_gt_a, total_preds_a = [], []
295 | total_gt_b, total_preds_b = [], []
296 |
297 | for idx, batch in enumerate(test_dataloader):
298 | input_ids, input_types, labels, types = batch
299 | input_ids = input_ids.to(device)
300 | input_types = input_types.to(device)
301 | # labels should be flattened
302 | labels = labels.to(device).view(-1)
303 |
304 | # the probs given by the model, without grads
305 | with torch.no_grad():
306 | # the probs given by the model
307 | probs_a, probs_b = model(input_ids, input_types)
308 | mask_a, mask_b = (types==0).numpy(), (types==1).numpy()
309 | output_a, labels_a = probs_a[mask_a], labels[mask_a]
310 | output_b, labels_b = probs_b[mask_b], labels[mask_b]
311 |
312 | if mask_a.sum()==0:
313 | loss = criterion(output_b, labels_b)
314 | elif mask_b.sum()==0:
315 | loss = criterion(output_a, labels_a)
316 | else:
317 | loss = criterion(output_a, labels_a) + criterion(output_b, labels_b)
318 |
319 | gt_a = labels_a.cpu().numpy().tolist()
320 | preds_a = output_a.argmax(axis=1).cpu().numpy().tolist() if len(gt_a)!=0 else []
321 |
322 | gt_b = labels_b.cpu().numpy().tolist()
323 | preds_b = output_b.argmax(axis=1).cpu().numpy().tolist() if len(gt_b)!=0 else []
324 |
325 | total_preds_a += preds_a
326 | total_gt_a += gt_a
327 | total_preds_b += preds_b
328 | total_gt_b += gt_b
329 | total_loss.append(loss.item())
330 |
331 | loss = np.array(total_loss).mean()
332 | acc_a = metrics.accuracy_score(total_gt_a, total_preds_a) if len(total_gt_a)!=0 else 0
333 | f1_a = metrics.f1_score(total_gt_a, total_preds_a, zero_division=0)
334 | if (f1_a == 0):
335 | print("F1_a = 0, checking precision, recall, fscore and support...")
336 | print(metrics.precision_recall_fscore_support(total_gt_a, total_preds_a, zero_division=0))
337 |
338 | acc_b = metrics.accuracy_score(total_gt_b, total_preds_b) if len(total_gt_b)!=0 else 0
339 | f1_b = metrics.f1_score(total_gt_b, total_preds_b, zero_division=0)
340 | if (f1_b == 0):
341 | print("F1_b = 0, checking precision, recall, fscore and support...")
342 | print(metrics.precision_recall_fscore_support(total_gt_b, total_preds_b, zero_division=0))
343 |
344 | # Setting average=None to return class-specific scores
345 | # macro_f1 = metrics.f1_score(total_gt, total_preds, average='macro')
346 | # f1 = metrics.f1_score(total_gt, total_preds)
347 |
348 | # print loss and classification report
349 | print("Loss on dev set: ", loss)
350 | print("F1 on dev set: {:.6f}, f1_a: {:.6f}, f1_b: {:.6f}".format((f1_a+f1_b)/2, f1_a, f1_b))
351 |
352 | # return loss, acc, macro_f1
353 | return loss, acc_a, acc_b, f1_a, f1_b
354 |
355 |
356 | if __name__ == '__main__':
357 | config = Config()
358 | device = config.device
359 | pretrained = config.pretrained
360 | model_type = config.model_type
361 | use_fgm = config.use_fgm
362 |
363 | save_dir = config.save_dir
364 | data_dir = config.data_dir
365 | # whether to shuffle the pos of source and target to augment data
366 | shuffle_order = config.shuffle_order
367 | # whether to use the positive case in task b for task a (positives)
368 | # and to use the negativate case in task a for task b (negatives)
369 | aug_data = config.aug_data
370 | # method for clipping long seqeunces, 'head' or 'tail'
371 | clip_method = config.clip_method
372 |
373 | task_type = config.task_type
374 | task_a = ['短短匹配A类', '短长匹配A类', '长长匹配A类']
375 | task_b = ['短短匹配B类', '短长匹配B类', '长长匹配B类']
376 |
377 | # hypter parameters here
378 | epochs = config.epochs
379 | lr = config.lr
380 | classifer_lr = config.classifier_lr
381 | weight_decay = config.weight_decay
382 | hidden_size = config.hidden_size
383 | train_bs = config.train_bs
384 | eval_bs = config.eval_bs
385 |
386 | print_every = config.print_every
387 | eval_every = config.eval_every
388 |
389 | train_data_dir, dev_data_dir = [], []
390 | # integrate the two tasks into one dataset using task_type = 'ab'
391 | if 'a' in task_type:
392 | for task in task_a:
393 | train_data_dir.append(data_dir + task + '/train.txt')
394 | train_data_dir.append(data_dir + task + '/train_r2.txt')
395 | train_data_dir.append(data_dir + task + '/train_r3.txt')
396 | train_data_dir.append(data_dir + task + '/train_rematch.txt') # training file in rematch
397 | dev_data_dir.append(data_dir + task + '/valid.txt')
398 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') # valid file in rematch
399 |
400 | if 'b' in task_type:
401 | for task in task_b:
402 | train_data_dir.append(data_dir + task + '/train.txt')
403 | train_data_dir.append(data_dir + task + '/train_r2.txt')
404 | train_data_dir.append(data_dir + task + '/train_r3.txt')
405 | train_data_dir.append(data_dir + task + '/train_rematch.txt') # training file in rematch
406 | dev_data_dir.append(data_dir + task + '/valid.txt')
407 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') # valid file in rematch
408 |
409 | # toy dataset for testing
410 | if config.load_toy_dataset:
411 | train_data_dir = [
412 | '../data/sohu2021_open_data/短短匹配A类/valid.txt',
413 | '../data/sohu2021_open_data/短短匹配B类/valid.txt',
414 | '../data/sohu2021_open_data/短长匹配A类/valid.txt',
415 | '../data/sohu2021_open_data/短长匹配B类/valid.txt',
416 | '../data/sohu2021_open_data/长长匹配A类/valid.txt',
417 | '../data/sohu2021_open_data/长长匹配B类/valid.txt']
418 | dev_data_dir = ['../data/sohu2021_open_data/短短匹配A类/valid.txt',
419 | '../data/sohu2021_open_data/短短匹配B类/valid.txt',]
420 |
421 |
422 | print("Loading pretrained Model from {}...".format(pretrained))
423 | # integrating SBERT model into a unified training framework
424 | if 'sbert' in model_type.lower():
425 | print("Using SentenceBERT model and dataset")
426 | if 'nezha' in model_type.lower():
427 | model = SNEZHASingleModel(bert_dir=pretrained, hidden_size=hidden_size)
428 | else:
429 | model = SBERTSingleModel(bert_dir=pretrained, hidden_size=hidden_size)
430 | model.to(device)
431 | print("Loading Training Data...")
432 | print(train_data_dir)
433 | # augment the data with shuffle_order=True (changing order of source and target)
434 | # or with aug_data=True (neg cases in A -> B, pos cases in B -> A)
435 | train_dataset = SentencePairDatasetForSBERT(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method)
436 | train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True)
437 |
438 | print("Loading Dev Data...")
439 | test_dataset = SentencePairDatasetForSBERT(dev_data_dir, True, pretrained, clip=clip_method)
440 | test_dataloader = DataLoader(test_dataset, batch_size=eval_bs, shuffle=False)
441 |
442 | # 0517: for training, load weights from pretrained with from_pretrained=True (by default)
443 | # for larger model, adjust the hidden_size according to its config
444 | # distinguish model architectures or pretrained models according to model_type
445 | else:
446 | print("Using BERT model and dataset")
447 | if 'nezha' in model_type.lower():
448 | print("Using NEZHA pretrained model")
449 | model = NezhaClassifierSingleModel(bert_dir=pretrained, hidden_size=hidden_size)
450 | elif 'cnn' in model_type.lower():
451 | print("Adding TextCNN after BERT output")
452 | model = BertClassifierTextCNNSingleModel(bert_dir=pretrained, hidden_size=hidden_size)
453 | else:
454 | print("Using conventional BERT model with linears")
455 | model = BertClassifierSingleModel(bert_dir=pretrained, hidden_size=hidden_size)
456 | model.to(device)
457 |
458 | print("Loading Training Data...")
459 | print(train_data_dir)
460 | # augment the data with shuffle_order=True (changing order of source and target)
461 | # or with aug_data=True (neg cases in A -> B, pos cases in B -> A)
462 | train_dataset = SentencePairDatasetWithType(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method)
463 | train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True)
464 |
465 | print("Loading Dev Data...")
466 | test_dataset = SentencePairDatasetWithType(dev_data_dir, True, pretrained, clip=clip_method)
467 | test_dataloader = DataLoader(test_dataset, batch_size=eval_bs, shuffle=False)
468 |
469 | optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, correct_bias=False)
470 | # 0514 setting different lr for bert encoder and classifier
471 | # TODO: verify the large lr works for classifiers
472 | # optimizer = AdamW([
473 | # {"params": model.all_classifier.parameters(), "lr": classifer_lr},
474 | # {"params": model.bert.parameters()}],
475 | # lr=lr)
476 |
477 | # for p in optimizer.param_groups:
478 | # outputs = ''
479 | # for k, v in p.items():
480 | # if k is 'params':
481 | # outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ')
482 | # else:
483 | # outputs += (k + ': ' + str(v).ljust(10) + ' ')
484 | # print(outputs)
485 |
486 | total_steps = len(train_dataloader) * epochs
487 |
488 | # TODO: using ReduceLROnPlateau instead of linear scheduler
489 | if config.use_scheduler:
490 | scheduler = get_linear_schedule_with_warmup(
491 | optimizer,
492 | num_training_steps = total_steps,
493 | num_warmup_steps = config.num_warmup_steps,
494 | )
495 | else:
496 | scheduler = None
497 |
498 | print("Training on Task {}...".format(task_type))
499 | writer = SummaryWriter('runs/{}'.format(model_type + '_' + task_type))
500 |
501 | best_dev_loss = 999
502 | best_dev_f1 = 0
503 | for epoch in range(epochs):
504 | train_loss, train_f1, train_f1_a, train_f1_b = train(model, device, epoch, train_dataloader, test_dataloader, \
505 | save_dir, optimizer, scheduler=scheduler, model_type=model_type, \
506 | print_every=print_every, eval_every=eval_every, writer=writer, use_fgm=use_fgm)
507 |
--------------------------------------------------------------------------------
/决赛提交/sohu_matching/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch
3 | import torch.nn as nn
4 | import pandas as pd
5 |
6 | # importing files for RAdam and lookahead
7 | import math
8 | from torch.optim.optimizer import Optimizer, required
9 | from collections import defaultdict
10 | import itertools as it
11 |
12 | # from model import *
13 |
14 | def pad_to_maxlen(input_ids, max_len, pad_value=0):
15 | if len(input_ids) >= max_len:
16 | input_ids = input_ids[:max_len]
17 | else:
18 | input_ids = input_ids + [pad_value] * (max_len-len(input_ids))
19 | return input_ids
20 |
21 |
22 | def augment_data(data):
23 | B2A = pd.DataFrame()
24 | A2B = pd.DataFrame()
25 |
26 | train_A = data[(data['type'] == 0)]
27 | train_B = data[(data['type'] == 1)]
28 |
29 | B2A = B2A.append(
30 | train_B.loc[train_B['label'] == '1'], ignore_index=True)
31 | A2B = A2B.append(
32 | train_A.loc[train_A['label'] == '0'], ignore_index=True)
33 |
34 | train_aug_A = pd.concat([train_A, B2A], axis=0, ignore_index=True)
35 | train_aug_A.drop_duplicates(
36 | subset=['source', 'target'], keep='first', inplace=True, ignore_index=True)
37 |
38 | train_aug_B = pd.concat([train_B, A2B], axis=0, ignore_index=True)
39 | train_aug_B.drop_duplicates(
40 | subset=['source', 'target'], keep='first', inplace=True, ignore_index=True)
41 |
42 | train_all = pd.concat([train_aug_A, train_aug_B], axis=0, ignore_index=True)
43 | return train_all
44 |
45 |
46 | class focal_loss(nn.Module):
47 | def __init__(self, alpha=0.25, gamma=2, num_classes = 2, size_average=False):
48 | """
49 | focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
50 | 步骤详细的实现了 focal_loss损失函数.
51 | :param alpha: 阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
52 | :param gamma: 伽马γ,难易样本调节参数. retainnet中设置为2
53 | :param num_classes: 类别数量
54 | :param size_average: 损失计算方式,默认取均值
55 | """
56 | super(focal_loss,self).__init__()
57 | self.size_average = size_average
58 | if isinstance(alpha,list):
59 | assert len(alpha)==num_classes # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
60 | # print(" --- Focal_loss alpha = {}, 将对每一类权重进行精细化赋值 --- ".format(alpha))
61 | self.alpha = torch.Tensor(alpha)
62 | else:
63 | assert alpha<1 #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
64 | # print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
65 | self.alpha = torch.zeros(num_classes)
66 | self.alpha[0] += alpha
67 | self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
68 |
69 | self.gamma = gamma
70 |
71 | def forward(self, preds, labels):
72 | """
73 | focal_loss损失计算
74 | :param preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数
75 | :param labels: 实际类别. size:[B,N] or [B]
76 | :return:
77 | """
78 | # assert preds.dim()==2 and labels.dim()==1
79 | preds = preds.view(-1,preds.size(-1))
80 | self.alpha = self.alpha.to(preds.device)
81 | preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
82 | preds_softmax = torch.exp(preds_logsoft) # softmax
83 |
84 | preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll )
85 | preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
86 | self.alpha = self.alpha.gather(0,labels.view(-1))
87 | loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
88 |
89 | loss = torch.mul(self.alpha, loss.t())
90 | if self.size_average:
91 | loss = loss.mean()
92 | else:
93 | loss = loss.sum()
94 | return loss
95 |
96 |
97 | class FGM():
98 | def __init__(self, model):
99 | self.model = model
100 | self.backup = {}
101 |
102 | def attack(self, epsilon=1., emb_name='bert.embeddings.'):
103 | # emb_name这个参数要换成你模型中embedding的参数名
104 | for name, param in self.model.named_parameters():
105 | if param.requires_grad and emb_name in name:
106 | self.backup[name] = param.data.clone()
107 | norm = torch.norm(param.grad)
108 | if norm != 0 and not torch.isnan(norm):
109 | r_at = epsilon * param.grad / norm
110 | param.data.add_(r_at)
111 |
112 | def restore(self, emb_name='bert.embeddings.'):
113 | # emb_name这个参数要换成你模型中embedding的参数名
114 | for name, param in self.model.named_parameters():
115 | if param.requires_grad and emb_name in name:
116 | assert name in self.backup
117 | param.data = self.backup[name]
118 | self.backup = {}
119 |
120 |
121 | class PGD():
122 | def __init__(self, model):
123 | self.model = model
124 | self.emb_backup = {}
125 | self.grad_backup = {}
126 |
127 | def attack(self, epsilon=1., alpha=0.3, emb_name='bert.embeddings.', is_first_attack=False):
128 | # emb_name这个参数要换成你模型中embedding的参数名
129 | for name, param in self.model.named_parameters():
130 | if param.requires_grad and emb_name in name:
131 | if is_first_attack:
132 | self.emb_backup[name] = param.data.clone()
133 | norm = torch.norm(param.grad)
134 | if norm != 0 and not torch.isnan(norm):
135 | r_at = alpha * param.grad / norm
136 | param.data.add_(r_at)
137 | param.data = self.project(name, param.data, epsilon)
138 |
139 | def restore(self, emb_name='bert.embeddings.'):
140 | # emb_name这个参数要换成你模型中embedding的参数名
141 | for name, param in self.model.named_parameters():
142 | if param.requires_grad and emb_name in name:
143 | assert name in self.emb_backup
144 | param.data = self.emb_backup[name]
145 | self.emb_backup = {}
146 |
147 | def project(self, param_name, param_data, epsilon):
148 | r = param_data - self.emb_backup[param_name]
149 | if torch.norm(r) > epsilon:
150 | r = epsilon * r / torch.norm(r)
151 | return self.emb_backup[param_name] + r
152 |
153 | def backup_grad(self):
154 | for name, param in self.model.named_parameters():
155 | if param.requires_grad:
156 | # 不对最后的 bert.pooler 层和 linear1 层做对抗训练
157 | if 'encoder' in name or 'bert.embeddings.' in name:
158 | self.grad_backup[name] = param.grad.clone()
159 |
160 | def restore_grad(self):
161 | for name, param in self.model.named_parameters():
162 | if param.requires_grad:
163 | if 'encoder' in name or 'bert.embeddings.' in name:
164 | param.grad = self.grad_backup[name]
165 |
166 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py
167 | class Lookahead(Optimizer):
168 | def __init__(self, optimizer, alpha=0.5, k=6):
169 |
170 | if not 0.0 <= alpha <= 1.0:
171 | raise ValueError(f'Invalid slow update rate: {alpha}')
172 | if not 1 <= k:
173 | raise ValueError(f'Invalid lookahead steps: {k}')
174 |
175 | self.optimizer = optimizer
176 | self.param_groups = self.optimizer.param_groups
177 | self.alpha = alpha
178 | self.k = k
179 | for group in self.param_groups:
180 | group["step_counter"] = 0
181 |
182 | self.slow_weights = [
183 | [p.clone().detach() for p in group['params']]
184 | for group in self.param_groups]
185 |
186 | for w in it.chain(*self.slow_weights):
187 | w.requires_grad = False
188 | self.state = optimizer.state
189 |
190 | def step(self, closure=None):
191 | loss = None
192 | if closure is not None:
193 | loss = closure()
194 | loss = self.optimizer.step()
195 |
196 | for group,slow_weights in zip(self.param_groups,self.slow_weights):
197 | group['step_counter'] += 1
198 | if group['step_counter'] % self.k != 0:
199 | continue
200 | for p,q in zip(group['params'],slow_weights):
201 | if p.grad is None:
202 | continue
203 | q.data.add_(p.data - q.data, alpha=self.alpha )
204 | p.data.copy_(q.data)
205 | return loss
206 |
207 |
208 | class RAdam(Optimizer):
209 |
210 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
211 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
212 | self.buffer = [[None, None, None] for ind in range(10)]
213 | super(RAdam, self).__init__(params, defaults)
214 |
215 | def __setstate__(self, state):
216 | super(RAdam, self).__setstate__(state)
217 |
218 | def step(self, closure=None):
219 |
220 | loss = None
221 | if closure is not None:
222 | loss = closure()
223 |
224 | for group in self.param_groups:
225 |
226 | for p in group['params']:
227 | if p.grad is None:
228 | continue
229 | grad = p.grad.data.float()
230 | if grad.is_sparse:
231 | raise RuntimeError('RAdam does not support sparse gradients')
232 |
233 | p_data_fp32 = p.data.float()
234 |
235 | state = self.state[p]
236 |
237 | if len(state) == 0:
238 | state['step'] = 0
239 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
240 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
241 | else:
242 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
243 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
244 |
245 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
246 | beta1, beta2 = group['betas']
247 |
248 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
249 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
250 |
251 | state['step'] += 1
252 | buffered = self.buffer[int(state['step'] % 10)]
253 | if state['step'] == buffered[0]:
254 | N_sma, step_size = buffered[1], buffered[2]
255 | else:
256 | buffered[0] = state['step']
257 | beta2_t = beta2 ** state['step']
258 | N_sma_max = 2 / (1 - beta2) - 1
259 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
260 | buffered[1] = N_sma
261 |
262 | # more conservative since it's an approximated value
263 | if N_sma >= 5:
264 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
265 | else:
266 | step_size = 1.0 / (1 - beta1 ** state['step'])
267 | buffered[2] = step_size
268 |
269 | if group['weight_decay'] != 0:
270 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
271 |
272 | # more conservative since it's an approximated value
273 | if N_sma >= 5:
274 | denom = exp_avg_sq.sqrt().add_(group['eps'])
275 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
276 | else:
277 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
278 |
279 | p.data.copy_(p_data_fp32)
280 |
281 | return loss
--------------------------------------------------------------------------------