├── LICENSE ├── README.md └── 决赛提交 ├── Dockerfile ├── README.md └── sohu_matching ├── data ├── dummy_bert │ ├── config.json │ └── vocab.txt ├── dummy_ernie │ ├── config.json │ └── vocab.txt └── dummy_nezha │ ├── config.json │ └── vocab.txt ├── results └── rematch │ └── merge_final.csv └── src ├── NEZHA ├── __pycache__ │ ├── model_nezha.cpython-36.pyc │ └── nezha_utils.cpython-36.pyc ├── model_nezha.py └── nezha_utils.py ├── config.py ├── data.py ├── infer.py ├── infer_final.py ├── merge_result.py ├── model.py ├── search_better_merge.py ├── train.py ├── train_old.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sohu_text_matching 2 | 2021搜狐校园文本匹配算法大赛Top2:分比我们低的都是帅哥队 3 | 4 | 本repo包含了本次大赛决赛环节提交的代码文件,提交的模型文件可在百度网盘获取(链接:https://pan.baidu.com/s/1T9FtwiGFZhuC8qqwXKZSNA ,提取码:2333 )。 5 | 6 | 最终提交的5个模型(限制大小在2G内)在复赛测试集上的f1指标为0.78921,在决赛测试集上的f1指标为0.78123,在十组队伍中位列第二,最终取得亚军成绩。 7 | 8 | 复现复赛测试集结果,可将模型下载后放至`checkpoints/rematch`内,将测试集合并为决赛格式后,进入`src`文件夹运行`infer_final.py`并指定输入文件及输出位置即可。依赖项可参考dockerfile。 9 | -------------------------------------------------------------------------------- /决赛提交/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel 2 | 3 | # prepare your environment here 4 | ENV LANG "en_US.UTF-8" 5 | COPY . /app 6 | WORKDIR /app/sohu_matching/src 7 | 8 | # RUN pip install ... 9 | RUN pip install transformers && pip install pandas && pip install scikit-learn 10 | 11 | ENTRYPOINT ["python","infer_final.py"] -------------------------------------------------------------------------------- /决赛提交/README.md: -------------------------------------------------------------------------------- 1 | ### sohu_matching 2 | 3 | #### 小组:**分比我们低的都是帅哥** 4 | 5 | #### 决赛Docker运行说明 6 | 7 | 本项目的Docker构建过程符合提交指南要求,运行官方给出的测试命令即可进行推断: 8 | 9 | ```bash 10 | docker run --rm -it --gpus all \ 11 | -v ${TestInputDir}:/data/input \ 12 | -v ${TestOutputDir}:/data/output \ 13 | ${MyImageName} \ 14 | --input /data/input/test.txt \ 15 | --output /data/output/pred.csv 16 | ``` 17 | 18 | 基本镜像为`pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel`,构建时通过`pip install transformers && pip install pandas && pip install scikit-learn`安装依赖包即可运行。容器的WORKDIR设定为`/app/sohu_matching/src`(由于代码采取相对路径,不在此目录运行会报错;如果测试命令运行出错,可进入该目录直接`python infer_final.py`并指定输入与输出文件位置)。镜像大小约10G,其中`sohu_matching/checkpoints/rematch`存放我们的模型,总大小在2G以内,符合比赛要求。 19 | 20 | #### 简介 21 | 22 | 本项目包含了我们在2021搜狐校园文本匹配**复赛环节**的PyTorch版本代码,在复赛Public排行榜上排名第三,线上测评的F1分数为0.791075301658579,其中A类任务0.8548399419359769,B类任务0.727310661381181。 23 | 24 | 我们采用了联合训练的方式,在A、B两个任务上采用一个共同的基于预训练语言模型的encoder,而后分别为各个任务采用多组简单的全连接结构作为classifier。我们使用了不同的预训练模型(如NEZHA、MacBert、ROBERTA、ERNIE等),设计了选择了两种文本匹配的技术路线(通过[SEP]拼接source与target作为输入、类似SBERT的句子向量编码进行比较),并尝试了多种上分策略(如在给定语料上继续mlm预训练、focal loss损失函数、不同的pooling策略、加入TextCNN、fgm对抗训练、数据增强等)。我们选取了多组差异较大的模型的输出,通过投票的方式进行集成,得到最好成绩。 25 | 26 | #### 项目结构 27 | 28 | ```bash 29 | │ README.md # README 30 | │ test.yaml # conda环境配置 31 | │ # 基本上安装pytorch>=1.6和transformer即可复现 32 | ├─checkpoints # 用于保存模型 33 | ├─data 34 | │ └─dummy_bert # 包含BERT\ERNIE\NEZHA的分词词表及config.json 35 | │ └─dummy_ernie # 用于模型推断时从config文件定义模型,不加载原预训练权重 36 | │ └─dummy_nezha 37 | │ └─sohu2021_open_data # 包含初赛及复赛的训练、评估和测试数据 38 | │ ├─短短匹配A类 # 包括train.txt, train_r2.txt, train_r3.txt, train_rematch.txt 39 | │ ├─短短匹配B类 # valid.txt, valid_rematch.txt, test_with_id_rematch.txt 40 | │ ├─短长匹配A类 41 | │ ├─短长匹配B类 42 | │ ├─长长匹配A类 43 | │ └─长长匹配B类 44 | ├─logs # 用于保存日志,例:python train.py > log_dir 45 | ├─results # 用于保存测试集推理结果 46 | ├─valid_output # 记录模型在valid上的输出,并计算各类f1 47 | └─src # 主要代码文件夹 48 | │ config.py # 模型与训练等参数统一通过config.py设置 49 | │ data.py # 数据读取,DataLoader等 50 | │ infer.py # 测试集推理代码 51 | │ merge_result.py # 用于投票集成 52 | │ model.py # 模型定义 53 | │ search_better_merge.py # 在验证集输出上寻找最优投票组合 54 | │ train.py # 训练代码,支持多任务形式(更改model中的num_task) 55 | │ train_old.py # 训练代码,仅支持A\B两任务,复赛中主要使用该方式训练模型 56 | │ utils.py # 其他函数等 57 | │ 58 | ├─new_runs # tensorboard事件目录,用于可视化损失函数等指标 59 | ├─NEZHA # nezha相关的模型结构定义等 60 | │ │ model_nezha.py 61 | │ │ nezha_utils.py 62 | └─__pycache__ 63 | ``` 64 | 65 | #### 运行示例 66 | (备注:决赛提交中针对A\B类测试样本在同一个文件中的情况略微修改了`data.py`,直接运行`train_old.py`可能会有错误) 67 | 补充训练数据后,在`config.py`文件中设置训练相关参数,进入到src文件夹下,运行`train_old.py`进行训练(在复赛中,我们尝试了为6个子任务分别设置分类网络的形式,统一在`train.py`中,但对于A\B两任务的情况,初赛训练代码方式效果似乎更加,因此我们在`train_old.py`中保留了原方式,并作为主要训练代码;默认多卡训练,在`train_old.py`调整设备卡数),可通过重定向将输出保存为日志。训练结束后,在`config.py`中设置推理相关参数,进入到src文件夹下,运行`infer.py`进行推理(默认多卡推理,在`infer.py`调整设备卡数)。 68 | 69 | ```bash 70 | python train_old.py > ../logs/0523/0523_roberta_80k.log # 训练并保存输出日志 71 | python infer.py # 推理 72 | ``` 73 | 74 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/data/dummy_bert/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "directionality": "bidi", 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 3072, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "output_past": true, 18 | "pad_token_id": 0, 19 | "pooler_fc_size": 768, 20 | "pooler_num_attention_heads": 12, 21 | "pooler_num_fc_layers": 3, 22 | "pooler_size_per_head": 128, 23 | "pooler_type": "first_token_transform", 24 | "type_vocab_size": 2, 25 | "vocab_size": 21128 26 | } 27 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/data/dummy_ernie/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "gradient_checkpointing": false, 4 | "hidden_act": "relu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "layer_norm_eps": 1e-05, 10 | "max_position_embeddings": 513, 11 | "model_type": "bert", 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "pad_token_id": 0, 15 | "type_vocab_size": 2, 16 | "vocab_size": 18000 17 | } 18 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/data/dummy_nezha/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 21128, 13 | "use_relative_position": true, 14 | "model_type": "bert" 15 | } 16 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/NEZHA/__pycache__/model_nezha.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Decem-Y/sohu_text_matching_Rank2/4d87c85b6de65fda777b15f1d7e37af74a3033b9/决赛提交/sohu_matching/src/NEZHA/__pycache__/model_nezha.cpython-36.pyc -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/NEZHA/__pycache__/nezha_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Decem-Y/sohu_text_matching_Rank2/4d87c85b6de65fda777b15f1d7e37af74a3033b9/决赛提交/sohu_matching/src/NEZHA/__pycache__/nezha_utils.cpython-36.pyc -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/NEZHA/model_nezha.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team and Huawei Noah's Ark Lab. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import copy 20 | import json 21 | import logging 22 | import math 23 | import sys 24 | from io import open 25 | 26 | import numpy as np 27 | 28 | import torch 29 | from torch import nn 30 | from torch.nn import CrossEntropyLoss 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def gelu(x): 36 | """Implementation of the gelu activation function. 37 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 38 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 39 | Also see https://arxiv.org/abs/1606.08415 40 | """ 41 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 42 | 43 | 44 | def swish(x): 45 | return x * torch.sigmoid(x) 46 | 47 | 48 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 49 | 50 | 51 | class NezhaConfig(object): 52 | """Configuration class to store the configuration of a `BertModel`. 53 | """ 54 | 55 | def __init__(self, 56 | vocab_size_or_config_json_file, 57 | hidden_size=768, 58 | num_hidden_layers=12, 59 | num_attention_heads=12, 60 | intermediate_size=3072, 61 | hidden_act="gelu", 62 | hidden_dropout_prob=0.1, 63 | attention_probs_dropout_prob=0.1, 64 | max_position_embeddings=512, 65 | max_relative_position=64, 66 | type_vocab_size=2, 67 | initializer_range=0.02, 68 | layer_norm_eps=1e-12): 69 | """Constructs NezhaConfig. 70 | 71 | Args: 72 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 73 | hidden_size: Size of the encoder layers and the pooler layer. 74 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 75 | num_attention_heads: Number of attention heads for each attention layer in 76 | the Transformer encoder. 77 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 78 | layer in the Transformer encoder. 79 | hidden_act: The non-linear activation function (function or string) in the 80 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 81 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 82 | layers in the embeddings, encoder, and pooler. 83 | attention_probs_dropout_prob: The dropout ratio for the attention 84 | probabilities. 85 | max_position_embeddings: The maximum sequence length that this model might 86 | ever be used with. Typically set this to something large just in case 87 | (e.g., 512 or 1024 or 2048). 88 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 89 | `BertModel`. 90 | initializer_range: The sttdev of the truncated_normal_initializer for 91 | initializing all weight matrices. 92 | layer_norm_eps: The epsilon used by LayerNorm. 93 | """ 94 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 95 | and isinstance(vocab_size_or_config_json_file, unicode)): 96 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 97 | json_config = json.loads(reader.read()) 98 | for key, value in json_config.items(): 99 | self.__dict__[key] = value 100 | elif isinstance(vocab_size_or_config_json_file, int): 101 | self.vocab_size = vocab_size_or_config_json_file 102 | self.hidden_size = hidden_size 103 | self.num_hidden_layers = num_hidden_layers 104 | self.num_attention_heads = num_attention_heads 105 | self.hidden_act = hidden_act 106 | self.intermediate_size = intermediate_size 107 | self.hidden_dropout_prob = hidden_dropout_prob 108 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 109 | self.max_position_embeddings = max_position_embeddings 110 | self.max_relative_position = max_relative_position 111 | self.type_vocab_size = type_vocab_size 112 | self.initializer_range = initializer_range 113 | self.layer_norm_eps = layer_norm_eps 114 | else: 115 | raise ValueError("First argument must be either a vocabulary size (int)" 116 | "or the path to a pretrained model config file (str)") 117 | 118 | @classmethod 119 | def from_dict(cls, json_object): 120 | """Constructs a `NezhaConfig` from a Python dictionary of parameters.""" 121 | config = NezhaConfig(vocab_size_or_config_json_file=-1) 122 | for key, value in json_object.items(): 123 | config.__dict__[key] = value 124 | return config 125 | 126 | @classmethod 127 | def from_json_file(cls, json_file): 128 | """Constructs a `NezhaConfig` from a json file of parameters.""" 129 | with open(json_file, "r", encoding='utf-8') as reader: 130 | text = reader.read() 131 | return cls.from_dict(json.loads(text)) 132 | 133 | def __repr__(self): 134 | return str(self.to_json_string()) 135 | 136 | def to_dict(self): 137 | """Serializes this instance to a Python dictionary.""" 138 | output = copy.deepcopy(self.__dict__) 139 | return output 140 | 141 | def to_json_string(self): 142 | """Serializes this instance to a JSON string.""" 143 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 144 | 145 | def to_json_file(self, json_file_path): 146 | """ Save this instance to a json file.""" 147 | with open(json_file_path, "w", encoding='utf-8') as writer: 148 | writer.write(self.to_json_string()) 149 | 150 | 151 | try: 152 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 153 | except ImportError: 154 | logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") 155 | 156 | 157 | class BertLayerNorm(nn.Module): 158 | def __init__(self, hidden_size, eps=1e-12): 159 | """Construct a layernorm module in the TF style (epsilon inside the square root). 160 | """ 161 | super(BertLayerNorm, self).__init__() 162 | self.weight = nn.Parameter(torch.ones(hidden_size)) 163 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 164 | self.variance_epsilon = eps 165 | 166 | def forward(self, x): 167 | u = x.mean(-1, keepdim=True) 168 | s = (x - u).pow(2).mean(-1, keepdim=True) 169 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 170 | return self.weight * x + self.bias 171 | 172 | 173 | class BertEmbeddings(nn.Module): 174 | """Construct the embeddings from word, position and token_type embeddings. 175 | """ 176 | 177 | def __init__(self, config): 178 | super(BertEmbeddings, self).__init__() 179 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 180 | try: 181 | self.use_relative_position = config.use_relative_position 182 | except: 183 | self.use_relative_position = False 184 | if not self.use_relative_position: 185 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 186 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 187 | 188 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 189 | # any TensorFlow checkpoint file 190 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 191 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 192 | 193 | def forward(self, input_ids, token_type_ids=None): 194 | seq_length = input_ids.size(1) 195 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 196 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 197 | if token_type_ids is None: 198 | token_type_ids = torch.zeros_like(input_ids) 199 | 200 | words_embeddings = self.word_embeddings(input_ids) 201 | embeddings = words_embeddings 202 | if not self.use_relative_position: 203 | position_embeddings = self.position_embeddings(position_ids) 204 | embeddings += position_embeddings 205 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 206 | embeddings += token_type_embeddings 207 | embeddings = self.LayerNorm(embeddings) 208 | embeddings = self.dropout(embeddings) 209 | return embeddings 210 | 211 | 212 | class BertSelfAttention(nn.Module): 213 | def __init__(self, config): 214 | super(BertSelfAttention, self).__init__() 215 | if config.hidden_size % config.num_attention_heads != 0: 216 | raise ValueError( 217 | "The hidden size (%d) is not a multiple of the number of attention " 218 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 219 | self.num_attention_heads = config.num_attention_heads 220 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 221 | self.all_head_size = self.num_attention_heads * self.attention_head_size 222 | 223 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 224 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 225 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 226 | 227 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 228 | 229 | def transpose_for_scores(self, x): 230 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 231 | x = x.view(*new_x_shape) 232 | return x.permute(0, 2, 1, 3) 233 | 234 | def forward(self, hidden_states, attention_mask): 235 | mixed_query_layer = self.query(hidden_states) 236 | mixed_key_layer = self.key(hidden_states) 237 | mixed_value_layer = self.value(hidden_states) 238 | 239 | query_layer = self.transpose_for_scores(mixed_query_layer) 240 | key_layer = self.transpose_for_scores(mixed_key_layer) 241 | value_layer = self.transpose_for_scores(mixed_value_layer) 242 | 243 | # Take the dot product between "query" and "key" to get the raw attention scores. 244 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 245 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 246 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 247 | attention_scores = attention_scores + attention_mask 248 | 249 | # Normalize the attention scores to probabilities. 250 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 251 | 252 | # This is actually dropping out entire tokens to attend to, which might 253 | # seem a bit unusual, but is taken from the original Transformer paper. 254 | attention_probs = self.dropout(attention_probs) 255 | 256 | context_layer = torch.matmul(attention_probs, value_layer) 257 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 258 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 259 | context_layer = context_layer.view(*new_context_layer_shape) 260 | return context_layer, attention_scores 261 | 262 | 263 | def _generate_relative_positions_matrix(length, max_relative_position, 264 | cache=False): 265 | """Generates matrix of relative positions between inputs.""" 266 | if not cache: 267 | range_vec = torch.arange(length) 268 | range_mat = range_vec.repeat(length).view(length, length) 269 | distance_mat = range_mat - torch.t(range_mat) 270 | else: 271 | distance_mat = torch.arange(-length + 1, 1, 1).unsqueeze(0) 272 | 273 | distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position) 274 | final_mat = distance_mat_clipped + max_relative_position 275 | 276 | return final_mat 277 | 278 | 279 | def _generate_relative_positions_embeddings(length, depth, max_relative_position=127): 280 | vocab_size = max_relative_position * 2 + 1 281 | range_vec = torch.arange(length) 282 | range_mat = range_vec.repeat(length).view(length, length) 283 | distance_mat = range_mat - torch.t(range_mat) 284 | distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position) 285 | final_mat = distance_mat_clipped + max_relative_position 286 | embeddings_table = np.zeros([vocab_size, depth]) 287 | for pos in range(vocab_size): 288 | for i in range(depth // 2): 289 | embeddings_table[pos, 2 * i] = np.sin(pos / np.power(10000, 2 * i / depth)) 290 | embeddings_table[pos, 2 * i + 1] = np.cos(pos / np.power(10000, 2 * i / depth)) 291 | 292 | embeddings_table_tensor = torch.tensor(embeddings_table).float() 293 | flat_relative_positions_matrix = final_mat.view(-1) 294 | one_hot_relative_positions_matrix = torch.nn.functional.one_hot(flat_relative_positions_matrix, 295 | num_classes=vocab_size).float() 296 | embeddings = torch.matmul(one_hot_relative_positions_matrix, embeddings_table_tensor) 297 | my_shape = list(final_mat.size()) 298 | my_shape.append(depth) 299 | embeddings = embeddings.view(my_shape) 300 | return embeddings 301 | 302 | 303 | class NeZhaSelfAttention(nn.Module): 304 | def __init__(self, config): 305 | super(NeZhaSelfAttention, self).__init__() 306 | if config.hidden_size % config.num_attention_heads != 0: 307 | raise ValueError( 308 | "The hidden size (%d) is not a multiple of the number of attention " 309 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 310 | self.num_attention_heads = config.num_attention_heads 311 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 312 | self.all_head_size = self.num_attention_heads * self.attention_head_size 313 | 314 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 315 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 316 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 317 | # self.relative_positions_embeddings = _generate_relative_positions_embeddings( 318 | # length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position).to( 319 | # self.query.weight.device) 320 | self.relative_positions_embeddings = _generate_relative_positions_embeddings( 321 | length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position).cuda() 322 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 323 | 324 | def transpose_for_scores(self, x): 325 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 326 | x = x.view(*new_x_shape) 327 | return x.permute(0, 2, 1, 3) 328 | 329 | def forward(self, hidden_states, attention_mask): 330 | device = 'cpu' 331 | if hidden_states.is_cuda: 332 | device = hidden_states.get_device() 333 | mixed_query_layer = self.query(hidden_states) 334 | mixed_key_layer = self.key(hidden_states) 335 | mixed_value_layer = self.value(hidden_states) 336 | 337 | query_layer = self.transpose_for_scores(mixed_query_layer) 338 | key_layer = self.transpose_for_scores(mixed_key_layer) 339 | value_layer = self.transpose_for_scores(mixed_value_layer) 340 | 341 | # Take the dot product between "query" and "key" to get the raw attention scores. 342 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 343 | batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size() 344 | 345 | relations_keys = self.relative_positions_embeddings.detach().clone()[:to_seq_length, :to_seq_length, :].to( 346 | device) 347 | # relations_keys = embeddings.clone().detach().to(device) 348 | query_layer_t = query_layer.permute(2, 0, 1, 3) 349 | query_layer_r = query_layer_t.contiguous().view(from_seq_length, batch_size * num_attention_heads, 350 | self.attention_head_size) 351 | key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1)) 352 | key_position_scores_r = key_position_scores.view(from_seq_length, batch_size, 353 | num_attention_heads, from_seq_length) 354 | key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3) 355 | attention_scores = attention_scores + key_position_scores_r_t 356 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 357 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 358 | attention_scores = attention_scores + attention_mask 359 | 360 | # Normalize the attention scores to probabilities. 361 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 362 | 363 | # This is actually dropping out entire tokens to attend to, which might 364 | # seem a bit unusual, but is taken from the original Transformer paper. 365 | attention_probs = self.dropout(attention_probs) 366 | 367 | context_layer = torch.matmul(attention_probs, value_layer) 368 | 369 | relations_values = self.relative_positions_embeddings.clone()[:to_seq_length, :to_seq_length, :].to( 370 | device) 371 | attention_probs_t = attention_probs.permute(2, 0, 1, 3) 372 | attentions_probs_r = attention_probs_t.contiguous().view(from_seq_length, batch_size * num_attention_heads, 373 | to_seq_length) 374 | value_position_scores = torch.matmul(attentions_probs_r, relations_values) 375 | value_position_scores_r = value_position_scores.view(from_seq_length, batch_size, 376 | num_attention_heads, self.attention_head_size) 377 | value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3) 378 | context_layer = context_layer + value_position_scores_r_t 379 | 380 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 381 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 382 | context_layer = context_layer.view(*new_context_layer_shape) 383 | return context_layer, attention_scores 384 | 385 | 386 | class BertSelfOutput(nn.Module): 387 | def __init__(self, config): 388 | super(BertSelfOutput, self).__init__() 389 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 390 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 391 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 392 | 393 | def forward(self, hidden_states, input_tensor): 394 | hidden_states = self.dense(hidden_states) 395 | hidden_states = self.dropout(hidden_states) 396 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 397 | return hidden_states 398 | 399 | 400 | class BertAttention(nn.Module): 401 | def __init__(self, config): 402 | super(BertAttention, self).__init__() 403 | try: 404 | self.use_relative_position = config.use_relative_position 405 | except: 406 | self.use_relative_position = False 407 | if self.use_relative_position: 408 | self.self = NeZhaSelfAttention(config) 409 | else: 410 | self.self = BertSelfAttention(config) 411 | 412 | self.output = BertSelfOutput(config) 413 | 414 | def forward(self, input_tensor, attention_mask): 415 | self_output = self.self(input_tensor, attention_mask) 416 | self_output, layer_att = self_output 417 | attention_output = self.output(self_output, input_tensor) 418 | return attention_output, layer_att 419 | 420 | 421 | class BertIntermediate(nn.Module): 422 | def __init__(self, config): 423 | super(BertIntermediate, self).__init__() 424 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 425 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 426 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 427 | else: 428 | self.intermediate_act_fn = config.hidden_act 429 | 430 | def forward(self, hidden_states): 431 | hidden_states = self.dense(hidden_states) 432 | hidden_states = self.intermediate_act_fn(hidden_states) 433 | return hidden_states 434 | 435 | 436 | class BertOutput(nn.Module): 437 | def __init__(self, config): 438 | super(BertOutput, self).__init__() 439 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 440 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 441 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 442 | 443 | def forward(self, hidden_states, input_tensor): 444 | hidden_states = self.dense(hidden_states) 445 | hidden_states = self.dropout(hidden_states) 446 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 447 | return hidden_states 448 | 449 | 450 | class BertLayer(nn.Module): 451 | def __init__(self, config): 452 | super(BertLayer, self).__init__() 453 | self.attention = BertAttention(config) 454 | self.intermediate = BertIntermediate(config) 455 | self.output = BertOutput(config) 456 | 457 | def forward(self, hidden_states, attention_mask): 458 | attention_output = self.attention(hidden_states, attention_mask) 459 | attention_output, layer_att = attention_output 460 | intermediate_output = self.intermediate(attention_output) 461 | layer_output = self.output(intermediate_output, attention_output) 462 | return layer_output, layer_att 463 | 464 | 465 | class BertEncoder(nn.Module): 466 | def __init__(self, config): 467 | super(BertEncoder, self).__init__() 468 | layer = BertLayer(config) 469 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 470 | 471 | def forward(self, hidden_states, attention_mask): 472 | all_encoder_layers = [] 473 | all_encoder_att = [] 474 | for i, layer_module in enumerate(self.layer): 475 | all_encoder_layers.append(hidden_states) 476 | hidden_states = layer_module(all_encoder_layers[i], attention_mask) 477 | hidden_states, layer_att = hidden_states 478 | all_encoder_att.append(layer_att) 479 | all_encoder_layers.append(hidden_states) 480 | return all_encoder_layers, all_encoder_att 481 | 482 | 483 | class BertPooler(nn.Module): 484 | def __init__(self, config): 485 | super(BertPooler, self).__init__() 486 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 487 | self.activation = nn.Tanh() 488 | 489 | def forward(self, hidden_states): 490 | # We "pool" the model by simply taking the hidden state corresponding 491 | # to the first token. 492 | first_token_tensor = hidden_states[:, 0] 493 | pooled_output = self.dense(first_token_tensor) 494 | pooled_output = self.activation(pooled_output) 495 | return pooled_output 496 | 497 | 498 | class BertPreTrainedModel(nn.Module): 499 | """ An abstract class to handle weights initialization and 500 | a simple interface for dowloading and loading pretrained models. 501 | """ 502 | 503 | def __init__(self, config, *inputs, **kwargs): 504 | super(BertPreTrainedModel, self).__init__() 505 | if not isinstance(config, NezhaConfig): 506 | raise ValueError( 507 | "Parameter config in `{}(config)` should be an instance of class `NezhaConfig`. " 508 | "To create a model from a Google pretrained model use " 509 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 510 | self.__class__.__name__, self.__class__.__name__ 511 | )) 512 | self.config = config 513 | 514 | def init_bert_weights(self, module): 515 | """ Initialize the weights. 516 | """ 517 | if isinstance(module, (nn.Linear, nn.Embedding)): 518 | # Slightly different from the TF version which uses truncated_normal for initialization 519 | # cf https://github.com/pytorch/pytorch/pull/5617 520 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 521 | elif isinstance(module, BertLayerNorm): 522 | module.bias.data.zero_() 523 | module.weight.data.fill_(1.0) 524 | if isinstance(module, nn.Linear) and module.bias is not None: 525 | module.bias.data.zero_() 526 | 527 | 528 | class NEZHAModel(BertPreTrainedModel): 529 | def __init__(self, config): 530 | super(NEZHAModel, self).__init__(config) 531 | self.embeddings = BertEmbeddings(config) 532 | self.encoder = BertEncoder(config) 533 | self.pooler = BertPooler(config) 534 | 535 | self.apply(self.init_bert_weights) 536 | 537 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_attention_mask=False, 538 | model_distillation=False, output_all_encoded_layers=False): 539 | if attention_mask is None: 540 | attention_mask = torch.ones_like(input_ids) 541 | if token_type_ids is None: 542 | token_type_ids = torch.zeros_like(input_ids) 543 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 544 | # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 545 | extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) # fp16 compatibility 546 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 547 | 548 | embedding_output = self.embeddings(input_ids, token_type_ids) 549 | encoded_layers = self.encoder(embedding_output, 550 | extended_attention_mask) 551 | encoded_layers, attention_layers = encoded_layers 552 | sequence_output = encoded_layers[-1] 553 | pooled_output = self.pooler(sequence_output) 554 | if output_attention_mask: 555 | return encoded_layers, attention_layers, pooled_output, extended_attention_mask 556 | if model_distillation: 557 | return encoded_layers, attention_layers 558 | if not output_all_encoded_layers: 559 | encoded_layers = encoded_layers[-1] 560 | return encoded_layers, pooled_output 561 | 562 | 563 | class BertPredictionHeadTransform(nn.Module): 564 | def __init__(self, config): 565 | super(BertPredictionHeadTransform, self).__init__() 566 | # Need to unty it when we separate the dimensions of hidden and emb 567 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 568 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 569 | self.transform_act_fn = ACT2FN[config.hidden_act] 570 | else: 571 | self.transform_act_fn = config.hidden_act 572 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 573 | 574 | def forward(self, hidden_states): 575 | hidden_states = self.dense(hidden_states) 576 | hidden_states = self.transform_act_fn(hidden_states) 577 | hidden_states = self.LayerNorm(hidden_states) 578 | return hidden_states 579 | 580 | 581 | class BertLMPredictionHead(nn.Module): 582 | def __init__(self, config, bert_model_embedding_weights): 583 | super(BertLMPredictionHead, self).__init__() 584 | self.transform = BertPredictionHeadTransform(config) 585 | 586 | # The output weights are the same as the input embeddings, but there is 587 | # an output-only bias for each token. 588 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 589 | bert_model_embedding_weights.size(0), 590 | bias=False) 591 | self.decoder.weight = bert_model_embedding_weights 592 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 593 | 594 | def forward(self, hidden_states): 595 | hidden_states = self.transform(hidden_states) 596 | hidden_states = self.decoder(hidden_states) + self.bias 597 | return hidden_states 598 | 599 | 600 | class BertOnlyMLMHead(nn.Module): 601 | def __init__(self, config, bert_model_embedding_weights): 602 | super(BertOnlyMLMHead, self).__init__() 603 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 604 | 605 | def forward(self, sequence_output): 606 | prediction_scores = self.predictions(sequence_output) 607 | return prediction_scores 608 | 609 | 610 | class BertOnlyNSPHead(nn.Module): 611 | def __init__(self, config): 612 | super(BertOnlyNSPHead, self).__init__() 613 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 614 | 615 | def forward(self, pooled_output): 616 | seq_relationship_score = self.seq_relationship(pooled_output) 617 | return seq_relationship_score 618 | 619 | 620 | class BertPreTrainingHeads(nn.Module): 621 | def __init__(self, config, bert_model_embedding_weights): 622 | super(BertPreTrainingHeads, self).__init__() 623 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 624 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 625 | 626 | def forward(self, sequence_output, pooled_output): 627 | prediction_scores = self.predictions(sequence_output) 628 | seq_relationship_score = self.seq_relationship(pooled_output) 629 | return prediction_scores, seq_relationship_score 630 | 631 | 632 | class BertForPreTraining(BertPreTrainedModel): 633 | """BERT model with pre-training heads. 634 | This module comprises the BERT model followed by the two pre-training heads: 635 | - the masked language modeling head, and 636 | - the next sentence classification head. 637 | 638 | Params: 639 | config: a NezhaConfig class instance with the configuration to build a new model. 640 | 641 | Inputs: 642 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 643 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 644 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 645 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 646 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 647 | a `sentence B` token (see BERT paper for more details). 648 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 649 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 650 | input sequence length in the current batch. It's the mask that we typically use for attention when 651 | a batch has varying length sentences. 652 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 653 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 654 | is only computed for the labels set in [0, ..., vocab_size] 655 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] 656 | with indices selected in [0, 1]. 657 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 658 | 659 | Outputs: 660 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 661 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 662 | sentence classification loss. 663 | if `masked_lm_labels` or `next_sentence_label` is `None`: 664 | Outputs a tuple comprising 665 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 666 | - the next sentence classification logits of shape [batch_size, 2]. 667 | 668 | Example usage: 669 | ```python 670 | # Already been converted into WordPiece token ids 671 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 672 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 673 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 674 | 675 | config = NezhaConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 676 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 677 | 678 | model = BertForPreTraining(config) 679 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 680 | ``` 681 | """ 682 | 683 | def __init__(self, config): 684 | super(BertForPreTraining, self).__init__(config) 685 | self.bert = NEZHAModel(config) 686 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 687 | self.apply(self.init_bert_weights) 688 | 689 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, 690 | masked_lm_labels=None, next_sentence_label=None): 691 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 692 | output_all_encoded_layers=False) 693 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 694 | 695 | if masked_lm_labels is not None and next_sentence_label is not None: 696 | loss_fct = CrossEntropyLoss(ignore_index=-1) 697 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 698 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 699 | total_loss = masked_lm_loss + next_sentence_loss 700 | return total_loss 701 | elif masked_lm_labels is not None: 702 | loss_fct = CrossEntropyLoss(ignore_index=-1) 703 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 704 | total_loss = masked_lm_loss 705 | return total_loss 706 | else: 707 | return prediction_scores, seq_relationship_score 708 | 709 | 710 | class BertForMaskedLM(BertPreTrainedModel): 711 | """BERT model with the masked language modeling head. 712 | This module comprises the BERT model followed by the masked language modeling head. 713 | 714 | Params: 715 | config: a NezhaConfig class instance with the configuration to build a new model. 716 | 717 | Inputs: 718 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 719 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 720 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 721 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 722 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 723 | a `sentence B` token (see BERT paper for more details). 724 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 725 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 726 | input sequence length in the current batch. It's the mask that we typically use for attention when 727 | a batch has varying length sentences. 728 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 729 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 730 | is only computed for the labels set in [0, ..., vocab_size] 731 | 732 | Outputs: 733 | if `masked_lm_labels` is not `None`: 734 | Outputs the masked language modeling loss. 735 | if `masked_lm_labels` is `None`: 736 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 737 | 738 | Example usage: 739 | ```python 740 | # Already been converted into WordPiece token ids 741 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 742 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 743 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 744 | 745 | config = NezhaConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 746 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 747 | 748 | model = BertForMaskedLM(config) 749 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 750 | ``` 751 | """ 752 | 753 | def __init__(self, config): 754 | super(BertForMaskedLM, self).__init__(config) 755 | self.bert = NEZHAModel(config) 756 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 757 | self.apply(self.init_bert_weights) 758 | 759 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, 760 | output_att=False, infer=False): 761 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, 762 | output_all_encoded_layers=True, output_att=output_att) 763 | 764 | if output_att: 765 | sequence_output, att_output = sequence_output 766 | prediction_scores = self.cls(sequence_output[-1]) 767 | 768 | if masked_lm_labels is not None: 769 | loss_fct = CrossEntropyLoss(ignore_index=-1) 770 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 771 | if not output_att: 772 | return masked_lm_loss 773 | else: 774 | return masked_lm_loss, att_output 775 | else: 776 | if not output_att: 777 | return prediction_scores 778 | else: 779 | return prediction_scores, att_output 780 | 781 | 782 | class BertForSequenceClassification(BertPreTrainedModel): 783 | """BERT model for classification. 784 | This module is composed of the BERT model with a linear layer on top of 785 | the pooled output. 786 | 787 | Params: 788 | `config`: a NezhaConfig class instance with the configuration to build a new model. 789 | `num_labels`: the number of classes for the classifier. Default = 2. 790 | 791 | Inputs: 792 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 793 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 794 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 795 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 796 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 797 | a `sentence B` token (see BERT paper for more details). 798 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 799 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 800 | input sequence length in the current batch. It's the mask that we typically use for attention when 801 | a batch has varying length sentences. 802 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 803 | with indices selected in [0, ..., num_labels]. 804 | 805 | Outputs: 806 | if `labels` is not `None`: 807 | Outputs the CrossEntropy classification loss of the output with the labels. 808 | if `labels` is `None`: 809 | Outputs the classification logits of shape [batch_size, num_labels]. 810 | 811 | Example usage: 812 | ```python 813 | # Already been converted into WordPiece token ids 814 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 815 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 816 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 817 | 818 | config = NezhaConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 819 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 820 | 821 | num_labels = 2 822 | 823 | model = BertForSequenceClassification(config, num_labels) 824 | logits = model(input_ids, token_type_ids, input_mask) 825 | ``` 826 | """ 827 | 828 | def __init__(self, config, num_labels): 829 | super(BertForSequenceClassification, self).__init__(config) 830 | self.num_labels = num_labels 831 | self.bert = NEZHAModel(config) 832 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 833 | self.classifier = nn.Linear(config.hidden_size, num_labels) 834 | self.apply(self.init_bert_weights) 835 | 836 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 837 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 838 | output_all_encoded_layers=False) 839 | task_output = self.dropout(pooled_output) 840 | logits = self.classifier(task_output) 841 | if labels is not None: 842 | loss_fct = CrossEntropyLoss() 843 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 844 | return loss 845 | else: 846 | return logits 847 | 848 | 849 | class NeZhaForMultipleChoice(BertPreTrainedModel): 850 | def __init__(self, config, num_choices=2): 851 | super(NeZhaForMultipleChoice, self).__init__(config) 852 | self.num_choices = num_choices 853 | self.bert = NEZHAModel(config) 854 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 855 | self.classifier = nn.Linear(config.hidden_size, 1) 856 | self.apply(self.init_bert_weights) 857 | 858 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, return_logits=False): 859 | # input_ids: [bs,num_choice,seq_l] 860 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) # flat_input_ids: [bs*num_choice,seq_l] 861 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 862 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 863 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, 864 | output_all_encoded_layers=False) 865 | pooled_output = self.dropout(pooled_output) 866 | logits = self.classifier(pooled_output) # logits: (bs*num_choice,1) 867 | reshaped_logits = logits.view(-1, self.num_choices) # logits: (bs, num_choice) 868 | 869 | if labels is not None: 870 | loss_fct = CrossEntropyLoss() 871 | loss = loss_fct(reshaped_logits, labels) 872 | if return_logits: 873 | return loss, reshaped_logits 874 | else: 875 | return loss 876 | else: 877 | return reshaped_logits 878 | 879 | 880 | class NeZhaForQuestionAnswering(BertPreTrainedModel): 881 | def __init__(self, config): 882 | super(NeZhaForQuestionAnswering, self).__init__(config) 883 | self.bert = NEZHAModel(config) 884 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 885 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 886 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 887 | self.apply(self.init_bert_weights) 888 | 889 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 890 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 891 | logits = self.qa_outputs(sequence_output) 892 | start_logits, end_logits = logits.split(1, dim=-1) 893 | start_logits = start_logits.squeeze(-1) 894 | end_logits = end_logits.squeeze(-1) 895 | 896 | if start_positions is not None and end_positions is not None: 897 | # If we are on multi-GPU, split add a dimension 898 | if len(start_positions.size()) > 1: 899 | start_positions = start_positions.squeeze(-1) 900 | if len(end_positions.size()) > 1: 901 | end_positions = end_positions.squeeze(-1) 902 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 903 | ignored_index = start_logits.size(1) 904 | start_positions.clamp_(0, ignored_index) 905 | end_positions.clamp_(0, ignored_index) 906 | 907 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 908 | start_loss = loss_fct(start_logits, start_positions) 909 | end_loss = loss_fct(end_logits, end_positions) 910 | total_loss = (start_loss + end_loss) / 2 911 | return total_loss 912 | else: 913 | return start_logits, end_logits 914 | 915 | 916 | class BertForJointLSTM(BertPreTrainedModel): 917 | def __init__(self, config, num_intent_labels, num_slot_labels): 918 | super(BertForJointLSTM, self).__init__(config) 919 | self.num_intent_labels = num_intent_labels 920 | self.num_slot_labels = num_slot_labels 921 | self.bert = NEZHAModel(config) 922 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 923 | self.intent_classifier = nn.Linear(config.hidden_size, num_intent_labels) 924 | self.lstm = nn.LSTM( 925 | input_size=config.hidden_size, 926 | hidden_size=300, 927 | batch_first=True, 928 | bidirectional=True 929 | 930 | ) 931 | self.slot_classifier = nn.Linear(300 * 2, num_slot_labels) 932 | self.apply(self.init_bert_weights) 933 | 934 | def forward(self, input_ids, token_type_ids=None, 935 | attention_mask=None, intent_labels=None, slot_labels=None): 936 | encoded_layers, attention_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 937 | intent_logits = self.intent_classifier(self.dropout(pooled_output)) 938 | 939 | last_encoded_layer = encoded_layers[-1] 940 | slot_logits, _ = self.lstm(last_encoded_layer) 941 | slot_logits = self.slot_classifier(slot_logits) 942 | tmp = [] 943 | if intent_labels is not None and slot_labels is not None: 944 | loss_fct = CrossEntropyLoss() 945 | intent_loss = loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_labels.view(-1)) 946 | if attention_mask is not None: 947 | active_slot_loss = attention_mask.view(-1) == 1 948 | active_slot_logits = slot_logits.view(-1, self.num_slot_labels)[active_slot_loss] 949 | active_slot_labels = slot_labels.view(-1)[active_slot_loss] 950 | slot_loss = loss_fct(active_slot_logits, active_slot_labels) 951 | else: 952 | slot_loss = loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels.view(-1)) 953 | 954 | return intent_loss, slot_loss 955 | else: 956 | return intent_logits, slot_logits 957 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/NEZHA/nezha_utils.py: -------------------------------------------------------------------------------- 1 | # /usr/bin/env python 2 | # coding=utf-8 3 | import os 4 | from glob import glob 5 | 6 | import torch 7 | 8 | 9 | def check_args(args): 10 | args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file) 11 | args.log_file = os.path.join(args.checkpoint_dir, args.log_file) 12 | os.makedirs(args.checkpoint_dir, exist_ok=True) 13 | with open(args.setting_file, 'wt') as opt_file: 14 | opt_file.write('------------ Options -------------\n') 15 | print('------------ Options -------------') 16 | for k in args.__dict__: 17 | v = args.__dict__[k] 18 | opt_file.write('%s: %s\n' % (str(k), str(v))) 19 | print('%s: %s' % (str(k), str(v))) 20 | opt_file.write('-------------- End ----------------\n') 21 | print('------------ End -------------') 22 | 23 | return args 24 | 25 | 26 | def torch_show_all_params(model, rank=0): 27 | params = list(model.parameters()) 28 | k = 0 29 | for i in params: 30 | l = 1 31 | for j in i.size(): 32 | l *= j 33 | k = k + l 34 | if rank == 0: 35 | print("Total param num:" + str(k)) 36 | 37 | 38 | def torch_init_model(model, init_checkpoint, delete_module=False): 39 | state_dict = torch.load(init_checkpoint, map_location='cpu') 40 | state_dict_new = {} 41 | # delete module. 42 | if delete_module: 43 | for key in state_dict.keys(): 44 | v = state_dict[key] 45 | state_dict_new[key.replace('module.', '')] = v 46 | state_dict = state_dict_new 47 | missing_keys = [] 48 | unexpected_keys = [] 49 | error_msgs = [] 50 | # copy state_dict so _load_from_state_dict can modify it 51 | metadata = getattr(state_dict, '_metadata', None) 52 | state_dict = state_dict.copy() 53 | if metadata is not None: 54 | state_dict._metadata = metadata 55 | 56 | def load(module, prefix=''): 57 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 58 | 59 | module._load_from_state_dict( 60 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 61 | for name, child in module._modules.items(): 62 | if child is not None: 63 | load(child, prefix + name + '.') 64 | 65 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.') 66 | 67 | print("missing keys:{}".format(missing_keys)) 68 | print('unexpected keys:{}'.format(unexpected_keys)) 69 | print('error msgs:{}'.format(error_msgs)) 70 | 71 | 72 | def torch_save_model(model, output_dir, scores, max_save_num=1): 73 | # Save model checkpoint 74 | if not os.path.exists(output_dir): 75 | os.makedirs(output_dir) 76 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 77 | saved_pths = glob(os.path.join(output_dir, '*.pth')) 78 | saved_pths.sort() 79 | while len(saved_pths) >= max_save_num: 80 | if os.path.exists(saved_pths[0].replace('//', '/')): 81 | os.remove(saved_pths[0].replace('//', '/')) 82 | del saved_pths[0] 83 | 84 | save_prex = "checkpoint_score" 85 | for k in scores: 86 | save_prex += ('_' + k + '-' + str(scores[k])[:6]) 87 | save_prex += '.pth' 88 | 89 | torch.save(model_to_save.state_dict(), 90 | os.path.join(output_dir, save_prex)) 91 | print("Saving model checkpoint to %s", output_dir) 92 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/config.py: -------------------------------------------------------------------------------- 1 | class Config(): 2 | def __init__(self): 3 | self.device= 'cuda' 4 | self.model_type = '0523_roberta_80k_6tasks' 5 | self.task_type = 'ab' 6 | 7 | self.save_dir = '/data1/wangchenyue/sohu_matching/checkpoints/rematch/' 8 | self.data_dir = '/data1/wangchenyue/sohu_matching/data/sohu2021_open_data/' 9 | self.load_toy_dataset = False 10 | 11 | # self.pretrained = '/data1/wangchenyue/Downloads/chinese-macbert-base/' 12 | # self.pretrained = '/data1/wangchenyue/Downloads/nezha-base-wwm/' 13 | # self.pretrained = '/data1/wangchenyue/Downloads/chinese-roberta-wwm-ext/' 14 | # self.pretrained = '/data1/wangchenyue/Downloads/roberta-base-finetuned-chinanews-chinese/' 15 | # self.pretrained = '/data1/wangchenyue/Downloads/chinese-bert-wwm-ext/' 16 | # self.pretrained = '/data1/wangchenyue/Downloads/roberta-base-word-chinese' 17 | # self.pretrained = '/data1/wangchenyue/Downloads/ernie-1.0/' 18 | self.pretrained = '/data1/wangchenyue/Downloads/DSP/roberta-wwm-rematch/checkpoint-80000/' 19 | 20 | self.epochs = 3 21 | self.lr = 2e-5 22 | self.classifier_lr = 1e-3 23 | self.use_scheduler = True 24 | self.weight_decay = 1e-3 25 | self.num_warmup_steps = 2000 26 | 27 | # for larger models, e.g.roberta-large hidden_size = 1024, otherwise 768 28 | self.hidden_size = 768 29 | # for sbert, train_bs = 16, eval_bs = 32, otherwise 32/64 30 | self.train_bs = 32 31 | self.eval_bs = 64 32 | self.criterion = 'CE' 33 | self.print_every = 50 34 | self.eval_every = 500 35 | 36 | # whether to shffle the order in training data as augmentation 37 | self.shuffle_order = False 38 | self.aug_data = False 39 | # how to clip the long sequences, 'head': using the first sentences, 'tail': using the last sentences 40 | # 'head' is reportedly better than 'tail' 41 | self.clip_method = 'head' 42 | 43 | # whether to use fgm for adversial attack in training 44 | self.use_fgm = False 45 | 46 | # settings for inference 47 | # self.infer_model_dir = '../checkpoints/0502/' 48 | self.infer_model_dir = '/data1/wangchenyue/sohu_matching/checkpoints/rematch/' 49 | self.infer_model_name = '0525_roberta_6tasks_epoch_1_ab_loss' 50 | # fake pretrained model dir containing config.json and vocab.txt, for tokenzier and model initialization 51 | self.dummy_pretrained = '../data/dummy_bert/' 52 | # self.dummy_pretrained = '../data/dummy_ernie/' 53 | # self.dummy_pretrained = '../data/dummy_nezha/' 54 | # infer_task_type should match the last letter in infer_model_name 55 | self.infer_task_type = self.infer_model_name.split('_')[-2] 56 | self.infer_output_dir = '/data1/wangchenyue/sohu_matching/results/rematch/' 57 | self.infer_output_filename = '{}.csv'.format(self.infer_model_name) 58 | self.infer_clip_method = 'head' 59 | # for NEZHA, infer_bs=64, otherwise 256 60 | self.infer_bs = 256 61 | self.infer_fixed_thres_a = 0.45 62 | self.infer_fixed_thres_b = 0.35 63 | self.infer_search_thres = True 64 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from transformers import BertTokenizer, AutoTokenizer 4 | from utils import pad_to_maxlen, augment_data 5 | import pandas as pd 6 | 7 | from tqdm import tqdm 8 | import json 9 | 10 | # the main difference between the two datasets is 11 | # the length limit (512 for one sentence in SBERT 12 | # but for the two concated sentences in BERT setting) 13 | class SentencePairDatasetForSBERT(Dataset): 14 | def __init__(self, file_dir, is_train, tokenizer_config, shuffle_order=False, aug_data=False, len_limit=512, clip='head'): 15 | self.is_train = is_train 16 | self.shuffle_order = shuffle_order 17 | self.aug_data = aug_data 18 | self.total_source_input_ids = [] 19 | # token_types are no longer neccessary if not concat into one text 20 | # self.total_source_input_types = [] 21 | self.total_target_input_ids = [] 22 | # self.total_target_input_types = [] 23 | self.sample_types = [] 24 | 25 | # use AutoTokenzier instead of BertTokenizer to support speice.model (AlbertTokenizer-like) 26 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_config) 27 | lines = [] 28 | for single_file_dir in file_dir: 29 | with open(single_file_dir, 'r', encoding='utf-8') as f_in: 30 | content = f_in.readlines() 31 | for item in content: 32 | line = json.loads(item.strip()) 33 | if not is_train: 34 | line['type'] = 0 if 'a' in line['id'] else 1 35 | lines.append(line) 36 | 37 | content = pd.DataFrame(lines) 38 | content.columns = ['source', 'target', 'label', 'type'] 39 | 40 | # utilize labelB=1-->A positive, labelA=0-->B negative 41 | if self.is_train and self.aug_data: 42 | print("augmenting data...") 43 | content = augment_data(content) 44 | 45 | sources = content['source'].values.tolist() 46 | targets = content['target'].values.tolist() 47 | 48 | self.sample_types = content['type'].values.tolist() 49 | if self.is_train: 50 | self.labels = content['label'].values.tolist() 51 | else: 52 | self.ids = content['label'].values.tolist() 53 | 54 | # shuffle_order is only allowed for training mode 55 | if self.shuffle_order and self.is_train: 56 | sources += content['target'].values.tolist() 57 | targets += content['source'].values.tolist() 58 | self.labels += self.labels 59 | self.sample_types += self.sample_types 60 | 61 | for source, target in tqdm(zip(sources, targets), total=len(sources)): 62 | # tokenize before clipping 63 | source = tokenizer.encode(source)[1:-1] 64 | target = tokenizer.encode(target)[1:-1] 65 | 66 | # clip the sentences if too long 67 | # TODO: different strategies to clip long sequences 68 | if clip == 'head': 69 | if len(source)+2 > len_limit: 70 | source = source[0: len_limit-2] 71 | if len(target)+2 > len_limit: 72 | target = target[0: len_limit-2] 73 | 74 | if clip == 'tail': 75 | if len(source)+2 > len_limit: 76 | source = source[-len_limit+2:] 77 | if len(target)+2 > len_limit: 78 | target = target[-len_limit+2:] 79 | 80 | # check if the length is within the limit 81 | assert len(source)+2 <= len_limit and len(target)+2 <= len_limit 82 | 83 | # [CLS]:101, [SEP]:102 84 | source_input_ids = [101] + source + [102] 85 | target_input_ids = [101] + target + [102] 86 | 87 | assert len(source_input_ids) <= len_limit and len(target_input_ids) <= len_limit 88 | 89 | self.total_source_input_ids.append(source_input_ids) 90 | self.total_target_input_ids.append(target_input_ids) 91 | 92 | self.max_source_input_len = max([len(s) for s in self.total_source_input_ids]) 93 | self.max_target_input_len = max([len(s) for s in self.total_target_input_ids]) 94 | print("max source length: ", self.max_source_input_len) 95 | print("max target length: ", self.max_target_input_len) 96 | 97 | def __len__(self): 98 | return len(self.total_target_input_ids) 99 | 100 | def __getitem__(self, idx): 101 | source_input_ids = pad_to_maxlen(self.total_source_input_ids[idx], self.max_source_input_len) 102 | target_input_ids = pad_to_maxlen(self.total_target_input_ids[idx], self.max_target_input_len) 103 | sample_type = int(self.sample_types[idx]) 104 | 105 | if self.is_train: 106 | label = int(self.labels[idx]) 107 | return torch.LongTensor(source_input_ids), torch.LongTensor(target_input_ids), torch.LongTensor([label]), sample_type 108 | 109 | else: 110 | index = self.ids[idx] 111 | return torch.LongTensor(source_input_ids), torch.LongTensor(target_input_ids), index, sample_type 112 | 113 | class SentencePairDatasetWithType(Dataset): 114 | def __init__(self, file_dir, is_train, tokenizer_config, shuffle_order=False, aug_data=False, len_limit=512, clip='head'): 115 | self.is_train = is_train 116 | self.shuffle_order = shuffle_order 117 | self.aug_data = aug_data 118 | self.total_input_ids = [] 119 | self.total_input_types = [] 120 | self.sample_types = [] 121 | 122 | # use AutoTokenzier instead of BertTokenizer to support speice.model (AlbertTokenizer-like) 123 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_config) 124 | 125 | # read json lines and convert to dict / df 126 | lines = [] 127 | for single_file_dir in file_dir: 128 | with open(single_file_dir, 'r', encoding='utf-8') as f_in: 129 | content = f_in.readlines() 130 | for item in content: 131 | line = json.loads(item.strip()) 132 | 133 | # for final stage, task a and b are included in the same file 134 | if not is_train: 135 | line['type'] = 0 if 'a' in line['id'] else 1 136 | lines.append(line) 137 | print(single_file_dir, len(lines)) 138 | content = pd.DataFrame(lines) 139 | # print(content.head()) 140 | content.columns = ['source', 'target', 'label', 'type'] 141 | 142 | # utilize labelB=1-->A positive, labelA=0-->B negative 143 | if self.is_train and self.aug_data: 144 | print("augmenting data...") 145 | content = augment_data(content) 146 | 147 | sources = content['source'].values.tolist() 148 | targets = content['target'].values.tolist() 149 | 150 | self.sample_types = content['type'].values.tolist() 151 | if self.is_train: 152 | self.labels = content['label'].values.tolist() 153 | else: 154 | self.ids = content['label'].values.tolist() 155 | 156 | # shuffle_order is only allowed for training mode 157 | if self.shuffle_order and self.is_train: 158 | sources += content['target'].values.tolist() 159 | targets += content['source'].values.tolist() 160 | self.labels += self.labels 161 | self.sample_types += self.sample_types 162 | 163 | len_limit_s = (len_limit-3)//2 164 | len_limit_t = (len_limit-3)-len_limit_s 165 | # print('len_limit_s: ', len_limit_s) 166 | # print('len_limit_t: ', len_limit_t) 167 | for source, target in tqdm(zip(sources, targets), total=len(sources)): 168 | # tokenize before clipping 169 | source = tokenizer.encode(source)[1:-1] 170 | target = tokenizer.encode(target)[1:-1] 171 | 172 | # clip the sentences if too long 173 | # TODO: different strategies to clip long sequences 174 | if clip == 'head' and len(source)+len(target)+3 > len_limit: 175 | if len(source)>len_limit_s and len(target)>len_limit_t: 176 | source = source[0:len_limit_s] 177 | target = target[0:len_limit_t] 178 | elif len(source)>len_limit_s: 179 | source = source[0:len_limit-3-len(target)] 180 | elif len(target)>len_limit_t: 181 | target = target[0:len_limit-3-len(source)] 182 | 183 | if clip == 'tail' and len(source)+len(target)+3 > len_limit: 184 | if len(source)>len_limit_s and len(target)>len_limit_t: 185 | source = source[-len_limit_s:] 186 | target = target[-len_limit_t:] 187 | elif len(source)>len_limit_s: 188 | source = source[-(len_limit-3-len(target)):] 189 | elif len(target)>len_limit_t: 190 | target = target[-(len_limit-3-len(source)):] 191 | 192 | # check if the total length is within the limit 193 | assert len(source)+len(target)+3 <= len_limit 194 | 195 | # [CLS]:101, [SEP]:102 196 | input_ids = [101] + source + [102] + target + [102] 197 | input_types = [0]*(len(source)+2) + [1]*(len(target)+1) 198 | 199 | assert len(input_ids) <= len_limit and len(input_types) <= len_limit 200 | self.total_input_ids.append(input_ids) 201 | self.total_input_types.append(input_types) 202 | 203 | self.max_input_len = max([len(s) for s in self.total_input_ids]) 204 | print("max length: ", self.max_input_len) 205 | 206 | def __len__(self): 207 | return len(self.total_input_ids) 208 | 209 | def __getitem__(self, idx): 210 | if self.is_train: 211 | input_ids = pad_to_maxlen(self.total_input_ids[idx], self.max_input_len) 212 | input_types = pad_to_maxlen(self.total_input_types[idx], self.max_input_len) 213 | label = int(self.labels[idx]) 214 | sample_type = int(self.sample_types[idx]) 215 | # print(len(input_ids), len(input_types), label) 216 | return torch.LongTensor(input_ids), torch.LongTensor(input_types), torch.LongTensor([label]), sample_type 217 | 218 | else: 219 | input_ids = pad_to_maxlen(self.total_input_ids[idx], self.max_input_len) 220 | input_types = pad_to_maxlen(self.total_input_types[idx], self.max_input_len) 221 | index = self.ids[idx] 222 | sample_type = int(self.sample_types[idx]) 223 | return torch.LongTensor(input_ids), torch.LongTensor(input_types), index, sample_type 224 | 225 | # NOT CURRENTLY IN USE 226 | # template for the dataset of multiple task types 227 | # compatible with training code by changing task_num 228 | class SentencePairDatasetWithMultiType(Dataset): 229 | def __init__(self, file_dir, is_train, tokenizer_config, shuffle_order=False, aug_data=False, len_limit=512, clip='head'): 230 | self.is_train = is_train 231 | self.shuffle_order = shuffle_order 232 | self.aug_data = aug_data 233 | self.total_input_ids = [] 234 | self.total_input_types = [] 235 | self.sample_types = [] 236 | 237 | # use AutoTokenzier instead of BertTokenizer to support speice.model (AlbertTokenizer-like) 238 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_config) 239 | 240 | # read json lines and convert to dict / df 241 | lines = [] 242 | for single_file_dir in file_dir: 243 | with open(single_file_dir, 'r', encoding='utf-8') as f_in: 244 | content = f_in.readlines() 245 | for item in content: 246 | line = json.loads(item.strip()) 247 | # BUG FIXED, order MATTERS! 248 | # mannually add key 'type' to distinguish the origin of samples 249 | # 0 for A, 1 for B 250 | if 'A' in single_file_dir: 251 | if self.is_train: 252 | line['label'] = line.pop('labelA') 253 | # assign type according to task names 254 | if '短短' in single_file_dir: 255 | line['type'] = 0 256 | elif '短长' in single_file_dir: 257 | line['type'] = 2 258 | else: 259 | line['type'] = 4 260 | else: 261 | if self.is_train: 262 | line['label'] = line.pop('labelB') 263 | # assign type according to task names 264 | if '短短' in single_file_dir: 265 | line['type'] = 1 266 | elif '短长' in single_file_dir: 267 | line['type'] = 3 268 | else: 269 | line['type'] = 5 270 | lines.append(line) 271 | print(single_file_dir, len(lines)) 272 | content = pd.DataFrame(lines) 273 | # print(content.head()) 274 | content.columns = ['source', 'target', 'label', 'type'] 275 | 276 | # utilize labelB=1-->A positive, labelA=0-->B negative 277 | if self.is_train and self.aug_data: 278 | print("augmenting data...") 279 | content = augment_data(content) 280 | 281 | sources = content['source'].values.tolist() 282 | targets = content['target'].values.tolist() 283 | 284 | self.sample_types = content['type'].values.tolist() 285 | if self.is_train: 286 | self.labels = content['label'].values.tolist() 287 | else: 288 | self.ids = content['label'].values.tolist() 289 | 290 | # shuffle_order is only allowed for training mode 291 | if self.shuffle_order and self.is_train: 292 | sources += content['target'].values.tolist() 293 | targets += content['source'].values.tolist() 294 | self.labels += self.labels 295 | self.sample_types += self.sample_types 296 | 297 | len_limit_s = (len_limit-3)//2 298 | len_limit_t = (len_limit-3)-len_limit_s 299 | # print('len_limit_s: ', len_limit_s) 300 | # print('len_limit_t: ', len_limit_t) 301 | for source, target in tqdm(zip(sources, targets), total=len(sources)): 302 | # tokenize before clipping 303 | source = tokenizer.encode(source)[1:-1] 304 | target = tokenizer.encode(target)[1:-1] 305 | 306 | # clip the sentences if too long 307 | # TODO: different strategies to clip long sequences 308 | if clip == 'head' and len(source)+len(target)+3 > len_limit: 309 | if len(source)>len_limit_s and len(target)>len_limit_t: 310 | source = source[0:len_limit_s] 311 | target = target[0:len_limit_t] 312 | elif len(source)>len_limit_s: 313 | source = source[0:len_limit-3-len(target)] 314 | elif len(target)>len_limit_t: 315 | target = target[0:len_limit-3-len(source)] 316 | 317 | if clip == 'tail' and len(source)+len(target)+3 > len_limit: 318 | if len(source)>len_limit_s and len(target)>len_limit_t: 319 | source = source[-len_limit_s:] 320 | target = target[-len_limit_t:] 321 | elif len(source)>len_limit_s: 322 | source = source[-(len_limit-3-len(target)):] 323 | elif len(target)>len_limit_t: 324 | target = target[-(len_limit-3-len(source)):] 325 | 326 | # check if the total length is within the limit 327 | assert len(source)+len(target)+3 <= len_limit 328 | 329 | # [CLS]:101, [SEP]:102 330 | input_ids = [101] + source + [102] + target + [102] 331 | input_types = [0]*(len(source)+2) + [1]*(len(target)+1) 332 | 333 | assert len(input_ids) <= len_limit and len(input_types) <= len_limit 334 | self.total_input_ids.append(input_ids) 335 | self.total_input_types.append(input_types) 336 | 337 | self.max_input_len = max([len(s) for s in self.total_input_ids]) 338 | print("max length: ", self.max_input_len) 339 | 340 | def __len__(self): 341 | return len(self.total_input_ids) 342 | 343 | def __getitem__(self, idx): 344 | if self.is_train: 345 | input_ids = pad_to_maxlen(self.total_input_ids[idx], self.max_input_len) 346 | input_types = pad_to_maxlen(self.total_input_types[idx], self.max_input_len) 347 | label = int(self.labels[idx]) 348 | sample_type = int(self.sample_types[idx]) 349 | # print(len(input_ids), len(input_types), label) 350 | return torch.LongTensor(input_ids), torch.LongTensor(input_types), torch.LongTensor([label]), sample_type 351 | 352 | else: 353 | input_ids = pad_to_maxlen(self.total_input_ids[idx], self.max_input_len) 354 | input_types = pad_to_maxlen(self.total_input_types[idx], self.max_input_len) 355 | index = self.ids[idx] 356 | sample_type = int(self.sample_types[idx]) 357 | return torch.LongTensor(input_ids), torch.LongTensor(input_types), index, sample_type -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/infer.py: -------------------------------------------------------------------------------- 1 | from model import BertClassifierSingleModel, NezhaClassifierSingleModel, SBERTSingleModel, SNEZHASingleModel, BertClassifierTextCNNSingleModel 2 | 3 | from data import SentencePairDatasetWithType, SentencePairDatasetForSBERT 4 | from config import Config 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | 10 | import numpy as np 11 | from sklearn import metrics 12 | from tqdm import tqdm 13 | 14 | import os 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '1,2' 16 | 17 | def infer(model, device, dev_dataloader, test_dataloader, search_thres=True, threshold_fixed_a=0.5, threshold_fixed_b=0.5, save_valid=True): 18 | print("Inferring") 19 | model.eval() 20 | 21 | if torch.cuda.device_count() > 1: 22 | model = torch.nn.DataParallel(model) 23 | 24 | total_gt_a, total_preds_a, total_probs_a = [], [], [] 25 | total_gt_b, total_preds_b, total_probs_b = [], [], [] 26 | 27 | print("Model running on dev set...") 28 | for idx, batch in enumerate(tqdm(dev_dataloader)): 29 | input_ids, input_types, labels, types = batch 30 | input_ids = input_ids.to(device) 31 | input_types = input_types.to(device) 32 | # labels should be flattened 33 | labels = labels.to(device).view(-1) 34 | 35 | with torch.no_grad(): 36 | all_probs = model(input_ids, input_types) 37 | num_tasks = len(all_probs) 38 | 39 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)] 40 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)] 41 | all_labels = [labels[all_masks[task_id]] for task_id in range(num_tasks)] 42 | 43 | all_gt = [all_labels[task_id].cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)] 44 | all_preds = [all_output[task_id].argmax(axis=1).cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)] 45 | 46 | gt_a, preds_a, probs_a = [], [], [] 47 | for task_id in range(0, num_tasks, 2): 48 | gt_a += all_gt[task_id] 49 | preds_a += all_preds[task_id] 50 | probs_a += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()] 51 | 52 | gt_b, preds_b, probs_b = [], [], [] 53 | for task_id in range(1, num_tasks, 2): 54 | gt_b += all_gt[task_id] 55 | preds_b += all_preds[task_id] 56 | probs_b += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()] 57 | 58 | total_gt_a += gt_a 59 | total_preds_a += preds_a 60 | total_probs_a += probs_a 61 | 62 | total_gt_b += gt_b 63 | total_preds_b += preds_b 64 | total_probs_b += probs_b 65 | 66 | if search_thres: 67 | # search for the optimal threshold 68 | print("Searching for the best threshold on valid dataset...") 69 | thresholds = np.arange(0.2, 0.9, 0.01) 70 | fscore_a = np.zeros(shape=(len(thresholds))) 71 | fscore_b = np.zeros(shape=(len(thresholds))) 72 | print('Length of sequence: {}'.format(len(thresholds))) 73 | 74 | print("Original F1 Score for Task A: {}".format(str(metrics.f1_score(total_gt_a, total_preds_a, zero_division=0)))) 75 | if len(total_gt_a) != 0: 76 | print("\tClassification Report\n") 77 | print(metrics.classification_report(total_gt_a, total_preds_a)) 78 | 79 | print("Original F1 Score for Task B: {}".format(str(metrics.f1_score(total_gt_b, total_preds_b, zero_division=0)))) 80 | if len(total_gt_b) != 0: 81 | print("\tClassification Report\n") 82 | print(metrics.classification_report(total_gt_b, total_preds_b)) 83 | 84 | for index, thres in enumerate(tqdm(thresholds)): 85 | y_pred_prob_a = (np.array(total_probs_a) > thres).astype('int') 86 | fscore_a[index] = metrics.f1_score(total_gt_a, y_pred_prob_a.tolist(), zero_division=0) 87 | 88 | y_pred_prob_b = (np.array(total_probs_b) > thres).astype('int') 89 | fscore_b[index] = metrics.f1_score(total_gt_b, y_pred_prob_b.tolist(), zero_division=0) 90 | 91 | # record the optimal threshold for task A 92 | # print(fscore_a) 93 | index_a = np.argmax(fscore_a) 94 | threshold_opt_a = round(thresholds[index_a], ndigits=4) 95 | f1_score_opt_a = round(fscore_a[index_a], ndigits=6) 96 | print('Best Threshold for Task A: {} with F-Score: {}'.format(threshold_opt_a, f1_score_opt_a)) 97 | # print("\nThreshold Classification Report\n") 98 | # print(metrics.classification_report(total_gt_a, (np.array(total_probs_a) > threshold_opt_a).astype('int').tolist())) 99 | 100 | # record the optimal threshold for task B 101 | index_b = np.argmax(fscore_b) 102 | threshold_opt_b = round(thresholds[index_b], ndigits=4) 103 | f1_score_opt_b = round(fscore_b[index_b], ndigits=6) 104 | print('Best Threshold for Task B: {} with F-Score: {}'.format(threshold_opt_b, f1_score_opt_b)) 105 | # print("\nThreshold Classification Report\n") 106 | # print(metrics.classification_report(total_gt_b, (np.array(total_probs_b) > threshold_opt_b).astype('int').tolist())) 107 | 108 | if save_valid: 109 | y_pred_prob_a = (np.array(total_probs_a) > threshold_opt_a).astype('int') 110 | y_pred_prob_b = (np.array(total_probs_b) > threshold_opt_b).astype('int') 111 | # index of valid and valid_rematch 112 | # ssa, sla, lla = y_pred_prob_a[0:3395], y_pred_prob_a[3395:7681], y_pred_prob_a[7681:] 113 | # gt_ssa, gt_sla, gt_lla = total_gt_a[0:3395], total_gt_a[3395:7681], total_gt_a[7681:] 114 | # ssb, slb, llb = y_pred_prob_b[0:3393], y_pred_prob_b[3393:7684], y_pred_prob_b[7684:] 115 | # gt_ssb, gt_slb, gt_llb = total_gt_b[0:3393], total_gt_b[3393:7684], total_gt_b[7684:] 116 | 117 | # valid_rematch only 118 | ssa, sla, lla = y_pred_prob_a[0:1750], y_pred_prob_a[1750:4380], y_pred_prob_a[4380:] 119 | gt_ssa, gt_sla, gt_lla = total_gt_a[0:1750], total_gt_a[1750:4380], total_gt_a[4380:] 120 | ssb, slb, llb = y_pred_prob_b[0:1750], y_pred_prob_b[1750:4385], y_pred_prob_b[4385:] 121 | gt_ssb, gt_slb, gt_llb = total_gt_b[0:1750], total_gt_b[1750:4385], total_gt_b[4385:] 122 | print("f1 on ssa: ", metrics.f1_score(gt_ssa, ssa)) 123 | print("f1 on sla: ", metrics.f1_score(gt_sla, sla)) 124 | print("f1 on lla: ", metrics.f1_score(gt_lla, lla)) 125 | print("f1 on ssb: ", metrics.f1_score(gt_ssb, ssb)) 126 | print("f1 on slb: ", metrics.f1_score(gt_slb, slb)) 127 | print("f1 on llb: ", metrics.f1_score(gt_llb, llb)) 128 | 129 | np.save('../valid_output/{}_pred_a.npy'.format(model_type), y_pred_prob_a) 130 | np.save('../valid_output/{}_pred_b.npy'.format(model_type), y_pred_prob_b) 131 | np.save('../valid_output/gt_a.npy', np.array(total_gt_a)) 132 | np.save('../valid_output/gt_b.npy', np.array(total_gt_b)) 133 | 134 | total_ids_a, total_probs_a = [], [] 135 | total_ids_b, total_probs_b = [], [] 136 | for idx, batch in enumerate(tqdm(test_dataloader)): 137 | input_ids, input_types, ids, types = batch 138 | input_ids = input_ids.to(device) 139 | input_types = input_types.to(device) 140 | 141 | # the probs given by the model, without grads 142 | with torch.no_grad(): 143 | # probs_a, probs_b = model(input_ids, input_types) 144 | # mask_a, mask_b = (types==0).numpy(), (types==1).numpy() 145 | 146 | all_probs = model(input_ids, input_types) 147 | num_tasks = len(all_probs) 148 | # mask_a, mask_b = (types==0).numpy(), (types==1).numpy() 149 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)] 150 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)] 151 | 152 | total_ids_a += [id for id in ids if id.endswith('a')] 153 | total_ids_b += [id for id in ids if id.endswith('b')] 154 | 155 | gt_a, preds_a, probs_a = [], [], [] 156 | for task_id in range(0, num_tasks, 2): 157 | probs_a += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()] 158 | 159 | gt_b, preds_b, probs_b = [], [], [] 160 | for task_id in range(1, num_tasks, 2): 161 | probs_b += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()] 162 | 163 | total_probs_a += probs_a 164 | total_probs_b += probs_b 165 | 166 | # positive if the prob passes the original threshold of 0.5 167 | total_fixed_preds_a = (np.array(total_probs_a) > threshold_fixed_a).astype('int').tolist() 168 | total_fixed_preds_b = (np.array(total_probs_b) > threshold_fixed_b).astype('int').tolist() 169 | 170 | if search_thres: 171 | # positive if the prob passes the optimal threshold 172 | total_preds_a = (np.array(total_probs_a) > threshold_opt_a).astype('int').tolist() 173 | total_preds_b = (np.array(total_probs_b) > threshold_opt_b).astype('int').tolist() 174 | else: 175 | total_preds_a = None 176 | total_preds_b = None 177 | 178 | return total_ids_a, total_preds_a, total_fixed_preds_a, \ 179 | total_ids_b, total_preds_b, total_fixed_preds_b 180 | 181 | if __name__=='__main__': 182 | config = Config() 183 | device = config.device 184 | dummy_pretrained = config.dummy_pretrained 185 | model_type = config.infer_model_name 186 | 187 | save_dir = config.infer_model_dir 188 | model_name = config.infer_model_name 189 | hidden_size = config.hidden_size 190 | output_dir= config.infer_output_dir 191 | output_filename = config.infer_output_filename 192 | data_dir = config.data_dir 193 | task_a = ['短短匹配A类', '短长匹配A类', '长长匹配A类'] 194 | task_b = ['短短匹配B类', '短长匹配B类', '长长匹配B类'] 195 | task_type = config.infer_task_type 196 | 197 | infer_bs = config.infer_bs 198 | search_thres = config.infer_search_thres 199 | threshold_fixed_a = config.infer_fixed_thres_a 200 | threshold_fixed_b = config.infer_fixed_thres_b 201 | # method for clipping long seqeunces, 'head' or 'tail' 202 | clip_method = config.infer_clip_method 203 | 204 | dev_data_dir, test_data_dir = [], [] 205 | if 'a' in task_type: 206 | for task in task_a: 207 | # dev_data_dir.append(data_dir + task + '/valid.txt') 208 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') 209 | test_data_dir.append(data_dir + task + '/test_with_id_rematch.txt') 210 | if 'b' in task_type: 211 | for task in task_b: 212 | # dev_data_dir.append(data_dir + task + '/valid.txt') 213 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') 214 | test_data_dir.append(data_dir + task + '/test_with_id_rematch.txt') 215 | 216 | print("Loading Bert Model from {}...".format(save_dir + model_name)) 217 | # distinguish model architectures or pretrained models according to model_type 218 | if 'sbert' in model_type.lower(): 219 | print("Using SentenceBERT model and dataset") 220 | if 'nezha' in model_type.lower(): 221 | model = SNEZHASingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size) 222 | else: 223 | model = SBERTSingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size) 224 | 225 | model_dict = torch.load(save_dir + model_name) 226 | # model_dict = torch.load(save_dir + model_name) 227 | # weights will be saved in module when DataParallel 228 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()}) 229 | model.to(device) 230 | 231 | print("Loading Dev Data...") 232 | dev_dataset = SentencePairDatasetForSBERT(dev_data_dir, True, dummy_pretrained, clip=clip_method) 233 | dev_dataloader = DataLoader(dev_dataset, batch_size=infer_bs, shuffle=False) 234 | 235 | print("Loading Test Data...") 236 | # for test dataset, is_train should be set to False, thus get ids instead of labels 237 | test_dataset = SentencePairDatasetForSBERT(test_data_dir, False, dummy_pretrained, clip=clip_method) 238 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False) 239 | 240 | else: 241 | print("Using BERT model and dataset") 242 | if 'nezha' in model_type.lower(): 243 | print("Using NEZHA pretrained model") 244 | model = NezhaClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size) 245 | elif 'cnn' in model_type.lower(): 246 | print("Adding TextCNN after BERT output") 247 | model = BertClassifierTextCNNSingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size) 248 | else: 249 | print("Using conventional BERT model with linears") 250 | model = BertClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False, hidden_size=hidden_size) 251 | 252 | model_dict = torch.load(save_dir + model_name) 253 | # weights will be saved in module when DataParallel 254 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()}) 255 | model.to(device) 256 | 257 | # model_dict = torch.load(save_dir + model_name).module.state_dict() 258 | # model.load_state_dict(model_dict) 259 | # model.to(device) 260 | 261 | print("Loading Dev Data...") 262 | dev_dataset = SentencePairDatasetWithType(dev_data_dir, True, dummy_pretrained, clip=clip_method) 263 | dev_dataloader = DataLoader(dev_dataset, batch_size=infer_bs, shuffle=False) 264 | 265 | print("Loading Test Data...") 266 | # for test dataset, is_train should be set to False, thus get ids instead of labels 267 | test_dataset = SentencePairDatasetWithType(test_data_dir, False, dummy_pretrained, clip=clip_method) 268 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False) 269 | 270 | total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer(model, device, dev_dataloader, test_dataloader, search_thres, threshold_fixed_a, threshold_fixed_b) 271 | 272 | with open(output_dir + 'fixed_' + output_filename, 'w') as f_out: 273 | for id, pred in zip(total_ids_a, total_fixed_preds_a): 274 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 275 | for id, pred in zip(total_ids_b, total_fixed_preds_b): 276 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 277 | 278 | if total_preds_a is not None: 279 | with open(output_dir + output_filename, 'w') as f_out: 280 | for id, pred in zip(total_ids_a, total_preds_a): 281 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 282 | for id, pred in zip(total_ids_b, total_preds_b): 283 | f_out.writelines(str(id) + ',' + str(pred) + '\n') -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/infer_final.py: -------------------------------------------------------------------------------- 1 | from model import BertClassifierSingleModel, NezhaClassifierSingleModel, SBERTSingleModel, SNEZHASingleModel, BertClassifierTextCNNSingleModel 2 | from data import SentencePairDatasetWithType, SentencePairDatasetForSBERT 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | from tqdm import tqdm 9 | from argparse import ArgumentParser 10 | 11 | import os 12 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 13 | import time 14 | 15 | def infer_final(model, device, test_dataloader, threshold_fixed_a=0.5, threshold_fixed_b=0.5): 16 | print("Inferring for final stage") 17 | model.eval() 18 | 19 | # as only one GPU is available for the final stage 20 | # if torch.cuda.device_count() > 1: 21 | # model = torch.nn.DataParallel(model) 22 | 23 | total_ids_a, total_probs_a = [], [] 24 | total_ids_b, total_probs_b = [], [] 25 | for idx, batch in enumerate(tqdm(test_dataloader)): 26 | input_ids, input_types, ids, types = batch 27 | input_ids = input_ids.to(device) 28 | input_types = input_types.to(device) 29 | 30 | # the probs given by the model, without grads 31 | with torch.no_grad(): 32 | all_probs = model(input_ids, input_types) 33 | num_tasks = len(all_probs) 34 | 35 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)] 36 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)] 37 | 38 | total_ids_a += [id for id in ids if id.endswith('a')] 39 | total_ids_b += [id for id in ids if id.endswith('b')] 40 | 41 | probs_a, probs_b = [], [] 42 | for task_id in range(0, num_tasks, 2): 43 | probs_a += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()] 44 | 45 | for task_id in range(1, num_tasks, 2): 46 | probs_b += [prob[-1] for prob in nn.functional.softmax(all_output[task_id], dim=1).cpu().numpy().tolist()] 47 | 48 | total_probs_a += probs_a 49 | total_probs_b += probs_b 50 | 51 | # positive if the prob passes the original threshold of 0.5 52 | total_fixed_preds_a = (np.array(total_probs_a) > threshold_fixed_a).astype('int').tolist() 53 | total_fixed_preds_b = (np.array(total_probs_b) > threshold_fixed_b).astype('int').tolist() 54 | 55 | total_preds_a = None 56 | total_preds_b = None 57 | 58 | return total_ids_a, total_preds_a, total_fixed_preds_a, \ 59 | total_ids_b, total_preds_b, total_fixed_preds_b 60 | 61 | if __name__=='__main__': 62 | s_time = time.time() 63 | 64 | parser = ArgumentParser() 65 | parser.add_argument("-i","--input", type=str, required=True, help="输入文件") 66 | parser.add_argument("-o","--output", type=str, required=True, help="输出文件") 67 | args = parser.parse_args() 68 | input_dir = args.input 69 | output_dir= args.output 70 | 71 | device = 'cuda' 72 | data_dir = '../data/sohu2021_open_data/' 73 | save_dir = '../checkpoints/rematch/' 74 | result_dir = '../results/final/' 75 | bert_tokenizer_config = '../data/dummy_bert/' # as NEZHA, MACBERT and ROBERTA share the same tokenizer vocabulary 76 | ernie_tokenizer_config = '../data/dummy_ernie/' # unfortunately, ERNIE has its unique vocabulary, should load dataset again 77 | 78 | # only use test dataloader for final stage 79 | # the test file will be in one file 80 | test_data_dir = [input_dir] 81 | bert_model_configs = [ 82 | # model_name, dummy_pretrained, threshold_fixed_a, threshold_fixed_b, infer_bs 83 | ('0520_roberta_80k_same_lr_zy_epoch_1_ab_loss', '../data/dummy_bert/', 0.4, 0.3, 128), 84 | ('0518_macbert_same_lr_epoch_1_ab_loss', '../data/dummy_bert/', 0.37, 0.39, 128), 85 | ('0523_roberta_dataaug_epoch_0_ab_loss', '../data/dummy_bert/', 0.41, 0.48, 128) 86 | ] 87 | 88 | ernie_model_configs = [ 89 | ('0523_ernie_epoch_1_ab_loss', '../data/dummy_ernie/', 0.42, 0.39, 128), 90 | ] 91 | 92 | sbert_model_configs = [ 93 | ('0520_roberta_sbert_same_lr_epoch_1_ab_loss', '../data/dummy_bert/', 0.4, 0.36, 128) 94 | ] 95 | 96 | # We will first infer for the bert-style models 97 | if len(bert_model_configs) != 0: 98 | print("Loading Test Data for BERT models...") 99 | test_dataset = SentencePairDatasetWithType(test_data_dir, False, bert_tokenizer_config) 100 | 101 | for model_config in bert_model_configs: 102 | model_name, dummy_pretrained, threshold_fixed_a, threshold_fixed_b, infer_bs = model_config 103 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False) 104 | print("Loading Bert Model from {}...".format(save_dir + model_name)) 105 | # distinguish model architectures or pretrained models according to model_type 106 | if 'nezha' in model_name.lower(): 107 | print("Using NEZHA pretrained model") 108 | model = NezhaClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False) 109 | elif 'cnn' in model_name.lower(): 110 | print("Adding TextCNN after BERT output") 111 | model = BertClassifierTextCNNSingleModel(bert_dir=dummy_pretrained, from_pretrained=False) 112 | else: 113 | print("Using conventional BERT model with linears") 114 | model = BertClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False) 115 | 116 | model_dict = torch.load(save_dir + model_name) 117 | # weights will be saved in module when DataParallel 118 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()}) 119 | model.to(device) 120 | 121 | # total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer(model, device, dev_dataloader, test_dataloader, search_thres, threshold_fixed_a, threshold_fixed_b) 122 | total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer_final(model, device, test_dataloader, threshold_fixed_a, threshold_fixed_b) 123 | 124 | with open(result_dir + 'final_' + '{}.csv'.format(model_name), 'w') as f_out: 125 | for id, pred in zip(total_ids_a, total_fixed_preds_a): 126 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 127 | for id, pred in zip(total_ids_b, total_fixed_preds_b): 128 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 129 | 130 | # infer for the ernie models, dataset should be reloaded for ernie's vocabulary 131 | if len(ernie_model_configs) != 0: 132 | print("Loading Test Data for ERNIE models...") 133 | test_dataset = SentencePairDatasetWithType(test_data_dir, False, ernie_tokenizer_config) 134 | 135 | for model_config in ernie_model_configs: 136 | model_name, dummy_pretrained, threshold_fixed_a, threshold_fixed_b, infer_bs = model_config 137 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False) 138 | print("Loading Bert Model from {}...".format(save_dir + model_name)) 139 | # distinguish model architectures or pretrained models according to model_type 140 | if 'nezha' in model_name.lower(): 141 | print("Using NEZHA pretrained model") 142 | model = NezhaClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False) 143 | elif 'cnn' in model_name.lower(): 144 | print("Adding TextCNN after BERT output") 145 | model = BertClassifierTextCNNSingleModel(bert_dir=dummy_pretrained, from_pretrained=False) 146 | else: 147 | print("Using conventional BERT model with linears") 148 | model = BertClassifierSingleModel(bert_dir=dummy_pretrained, from_pretrained=False) 149 | 150 | model_dict = torch.load(save_dir + model_name) 151 | # weights will be saved in module when DataParallel 152 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()}) 153 | model.to(device) 154 | 155 | # total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer(model, device, dev_dataloader, test_dataloader, search_thres, threshold_fixed_a, threshold_fixed_b) 156 | total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer_final(model, device, test_dataloader, threshold_fixed_a, threshold_fixed_b) 157 | 158 | with open(result_dir + 'final_' + '{}.csv'.format(model_name), 'w') as f_out: 159 | for id, pred in zip(total_ids_a, total_fixed_preds_a): 160 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 161 | for id, pred in zip(total_ids_b, total_fixed_preds_b): 162 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 163 | 164 | # infer for SBERT models 165 | if len(sbert_model_configs) != 0: 166 | print("Loading Test Data for SBERT models...") 167 | # for test dataset, is_train should be set to False, thus get ids instead of labels 168 | test_dataset = SentencePairDatasetForSBERT(test_data_dir, False, bert_tokenizer_config) 169 | 170 | for model_config in sbert_model_configs: 171 | model_name, dummy_pretrained, threshold_fixed_a, threshold_fixed_b, infer_bs = model_config 172 | test_dataloader = DataLoader(test_dataset, batch_size=infer_bs, shuffle=False) 173 | print("Loading SentenceBert Model from {}...".format(save_dir + model_name)) 174 | # distinguish model architectures or pretrained models according to model_type 175 | if 'nezha' in model_name.lower(): 176 | model = SNEZHASingleModel(bert_dir=dummy_pretrained, from_pretrained=False) 177 | else: 178 | model = SBERTSingleModel(bert_dir=dummy_pretrained, from_pretrained=False) 179 | 180 | model_dict = torch.load(save_dir + model_name) 181 | # weights will be saved in module when training on multiple GPUs with DataParallel 182 | model.load_state_dict({k.replace('module.','') : v for k, v in model_dict.items()}) 183 | model.to(device) 184 | 185 | # total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer(model, device, dev_dataloader, test_dataloader, search_thres, threshold_fixed_a, threshold_fixed_b) 186 | total_ids_a, total_preds_a, total_fixed_preds_a, total_ids_b, total_preds_b, total_fixed_preds_b = infer_final(model, device, test_dataloader, threshold_fixed_a, threshold_fixed_b) 187 | 188 | with open(result_dir + 'final_' + '{}.csv'.format(model_name), 'w') as f_out: 189 | for id, pred in zip(total_ids_a, total_fixed_preds_a): 190 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 191 | for id, pred in zip(total_ids_b, total_fixed_preds_b): 192 | f_out.writelines(str(id) + ',' + str(pred) + '\n') 193 | 194 | # finally, merge all the output files in output_dir 195 | print("Merging the model outputs...") 196 | result_list = [filename for filename in os.listdir(result_dir) if filename.endswith('.csv')] 197 | result_dict = {} 198 | for name in result_list: 199 | with open(result_dir + name, "r", encoding="utf-8") as fr: 200 | for line in fr: 201 | words = line.strip().split(",") 202 | if words[0] == "id": 203 | continue 204 | if words[0] not in result_dict: 205 | result_dict[words[0]] = [words[1]] 206 | else: 207 | result_dict[words[0]].append(words[1]) 208 | 209 | # merging the outputs into final csv file 210 | with open(output_dir, "w", encoding="utf-8") as fw: 211 | fw.write("id,label"+"\n") 212 | for k, v in result_dict.items(): 213 | tmp = {} 214 | for ele in v: 215 | if ele in tmp: 216 | tmp[ele] += 1 217 | else: 218 | tmp[ele] = 1 219 | tmp = sorted(tmp.items(), key=lambda d: d[1], reverse=True) 220 | fw.write(",".join([k, tmp[0][0]]) + "\n") 221 | 222 | e_time = time.time() 223 | print("Time taken: ", e_time - s_time) -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/merge_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if __name__ == '__main__': 4 | result_dir = '../results/rematch/' 5 | 6 | target_files = [ 7 | '0520_roberta_80k_same_lr_zy_epoch_1_ab_loss', 8 | '0518_macbert_same_lr_epoch_1_ab_loss', 9 | '0520_roberta_sbert_same_lr_epoch_1_ab_loss', 10 | '0523_roberta_dataaug_epoch_0_ab_loss', 11 | '0523_ernie_epoch_1_ab_loss' 12 | ] 13 | # 0.7931380664848722 14 | 15 | # target_files = [ 16 | # '0518_roberta_same_lr_epoch_1_ab_loss', 17 | # '0519_nezha_same_lr_epoch_1_ab_f1', 18 | # '0518_nezha_diff_lr_zy_epoch_1_ab_loss', 19 | # '0518_macbert_same_lr_epoch_1_ab_los', 20 | # '0523_roberta_dataaug_epoch_0_ab_loss' 21 | # ] 22 | # # 0.7930518678397445 23 | 24 | result_list = [file_name+'.csv' for file_name in target_files] 25 | result_dict = {} 26 | for name in result_list: 27 | with open(result_dir + name, "r", encoding="utf-8") as fr: 28 | for line in fr: 29 | words = line.strip().split(",") 30 | if words[0] == "id": 31 | continue 32 | if words[0] not in result_dict: 33 | result_dict[words[0]] = [words[1]] 34 | else: 35 | result_dict[words[0]].append(words[1]) 36 | 37 | with open(result_dir+"merge.csv", "w", encoding="utf-8") as fw: 38 | fw.write("id,label"+"\n") 39 | for k, v in result_dict.items(): 40 | tmp = {} 41 | for ele in v: 42 | if ele in tmp: 43 | tmp[ele] += 1 44 | else: 45 | tmp[ele] = 1 46 | tmp = sorted(tmp.items(), key=lambda d: d[1], reverse=True) 47 | # print(tmp) 48 | fw.write(",".join([k, tmp[0][0]]) + "\n") -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/model.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, BertConfig 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | import math 6 | from data import SentencePairDatasetWithType 7 | 8 | # import files for NEZHA models 9 | from NEZHA.model_nezha import NezhaConfig, NEZHAModel 10 | from NEZHA import nezha_utils 11 | 12 | import os 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 14 | 15 | # basic BERT-like models 16 | class BertClassifierSingleModel(nn.Module): 17 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size = 768, mid_size=512, freeze = False): 18 | super(BertClassifierSingleModel, self).__init__() 19 | self.hidden_size = hidden_size 20 | # could extended to multiple tasks setting, e.g. 6 classifiers for 6 subtasks 21 | self.task_num = task_num 22 | 23 | if from_pretrained: 24 | print("Initialize BERT from pretrained weights") 25 | self.bert = BertModel.from_pretrained(bert_dir) 26 | else: 27 | print("Initialize BERT from config.json, weight NOT loaded") 28 | self.bert_config = BertConfig.from_json_file(bert_dir+'config.json') 29 | self.bert = BertModel(self.bert_config) 30 | self.dropout = nn.Dropout(0.5) 31 | 32 | self.all_classifier = nn.ModuleList([ 33 | nn.Sequential( 34 | nn.Linear(hidden_size, mid_size), 35 | nn.BatchNorm1d(mid_size), 36 | nn.ReLU(), 37 | nn.Dropout(0.5), 38 | nn.Linear(mid_size, 2) 39 | ) 40 | for _ in range(self.task_num) 41 | ]) 42 | 43 | def forward(self, input_ids, input_types): 44 | # get shared BERT model output 45 | mask = torch.ne(input_ids, 0) 46 | bert_output = self.bert(input_ids, token_type_ids=input_types, attention_mask=mask) 47 | cls_embed = bert_output[1] 48 | output = self.dropout(cls_embed) 49 | 50 | # get probs for two tasks A and B 51 | all_probs = [classifier(output) for classifier in self.all_classifier] 52 | return all_probs 53 | 54 | class NezhaClassifierSingleModel(nn.Module): 55 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size = 768, mid_size=512, freeze = False): 56 | super(NezhaClassifierSingleModel, self).__init__() 57 | self.hidden_size = hidden_size 58 | self.task_num = task_num 59 | 60 | self.bert_config = NezhaConfig.from_json_file(bert_dir+'config.json') 61 | self.bert = NEZHAModel(config=self.bert_config) 62 | if from_pretrained: 63 | print("Initialize NEZHA from config.json, weight NOT loaded") 64 | nezha_utils.torch_init_model(self.bert, bert_dir+'pytorch_model.bin') 65 | 66 | self.dropout = nn.Dropout(0.5) 67 | self.all_classifier = nn.ModuleList([ 68 | nn.Sequential( 69 | nn.Linear(hidden_size, mid_size), 70 | nn.BatchNorm1d(mid_size), 71 | nn.ReLU(), 72 | nn.Dropout(0.5), 73 | nn.Linear(mid_size, 2) 74 | ) 75 | for _ in range(self.task_num) 76 | ]) 77 | 78 | def forward(self, input_ids, input_types): 79 | # get shared BERT model output 80 | mask = torch.ne(input_ids, 0) 81 | bert_output = self.bert(input_ids, token_type_ids=input_types, attention_mask=mask) 82 | cls_embed = bert_output[1] 83 | output = self.dropout(cls_embed) 84 | 85 | # get probs for two tasks A and B 86 | all_probs = [classifier(output) for classifier in self.all_classifier] 87 | return all_probs 88 | 89 | class SBERTSingleModel(nn.Module): 90 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size=768, mid_size=512, freeze = False): 91 | super(SBERTSingleModel, self).__init__() 92 | self.hidden_size = hidden_size 93 | self.task_num = task_num 94 | 95 | if from_pretrained: 96 | print("Initialize BERT from pretrained weights") 97 | self.bert = BertModel.from_pretrained(bert_dir) 98 | else: 99 | print("Initialize BERT from config.json, weight NOT loaded") 100 | self.bert_config = BertConfig.from_json_file(bert_dir+'config.json') 101 | self.bert = BertModel(self.bert_config) 102 | 103 | self.dropout = nn.Dropout(0.5) 104 | self.all_classifier = nn.ModuleList([ 105 | nn.Sequential( 106 | nn.Linear(hidden_size*3, mid_size), 107 | nn.BatchNorm1d(mid_size), 108 | nn.ReLU(), 109 | nn.Dropout(0.5), 110 | nn.Linear(mid_size, 2) 111 | ) 112 | for _ in range(self.task_num) 113 | ]) 114 | 115 | def forward(self, source_input_ids, target_input_ids): 116 | # 0 for [PAD], mask out the padded values 117 | source_attention_mask = torch.ne(source_input_ids, 0) 118 | target_attention_mask = torch.ne(target_input_ids, 0) 119 | 120 | # get bert output 121 | source_embedding = self.bert(source_input_ids, attention_mask=source_attention_mask) 122 | target_embedding = self.bert(target_input_ids, attention_mask=target_attention_mask) 123 | 124 | # simply take out the [CLS] represention 125 | # TODO: try different pooling strategies 126 | source_embedding = source_embedding[1] 127 | target_embedding = target_embedding[1] 128 | 129 | # concat the source embedding, target embedding and abs embedding as in the original SBERT paper 130 | abs_embedding = torch.abs(source_embedding-target_embedding) 131 | context_embedding = torch.cat([source_embedding, target_embedding, abs_embedding], -1) 132 | context_embedding = self.dropout(context_embedding) 133 | 134 | # get probs for two tasks A and B 135 | all_probs = [classifier(context_embedding) for classifier in self.all_classifier] 136 | return all_probs 137 | 138 | class SNEZHASingleModel(nn.Module): 139 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size=768, mid_size=512, freeze = False): 140 | super(SNEZHASingleModel, self).__init__() 141 | self.hidden_size = hidden_size 142 | self.task_num = task_num 143 | 144 | self.bert_config = NezhaConfig.from_json_file(bert_dir+'config.json') 145 | self.bert = NEZHAModel(config=self.bert_config) 146 | if from_pretrained: 147 | print("Initialize NEZHA from config.json, weight NOT loaded") 148 | nezha_utils.torch_init_model(self.bert, bert_dir+'pytorch_model.bin') 149 | 150 | self.dropout = nn.Dropout(0.5) 151 | self.all_classifier = nn.ModuleList([ 152 | nn.Sequential( 153 | nn.Linear(hidden_size*3, mid_size), 154 | nn.BatchNorm1d(mid_size), 155 | nn.ReLU(), 156 | nn.Dropout(0.5), 157 | nn.Linear(mid_size, 2) 158 | ) 159 | for _ in range(self.task_num) 160 | ]) 161 | 162 | def forward(self, source_input_ids, target_input_ids): 163 | # 0 for [PAD], mask out the padded values 164 | source_attention_mask = torch.ne(source_input_ids, 0) 165 | target_attention_mask = torch.ne(target_input_ids, 0) 166 | 167 | # get bert output 168 | source_embedding = self.bert(source_input_ids, attention_mask=source_attention_mask) 169 | target_embedding = self.bert(target_input_ids, attention_mask=target_attention_mask) 170 | 171 | # simply take out the [CLS] represention 172 | # TODO: try different pooling strategies 173 | source_embedding = source_embedding[1] 174 | target_embedding = target_embedding[1] 175 | 176 | # concat the source embedding, target embedding and abs embedding as in the original SBERT paper 177 | abs_embedding = torch.abs(source_embedding-target_embedding) 178 | context_embedding = torch.cat([source_embedding, target_embedding, abs_embedding], -1) 179 | context_embedding = self.dropout(context_embedding) 180 | 181 | # get probs for two tasks A and B 182 | all_probs = [classifier(context_embedding) for classifier in self.all_classifier] 183 | return all_probs 184 | 185 | class BertClassifierTextCNNSingleModel(nn.Module): 186 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size = 768, mid_size=512, freeze = False): 187 | super(BertClassifierTextCNNSingleModel, self).__init__() 188 | self.hidden_size = hidden_size 189 | self.task_num = task_num 190 | 191 | if from_pretrained: 192 | print("Initialize BERT from pretrained weights") 193 | self.bert = BertModel.from_pretrained(bert_dir) 194 | else: 195 | print("Initialize BERT from config.json, weight NOT loaded") 196 | self.bert_config = BertConfig.from_json_file(bert_dir+'config.json') 197 | self.bert = BertModel(self.bert_config) 198 | 199 | self.dropout = nn.Dropout(0.5) 200 | 201 | # for TextCNN 202 | filter_num = 128 203 | filter_sizes = [2,3,4] 204 | self.convs = nn.ModuleList( 205 | [nn.Conv2d(1, filter_num, (size, hidden_size)) for size in filter_sizes]) 206 | 207 | self.all_classifier = nn.ModuleList([ 208 | nn.Sequential( 209 | nn.Linear(len(filter_sizes) * filter_num, mid_size), 210 | nn.BatchNorm1d(mid_size), 211 | nn.ReLU(), 212 | nn.Dropout(0.5), 213 | nn.Linear(mid_size, 2) 214 | ) 215 | for _ in range(self.task_num) 216 | ]) 217 | 218 | def forward(self, input_ids, input_types): 219 | # get shared BERT model output 220 | mask = torch.ne(input_ids, 0) 221 | bert_output = self.bert(input_ids, token_type_ids=input_types, attention_mask=mask) 222 | bert_hidden = bert_output[0] 223 | output = self.dropout(bert_hidden) 224 | 225 | tcnn_input = output.unsqueeze(1) 226 | tcnn_output = [nn.functional.relu(conv(tcnn_input)).squeeze(3) for conv in self.convs] 227 | # max pooling in TextCNN 228 | # TODO: support avg pooling 229 | tcnn_output = [nn.functional.max_pool1d(item, item.size(2)).squeeze(2) for item in tcnn_output] 230 | tcnn_output = torch.cat(tcnn_output, 1) 231 | tcnn_output = self.dropout(tcnn_output) 232 | 233 | # get probs for two tasks A and B 234 | all_probs = [classifier(tcnn_output) for classifier in self.all_classifier] 235 | return all_probs 236 | 237 | class BertCoAttention(nn.Module): 238 | def __init__(self, config): 239 | super(BertCoAttention, self).__init__() 240 | if config.hidden_size % config.num_attention_heads != 0: 241 | raise ValueError( 242 | "The hidden size (%d) is not a multiple of the number of attention " 243 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 244 | self.output_attentions = config.output_attentions 245 | 246 | self.num_attention_heads = config.num_attention_heads 247 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 248 | self.all_head_size = self.num_attention_heads * self.attention_head_size 249 | 250 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 251 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 252 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 253 | 254 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 255 | 256 | def transpose_for_scores(self, x): 257 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 258 | x = x.view(*new_x_shape) 259 | return x.permute(0, 2, 1, 3) 260 | 261 | def forward(self, context_states, query_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, 262 | encoder_attention_mask=None): 263 | mixed_query_layer = self.query(query_states) 264 | 265 | extended_attention_mask = attention_mask[:, None, None, :] 266 | extended_attention_mask = extended_attention_mask.float() # fp16 compatibility 267 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 268 | attention_mask = extended_attention_mask 269 | 270 | # If this is instantiated as a cross-attention module, the keys 271 | # and values come from an encoder; the attention mask needs to be 272 | # such that the encoder's padding tokens are not attended to. 273 | if encoder_hidden_states is not None: 274 | mixed_key_layer = self.key(encoder_hidden_states) 275 | mixed_value_layer = self.value(encoder_hidden_states) 276 | attention_mask = encoder_attention_mask 277 | else: 278 | mixed_key_layer = self.key(context_states) 279 | mixed_value_layer = self.value(context_states) 280 | 281 | query_layer = self.transpose_for_scores(mixed_query_layer) 282 | key_layer = self.transpose_for_scores(mixed_key_layer) 283 | value_layer = self.transpose_for_scores(mixed_value_layer) 284 | 285 | # Take the dot product between "query" and "key" to get the raw attention scores. 286 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 287 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 288 | if attention_mask is not None: 289 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 290 | attention_scores = attention_scores + attention_mask 291 | 292 | # Normalize the attention scores to probabilities. 293 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 294 | 295 | # This is actually dropping out entire tokens to attend to, which might 296 | # seem a bit unusual, but is taken from the original Transformer paper. 297 | attention_probs = self.dropout(attention_probs) 298 | 299 | # Mask heads if we want to 300 | if head_mask is not None: 301 | attention_probs = attention_probs * head_mask 302 | 303 | context_layer = torch.matmul(attention_probs, value_layer) 304 | 305 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 306 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 307 | context_layer = context_layer.view(*new_context_layer_shape) 308 | 309 | # outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) 310 | outputs = context_layer 311 | return outputs 312 | 313 | class SBERTCoAttentionModel(nn.Module): 314 | def __init__(self, bert_dir, from_pretrained=True, task_num=2, hidden_size=768, mid_size=512, freeze = False): 315 | super(SBERTCoAttentionModel, self).__init__() 316 | self.hidden_size = hidden_size 317 | self.task_num = task_num 318 | 319 | if from_pretrained: 320 | print("Initialize BERT from pretrained weights") 321 | self.bert = BertModel.from_pretrained(bert_dir) 322 | else: 323 | print("Initialize BERT from config.json, weight NOT loaded") 324 | self.bert_config = BertConfig.from_json_file(bert_dir+'config.json') 325 | self.bert = BertModel(self.bert_config) 326 | 327 | self.dropout = nn.Dropout(0.5) 328 | self.co_attention = BertCoAttention(hidden_size=hidden_size) 329 | self.all_classifier = nn.ModuleList([ 330 | nn.Sequential( 331 | nn.Linear(hidden_size * 3, mid_size), 332 | nn.BatchNorm1d(mid_size), 333 | nn.ReLU(), 334 | nn.Dropout(0.5), 335 | nn.Linear(mid_size, 2) 336 | ) 337 | for _ in range(self.task_num) 338 | ]) 339 | 340 | def forward(self, source_input_ids, target_input_ids): 341 | # 0 for [PAD], mask out the padded values 342 | source_attention_mask = torch.ne(source_input_ids, 0) 343 | target_attention_mask = torch.ne(target_input_ids, 0) 344 | 345 | # get bert output 346 | source_embedding = self.bert(source_input_ids, attention_mask=source_attention_mask) 347 | target_embedding = self.bert(target_input_ids, attention_mask=target_attention_mask) 348 | 349 | source_coattention_outputs = self.co_attention(target_embedding[0], source_embedding[0], source_attention_mask) 350 | target_coattention_outputs = self.co_attention(source_embedding[0], target_embedding[0], target_attention_mask) 351 | source_coattention_embedding = source_coattention_outputs[:, 0, :] 352 | target_coattention_embedding = target_coattention_outputs[:, 0, :] 353 | 354 | # simply take out the [CLS] represention 355 | # TODO: try different pooling strategies 356 | # source_embedding = source_embedding[1] 357 | # target_embedding = target_embedding[1] 358 | 359 | # concat the source embedding, target embedding and abs embedding as in the original SBERT paper 360 | # we also add a coattention embedding as the forth embedding 361 | abs_embedding = torch.abs(source_coattention_embedding - target_coattention_embedding) 362 | context_embedding = torch.cat([source_coattention_embedding, target_coattention_embedding, abs_embedding], -1) 363 | context_embedding = self.dropout(context_embedding) 364 | 365 | # get probs for two tasks A and B 366 | all_probs = [classifier(context_embedding) for classifier in self.all_classifier] 367 | return all_probs -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/search_better_merge.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import f1_score 3 | from itertools import combinations 4 | from tqdm import tqdm 5 | # from collections import defaultdict 6 | 7 | def merge_on_valid(model_names, verbose=False): 8 | # len of valid_a: 4971, len of valid_b: 4969 9 | # total_preds_a, total_preds_b = [0]*4971, [0]*4969 10 | total_preds_a, total_preds_b = [0]*6911, [0]*6914 11 | # positive if the vote exceeds the threshold (>) 12 | threshold = len(model_names)/2 13 | for model in model_names: 14 | # print("processing model {}".format(model)) 15 | preds_a, preds_b = np.load('{}_pred_a.npy'.format(model)), np.load('{}_pred_b.npy'.format(model)) 16 | preds_a, preds_b = preds_a.tolist(), preds_b.tolist() 17 | assert len(total_preds_a)==len(preds_a) and len(total_preds_b)==len(preds_b) 18 | for idx, pred_a in enumerate(preds_a): 19 | total_preds_a[idx] += pred_a 20 | for idx, pred_b in enumerate(preds_b): 21 | total_preds_b[idx] += pred_b 22 | # print(len(preds_a), len(preds_b)) 23 | # print(type(preds_b)) 24 | 25 | total_preds_a, total_preds_b = np.array(total_preds_a), np.array(total_preds_b) 26 | vote_a, vote_b = (total_preds_a>threshold).astype('int'), (total_preds_b>threshold).astype('int') 27 | gt_a, gt_b = np.load(valid_dir + 'gt_a.npy'), np.load(valid_dir + 'gt_b.npy') 28 | # print(len(vote_a), len(vote_b)) 29 | # print(len(gt_a), len(gt_b)) 30 | 31 | f1a, f1b = f1_score(gt_a, vote_a), f1_score(gt_b, vote_b) 32 | ssa, ssb = f1_score(gt_a[:1750], vote_a[:1750]), f1_score(gt_b[:1750], vote_b[:1750]) 33 | sla, slb = f1_score(gt_a[1750:4380], vote_a[1750:4380]), f1_score(gt_b[1750:4385], vote_b[1750:4385]) 34 | lla, llb = f1_score(gt_a[4380:], vote_a[4380:]), f1_score(gt_b[4385:], vote_b[4385:]) 35 | 36 | if verbose: 37 | print("f1a: {}, f1b: {}".format(f1a, f1b)) 38 | print("ssa: {}, ssb: {}".format(ssa, ssb)) 39 | print("sla: {}, slb: {}".format(sla, slb)) 40 | print("lla: {}, llb: {}".format(lla, llb)) 41 | 42 | return f1a, f1b, ssa, ssb, sla, slb, lla, llb 43 | 44 | if __name__ == '__main__': 45 | valid_dir = '../valid_output/' 46 | total_model_names = [ 47 | '0518_roberta_same_lr_epoch_1_ab_loss', 48 | '0520_roberta_diff_lr_epoch_1_ab_loss', 49 | '0520_roberta_tcnn_diff_lr_epoch_1_ab_loss', 50 | '0520_roberta_80k_same_lr_zy_epoch_1_ab_loss', 51 | '0522_roberta_80k_fl_epoch_1_ab_loss', 52 | '0519_nezha_same_lr_epoch_1_ab_f1', 53 | '0519_nezha_same_lr_epoch_0_ab_loss', 54 | '0518_nezha_diff_lr_zy_epoch_1_ab_loss', 55 | '0518_macbert_same_lr_epoch_1_ab_loss', 56 | '0520_macbert_sbert_same_lr_epoch_1_ab_loss', 57 | '0520_roberta_sbert_same_lr_epoch_1_ab_loss', 58 | '0522_roberta_80k_tcnn_epoch_1_ab_loss', 59 | '0523_roberta_dataaug_epoch_0_ab_loss', 60 | '0523_ernie_epoch_1_ab_loss' 61 | ] 62 | total_model_dir = [valid_dir + model_name for model_name in total_model_names] 63 | f1a, f1b, *_ = merge_on_valid(total_model_dir) 64 | print("total merge: f1 {}, f1a {}, f1b {}".format(((f1a+f1b)/2), f1a, f1b)) 65 | print() 66 | 67 | for size in [3,5,7,9,11]: 68 | print("searching the best merge of {} models".format(size)) 69 | records = [] 70 | combs = combinations(total_model_dir, size) 71 | best_f1 = 0 72 | best_comb = None 73 | for comb in tqdm(combs): 74 | f1a, f1b, *_ = merge_on_valid(list(comb)) 75 | if (f1a + f1b)/2 > best_f1: 76 | best_f1 = (f1a + f1b)/2 77 | best_comb = comb 78 | records.append((list(comb), (f1a+f1b)/2)) 79 | print("best f1 and model list:") 80 | print(best_f1, best_comb) 81 | merge_on_valid(list(best_comb), True) 82 | 83 | print("top5 candidates list:") 84 | records.sort(key=lambda x:x[-1], reverse=True) 85 | for i in range(5): 86 | print(records[i]) 87 | print() -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/train.py: -------------------------------------------------------------------------------- 1 | from model import BertClassifierSingleModel, NezhaClassifierSingleModel, SBERTSingleModel, SNEZHASingleModel, BertClassifierTextCNNSingleModel 2 | from data import SentencePairDatasetWithType, SentencePairDatasetForSBERT, SentencePairDatasetWithMultiType 3 | from utils import focal_loss, FGM 4 | from transformers import AdamW, get_linear_schedule_with_warmup 5 | from config import Config 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | import numpy as np 12 | from sklearn import metrics 13 | from tensorboardX import SummaryWriter 14 | 15 | import os 16 | # os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6,7' # recommended for NEZHA 17 | # os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7' 18 | os.environ['CUDA_VISIBLE_DEVICES'] = '5,6,7' 19 | 20 | def train(model, device, epoch, train_dataloader, test_dataloader, save_dir, optimizer, scheduler=None, criterion_type='CE', model_type='bert', print_every=100, eval_every=500, writer=None, use_fgm=False): 21 | print("Training at epoch {}".format(epoch)) 22 | if use_fgm: 23 | print("Using fgm for adversial attack") 24 | 25 | est_batch = len(train_dataloader.dataset) / (train_dataloader.batch_size) 26 | model.train() 27 | 28 | # for multiple GPU support 29 | model = torch.nn.DataParallel(model) 30 | 31 | assert criterion_type == 'CE' or criterion_type == 'FL' 32 | if criterion_type == 'CE': 33 | criterion = nn.CrossEntropyLoss() 34 | elif criterion_type == 'FL': 35 | criterion = focal_loss() 36 | 37 | if use_fgm: 38 | fgm = FGM(model) 39 | 40 | total_loss = [] 41 | total_gt_a, total_preds_a = [], [] 42 | total_gt_b, total_preds_b = [], [] 43 | for idx, batch in enumerate(train_dataloader): 44 | # for SentencePairDatasetWithType, types would be returned 45 | input_ids, input_types, labels, types = batch 46 | input_ids = input_ids.to(device) 47 | input_types = input_types.to(device) 48 | # labels should be flattened 49 | labels = labels.to(device).view(-1) 50 | 51 | optimizer.zero_grad() 52 | 53 | # the probs given by the model 54 | all_probs = model(input_ids, input_types) 55 | num_tasks = len(all_probs) 56 | 57 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)] 58 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)] 59 | all_labels = [labels[all_masks[task_id]] for task_id in range(num_tasks)] 60 | 61 | # calculate the loss and BP 62 | # TODO: different weights for each task? 63 | all_loss = None 64 | for task_id in range(num_tasks): 65 | if all_masks[task_id].sum() != 0: 66 | if all_loss is None: 67 | all_loss = criterion(all_output[task_id], all_labels[task_id]) 68 | else: 69 | all_loss += criterion(all_output[task_id], all_labels[task_id]) 70 | all_loss.backward() 71 | 72 | # code for fgm adversial training 73 | if use_fgm: 74 | fgm.attack() 75 | # adv_probs_a, adv_probs_b = model(input_ids, input_types) 76 | adv_all_probs = model(input_ids, input_types) 77 | adv_all_output = [adv_all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)] 78 | # calculate the loss and BP 79 | adv_all_loss = None 80 | for task_id in range(num_tasks): 81 | if all_masks[task_id].sum() != 0: 82 | if adv_all_loss is None: 83 | adv_all_loss = criterion(adv_all_output[task_id], all_labels[task_id]) 84 | else: 85 | adv_all_loss += criterion(adv_all_output[task_id], all_labels[task_id]) 86 | adv_all_loss.backward() 87 | fgm.restore() 88 | 89 | optimizer.step() 90 | if scheduler is not None: 91 | scheduler.step() 92 | 93 | 94 | all_gt = [all_labels[task_id].cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)] 95 | all_preds = [all_output[task_id].argmax(axis=1).cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)] 96 | 97 | gt_a, preds_a = [], [] 98 | for task_id in range(0, num_tasks, 2): 99 | gt_a += all_gt[task_id] 100 | preds_a += all_preds[task_id] 101 | 102 | gt_b, preds_b = [], [] 103 | for task_id in range(1, num_tasks, 2): 104 | gt_b += all_gt[task_id] 105 | preds_b += all_preds[task_id] 106 | 107 | total_preds_a += preds_a 108 | total_gt_a += gt_a 109 | total_preds_b += preds_b 110 | total_gt_b += gt_b 111 | total_loss.append(all_loss.item()) 112 | # print('a', preds_a, gt_a) 113 | # print('b', preds_b, gt_b) 114 | 115 | acc_a = metrics.accuracy_score(gt_a, preds_a) if len(gt_a)!=0 else 0 116 | f1_a = metrics.f1_score(gt_a, preds_a, zero_division=0) 117 | acc_b = metrics.accuracy_score(gt_b, preds_b) if len(gt_b)!=0 else 0 118 | f1_b = metrics.f1_score(gt_b, preds_b, zero_division=0) 119 | 120 | # learning rate for bert is the second (the last) parameter group 121 | writer.add_scalar('train/learning_rate', optimizer.param_groups[-1]['lr'], global_step=epoch*est_batch+idx) 122 | writer.add_scalar('train/loss', all_loss.item(), global_step=epoch*est_batch+idx) 123 | writer.add_scalar('train/acc_a', acc_a, global_step=epoch*est_batch+idx) 124 | writer.add_scalar('train/acc_b', acc_b, global_step=epoch*est_batch+idx) 125 | writer.add_scalar('train/f1_a', f1_a, global_step=epoch*est_batch+idx) 126 | writer.add_scalar('train/f1_b', f1_b, global_step=epoch*est_batch+idx) 127 | 128 | # print the loss and accuracy score if reach print_every 129 | if (idx+1) % print_every == 0: 130 | print("\tBatch: {} / {:.0f}, Loss: {:.6f}".format(idx, est_batch, all_loss.item())) 131 | print("\t\t Task A\tAcc: {:.6f}, F1: {:.6f}".format(acc_a, f1_a)) 132 | print("\t\t Task B\tAcc: {:.6f}, F1: {:.6f}".format(acc_b, f1_b)) 133 | 134 | # evaluate the model if reach eval_every, instead of evaluate after the whole epoch 135 | global best_dev_loss, best_dev_f1 136 | if (idx+1) % eval_every == 0: 137 | dev_loss, dev_acc_a, dev_acc_b, dev_f1_a, dev_f1_b = eval(model, device, test_dataloader, criterion_type) 138 | dev_f1 = (dev_f1_a + dev_f1_b) / 2 139 | writer.add_scalar('eval/loss', dev_loss, global_step=epoch*est_batch+idx) 140 | writer.add_scalar('eval/acc_a', dev_acc_a, global_step=epoch*est_batch+idx) 141 | writer.add_scalar('eval/acc_b', dev_acc_b, global_step=epoch*est_batch+idx) 142 | writer.add_scalar('eval/f1_a', dev_f1_a, global_step=epoch*est_batch+idx) 143 | writer.add_scalar('eval/f1_b', dev_f1_b, global_step=epoch*est_batch+idx) 144 | # in practice, better loss is preferred instead of better f1 score, 145 | # which could be resulted from random overfitting on the valid set 146 | # 0517: save the model's state_dict instead of the whole model (mainly for NEZHA's sake) 147 | if (dev_loss < best_dev_loss or dev_f1 > best_dev_f1): 148 | if dev_loss < best_dev_loss: 149 | best_dev_loss = dev_loss 150 | torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'loss') 151 | print("----------BETTER LOSS, MODEL SAVED-----------") 152 | if dev_f1 > best_dev_f1: 153 | best_dev_f1 = dev_f1 154 | torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'f1') 155 | print("----------BETTER F1, MODEL SAVED-----------") 156 | 157 | loss = np.array(total_loss).mean() 158 | # Setting average=None to return class-specific scores 159 | # 0502 BUG FIXED: do not use 'macro', DO NOT require class-specific metrics! 160 | # macro_f1 = metrics.f1_score(total_gt, total_preds, average='macro') 161 | f1_a = metrics.f1_score(total_gt_a, total_preds_a, zero_division=0) 162 | f1_b = metrics.f1_score(total_gt_b, total_preds_b, zero_division=0) 163 | f1 = (f1_a + f1_b) / 2 164 | print("Average f1 on training set: {:.6f}, f1_a: {:.6f}, f1_b: {:.6f}".format(f1, f1_a, f1_b)) 165 | 166 | return loss, f1, f1_a, f1_b 167 | 168 | 169 | def eval(model, device, test_dataloader, criterion_type='CE'): 170 | print("Evaluating") 171 | model.eval() 172 | # if called while training, then model parallel is already done 173 | # model = torch.nn.DataParallel(model) 174 | 175 | assert criterion_type == 'CE' or criterion_type == 'FL' 176 | if criterion_type == 'CE': 177 | criterion = nn.CrossEntropyLoss() 178 | elif criterion_type == 'FL': 179 | criterion = focal_loss() 180 | 181 | total_loss = [] 182 | total_gt_a, total_preds_a = [], [] 183 | total_gt_b, total_preds_b = [], [] 184 | 185 | for idx, batch in enumerate(test_dataloader): 186 | input_ids, input_types, labels, types = batch 187 | input_ids = input_ids.to(device) 188 | input_types = input_types.to(device) 189 | # labels should be flattened 190 | labels = labels.to(device).view(-1) 191 | 192 | # the probs given by the model, without grads 193 | with torch.no_grad(): 194 | # the probs given by the model 195 | # probs_a, probs_b = model(input_ids, input_types) 196 | # mask_a, mask_b = (types==0).numpy(), (types==1).numpy() 197 | # output_a, labels_a = probs_a[mask_a], labels[mask_a] 198 | # output_b, labels_b = probs_b[mask_b], labels[mask_b] 199 | 200 | all_probs = model(input_ids, input_types) 201 | num_tasks = len(all_probs) 202 | 203 | # mask_a, mask_b = (types==0).numpy(), (types==1).numpy() 204 | all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)] 205 | all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)] 206 | all_labels = [labels[all_masks[task_id]] for task_id in range(num_tasks)] 207 | 208 | all_loss = None 209 | for task_id in range(num_tasks): 210 | # print(task_id, all_masks[task_id]) 211 | if all_masks[task_id].sum() != 0: 212 | if all_loss is None: 213 | all_loss = criterion(all_output[task_id], all_labels[task_id]) 214 | else: 215 | all_loss += criterion(all_output[task_id], all_labels[task_id]) 216 | 217 | all_gt = [all_labels[task_id].cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)] 218 | all_preds = [all_output[task_id].argmax(axis=1).cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)] 219 | 220 | gt_a, preds_a = [], [] 221 | for task_id in range(0, num_tasks, 2): 222 | gt_a += all_gt[task_id] 223 | preds_a += all_preds[task_id] 224 | 225 | gt_b, preds_b = [], [] 226 | for task_id in range(1, num_tasks, 2): 227 | gt_b += all_gt[task_id] 228 | preds_b += all_preds[task_id] 229 | 230 | total_preds_a += preds_a 231 | total_gt_a += gt_a 232 | total_preds_b += preds_b 233 | total_gt_b += gt_b 234 | total_loss.append(all_loss.item()) 235 | 236 | loss = np.array(total_loss).mean() 237 | acc_a = metrics.accuracy_score(total_gt_a, total_preds_a) if len(total_gt_a)!=0 else 0 238 | f1_a = metrics.f1_score(total_gt_a, total_preds_a, zero_division=0) 239 | if (f1_a == 0): 240 | print("F1_a = 0, checking precision, recall, fscore and support...") 241 | print(metrics.precision_recall_fscore_support(total_gt_a, total_preds_a, zero_division=0)) 242 | 243 | acc_b = metrics.accuracy_score(total_gt_b, total_preds_b) if len(total_gt_b)!=0 else 0 244 | f1_b = metrics.f1_score(total_gt_b, total_preds_b, zero_division=0) 245 | if (f1_b == 0): 246 | print("F1_b = 0, checking precision, recall, fscore and support...") 247 | print(metrics.precision_recall_fscore_support(total_gt_b, total_preds_b, zero_division=0)) 248 | 249 | # Setting average=None to return class-specific scores 250 | # macro_f1 = metrics.f1_score(total_gt, total_preds, average='macro') 251 | # f1 = metrics.f1_score(total_gt, total_preds) 252 | 253 | # print loss and classification report 254 | print("Loss on dev set: ", loss) 255 | print("F1 on dev set: {:.6f}, f1_a: {:.6f}, f1_b: {:.6f}".format((f1_a+f1_b)/2, f1_a, f1_b)) 256 | 257 | # return loss, acc, macro_f1 258 | return loss, acc_a, acc_b, f1_a, f1_b 259 | 260 | 261 | if __name__ == '__main__': 262 | config = Config() 263 | device = config.device 264 | pretrained = config.pretrained 265 | model_type = config.model_type 266 | use_fgm = config.use_fgm 267 | 268 | save_dir = config.save_dir 269 | data_dir = config.data_dir 270 | # whether to shuffle the pos of source and target to augment data 271 | shuffle_order = config.shuffle_order 272 | # whether to use the positive case in task b for task a (positives) 273 | # and to use the negativate case in task a for task b (negatives) 274 | aug_data = config.aug_data 275 | # method for clipping long seqeunces, 'head' or 'tail' 276 | clip_method = config.clip_method 277 | 278 | task_type = config.task_type 279 | task_a = ['短短匹配A类', '短长匹配A类', '长长匹配A类'] 280 | task_b = ['短短匹配B类', '短长匹配B类', '长长匹配B类'] 281 | 282 | # hypter parameters here 283 | epochs = config.epochs 284 | lr = config.lr 285 | classifer_lr = config.classifier_lr 286 | weight_decay = config.weight_decay 287 | hidden_size = config.hidden_size 288 | train_bs = config.train_bs 289 | eval_bs = config.eval_bs 290 | 291 | print_every = config.print_every 292 | eval_every = config.eval_every 293 | 294 | train_data_dir, dev_data_dir = [], [] 295 | # integrate the two tasks into one dataset using task_type = 'ab' 296 | if 'a' in task_type: 297 | for task in task_a: 298 | train_data_dir.append(data_dir + task + '/train.txt') 299 | train_data_dir.append(data_dir + task + '/train_r2.txt') 300 | train_data_dir.append(data_dir + task + '/train_r3.txt') 301 | train_data_dir.append(data_dir + task + '/train_rematch.txt') # training file in rematch 302 | dev_data_dir.append(data_dir + task + '/valid.txt') 303 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') # valid file in rematch 304 | 305 | if 'b' in task_type: 306 | for task in task_b: 307 | train_data_dir.append(data_dir + task + '/train.txt') 308 | train_data_dir.append(data_dir + task + '/train_r2.txt') 309 | train_data_dir.append(data_dir + task + '/train_r3.txt') 310 | train_data_dir.append(data_dir + task + '/train_rematch.txt') # training file in rematch 311 | dev_data_dir.append(data_dir + task + '/valid.txt') 312 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') # valid file in rematch 313 | 314 | # toy dataset for testing 315 | if config.load_toy_dataset: 316 | # train_data_dir = ['../data/sohu2021_open_data/短短匹配A类/train.txt', 317 | # '../data/sohu2021_open_data/短短匹配B类/train.txt'] 318 | train_data_dir = [ 319 | '../data/sohu2021_open_data/短短匹配A类/valid.txt', 320 | '../data/sohu2021_open_data/短短匹配B类/valid.txt', 321 | '../data/sohu2021_open_data/短长匹配A类/valid.txt', 322 | '../data/sohu2021_open_data/短长匹配B类/valid.txt', 323 | '../data/sohu2021_open_data/长长匹配A类/valid.txt', 324 | '../data/sohu2021_open_data/长长匹配B类/valid.txt'] 325 | dev_data_dir = ['../data/sohu2021_open_data/短短匹配A类/valid.txt', 326 | '../data/sohu2021_open_data/短短匹配B类/valid.txt',] 327 | dev_data_dir = train_data_dir 328 | 329 | # if config.load_toy_dataset: 330 | # train_data_dir = ['../data/sohu2021_open_data/长长匹配A类/train.txt'] 331 | # dev_data_dir = ['../data/sohu2021_open_data/长长匹配A类/valid.txt'] 332 | 333 | print("Loading pretrained Model from {}...".format(pretrained)) 334 | # integrating SBERT model into a unified training framework 335 | if 'sbert' in model_type.lower(): 336 | print("Using SentenceBERT model and dataset") 337 | if 'nezha' in model_type.lower(): 338 | model = SNEZHASingleModel(bert_dir=pretrained, hidden_size=hidden_size) 339 | else: 340 | model = SBERTSingleModel(bert_dir=pretrained, hidden_size=hidden_size) 341 | model.to(device) 342 | print("Loading Training Data...") 343 | print(train_data_dir) 344 | # augment the data with shuffle_order=True (changing order of source and target) 345 | # or with aug_data=True (neg cases in A -> B, pos cases in B -> A) 346 | train_dataset = SentencePairDatasetForSBERT(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method) 347 | train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True) 348 | 349 | print("Loading Dev Data...") 350 | test_dataset = SentencePairDatasetForSBERT(dev_data_dir, True, pretrained, clip=clip_method) 351 | test_dataloader = DataLoader(test_dataset, batch_size=eval_bs, shuffle=False) 352 | 353 | # 0517: for training, load weights from pretrained with from_pretrained=True (by default) 354 | # for larger model, adjust the hidden_size according to its config 355 | # distinguish model architectures or pretrained models according to model_type 356 | else: 357 | print("Using BERT model and dataset") 358 | if 'nezha' in model_type.lower(): 359 | print("Using NEZHA pretrained model") 360 | model = NezhaClassifierSingleModel(bert_dir=pretrained, hidden_size=hidden_size) 361 | elif 'cnn' in model_type.lower(): 362 | print("Adding TextCNN after BERT output") 363 | model = BertClassifierTextCNNSingleModel(bert_dir=pretrained, hidden_size=hidden_size) 364 | else: 365 | print("Using conventional BERT model with linears") 366 | # model = BertClassifierSingleModel(bert_dir=pretrained, hidden_size=hidden_size) 367 | model = BertClassifierSingleModel(bert_dir=pretrained, task_num=6, hidden_size=hidden_size) 368 | model.to(device) 369 | 370 | print("Loading Training Data...") 371 | print(train_data_dir) 372 | # augment the data with shuffle_order=True (changing order of source and target) 373 | # or with aug_data=True (neg cases in A -> B, pos cases in B -> A) 374 | # train_dataset = SentencePairDatasetWithType(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method) 375 | train_dataset = SentencePairDatasetWithMultiType(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method) 376 | train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True) 377 | 378 | print("Loading Dev Data...") 379 | # test_dataset = SentencePairDatasetWithType(dev_data_dir, True, pretrained, clip=clip_method) 380 | test_dataset = SentencePairDatasetWithMultiType(dev_data_dir, True, pretrained, clip=clip_method) 381 | test_dataloader = DataLoader(test_dataset, batch_size=eval_bs, shuffle=False) 382 | 383 | optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, correct_bias=False) 384 | # 0514 setting different lr for bert encoder and classifier 385 | # TODO: verify the large lr works for classifiers 386 | # optimizer = AdamW([ 387 | # {"params": model.all_classifier.parameters(), "lr": classifer_lr}, 388 | # {"params": model.bert.parameters()}], 389 | # lr=lr) 390 | 391 | # for p in optimizer.param_groups: 392 | # outputs = '' 393 | # for k, v in p.items(): 394 | # if k is 'params': 395 | # outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ') 396 | # else: 397 | # outputs += (k + ': ' + str(v).ljust(10) + ' ') 398 | # print(outputs) 399 | 400 | total_steps = len(train_dataloader) * epochs 401 | 402 | # TODO: using ReduceLROnPlateau instead of linear scheduler 403 | if config.use_scheduler: 404 | scheduler = get_linear_schedule_with_warmup( 405 | optimizer, 406 | num_training_steps = total_steps, 407 | num_warmup_steps = config.num_warmup_steps, 408 | ) 409 | else: 410 | scheduler = None 411 | 412 | print("Training on Task {}...".format(task_type)) 413 | writer = SummaryWriter('runs/{}'.format(model_type + '_' + task_type)) 414 | 415 | best_dev_loss = 999 416 | best_dev_f1 = 0 417 | for epoch in range(epochs): 418 | train_loss, train_f1, train_f1_a, train_f1_b = train(model, device, epoch, train_dataloader, test_dataloader, \ 419 | save_dir, optimizer, scheduler=scheduler, model_type=model_type, \ 420 | print_every=print_every, eval_every=eval_every, writer=writer, use_fgm=use_fgm) 421 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/train_old.py: -------------------------------------------------------------------------------- 1 | from model import BertClassifierSingleModel, NezhaClassifierSingleModel, SBERTSingleModel, SNEZHASingleModel, BertClassifierTextCNNSingleModel 2 | from data import SentencePairDatasetWithType, SentencePairDatasetForSBERT 3 | from utils import focal_loss, FGM 4 | from transformers import AdamW, get_linear_schedule_with_warmup 5 | from config import Config 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | import numpy as np 12 | from sklearn import metrics 13 | from tensorboardX import SummaryWriter 14 | 15 | import os 16 | # os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6,7' # recommended for NEZHA 17 | # os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7' 18 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 19 | os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4' 20 | 21 | def train(model, device, epoch, train_dataloader, test_dataloader, save_dir, optimizer, scheduler=None, criterion_type='CE', model_type='bert', print_every=100, eval_every=500, writer=None, use_fgm=False): 22 | print("Training at epoch {}".format(epoch)) 23 | if use_fgm: 24 | print("Using fgm for adversial attack") 25 | 26 | est_batch = len(train_dataloader.dataset) / (train_dataloader.batch_size) 27 | model.train() 28 | 29 | # for multiple GPU support 30 | model = torch.nn.DataParallel(model) 31 | 32 | assert criterion_type == 'CE' or criterion_type == 'FL' 33 | if criterion_type == 'CE': 34 | criterion = nn.CrossEntropyLoss() 35 | elif criterion_type == 'FL': 36 | criterion = focal_loss() 37 | 38 | if use_fgm: 39 | fgm = FGM(model) 40 | 41 | total_loss = [] 42 | total_gt_a, total_preds_a = [], [] 43 | total_gt_b, total_preds_b = [], [] 44 | 45 | # the following commented code is compatitable with multitasks (e.g. 6 subtasks with designated dataset) 46 | # however, the model's performance seems to be influenced by 1 precent in task b 47 | # for idx, batch in enumerate(train_dataloader): 48 | # # for SentencePairDatasetWithType, types would be returned 49 | # input_ids, input_types, labels, types = batch 50 | # input_ids = input_ids.to(device) 51 | # input_types = input_types.to(device) 52 | # # labels should be flattened 53 | # labels = labels.to(device).view(-1) 54 | 55 | # optimizer.zero_grad() 56 | 57 | # # the probs given by the model 58 | # all_probs = model(input_ids, input_types) 59 | # num_tasks = len(all_probs) 60 | 61 | # all_masks = [(types==task_id).numpy() for task_id in range(num_tasks)] 62 | # all_output = [all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)] 63 | # all_labels = [labels[all_masks[task_id]] for task_id in range(num_tasks)] 64 | 65 | # # calculate the loss and BP 66 | # # TODO: different weights for each task? 67 | # all_loss = None 68 | # for task_id in range(num_tasks): 69 | # if all_masks[task_id].sum() != 0: 70 | # if all_loss is None: 71 | # all_loss = criterion(all_output[task_id], all_labels[task_id]) 72 | # else: 73 | # all_loss += criterion(all_output[task_id], all_labels[task_id]) 74 | # all_loss.backward() 75 | 76 | # # code for fgm adversial training 77 | # if use_fgm: 78 | # fgm.attack() 79 | # # adv_probs_a, adv_probs_b = model(input_ids, input_types) 80 | # adv_all_probs = model(input_ids, input_types) 81 | # adv_all_output = [adv_all_probs[task_id][all_masks[task_id]] for task_id in range(num_tasks)] 82 | # # calculate the loss and BP 83 | # adv_all_loss = None 84 | # for task_id in range(num_tasks): 85 | # if all_masks[task_id].sum() != 0: 86 | # if adv_all_loss is None: 87 | # adv_all_loss = criterion(adv_all_output[task_id], all_labels[task_id]) 88 | # else: 89 | # adv_all_loss += criterion(adv_all_output[task_id], all_labels[task_id]) 90 | # adv_all_loss.backward() 91 | # fgm.restore() 92 | 93 | # optimizer.step() 94 | # if scheduler is not None: 95 | # scheduler.step() 96 | 97 | 98 | # all_gt = [all_labels[task_id].cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)] 99 | # all_preds = [all_output[task_id].argmax(axis=1).cpu().numpy().tolist() if all_masks[task_id].sum()!=0 else [] for task_id in range(num_tasks)] 100 | 101 | # gt_a, preds_a = [], [] 102 | # for task_id in range(0, num_tasks, 2): 103 | # gt_a += all_gt[task_id] 104 | # preds_a += all_preds[task_id] 105 | 106 | # gt_b, preds_b = [], [] 107 | # for task_id in range(1, num_tasks, 2): 108 | # gt_b += all_gt[task_id] 109 | # preds_b += all_preds[task_id] 110 | 111 | # total_preds_a += preds_a 112 | # total_gt_a += gt_a 113 | # total_preds_b += preds_b 114 | # total_gt_b += gt_b 115 | # total_loss.append(all_loss.item()) 116 | # # print('a', preds_a, gt_a) 117 | # # print('b', preds_b, gt_b) 118 | 119 | # acc_a = metrics.accuracy_score(gt_a, preds_a) if len(gt_a)!=0 else 0 120 | # f1_a = metrics.f1_score(gt_a, preds_a, zero_division=0) 121 | # acc_b = metrics.accuracy_score(gt_b, preds_b) if len(gt_b)!=0 else 0 122 | # f1_b = metrics.f1_score(gt_b, preds_b, zero_division=0) 123 | 124 | # # learning rate for bert is the second (the last) parameter group 125 | # writer.add_scalar('train/learning_rate', optimizer.param_groups[-1]['lr'], global_step=epoch*est_batch+idx) 126 | # writer.add_scalar('train/loss', all_loss.item(), global_step=epoch*est_batch+idx) 127 | # writer.add_scalar('train/acc_a', acc_a, global_step=epoch*est_batch+idx) 128 | # writer.add_scalar('train/acc_b', acc_b, global_step=epoch*est_batch+idx) 129 | # writer.add_scalar('train/f1_a', f1_a, global_step=epoch*est_batch+idx) 130 | # writer.add_scalar('train/f1_b', f1_b, global_step=epoch*est_batch+idx) 131 | 132 | # # print the loss and accuracy score if reach print_every 133 | # if (idx+1) % print_every == 0: 134 | # print("\tBatch: {} / {:.0f}, Loss: {:.6f}".format(idx, est_batch, all_loss.item())) 135 | # print("\t\t Task A\tAcc: {:.6f}, F1: {:.6f}".format(acc_a, f1_a)) 136 | # print("\t\t Task B\tAcc: {:.6f}, F1: {:.6f}".format(acc_b, f1_b)) 137 | 138 | # # evaluate the model if reach eval_every, instead of evaluate after the whole epoch 139 | # global best_dev_loss, best_dev_f1 140 | # if (idx+1) % eval_every == 0: 141 | # dev_loss, dev_acc_a, dev_acc_b, dev_f1_a, dev_f1_b = eval(model, device, test_dataloader, criterion_type) 142 | # dev_f1 = (dev_f1_a + dev_f1_b) / 2 143 | # writer.add_scalar('eval/loss', dev_loss, global_step=epoch*est_batch+idx) 144 | # writer.add_scalar('eval/acc_a', dev_acc_a, global_step=epoch*est_batch+idx) 145 | # writer.add_scalar('eval/acc_b', dev_acc_b, global_step=epoch*est_batch+idx) 146 | # writer.add_scalar('eval/f1_a', dev_f1_a, global_step=epoch*est_batch+idx) 147 | # writer.add_scalar('eval/f1_b', dev_f1_b, global_step=epoch*est_batch+idx) 148 | # # in practice, better loss is preferred instead of better f1 score, 149 | # # which could be resulted from random overfitting on the valid set 150 | # # 0517: save the model's state_dict instead of the whole model (mainly for NEZHA's sake) 151 | # if (dev_loss < best_dev_loss or dev_f1 > best_dev_f1): 152 | # if dev_loss < best_dev_loss: 153 | # best_dev_loss = dev_loss 154 | # torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'loss') 155 | # print("----------BETTER LOSS, MODEL SAVED-----------") 156 | # if dev_f1 > best_dev_f1: 157 | # best_dev_f1 = dev_f1 158 | # torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'f1') 159 | # print("----------BETTER F1, MODEL SAVED-----------") 160 | 161 | for idx, batch in enumerate(train_dataloader): 162 | # for SentencePairDatasetWithType, types would be returned 163 | input_ids, input_types, labels, types = batch 164 | input_ids = input_ids.to(device) 165 | input_types = input_types.to(device) 166 | # labels should be flattened 167 | labels = labels.to(device).view(-1) 168 | 169 | optimizer.zero_grad() 170 | 171 | # the probs given by the model 172 | probs_a, probs_b = model(input_ids, input_types) 173 | 174 | mask_a, mask_b = (types==0).numpy(), (types==1).numpy() 175 | output_a, labels_a = probs_a[mask_a], labels[mask_a] 176 | output_b, labels_b = probs_b[mask_b], labels[mask_b] 177 | 178 | # calculate the loss and BP 179 | # loss_a = criterion(output_a, labels_a) if mask_a.sum()!=0 else None 180 | # loss_b = criterion(output_b, labels_b) if mask_b.sum()!=0 else None 181 | # so-called multi-task training 182 | # TODO: different weights for each task? 183 | if mask_a.sum()==0: 184 | loss = criterion(output_b, labels_b) 185 | elif mask_b.sum()==0: 186 | loss = criterion(output_a, labels_a) 187 | else: 188 | loss = criterion(output_a, labels_a) + criterion(output_b, labels_b) 189 | # print(loss.item()) 190 | loss.backward() 191 | 192 | # code for fgm adversial training 193 | if use_fgm: 194 | fgm.attack() 195 | adv_probs_a, adv_probs_b = model(input_ids, input_types) 196 | # calculate the loss and BP 197 | adv_output_a, adv_output_b = adv_probs_a[mask_a], adv_probs_b[mask_b] 198 | if mask_a.sum()==0: 199 | adv_loss = criterion(adv_output_b, labels_b) 200 | elif mask_b.sum()==0: 201 | adv_loss = criterion(adv_output_a, labels_a) 202 | else: 203 | adv_loss = criterion(adv_output_a, labels_a) + criterion(adv_output_b, labels_b) 204 | adv_loss.backward() 205 | fgm.restore() 206 | 207 | optimizer.step() 208 | if scheduler is not None: 209 | scheduler.step() 210 | 211 | gt_a = labels_a.cpu().numpy().tolist() 212 | preds_a = output_a.argmax(axis=1).cpu().numpy().tolist() if len(gt_a)!=0 else [] 213 | 214 | gt_b = labels_b.cpu().numpy().tolist() 215 | preds_b = output_b.argmax(axis=1).cpu().numpy().tolist() if len(gt_b)!=0 else [] 216 | 217 | total_preds_a += preds_a 218 | total_gt_a += gt_a 219 | total_preds_b += preds_b 220 | total_gt_b += gt_b 221 | total_loss.append(loss.item()) 222 | # print('a', preds_a, gt_a) 223 | # print('b', preds_b, gt_b) 224 | 225 | acc_a = metrics.accuracy_score(gt_a, preds_a) if len(gt_a)!=0 else 0 226 | f1_a = metrics.f1_score(gt_a, preds_a, zero_division=0) 227 | acc_b = metrics.accuracy_score(gt_b, preds_b) if len(gt_b)!=0 else 0 228 | f1_b = metrics.f1_score(gt_b, preds_b, zero_division=0) 229 | 230 | writer.add_scalar('train/learning_rate', optimizer.param_groups[-1]['lr'], global_step=epoch*est_batch+idx) 231 | writer.add_scalar('train/loss', loss.item(), global_step=epoch*est_batch+idx) 232 | writer.add_scalar('train/acc_a', acc_a, global_step=epoch*est_batch+idx) 233 | writer.add_scalar('train/acc_b', acc_b, global_step=epoch*est_batch+idx) 234 | writer.add_scalar('train/f1_a', f1_a, global_step=epoch*est_batch+idx) 235 | writer.add_scalar('train/f1_b', f1_b, global_step=epoch*est_batch+idx) 236 | 237 | # print the loss and accuracy score if reach print_every 238 | if (idx+1) % print_every == 0: 239 | print("\tBatch: {} / {:.0f}, Loss: {:.6f}".format(idx, est_batch, loss.item())) 240 | print("\t\t Task A\tAcc: {:.6f}, F1: {:.6f}".format(acc_a, f1_a)) 241 | # if (f1_a == 0): 242 | # print(metrics.precision_recall_fscore_support(gt_a, preds_a, zero_division=0)) 243 | 244 | print("\t\t Task B\tAcc: {:.6f}, F1: {:.6f}".format(acc_b, f1_b)) 245 | # if (f1_b == 0): 246 | # print(metrics.precision_recall_fscore_support(gt_b, preds_b, zero_division=0)) 247 | # evaluate the model if reach eval_every, instead of evaluate after the whole epoch 248 | global best_dev_loss, best_dev_f1 249 | if (idx+1) % eval_every == 0: 250 | dev_loss, dev_acc_a, dev_acc_b, dev_f1_a, dev_f1_b = eval(model, device, test_dataloader, criterion_type) 251 | dev_f1 = (dev_f1_a + dev_f1_b) / 2 252 | writer.add_scalar('eval/loss', dev_loss, global_step=epoch*est_batch+idx) 253 | writer.add_scalar('eval/acc_a', dev_acc_a, global_step=epoch*est_batch+idx) 254 | writer.add_scalar('eval/acc_b', dev_acc_b, global_step=epoch*est_batch+idx) 255 | writer.add_scalar('eval/f1_a', dev_f1_a, global_step=epoch*est_batch+idx) 256 | writer.add_scalar('eval/f1_b', dev_f1_b, global_step=epoch*est_batch+idx) 257 | if (dev_loss < best_dev_loss or dev_f1 > best_dev_f1): 258 | if dev_loss < best_dev_loss: 259 | best_dev_loss = dev_loss 260 | # torch.save(model, save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'loss') 261 | torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'loss') 262 | print("----------BETTER LOSS, MODEL SAVED-----------") 263 | if dev_f1 > best_dev_f1: 264 | best_dev_f1 = dev_f1 265 | # torch.save(model, save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'f1') 266 | torch.save(model.state_dict(), save_dir + model_type + '_epoch_{}_{}_'.format(epoch, task_type) + 'f1') 267 | print("----------BETTER F1, MODEL SAVED-----------") 268 | 269 | loss = np.array(total_loss).mean() 270 | # Setting average=None to return class-specific scores 271 | # 0502 BUG FIXED: do not use 'macro', DO NOT require class-specific metrics! 272 | # macro_f1 = metrics.f1_score(total_gt, total_preds, average='macro') 273 | f1_a = metrics.f1_score(total_gt_a, total_preds_a, zero_division=0) 274 | f1_b = metrics.f1_score(total_gt_b, total_preds_b, zero_division=0) 275 | f1 = (f1_a + f1_b) / 2 276 | print("Average f1 on training set: {:.6f}, f1_a: {:.6f}, f1_b: {:.6f}".format(f1, f1_a, f1_b)) 277 | 278 | return loss, f1, f1_a, f1_b 279 | 280 | 281 | def eval(model, device, test_dataloader, criterion_type='CE'): 282 | print("Evaluating") 283 | model.eval() 284 | # if called while training, then model parallel is already done 285 | # model = torch.nn.DataParallel(model) 286 | 287 | assert criterion_type == 'CE' or criterion_type == 'FL' 288 | if criterion_type == 'CE': 289 | criterion = nn.CrossEntropyLoss() 290 | elif criterion_type == 'FL': 291 | criterion = focal_loss() 292 | 293 | total_loss = [] 294 | total_gt_a, total_preds_a = [], [] 295 | total_gt_b, total_preds_b = [], [] 296 | 297 | for idx, batch in enumerate(test_dataloader): 298 | input_ids, input_types, labels, types = batch 299 | input_ids = input_ids.to(device) 300 | input_types = input_types.to(device) 301 | # labels should be flattened 302 | labels = labels.to(device).view(-1) 303 | 304 | # the probs given by the model, without grads 305 | with torch.no_grad(): 306 | # the probs given by the model 307 | probs_a, probs_b = model(input_ids, input_types) 308 | mask_a, mask_b = (types==0).numpy(), (types==1).numpy() 309 | output_a, labels_a = probs_a[mask_a], labels[mask_a] 310 | output_b, labels_b = probs_b[mask_b], labels[mask_b] 311 | 312 | if mask_a.sum()==0: 313 | loss = criterion(output_b, labels_b) 314 | elif mask_b.sum()==0: 315 | loss = criterion(output_a, labels_a) 316 | else: 317 | loss = criterion(output_a, labels_a) + criterion(output_b, labels_b) 318 | 319 | gt_a = labels_a.cpu().numpy().tolist() 320 | preds_a = output_a.argmax(axis=1).cpu().numpy().tolist() if len(gt_a)!=0 else [] 321 | 322 | gt_b = labels_b.cpu().numpy().tolist() 323 | preds_b = output_b.argmax(axis=1).cpu().numpy().tolist() if len(gt_b)!=0 else [] 324 | 325 | total_preds_a += preds_a 326 | total_gt_a += gt_a 327 | total_preds_b += preds_b 328 | total_gt_b += gt_b 329 | total_loss.append(loss.item()) 330 | 331 | loss = np.array(total_loss).mean() 332 | acc_a = metrics.accuracy_score(total_gt_a, total_preds_a) if len(total_gt_a)!=0 else 0 333 | f1_a = metrics.f1_score(total_gt_a, total_preds_a, zero_division=0) 334 | if (f1_a == 0): 335 | print("F1_a = 0, checking precision, recall, fscore and support...") 336 | print(metrics.precision_recall_fscore_support(total_gt_a, total_preds_a, zero_division=0)) 337 | 338 | acc_b = metrics.accuracy_score(total_gt_b, total_preds_b) if len(total_gt_b)!=0 else 0 339 | f1_b = metrics.f1_score(total_gt_b, total_preds_b, zero_division=0) 340 | if (f1_b == 0): 341 | print("F1_b = 0, checking precision, recall, fscore and support...") 342 | print(metrics.precision_recall_fscore_support(total_gt_b, total_preds_b, zero_division=0)) 343 | 344 | # Setting average=None to return class-specific scores 345 | # macro_f1 = metrics.f1_score(total_gt, total_preds, average='macro') 346 | # f1 = metrics.f1_score(total_gt, total_preds) 347 | 348 | # print loss and classification report 349 | print("Loss on dev set: ", loss) 350 | print("F1 on dev set: {:.6f}, f1_a: {:.6f}, f1_b: {:.6f}".format((f1_a+f1_b)/2, f1_a, f1_b)) 351 | 352 | # return loss, acc, macro_f1 353 | return loss, acc_a, acc_b, f1_a, f1_b 354 | 355 | 356 | if __name__ == '__main__': 357 | config = Config() 358 | device = config.device 359 | pretrained = config.pretrained 360 | model_type = config.model_type 361 | use_fgm = config.use_fgm 362 | 363 | save_dir = config.save_dir 364 | data_dir = config.data_dir 365 | # whether to shuffle the pos of source and target to augment data 366 | shuffle_order = config.shuffle_order 367 | # whether to use the positive case in task b for task a (positives) 368 | # and to use the negativate case in task a for task b (negatives) 369 | aug_data = config.aug_data 370 | # method for clipping long seqeunces, 'head' or 'tail' 371 | clip_method = config.clip_method 372 | 373 | task_type = config.task_type 374 | task_a = ['短短匹配A类', '短长匹配A类', '长长匹配A类'] 375 | task_b = ['短短匹配B类', '短长匹配B类', '长长匹配B类'] 376 | 377 | # hypter parameters here 378 | epochs = config.epochs 379 | lr = config.lr 380 | classifer_lr = config.classifier_lr 381 | weight_decay = config.weight_decay 382 | hidden_size = config.hidden_size 383 | train_bs = config.train_bs 384 | eval_bs = config.eval_bs 385 | 386 | print_every = config.print_every 387 | eval_every = config.eval_every 388 | 389 | train_data_dir, dev_data_dir = [], [] 390 | # integrate the two tasks into one dataset using task_type = 'ab' 391 | if 'a' in task_type: 392 | for task in task_a: 393 | train_data_dir.append(data_dir + task + '/train.txt') 394 | train_data_dir.append(data_dir + task + '/train_r2.txt') 395 | train_data_dir.append(data_dir + task + '/train_r3.txt') 396 | train_data_dir.append(data_dir + task + '/train_rematch.txt') # training file in rematch 397 | dev_data_dir.append(data_dir + task + '/valid.txt') 398 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') # valid file in rematch 399 | 400 | if 'b' in task_type: 401 | for task in task_b: 402 | train_data_dir.append(data_dir + task + '/train.txt') 403 | train_data_dir.append(data_dir + task + '/train_r2.txt') 404 | train_data_dir.append(data_dir + task + '/train_r3.txt') 405 | train_data_dir.append(data_dir + task + '/train_rematch.txt') # training file in rematch 406 | dev_data_dir.append(data_dir + task + '/valid.txt') 407 | dev_data_dir.append(data_dir + task + '/valid_rematch.txt') # valid file in rematch 408 | 409 | # toy dataset for testing 410 | if config.load_toy_dataset: 411 | train_data_dir = [ 412 | '../data/sohu2021_open_data/短短匹配A类/valid.txt', 413 | '../data/sohu2021_open_data/短短匹配B类/valid.txt', 414 | '../data/sohu2021_open_data/短长匹配A类/valid.txt', 415 | '../data/sohu2021_open_data/短长匹配B类/valid.txt', 416 | '../data/sohu2021_open_data/长长匹配A类/valid.txt', 417 | '../data/sohu2021_open_data/长长匹配B类/valid.txt'] 418 | dev_data_dir = ['../data/sohu2021_open_data/短短匹配A类/valid.txt', 419 | '../data/sohu2021_open_data/短短匹配B类/valid.txt',] 420 | 421 | 422 | print("Loading pretrained Model from {}...".format(pretrained)) 423 | # integrating SBERT model into a unified training framework 424 | if 'sbert' in model_type.lower(): 425 | print("Using SentenceBERT model and dataset") 426 | if 'nezha' in model_type.lower(): 427 | model = SNEZHASingleModel(bert_dir=pretrained, hidden_size=hidden_size) 428 | else: 429 | model = SBERTSingleModel(bert_dir=pretrained, hidden_size=hidden_size) 430 | model.to(device) 431 | print("Loading Training Data...") 432 | print(train_data_dir) 433 | # augment the data with shuffle_order=True (changing order of source and target) 434 | # or with aug_data=True (neg cases in A -> B, pos cases in B -> A) 435 | train_dataset = SentencePairDatasetForSBERT(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method) 436 | train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True) 437 | 438 | print("Loading Dev Data...") 439 | test_dataset = SentencePairDatasetForSBERT(dev_data_dir, True, pretrained, clip=clip_method) 440 | test_dataloader = DataLoader(test_dataset, batch_size=eval_bs, shuffle=False) 441 | 442 | # 0517: for training, load weights from pretrained with from_pretrained=True (by default) 443 | # for larger model, adjust the hidden_size according to its config 444 | # distinguish model architectures or pretrained models according to model_type 445 | else: 446 | print("Using BERT model and dataset") 447 | if 'nezha' in model_type.lower(): 448 | print("Using NEZHA pretrained model") 449 | model = NezhaClassifierSingleModel(bert_dir=pretrained, hidden_size=hidden_size) 450 | elif 'cnn' in model_type.lower(): 451 | print("Adding TextCNN after BERT output") 452 | model = BertClassifierTextCNNSingleModel(bert_dir=pretrained, hidden_size=hidden_size) 453 | else: 454 | print("Using conventional BERT model with linears") 455 | model = BertClassifierSingleModel(bert_dir=pretrained, hidden_size=hidden_size) 456 | model.to(device) 457 | 458 | print("Loading Training Data...") 459 | print(train_data_dir) 460 | # augment the data with shuffle_order=True (changing order of source and target) 461 | # or with aug_data=True (neg cases in A -> B, pos cases in B -> A) 462 | train_dataset = SentencePairDatasetWithType(train_data_dir, True, pretrained, shuffle_order, aug_data=aug_data, clip=clip_method) 463 | train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True) 464 | 465 | print("Loading Dev Data...") 466 | test_dataset = SentencePairDatasetWithType(dev_data_dir, True, pretrained, clip=clip_method) 467 | test_dataloader = DataLoader(test_dataset, batch_size=eval_bs, shuffle=False) 468 | 469 | optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, correct_bias=False) 470 | # 0514 setting different lr for bert encoder and classifier 471 | # TODO: verify the large lr works for classifiers 472 | # optimizer = AdamW([ 473 | # {"params": model.all_classifier.parameters(), "lr": classifer_lr}, 474 | # {"params": model.bert.parameters()}], 475 | # lr=lr) 476 | 477 | # for p in optimizer.param_groups: 478 | # outputs = '' 479 | # for k, v in p.items(): 480 | # if k is 'params': 481 | # outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ') 482 | # else: 483 | # outputs += (k + ': ' + str(v).ljust(10) + ' ') 484 | # print(outputs) 485 | 486 | total_steps = len(train_dataloader) * epochs 487 | 488 | # TODO: using ReduceLROnPlateau instead of linear scheduler 489 | if config.use_scheduler: 490 | scheduler = get_linear_schedule_with_warmup( 491 | optimizer, 492 | num_training_steps = total_steps, 493 | num_warmup_steps = config.num_warmup_steps, 494 | ) 495 | else: 496 | scheduler = None 497 | 498 | print("Training on Task {}...".format(task_type)) 499 | writer = SummaryWriter('runs/{}'.format(model_type + '_' + task_type)) 500 | 501 | best_dev_loss = 999 502 | best_dev_f1 = 0 503 | for epoch in range(epochs): 504 | train_loss, train_f1, train_f1_a, train_f1_b = train(model, device, epoch, train_dataloader, test_dataloader, \ 505 | save_dir, optimizer, scheduler=scheduler, model_type=model_type, \ 506 | print_every=print_every, eval_every=eval_every, writer=writer, use_fgm=use_fgm) 507 | -------------------------------------------------------------------------------- /决赛提交/sohu_matching/src/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | import pandas as pd 5 | 6 | # importing files for RAdam and lookahead 7 | import math 8 | from torch.optim.optimizer import Optimizer, required 9 | from collections import defaultdict 10 | import itertools as it 11 | 12 | # from model import * 13 | 14 | def pad_to_maxlen(input_ids, max_len, pad_value=0): 15 | if len(input_ids) >= max_len: 16 | input_ids = input_ids[:max_len] 17 | else: 18 | input_ids = input_ids + [pad_value] * (max_len-len(input_ids)) 19 | return input_ids 20 | 21 | 22 | def augment_data(data): 23 | B2A = pd.DataFrame() 24 | A2B = pd.DataFrame() 25 | 26 | train_A = data[(data['type'] == 0)] 27 | train_B = data[(data['type'] == 1)] 28 | 29 | B2A = B2A.append( 30 | train_B.loc[train_B['label'] == '1'], ignore_index=True) 31 | A2B = A2B.append( 32 | train_A.loc[train_A['label'] == '0'], ignore_index=True) 33 | 34 | train_aug_A = pd.concat([train_A, B2A], axis=0, ignore_index=True) 35 | train_aug_A.drop_duplicates( 36 | subset=['source', 'target'], keep='first', inplace=True, ignore_index=True) 37 | 38 | train_aug_B = pd.concat([train_B, A2B], axis=0, ignore_index=True) 39 | train_aug_B.drop_duplicates( 40 | subset=['source', 'target'], keep='first', inplace=True, ignore_index=True) 41 | 42 | train_all = pd.concat([train_aug_A, train_aug_B], axis=0, ignore_index=True) 43 | return train_all 44 | 45 | 46 | class focal_loss(nn.Module): 47 | def __init__(self, alpha=0.25, gamma=2, num_classes = 2, size_average=False): 48 | """ 49 | focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi) 50 | 步骤详细的实现了 focal_loss损失函数. 51 | :param alpha: 阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25 52 | :param gamma: 伽马γ,难易样本调节参数. retainnet中设置为2 53 | :param num_classes: 类别数量 54 | :param size_average: 损失计算方式,默认取均值 55 | """ 56 | super(focal_loss,self).__init__() 57 | self.size_average = size_average 58 | if isinstance(alpha,list): 59 | assert len(alpha)==num_classes # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重 60 | # print(" --- Focal_loss alpha = {}, 将对每一类权重进行精细化赋值 --- ".format(alpha)) 61 | self.alpha = torch.Tensor(alpha) 62 | else: 63 | assert alpha<1 #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类 64 | # print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha)) 65 | self.alpha = torch.zeros(num_classes) 66 | self.alpha[0] += alpha 67 | self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes] 68 | 69 | self.gamma = gamma 70 | 71 | def forward(self, preds, labels): 72 | """ 73 | focal_loss损失计算 74 | :param preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数 75 | :param labels: 实际类别. size:[B,N] or [B] 76 | :return: 77 | """ 78 | # assert preds.dim()==2 and labels.dim()==1 79 | preds = preds.view(-1,preds.size(-1)) 80 | self.alpha = self.alpha.to(preds.device) 81 | preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax 82 | preds_softmax = torch.exp(preds_logsoft) # softmax 83 | 84 | preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll ) 85 | preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) 86 | self.alpha = self.alpha.gather(0,labels.view(-1)) 87 | loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ 88 | 89 | loss = torch.mul(self.alpha, loss.t()) 90 | if self.size_average: 91 | loss = loss.mean() 92 | else: 93 | loss = loss.sum() 94 | return loss 95 | 96 | 97 | class FGM(): 98 | def __init__(self, model): 99 | self.model = model 100 | self.backup = {} 101 | 102 | def attack(self, epsilon=1., emb_name='bert.embeddings.'): 103 | # emb_name这个参数要换成你模型中embedding的参数名 104 | for name, param in self.model.named_parameters(): 105 | if param.requires_grad and emb_name in name: 106 | self.backup[name] = param.data.clone() 107 | norm = torch.norm(param.grad) 108 | if norm != 0 and not torch.isnan(norm): 109 | r_at = epsilon * param.grad / norm 110 | param.data.add_(r_at) 111 | 112 | def restore(self, emb_name='bert.embeddings.'): 113 | # emb_name这个参数要换成你模型中embedding的参数名 114 | for name, param in self.model.named_parameters(): 115 | if param.requires_grad and emb_name in name: 116 | assert name in self.backup 117 | param.data = self.backup[name] 118 | self.backup = {} 119 | 120 | 121 | class PGD(): 122 | def __init__(self, model): 123 | self.model = model 124 | self.emb_backup = {} 125 | self.grad_backup = {} 126 | 127 | def attack(self, epsilon=1., alpha=0.3, emb_name='bert.embeddings.', is_first_attack=False): 128 | # emb_name这个参数要换成你模型中embedding的参数名 129 | for name, param in self.model.named_parameters(): 130 | if param.requires_grad and emb_name in name: 131 | if is_first_attack: 132 | self.emb_backup[name] = param.data.clone() 133 | norm = torch.norm(param.grad) 134 | if norm != 0 and not torch.isnan(norm): 135 | r_at = alpha * param.grad / norm 136 | param.data.add_(r_at) 137 | param.data = self.project(name, param.data, epsilon) 138 | 139 | def restore(self, emb_name='bert.embeddings.'): 140 | # emb_name这个参数要换成你模型中embedding的参数名 141 | for name, param in self.model.named_parameters(): 142 | if param.requires_grad and emb_name in name: 143 | assert name in self.emb_backup 144 | param.data = self.emb_backup[name] 145 | self.emb_backup = {} 146 | 147 | def project(self, param_name, param_data, epsilon): 148 | r = param_data - self.emb_backup[param_name] 149 | if torch.norm(r) > epsilon: 150 | r = epsilon * r / torch.norm(r) 151 | return self.emb_backup[param_name] + r 152 | 153 | def backup_grad(self): 154 | for name, param in self.model.named_parameters(): 155 | if param.requires_grad: 156 | # 不对最后的 bert.pooler 层和 linear1 层做对抗训练 157 | if 'encoder' in name or 'bert.embeddings.' in name: 158 | self.grad_backup[name] = param.grad.clone() 159 | 160 | def restore_grad(self): 161 | for name, param in self.model.named_parameters(): 162 | if param.requires_grad: 163 | if 'encoder' in name or 'bert.embeddings.' in name: 164 | param.grad = self.grad_backup[name] 165 | 166 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 167 | class Lookahead(Optimizer): 168 | def __init__(self, optimizer, alpha=0.5, k=6): 169 | 170 | if not 0.0 <= alpha <= 1.0: 171 | raise ValueError(f'Invalid slow update rate: {alpha}') 172 | if not 1 <= k: 173 | raise ValueError(f'Invalid lookahead steps: {k}') 174 | 175 | self.optimizer = optimizer 176 | self.param_groups = self.optimizer.param_groups 177 | self.alpha = alpha 178 | self.k = k 179 | for group in self.param_groups: 180 | group["step_counter"] = 0 181 | 182 | self.slow_weights = [ 183 | [p.clone().detach() for p in group['params']] 184 | for group in self.param_groups] 185 | 186 | for w in it.chain(*self.slow_weights): 187 | w.requires_grad = False 188 | self.state = optimizer.state 189 | 190 | def step(self, closure=None): 191 | loss = None 192 | if closure is not None: 193 | loss = closure() 194 | loss = self.optimizer.step() 195 | 196 | for group,slow_weights in zip(self.param_groups,self.slow_weights): 197 | group['step_counter'] += 1 198 | if group['step_counter'] % self.k != 0: 199 | continue 200 | for p,q in zip(group['params'],slow_weights): 201 | if p.grad is None: 202 | continue 203 | q.data.add_(p.data - q.data, alpha=self.alpha ) 204 | p.data.copy_(q.data) 205 | return loss 206 | 207 | 208 | class RAdam(Optimizer): 209 | 210 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 211 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 212 | self.buffer = [[None, None, None] for ind in range(10)] 213 | super(RAdam, self).__init__(params, defaults) 214 | 215 | def __setstate__(self, state): 216 | super(RAdam, self).__setstate__(state) 217 | 218 | def step(self, closure=None): 219 | 220 | loss = None 221 | if closure is not None: 222 | loss = closure() 223 | 224 | for group in self.param_groups: 225 | 226 | for p in group['params']: 227 | if p.grad is None: 228 | continue 229 | grad = p.grad.data.float() 230 | if grad.is_sparse: 231 | raise RuntimeError('RAdam does not support sparse gradients') 232 | 233 | p_data_fp32 = p.data.float() 234 | 235 | state = self.state[p] 236 | 237 | if len(state) == 0: 238 | state['step'] = 0 239 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 240 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 241 | else: 242 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 243 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 244 | 245 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 246 | beta1, beta2 = group['betas'] 247 | 248 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2) 249 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 250 | 251 | state['step'] += 1 252 | buffered = self.buffer[int(state['step'] % 10)] 253 | if state['step'] == buffered[0]: 254 | N_sma, step_size = buffered[1], buffered[2] 255 | else: 256 | buffered[0] = state['step'] 257 | beta2_t = beta2 ** state['step'] 258 | N_sma_max = 2 / (1 - beta2) - 1 259 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 260 | buffered[1] = N_sma 261 | 262 | # more conservative since it's an approximated value 263 | if N_sma >= 5: 264 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 265 | else: 266 | step_size = 1.0 / (1 - beta1 ** state['step']) 267 | buffered[2] = step_size 268 | 269 | if group['weight_decay'] != 0: 270 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 271 | 272 | # more conservative since it's an approximated value 273 | if N_sma >= 5: 274 | denom = exp_avg_sq.sqrt().add_(group['eps']) 275 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) 276 | else: 277 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) 278 | 279 | p.data.copy_(p_data_fp32) 280 | 281 | return loss --------------------------------------------------------------------------------