├── LICENSE ├── README.md ├── checkCN.py ├── dataset ├── ICSD.zip ├── ICSD_CN.zip ├── ICSD_CN_oxide.zip ├── ICSD_oxide.zip └── README.md ├── formulas.csv ├── getOS.py ├── materials_icsd.py ├── materials_icsd_cn.py ├── materials_icsd_cno.py ├── materials_icsd_o.py ├── performances.png ├── random_config └── config.json ├── requirements.txt ├── tokenizer └── vocab.txt ├── train_BERTOS.py ├── train_BERTOS.sh └── trained_models ├── ICSD.zip ├── ICSD_CN.zip ├── ICSD_CN_oxide.zip ├── ICSD_oxide.zip └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERTOS 2 | BERTOS: transformer language model for oxidation state prediction 3 | 4 | Citation: Fu, Nihang, Jeffrey Hu, Ying Feng, Gregory Morrison, Hans‐Conrad zur Loye, and Jianjun Hu. "Composition Based Oxidation State Prediction of Materials Using Deep Learning Language Models." Advanced Science (2023): 2301011. [Link](https://onlinelibrary.wiley.com/doi/full/10.1002/advs.202301011) 5 | 6 | 7 | Nihang Fu, Jeffrey Hu, Ying Feng, Jianjun Hu*
8 | 9 | Machine Learning and Evolution Laboratory
10 | Department of computer science and Engineering
11 | University of South Carolina 12 | 13 | [Online Toolbox](http://www.materialsatlas.org/bertos) 14 | 15 | ## Table of Contents 16 | - [Installations](#Installations) 17 | 18 | - [Datasets](#Datasets) 19 | 20 | - [Usage](#Usage) 21 | 22 | - [Pretrained Models](#Pretrained-models) 23 | 24 | - [Performance](#Performance) 25 | 26 | - [Acknowledgement](#Acknowledgement) 27 | 28 | ## Installations 29 | 30 | 0. Set up a virtual environment 31 | ``` 32 | conda create -n bertos 33 | conda activate bertos 34 | ``` 35 | 36 | 1. PyTorch and transformers for computers with Nvidia GPU. 37 | ``` 38 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge 39 | conda install -c conda-forge transformers 40 | ``` 41 | If you only have CPU on your computer, try this: 42 | ``` 43 | pip install transformers[torch] 44 | ``` 45 | If you are using Mac M1 chip computer, following [this tutorial](https://jamescalam.medium.com/hugging-face-and-sentence-transformers-on-m1-macs-4b12e40c21ce) or [this one](https://towardsdatascience.com/hugging-face-transformers-on-apple-m1-26f0705874d7) to install pytorch and transformers. 46 | 47 | 2. Other packagess 48 | ``` 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | ## Datasets 53 | Our training process is carried out on our BERTOS datasets. After extracting the data under `datasets` folder, you will get the following four folders `ICSD`, `ICSD_CN`, `ICSD_CN_oxide`, and `ICSD_oxide`. 54 | 55 | ## Usage 56 | ### A Quick Run 57 | Quickly run the script to train a BERTOS using the OS-ICSD-CN training set and save the model into the `./model_icsdcn` folder. 58 | ``` 59 | bash train_BERTOS.sh 60 | ``` 61 | ### Training 62 | The command is to train a BERTOS model. 63 | ``` 64 | python train_BERTOS.py --config_name $CONFIG_NAME$ --dataset_name $DATASET_LOADER$ --max_length $MAX_LENGTH$ --per_device_train_batch_size $BATCH_ SIZE$ --learning_rate $LEARNING_RATE$ --num_train_epochs $EPOCHS$ --output_dir $MODEL_OUTPUT_DIRECTORY$ 65 | ``` 66 | We use `ICSD_CN` dataset as an example: 67 | ``` 68 | python train_BERTOS.py --config_name ./random_config --dataset_name materials_icsd_cn.py --max_length 100 --per_device_train_batch_size 256 --learning_rate 1e-3 --num_train_epochs 500 --output_dir ./model_icsdcn 69 | ``` 70 | If you want to change the dataset, you can use a different dataset file to replace `$DATASET_LOADER$`, like `materials_icsd.py`, `materials_icsdcn.py`, `materials_icsdcno.py`, and `materials_icsdo.py`. And you can also follow the intructions of [huggingface]() to build your own custom dataset. 71 | 72 | ### Predict 73 | Run `getOS.py` file to get predicted oxidation states for an input formula or input formulas.csv file containing multiple formulas.
74 | Using default pretrained model (trained on ICSD_CN): 75 | ``` 76 | python getOS.py --i SrTiO3 --model_name_or_path ./trained_models/ICSD_CN 77 | python getOS.py --f formulas.csv --model_name_or_path ./trained_models/ICSD_CN 78 | ``` 79 | Or using your model: 80 | ``` 81 | python getOS.py --i SrTiO3 --model_name_or_path ./model_directory 82 | python getOS.py --f formulas.csv --model_name_or_path ./model_directory 83 | 84 | ``` 85 | 86 | ### Check charge neutrality for hypothetical formulas 87 | Run `checkCN.py` file to check charge neutrality for an input formula or input formulas.csv file containing multiple formulas.
88 | Using default pretrained model (trained on ICSD_CN): 89 | ``` 90 | python checkCN.py --i SrTiO3 91 | python checkCN.py --f formulas.csv 92 | ``` 93 | Or using your model: 94 | ``` 95 | python checkCN.py --i SrTiO3 --model_name_or_path ./model_directory 96 | python checkCN.py --f formulas.csv --model_name_or_path ./model_directory 97 | ``` 98 | 99 | ## Pretrained Models 100 | Our trained models can be downloaded from figshare [BERTOS models](https://figshare.com/articles/online_resource/BERTOS_model/21554823), and you can use it as a test or prediction model. 101 | 102 | 103 | ## Performance 104 | 105 | ![Performance](performances.png) 106 | Removing `OS`, the datasets under `datasets` folder correspond to the datasets in the figure. 107 | 108 | ## Acknowledgement 109 | We use the transformer model as implemented in Huggingface. 110 | ``` 111 | @article{wolf2019huggingface, 112 | title={Huggingface's transformers: State-of-the-art natural language processing}, 113 | author={Wolf, Thomas and Debut, Lysandre and Sanh, Victor and Chaumond, Julien and Delangue, Clement and Moi, Anthony and Cistac, Pierric and Rault, Tim and Louf, R{\'e}mi and Funtowicz, Morgan and others}, 114 | journal={arXiv preprint arXiv:1910.03771}, 115 | year={2019} 116 | } 117 | ``` 118 | 119 | ## Cite our work 120 | ``` 121 | Fu, Nihang, Jeffrey Hu, Ying Feng, Gregory Morrison, Hans‐Conrad zur Loye, and Jianjun Hu. "Composition Based Oxidation State Prediction of Materials Using Deep Learning Language Models." Advanced Science (2023): 2301011. [PDF](https://arxiv.org/pdf/2211.15895) 122 | 123 | ``` 124 | 125 | # Contact 126 | If you have any problem using BERTOS, feel free to contact via [funihang@gmail.com](mailto:funihang@gmail.com). 127 | -------------------------------------------------------------------------------- /checkCN.py: -------------------------------------------------------------------------------- 1 | # for a formula: python getOS.py --i SO2 2 | # for a csv file conatining multiple formulas: python getOS.py --f formulas.csv 3 | 4 | import argparse 5 | import json 6 | import logging 7 | import os 8 | import torch 9 | 10 | import transformers 11 | from transformers import ( 12 | AutoConfig, 13 | AutoModelForTokenClassification, 14 | ) 15 | from transformers import BertTokenizerFast 16 | 17 | import numpy as np 18 | 19 | from pymatgen.io.cif import CifParser 20 | from pymatgen.core.composition import Composition 21 | from pymatgen.core.structure import Structure 22 | from pymatgen.core.periodic_table import Element 23 | 24 | import torch.nn.functional as F 25 | 26 | import pandas as pd 27 | 28 | def merge_os(osstr): 29 | #Sr(+2:1.00) Ti(+4:1.00) O(-2:1.00) O(-2:1.00) O(-2:1.00) 30 | items = osstr.split(" ") 31 | elementos={} 32 | for x in items: 33 | if x in elementos: 34 | elementos[x]+=1 35 | else: 36 | elementos[x]=1 37 | out='' 38 | for x in elementos: 39 | if elementos[x]==1: 40 | out+=x+" " 41 | else: 42 | e=x.split('(')[0] 43 | out+=f'{e}{elementos[x]}({"".join(x.split("(")[1:])} ' 44 | return out.strip() 45 | 46 | #import pymatgen 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser( 50 | description="Test trained model." 51 | ) 52 | parser.add_argument( 53 | "--i", 54 | type=str, 55 | default=None, 56 | help="Input formula", 57 | ) 58 | 59 | parser.add_argument( 60 | "--f", 61 | type=str, 62 | default=None, 63 | help="Input file", 64 | ) 65 | 66 | parser.add_argument( 67 | "--max_length", 68 | type=int, 69 | default=50, 70 | help=( 71 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 72 | " sequences shorter will be padded if `--pad_to_max_length` is passed." 73 | ), 74 | ) 75 | 76 | parser.add_argument( 77 | "--model_name_or_path", 78 | type=str, 79 | default='./trained_models/ICSD_CN/', 80 | help="Path to pretrained model or model identifier from huggingface.co/models.", 81 | required=False, 82 | ) 83 | 84 | parser.add_argument( 85 | "--tokenizer_name", 86 | type=str, 87 | default='./tokenizer', 88 | help="Pretrained tokenizer name or path if not the same as model_name", 89 | ) 90 | 91 | parser.add_argument( 92 | "--ignore_mismatched_sizes", 93 | action="store_true", 94 | default=True, 95 | help="ignore_mismatched_sizes set to True by default.", 96 | ) 97 | 98 | parser.add_argument( 99 | "--pad_to_max_length", 100 | action="store_true", 101 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 102 | ) 103 | args = parser.parse_args() 104 | return args 105 | 106 | def main(): 107 | args = parse_args() 108 | 109 | # Load tokenizer 110 | tokenizer_name_or_path = args.tokenizer_name 111 | tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name_or_path, do_lower_case=False) 112 | 113 | padding = "max_length" if args.pad_to_max_length else False 114 | 115 | # Load model config 116 | config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=14) 117 | 118 | # Load model 119 | model = AutoModelForTokenClassification.from_pretrained( 120 | args.model_name_or_path, 121 | from_tf=bool(".ckpt" in args.model_name_or_path), 122 | config=config, 123 | ignore_mismatched_sizes=args.ignore_mismatched_sizes, 124 | ) 125 | model.eval() 126 | if (args.i is not None) and (args.f is not None): 127 | print("Please input a formula (using --i) or give the csv file with some formulas (using --f)") 128 | return 129 | 130 | if args.i is not None: 131 | print("Input formula -------> ", args.i) 132 | comp = Composition(args.i) 133 | comp_dict = comp.to_reduced_dict 134 | 135 | input_seq = "" 136 | for ele in comp_dict.keys(): 137 | for count in range(int(comp_dict[ele])): 138 | input_seq = input_seq + ele + " " 139 | 140 | 141 | tokenized_inputs = torch.tensor(tokenizer.encode(input_seq, add_special_tokens=True)).unsqueeze(0) # Batch size 1 142 | 143 | outputs = model(tokenized_inputs) 144 | predictions = outputs.logits.argmax(dim=-1) 145 | probs = torch.max(F.softmax(outputs[0], dim=-1), dim=-1) 146 | 147 | 148 | true_pred = predictions[0][1:-1] 149 | true_probs = probs[0][0][1:-1] 150 | 151 | 152 | tmp = input_seq.split() 153 | outstr = '' 154 | count_cn = 0 155 | for i, ele in enumerate(tmp): 156 | outstr += ele 157 | true_os = true_pred[i].item() - 5 158 | count_cn += true_os 159 | if true_os>0: 160 | true_os='+'+str(true_os) 161 | prob = true_probs[i].item() 162 | 163 | outstr = outstr +f'({true_os}:{prob:.2f}) ' 164 | outstr=merge_os(outstr) 165 | 166 | print("Predicted Oxidation States:\n ", outstr) 167 | 168 | if count_cn == 0: 169 | print("Charge Neutral? Yes") 170 | else: 171 | print("Charge Neutral? No") 172 | 173 | if args.f is not None: 174 | print("Input file ------->", args.f) 175 | df = pd.read_csv(args.f, header=None) 176 | formulas = df[0] 177 | 178 | all_outs = [] 179 | for item in formulas: 180 | comp = Composition(item) 181 | comp_dict = comp.to_reduced_dict 182 | 183 | input_seq = "" 184 | for ele in comp_dict.keys(): 185 | for count in range(int(comp_dict[ele])): 186 | input_seq = input_seq + ele + " " 187 | 188 | tokenized_inputs = torch.tensor(tokenizer.encode(input_seq, add_special_tokens=True)).unsqueeze(0) # Batch size 1 189 | 190 | 191 | outputs = model(tokenized_inputs) 192 | predictions = outputs.logits.argmax(dim=-1) 193 | probs = torch.max(F.softmax(outputs[0], dim=-1), dim=-1) 194 | 195 | 196 | true_pred = predictions[0][1:-1] 197 | true_probs = probs[0][0][1:-1] 198 | 199 | tmp = input_seq.split() 200 | 201 | cn_count = 0 202 | outstr = '' 203 | for i, ele in enumerate(tmp): 204 | outstr += ele 205 | true_os = true_pred[i].item() - 5 206 | 207 | cn_count += true_os 208 | 209 | if true_os>0: 210 | true_os='+'+str(true_os) 211 | prob = true_probs[i].item() 212 | 213 | outstr = outstr +f'({true_os}:{prob:.2f}) ' 214 | outstr=merge_os(outstr) 215 | 216 | if cn_count == 0: 217 | all_outs.append([item, outstr, "True"]) 218 | else: 219 | all_outs.append([item, outstr, "False"]) 220 | 221 | out_df = pd.DataFrame(all_outs) 222 | out_df.columns = ["formula", "predicted OS", "charge neutrality"] 223 | 224 | #add _OS to the input filename as output file 225 | outfile='.'.join(args.f.split(".")[0:-1])+"_OS_CN."+args.f.split(".")[-1] 226 | 227 | out_df.to_csv(outfile, index=None) 228 | print("Output file ------>",f"{outfile} <-- check for the predicted oxidation states") 229 | 230 | 231 | if __name__ == "__main__": 232 | main() 233 | -------------------------------------------------------------------------------- /dataset/ICSD.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/dataset/ICSD.zip -------------------------------------------------------------------------------- /dataset/ICSD_CN.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/dataset/ICSD_CN.zip -------------------------------------------------------------------------------- /dataset/ICSD_CN_oxide.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/dataset/ICSD_CN_oxide.zip -------------------------------------------------------------------------------- /dataset/ICSD_oxide.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/dataset/ICSD_oxide.zip -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | 2 | just double click the zip files to unzip the datasets. 3 | -------------------------------------------------------------------------------- /formulas.csv: -------------------------------------------------------------------------------- 1 | SrTiO3 2 | LiMnO3 3 | Te2As 4 | CdP3Sr3 5 | -------------------------------------------------------------------------------- /getOS.py: -------------------------------------------------------------------------------- 1 | # for a formula: python getOS.py --i SO2 2 | # for a csv file conatining multiple formulas: python getOS.py --f formulas.csv 3 | 4 | import argparse 5 | import json 6 | import logging 7 | import os 8 | import torch 9 | 10 | import transformers 11 | from transformers import ( 12 | AutoConfig, 13 | AutoModelForTokenClassification, 14 | ) 15 | from transformers import BertTokenizerFast 16 | 17 | import numpy as np 18 | 19 | from pymatgen.io.cif import CifParser 20 | from pymatgen.core.composition import Composition 21 | from pymatgen.core.structure import Structure 22 | from pymatgen.core.periodic_table import Element 23 | 24 | import torch.nn.functional as F 25 | 26 | import pandas as pd 27 | 28 | def merge_os(osstr): 29 | #Sr(+2:1.00) Ti(+4:1.00) O(-2:1.00) O(-2:1.00) O(-2:1.00) 30 | items = osstr.split(" ") 31 | elementos={} 32 | for x in items: 33 | if x in elementos: 34 | elementos[x]+=1 35 | else: 36 | elementos[x]=1 37 | out='' 38 | for x in elementos: 39 | if elementos[x]==1: 40 | out+=x+" " 41 | else: 42 | e=x.split('(')[0] 43 | out+=f'{e}{elementos[x]}({"".join(x.split("(")[1:])} ' 44 | return out.strip() 45 | 46 | #import pymatgen 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser( 50 | description="Test trained model." 51 | ) 52 | parser.add_argument( 53 | "--i", 54 | type=str, 55 | default=None, 56 | help="Input formula", 57 | ) 58 | 59 | parser.add_argument( 60 | "--f", 61 | type=str, 62 | default=None, 63 | help="Input file", 64 | ) 65 | 66 | parser.add_argument( 67 | "--max_length", 68 | type=int, 69 | default=50, 70 | help=( 71 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 72 | " sequences shorter will be padded if `--pad_to_max_length` is passed." 73 | ), 74 | ) 75 | 76 | parser.add_argument( 77 | "--model_name_or_path", 78 | type=str, 79 | default='./trained_models/ICSD_CN/', 80 | help="Path to pretrained model or model identifier from huggingface.co/models.", 81 | required=False, 82 | ) 83 | 84 | parser.add_argument( 85 | "--tokenizer_name", 86 | type=str, 87 | default='./tokenizer', 88 | help="Pretrained tokenizer name or path if not the same as model_name", 89 | ) 90 | 91 | parser.add_argument( 92 | "--ignore_mismatched_sizes", 93 | action="store_true", 94 | default=True, 95 | help="ignore_mismatched_sizes set to True by default.", 96 | ) 97 | 98 | parser.add_argument( 99 | "--pad_to_max_length", 100 | action="store_true", 101 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 102 | ) 103 | args = parser.parse_args() 104 | return args 105 | 106 | def main(): 107 | args = parse_args() 108 | 109 | # Load tokenizer 110 | tokenizer_name_or_path = args.tokenizer_name 111 | tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name_or_path, do_lower_case=False) 112 | 113 | padding = "max_length" if args.pad_to_max_length else False 114 | 115 | # Load model config 116 | config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=14) 117 | 118 | # Load model 119 | model = AutoModelForTokenClassification.from_pretrained( 120 | args.model_name_or_path, 121 | from_tf=bool(".ckpt" in args.model_name_or_path), 122 | config=config, 123 | ignore_mismatched_sizes=args.ignore_mismatched_sizes, 124 | ) 125 | 126 | if (args.i is not None) and (args.f is not None): 127 | print("Please input a formula (using --i) or give the csv file with some formulas (using --f)") 128 | return 129 | 130 | if args.i is not None: 131 | print("Input formula -------> ", args.i) 132 | comp = Composition(args.i) 133 | comp_dict = comp.to_reduced_dict 134 | 135 | input_seq = "" 136 | for ele in comp_dict.keys(): 137 | for count in range(int(comp_dict[ele])): 138 | input_seq = input_seq + ele + " " 139 | 140 | 141 | tokenized_inputs = torch.tensor(tokenizer.encode(input_seq, add_special_tokens=True)).unsqueeze(0) # Batch size 1 142 | 143 | outputs = model(tokenized_inputs) 144 | predictions = outputs.logits.argmax(dim=-1) 145 | probs = torch.max(F.softmax(outputs[0], dim=-1), dim=-1) 146 | 147 | 148 | true_pred = predictions[0][1:-1] 149 | true_probs = probs[0][0][1:-1] 150 | 151 | 152 | tmp = input_seq.split() 153 | outstr = '' 154 | for i, ele in enumerate(tmp): 155 | outstr += ele 156 | true_os = true_pred[i].item() - 5 157 | if true_os>0: 158 | true_os='+'+str(true_os) 159 | prob = true_probs[i].item() 160 | 161 | outstr = outstr +f'({true_os}:{prob:.2f}) ' 162 | outstr=merge_os(outstr) 163 | print("Predicted Oxidation States:\n ", outstr) 164 | 165 | if args.f is not None: 166 | print("Input file ------->", args.f) 167 | df = pd.read_csv(args.f, header=None) 168 | formulas = df[0] 169 | 170 | all_outs = [] 171 | for item in formulas: 172 | comp = Composition(item) 173 | comp_dict = comp.to_reduced_dict 174 | 175 | input_seq = "" 176 | for ele in comp_dict.keys(): 177 | for count in range(int(comp_dict[ele])): 178 | input_seq = input_seq + ele + " " 179 | 180 | tokenized_inputs = torch.tensor(tokenizer.encode(input_seq, add_special_tokens=True)).unsqueeze(0) # Batch size 1 181 | 182 | 183 | outputs = model(tokenized_inputs) 184 | predictions = outputs.logits.argmax(dim=-1) 185 | probs = torch.max(F.softmax(outputs[0], dim=-1), dim=-1) 186 | 187 | 188 | true_pred = predictions[0][1:-1] 189 | true_probs = probs[0][0][1:-1] 190 | 191 | tmp = input_seq.split() 192 | outstr = '' 193 | for i, ele in enumerate(tmp): 194 | outstr += ele 195 | true_os = true_pred[i].item() - 5 196 | if true_os>0: 197 | true_os='+'+str(true_os) 198 | prob = true_probs[i].item() 199 | 200 | outstr = outstr +f'({true_os}:{prob:.2f}) ' 201 | outstr=merge_os(outstr) 202 | 203 | 204 | all_outs.append(outstr) 205 | 206 | out_df = pd.DataFrame(all_outs) 207 | 208 | #add _OS to the input filename as output file 209 | outfile='.'.join(args.f.split(".")[0:-1])+"_OS."+args.f.split(".")[-1] 210 | 211 | out_df.to_csv(outfile, header=None, index=None) 212 | print("Output file ------>",f"{outfile} <-- check for the predicted oxidation states") 213 | 214 | 215 | if __name__ == "__main__": 216 | main() 217 | -------------------------------------------------------------------------------- /materials_icsd.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Materials dataset""" 16 | 17 | import os 18 | import datasets 19 | 20 | 21 | logger = datasets.logging.get_logger(__name__) 22 | 23 | 24 | _CITATION = """ 25 | """ 26 | 27 | _DESCRIPTION = """ 28 | """ 29 | 30 | _ROOT = "./dataset/ICSD/" 31 | _TRAINING_FILE = "train.txt" 32 | _DEV_FILE = "validation.txt" 33 | _TEST_FILE = "test.txt" 34 | 35 | 36 | class Materials(datasets.GeneratorBasedBuilder): 37 | """Materials dataset""" 38 | 39 | VERSION = datasets.Version("1.0.0") 40 | 41 | BUILDER_CONFIGS = [ 42 | datasets.BuilderConfig(name="materials", version=VERSION, description="Materials dataset"), 43 | ] 44 | 45 | def _info(self): 46 | return datasets.DatasetInfo( 47 | description=_DESCRIPTION, 48 | features=datasets.Features( 49 | { 50 | "id": datasets.Value("string"), 51 | "tokens": datasets.Sequence(datasets.Value("string")), 52 | "ner_tags": datasets.Sequence( 53 | datasets.features.ClassLabel( 54 | names=[ 55 | "-5", 56 | "-4", 57 | "-3", 58 | "-2", 59 | "-1", 60 | "0", 61 | "1", 62 | "2", 63 | "3", 64 | "4", 65 | "5", 66 | "6", 67 | "7", 68 | "8", 69 | ] 70 | ) 71 | ), 72 | } 73 | ), 74 | supervised_keys=None, 75 | homepage="https://github.com/usccolumbia/BERTOS.git", 76 | citation=_CITATION, 77 | ) 78 | 79 | def _split_generators(self, dl_manager): 80 | """Returns SplitGenerators.""" 81 | 82 | data_files = { 83 | "train": os.path.join(_ROOT, _TRAINING_FILE), 84 | "validation": os.path.join(_ROOT, _DEV_FILE), 85 | "test": os.path.join(_ROOT, _TEST_FILE), 86 | } 87 | 88 | return [ 89 | datasets.SplitGenerator( 90 | name=datasets.Split.TRAIN, 91 | gen_kwargs={"filepath": data_files["train"], "split": "train"}, 92 | ), 93 | datasets.SplitGenerator( 94 | name=datasets.Split.VALIDATION, 95 | gen_kwargs={"filepath": data_files["validation"], "split": "validation"}, 96 | ), 97 | datasets.SplitGenerator( 98 | name=datasets.Split.TEST, 99 | gen_kwargs={"filepath": data_files["test"], "split": "test"}, 100 | ), 101 | ] 102 | 103 | def _generate_examples(self, filepath, split): 104 | """Yields examples.""" 105 | 106 | with open(filepath, encoding="utf-8") as f: 107 | 108 | guid = 0 109 | tokens = [] 110 | ner_tags = [] 111 | 112 | for line in f: 113 | if line == "" or line == "\n": 114 | if tokens: 115 | yield guid, { 116 | "id": str(guid), 117 | "tokens": tokens, 118 | "ner_tags": ner_tags, 119 | } 120 | guid += 1 121 | tokens = [] 122 | ner_tags = [] 123 | else: 124 | splits = line.split(" ") 125 | tokens.append(splits[0]) 126 | ner_tags.append(splits[1].rstrip()) 127 | 128 | # last example 129 | yield guid, { 130 | "id": str(guid), 131 | "tokens": tokens, 132 | "ner_tags": ner_tags, 133 | } 134 | -------------------------------------------------------------------------------- /materials_icsd_cn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Materials dataset""" 16 | 17 | import os 18 | import datasets 19 | 20 | 21 | logger = datasets.logging.get_logger(__name__) 22 | 23 | 24 | _CITATION = """ 25 | """ 26 | 27 | _DESCRIPTION = """ 28 | """ 29 | 30 | _ROOT = "./dataset/ICSD_CN/" 31 | _TRAINING_FILE = "train.txt" 32 | _DEV_FILE = "validation.txt" 33 | _TEST_FILE = "test.txt" 34 | 35 | 36 | class Materials(datasets.GeneratorBasedBuilder): 37 | """Materials dataset""" 38 | 39 | VERSION = datasets.Version("1.0.0") 40 | 41 | BUILDER_CONFIGS = [ 42 | datasets.BuilderConfig(name="materials", version=VERSION, description="Materials dataset"), 43 | ] 44 | 45 | def _info(self): 46 | return datasets.DatasetInfo( 47 | description=_DESCRIPTION, 48 | features=datasets.Features( 49 | { 50 | "id": datasets.Value("string"), 51 | "tokens": datasets.Sequence(datasets.Value("string")), 52 | "ner_tags": datasets.Sequence( 53 | datasets.features.ClassLabel( 54 | names=[ 55 | "-5", 56 | "-4", 57 | "-3", 58 | "-2", 59 | "-1", 60 | "0", 61 | "1", 62 | "2", 63 | "3", 64 | "4", 65 | "5", 66 | "6", 67 | "7", 68 | "8", 69 | ] 70 | ) 71 | ), 72 | } 73 | ), 74 | supervised_keys=None, 75 | homepage="https://github.com/usccolumbia/BERTOS.git", 76 | citation=_CITATION, 77 | ) 78 | 79 | def _split_generators(self, dl_manager): 80 | """Returns SplitGenerators.""" 81 | 82 | data_files = { 83 | "train": os.path.join(_ROOT, _TRAINING_FILE), 84 | "validation": os.path.join(_ROOT, _DEV_FILE), 85 | "test": os.path.join(_ROOT, _TEST_FILE), 86 | } 87 | 88 | return [ 89 | datasets.SplitGenerator( 90 | name=datasets.Split.TRAIN, 91 | gen_kwargs={"filepath": data_files["train"], "split": "train"}, 92 | ), 93 | datasets.SplitGenerator( 94 | name=datasets.Split.VALIDATION, 95 | gen_kwargs={"filepath": data_files["validation"], "split": "validation"}, 96 | ), 97 | datasets.SplitGenerator( 98 | name=datasets.Split.TEST, 99 | gen_kwargs={"filepath": data_files["test"], "split": "test"}, 100 | ), 101 | ] 102 | 103 | def _generate_examples(self, filepath, split): 104 | """Yields examples.""" 105 | 106 | with open(filepath, encoding="utf-8") as f: 107 | 108 | guid = 0 109 | tokens = [] 110 | ner_tags = [] 111 | 112 | for line in f: 113 | if line == "" or line == "\n": 114 | if tokens: 115 | yield guid, { 116 | "id": str(guid), 117 | "tokens": tokens, 118 | "ner_tags": ner_tags, 119 | } 120 | guid += 1 121 | tokens = [] 122 | ner_tags = [] 123 | else: 124 | splits = line.split(" ") 125 | tokens.append(splits[0]) 126 | ner_tags.append(splits[1].rstrip()) 127 | 128 | # last example 129 | yield guid, { 130 | "id": str(guid), 131 | "tokens": tokens, 132 | "ner_tags": ner_tags, 133 | } 134 | -------------------------------------------------------------------------------- /materials_icsd_cno.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Materials dataset""" 16 | 17 | import os 18 | import datasets 19 | 20 | 21 | logger = datasets.logging.get_logger(__name__) 22 | 23 | 24 | _CITATION = """ 25 | """ 26 | 27 | _DESCRIPTION = """ 28 | """ 29 | 30 | _ROOT = "./dataset/ICSD_CN_oxide/" 31 | _TRAINING_FILE = "train.txt" 32 | _DEV_FILE = "validation.txt" 33 | _TEST_FILE = "test.txt" 34 | 35 | 36 | class Materials(datasets.GeneratorBasedBuilder): 37 | """Materials dataset""" 38 | 39 | VERSION = datasets.Version("1.0.0") 40 | 41 | BUILDER_CONFIGS = [ 42 | datasets.BuilderConfig(name="materials", version=VERSION, description="Materials dataset"), 43 | ] 44 | 45 | def _info(self): 46 | return datasets.DatasetInfo( 47 | description=_DESCRIPTION, 48 | features=datasets.Features( 49 | { 50 | "id": datasets.Value("string"), 51 | "tokens": datasets.Sequence(datasets.Value("string")), 52 | "ner_tags": datasets.Sequence( 53 | datasets.features.ClassLabel( 54 | names=[ 55 | "-5", 56 | "-4", 57 | "-3", 58 | "-2", 59 | "-1", 60 | "0", 61 | "1", 62 | "2", 63 | "3", 64 | "4", 65 | "5", 66 | "6", 67 | "7", 68 | "8", 69 | ] 70 | ) 71 | ), 72 | } 73 | ), 74 | supervised_keys=None, 75 | homepage="https://github.com/usccolumbia/BERTOS.git", 76 | citation=_CITATION, 77 | ) 78 | 79 | def _split_generators(self, dl_manager): 80 | """Returns SplitGenerators.""" 81 | 82 | data_files = { 83 | "train": os.path.join(_ROOT, _TRAINING_FILE), 84 | "validation": os.path.join(_ROOT, _DEV_FILE), 85 | "test": os.path.join(_ROOT, _TEST_FILE), 86 | } 87 | 88 | return [ 89 | datasets.SplitGenerator( 90 | name=datasets.Split.TRAIN, 91 | gen_kwargs={"filepath": data_files["train"], "split": "train"}, 92 | ), 93 | datasets.SplitGenerator( 94 | name=datasets.Split.VALIDATION, 95 | gen_kwargs={"filepath": data_files["validation"], "split": "validation"}, 96 | ), 97 | datasets.SplitGenerator( 98 | name=datasets.Split.TEST, 99 | gen_kwargs={"filepath": data_files["test"], "split": "test"}, 100 | ), 101 | ] 102 | 103 | def _generate_examples(self, filepath, split): 104 | """Yields examples.""" 105 | 106 | #logger.info("? Generating examples from = %s", filepath) 107 | 108 | with open(filepath, encoding="utf-8") as f: 109 | 110 | guid = 0 111 | tokens = [] 112 | ner_tags = [] 113 | 114 | for line in f: 115 | if line == "" or line == "\n": 116 | if tokens: 117 | yield guid, { 118 | "id": str(guid), 119 | "tokens": tokens, 120 | "ner_tags": ner_tags, 121 | } 122 | guid += 1 123 | tokens = [] 124 | ner_tags = [] 125 | else: 126 | splits = line.split(" ") 127 | tokens.append(splits[0]) 128 | ner_tags.append(splits[1].rstrip()) 129 | 130 | # last example 131 | yield guid, { 132 | "id": str(guid), 133 | "tokens": tokens, 134 | "ner_tags": ner_tags, 135 | } 136 | -------------------------------------------------------------------------------- /materials_icsd_o.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Materials dataset""" 16 | 17 | import os 18 | import datasets 19 | 20 | 21 | logger = datasets.logging.get_logger(__name__) 22 | 23 | 24 | _CITATION = """ 25 | """ 26 | 27 | _DESCRIPTION = """ 28 | """ 29 | 30 | _ROOT = "./dataset/ICSD_oxide/" 31 | _TRAINING_FILE = "train.txt" 32 | _DEV_FILE = "validation.txt" 33 | _TEST_FILE = "test.txt" 34 | 35 | 36 | class Materials(datasets.GeneratorBasedBuilder): 37 | """Materials dataset""" 38 | 39 | VERSION = datasets.Version("1.0.0") 40 | 41 | BUILDER_CONFIGS = [ 42 | datasets.BuilderConfig(name="materials", version=VERSION, description="Materials dataset"), 43 | ] 44 | 45 | def _info(self): 46 | return datasets.DatasetInfo( 47 | description=_DESCRIPTION, 48 | features=datasets.Features( 49 | { 50 | "id": datasets.Value("string"), 51 | "tokens": datasets.Sequence(datasets.Value("string")), 52 | "ner_tags": datasets.Sequence( 53 | datasets.features.ClassLabel( 54 | names=[ 55 | "-5", 56 | "-4", 57 | "-3", 58 | "-2", 59 | "-1", 60 | "0", 61 | "1", 62 | "2", 63 | "3", 64 | "4", 65 | "5", 66 | "6", 67 | "7", 68 | "8", 69 | ] 70 | ) 71 | ), 72 | } 73 | ), 74 | supervised_keys=None, 75 | homepage="https://github.com/usccolumbia/BERTOS.git", 76 | citation=_CITATION, 77 | ) 78 | 79 | def _split_generators(self, dl_manager): 80 | """Returns SplitGenerators.""" 81 | 82 | data_files = { 83 | "train": os.path.join(_ROOT, _TRAINING_FILE), 84 | "validation": os.path.join(_ROOT, _DEV_FILE), 85 | "test": os.path.join(_ROOT, _TEST_FILE), 86 | } 87 | 88 | return [ 89 | datasets.SplitGenerator( 90 | name=datasets.Split.TRAIN, 91 | gen_kwargs={"filepath": data_files["train"], "split": "train"}, 92 | ), 93 | datasets.SplitGenerator( 94 | name=datasets.Split.VALIDATION, 95 | gen_kwargs={"filepath": data_files["validation"], "split": "validation"}, 96 | ), 97 | datasets.SplitGenerator( 98 | name=datasets.Split.TEST, 99 | gen_kwargs={"filepath": data_files["test"], "split": "test"}, 100 | ), 101 | ] 102 | 103 | def _generate_examples(self, filepath, split): 104 | """Yields examples.""" 105 | 106 | with open(filepath, encoding="utf-8") as f: 107 | 108 | guid = 0 109 | tokens = [] 110 | ner_tags = [] 111 | 112 | for line in f: 113 | if line == "" or line == "\n": 114 | if tokens: 115 | yield guid, { 116 | "id": str(guid), 117 | "tokens": tokens, 118 | "ner_tags": ner_tags, 119 | } 120 | guid += 1 121 | tokens = [] 122 | ner_tags = [] 123 | else: 124 | splits = line.split(" ") 125 | tokens.append(splits[0]) 126 | ner_tags.append(splits[1].rstrip()) 127 | 128 | # last example 129 | yield guid, { 130 | "id": str(guid), 131 | "tokens": tokens, 132 | "ner_tags": ner_tags, 133 | } 134 | -------------------------------------------------------------------------------- /performances.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/performances.png -------------------------------------------------------------------------------- /random_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "classifier_dropout": null, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 120, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 512, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 100, 14 | "model_type": "bert", 15 | "num_attention_heads": 4, 16 | "num_hidden_layers": 12, 17 | "pad_token_id": 0, 18 | "position_embedding_type": "absolute", 19 | "torch_dtype": "float32", 20 | "transformers_version": "4.23.0.dev0", 21 | "type_vocab_size": 2, 22 | "use_cache": true, 23 | "vocab_size": 123 24 | } 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.3.5 2 | argparse==1.4.0 3 | pymatgen==2022.0.17 4 | datasets==2.5.1 5 | tqdm==4.64.1 6 | accelerate==0.12.0 7 | evaluate==0.2.2 8 | transformers 9 | seqeval 10 | tensorboard 11 | -------------------------------------------------------------------------------- /tokenizer/vocab.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | [CLS] 4 | [SEP] 5 | [MASK] 6 | H 7 | He 8 | Li 9 | Be 10 | B 11 | C 12 | N 13 | O 14 | F 15 | Ne 16 | Na 17 | Mg 18 | Al 19 | Si 20 | P 21 | S 22 | Cl 23 | Ar 24 | K 25 | Ca 26 | Sc 27 | Ti 28 | V 29 | Cr 30 | Mn 31 | Fe 32 | Co 33 | Ni 34 | Cu 35 | Zn 36 | Ga 37 | Ge 38 | As 39 | Se 40 | Br 41 | Kr 42 | Rb 43 | Sr 44 | Y 45 | Zr 46 | Nb 47 | Mo 48 | Tc 49 | Ru 50 | Rh 51 | Pd 52 | Ag 53 | Cd 54 | In 55 | Sn 56 | Sb 57 | Te 58 | I 59 | Xe 60 | Cs 61 | Ba 62 | La 63 | Ce 64 | Pr 65 | Nd 66 | Pm 67 | Sm 68 | Eu 69 | Gd 70 | Tb 71 | Dy 72 | Ho 73 | Er 74 | Tm 75 | Yb 76 | Lu 77 | Hf 78 | Ta 79 | W 80 | Re 81 | Os 82 | Ir 83 | Pt 84 | Au 85 | Hg 86 | Tl 87 | Pb 88 | Bi 89 | Po 90 | At 91 | Rn 92 | Fr 93 | Ra 94 | Ac 95 | Th 96 | Pa 97 | U 98 | Np 99 | Pu 100 | Am 101 | Cm 102 | Bk 103 | Cf 104 | Es 105 | Fm 106 | Md 107 | No 108 | Lr 109 | Rf 110 | Db 111 | Sg 112 | Bh 113 | Hs 114 | Mt 115 | Ds 116 | Rg 117 | Cn 118 | Nh 119 | Fl 120 | Mc 121 | Lv 122 | Ts 123 | Og -------------------------------------------------------------------------------- /train_BERTOS.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Train BERTOS 16 | """ 17 | 18 | import argparse 19 | import json 20 | import logging 21 | import math 22 | import os 23 | import random 24 | from pathlib import Path 25 | 26 | import datasets 27 | import torch 28 | from datasets import ClassLabel, load_dataset 29 | from torch.utils.data import DataLoader 30 | from tqdm.auto import tqdm 31 | 32 | import evaluate 33 | import transformers 34 | from accelerate import Accelerator 35 | from accelerate.logging import get_logger 36 | from accelerate.utils import set_seed 37 | from huggingface_hub import Repository 38 | from transformers import ( 39 | AutoConfig, 40 | AutoModelForTokenClassification, 41 | DataCollatorForTokenClassification, 42 | PretrainedConfig, 43 | SchedulerType, 44 | default_data_collator, 45 | get_scheduler, 46 | ) 47 | from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry 48 | from transformers.utils.versions import require_version 49 | 50 | from transformers import BertTokenizerFast 51 | 52 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 53 | check_min_version("4.23.0.dev0") 54 | 55 | logger = get_logger(__name__) 56 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") 57 | 58 | 59 | def parse_args(): 60 | parser = argparse.ArgumentParser( 61 | description="Train BERTOS" 62 | ) 63 | parser.add_argument( 64 | "--dataset_name", 65 | type=str, 66 | default=None, 67 | help="The name of the dataset to use (via the datasets library).", 68 | ) 69 | parser.add_argument( 70 | "--text_column_name", 71 | type=str, 72 | default=None, 73 | help="The column name of text to input in the file (a csv or JSON file).", 74 | ) 75 | parser.add_argument( 76 | "--label_column_name", 77 | type=str, 78 | default=None, 79 | help="The column name of label to input in the file (a csv or JSON file).", 80 | ) 81 | parser.add_argument( 82 | "--max_length", 83 | type=int, 84 | default=128, 85 | help=( 86 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 87 | " sequences shorter will be padded if `--pad_to_max_length` is passed." 88 | ), 89 | ) 90 | parser.add_argument( 91 | "--pad_to_max_length", 92 | action="store_true", 93 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 94 | ) 95 | parser.add_argument( 96 | "--model_name_or_path", 97 | type=str, 98 | help="Path to pretrained model or model identifier from huggingface.co/models.", 99 | required=False, 100 | ) 101 | parser.add_argument( 102 | "--config_name", 103 | type=str, 104 | default=None, 105 | help="Pretrained config name or path if not the same as model_name", 106 | ) 107 | parser.add_argument( 108 | "--tokenizer_name", 109 | type=str, 110 | default='./tokenizer', 111 | help="Pretrained tokenizer name or path if not the same as model_name", 112 | ) 113 | parser.add_argument( 114 | "--per_device_train_batch_size", 115 | type=int, 116 | default=8, 117 | help="Batch size (per device) for the training dataloader.", 118 | ) 119 | parser.add_argument( 120 | "--per_device_eval_batch_size", 121 | type=int, 122 | default=8, 123 | help="Batch size (per device) for the evaluation dataloader.", 124 | ) 125 | parser.add_argument( 126 | "--learning_rate", 127 | type=float, 128 | default=5e-5, 129 | help="Initial learning rate (after the potential warmup period) to use.", 130 | ) 131 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 132 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 133 | parser.add_argument( 134 | "--max_train_steps", 135 | type=int, 136 | default=None, 137 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 138 | ) 139 | parser.add_argument( 140 | "--gradient_accumulation_steps", 141 | type=int, 142 | default=1, 143 | help="Number of updates steps to accumulate before performing a backward/update pass.", 144 | ) 145 | parser.add_argument( 146 | "--lr_scheduler_type", 147 | type=SchedulerType, 148 | default="linear", 149 | help="The scheduler type to use.", 150 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 151 | ) 152 | parser.add_argument( 153 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 154 | ) 155 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 156 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 157 | parser.add_argument( 158 | "--label_all_tokens", 159 | action="store_true", 160 | help="Setting labels of all special tokens to -100 and thus PyTorch will ignore them.", 161 | ) 162 | parser.add_argument( 163 | "--return_entity_level_metrics", 164 | action="store_true", 165 | help="Indication whether entity level metrics are to be returner.", 166 | ) 167 | parser.add_argument( 168 | "--task_name", 169 | type=str, 170 | default="ner", 171 | choices=["ner", "pos", "chunk"], 172 | help="The name of the task.", 173 | ) 174 | parser.add_argument( 175 | "--debug", 176 | action="store_true", 177 | help="Activate debug mode and run training only with a subset of data.", 178 | ) 179 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 180 | parser.add_argument( 181 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." 182 | ) 183 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") 184 | parser.add_argument( 185 | "--checkpointing_steps", 186 | type=str, 187 | default=None, 188 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 189 | ) 190 | parser.add_argument( 191 | "--resume_from_checkpoint", 192 | type=str, 193 | default=None, 194 | help="If the training should continue from a checkpoint folder.", 195 | ) 196 | parser.add_argument( 197 | "--with_tracking", 198 | action="store_true", 199 | help="Whether to enable experiment trackers for logging.", 200 | ) 201 | parser.add_argument( 202 | "--report_to", 203 | type=str, 204 | default="all", 205 | help=( 206 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 207 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' 208 | "Only applicable when `--with_tracking` is passed." 209 | ), 210 | ) 211 | parser.add_argument( 212 | "--ignore_mismatched_sizes", 213 | action="store_true", 214 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.", 215 | ) 216 | args = parser.parse_args() 217 | 218 | # Sanity checks 219 | if args.push_to_hub: 220 | assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." 221 | 222 | return args 223 | 224 | 225 | def main(): 226 | args = parse_args() 227 | 228 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 229 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 230 | send_example_telemetry("run_ner_no_trainer", args) 231 | 232 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 233 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers 234 | # in the environment 235 | accelerator = ( 236 | Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() 237 | ) 238 | # Make one log on every process with the configuration for debugging. 239 | logging.basicConfig( 240 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 241 | datefmt="%m/%d/%Y %H:%M:%S", 242 | level=logging.INFO, 243 | ) 244 | logger.info(accelerator.state, main_process_only=False) 245 | if accelerator.is_local_main_process: 246 | datasets.utils.logging.set_verbosity_warning() 247 | transformers.utils.logging.set_verbosity_info() 248 | else: 249 | datasets.utils.logging.set_verbosity_error() 250 | transformers.utils.logging.set_verbosity_error() 251 | 252 | # If passed along, set the training seed now. 253 | if args.seed is not None: 254 | set_seed(args.seed) 255 | 256 | # Handle the repository creation 257 | if accelerator.is_main_process: 258 | if args.push_to_hub: 259 | if args.hub_model_id is None: 260 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 261 | else: 262 | repo_name = args.hub_model_id 263 | repo = Repository(args.output_dir, clone_from=repo_name) 264 | 265 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 266 | if "step_*" not in gitignore: 267 | gitignore.write("step_*\n") 268 | if "epoch_*" not in gitignore: 269 | gitignore.write("epoch_*\n") 270 | elif args.output_dir is not None: 271 | os.makedirs(args.output_dir, exist_ok=True) 272 | accelerator.wait_for_everyone() 273 | 274 | 275 | ## load dataset 276 | if not args.dataset_name: 277 | raise ValueError( 278 | "Please give dataset file" 279 | ) 280 | 281 | raw_datasets = load_dataset(args.dataset_name) 282 | 283 | # Trim a number of training examples 284 | if args.debug: 285 | for split in raw_datasets.keys(): 286 | raw_datasets[split] = raw_datasets[split].select(range(100)) 287 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 288 | # https://huggingface.co/docs/datasets/loading_datasets.html. 289 | 290 | if raw_datasets["train"] is not None: 291 | column_names = raw_datasets["train"].column_names 292 | features = raw_datasets["train"].features 293 | else: 294 | column_names = raw_datasets["validation"].column_names 295 | features = raw_datasets["validation"].features 296 | 297 | if args.text_column_name is not None: 298 | text_column_name = args.text_column_name 299 | elif "tokens" in column_names: 300 | text_column_name = "tokens" 301 | else: 302 | text_column_name = column_names[0] 303 | 304 | if args.label_column_name is not None: 305 | label_column_name = args.label_column_name 306 | elif f"{args.task_name}_tags" in column_names: 307 | label_column_name = f"{args.task_name}_tags" 308 | else: 309 | label_column_name = column_names[1] 310 | 311 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the 312 | # unique labels. 313 | def get_label_list(labels): 314 | unique_labels = set() 315 | for label in labels: 316 | unique_labels = unique_labels | set(label) 317 | label_list = list(unique_labels) 318 | label_list.sort() 319 | return label_list 320 | 321 | # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere. 322 | # Otherwise, we have to get the list of labels manually. 323 | labels_are_int = isinstance(features[label_column_name].feature, ClassLabel) 324 | if labels_are_int: 325 | label_list = features[label_column_name].feature.names 326 | label_to_id = {i: i for i in range(len(label_list))} 327 | else: 328 | label_list = get_label_list(raw_datasets["train"][label_column_name]) 329 | label_to_id = {l: i for i, l in enumerate(label_list)} 330 | 331 | num_labels = len(label_list) 332 | 333 | # Load pretrained model and tokenizer 334 | ##prepare config file (BERT) 335 | config = AutoConfig.from_pretrained(args.config_name, num_labels=num_labels) 336 | 337 | ##load tokenizer 338 | tokenizer_name_or_path = args.tokenizer_name 339 | if not tokenizer_name_or_path: 340 | raise ValueError( 341 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 342 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 343 | ) 344 | 345 | tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name_or_path, do_lower_case=False) 346 | 347 | logger.info("Training new model from scratch") 348 | model = AutoModelForTokenClassification.from_config(config) 349 | 350 | model.resize_token_embeddings(len(tokenizer)) 351 | 352 | # Model has labels -> use them. 353 | if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: 354 | if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): 355 | # Reorganize `label_list` to match the ordering of the model. 356 | if labels_are_int: 357 | label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} 358 | label_list = [model.config.id2label[i] for i in range(num_labels)] 359 | else: 360 | label_list = [model.config.id2label[i] for i in range(num_labels)] 361 | label_to_id = {l: i for i, l in enumerate(label_list)} 362 | else: 363 | logger.warning( 364 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 365 | f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:" 366 | f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.", 367 | ) 368 | 369 | # Set the correspondences label/ID inside the model config 370 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 371 | model.config.id2label = {i: l for i, l in enumerate(label_list)} 372 | 373 | # Map that sends B-Xxx label to its I-Xxx counterpart 374 | b_to_i_label = [] 375 | 376 | # Preprocessing the datasets. 377 | # First we tokenize all the texts. 378 | padding = "max_length" if args.pad_to_max_length else False 379 | 380 | # Tokenize all texts and align the labels with them. 381 | 382 | def tokenize_and_align_labels(examples): 383 | tokenized_inputs = tokenizer( 384 | examples[text_column_name], 385 | max_length=args.max_length, 386 | padding=padding, 387 | truncation=True, 388 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 389 | is_split_into_words=True, 390 | ) 391 | 392 | labels = [] 393 | for i, label in enumerate(examples[label_column_name]): 394 | word_ids = tokenized_inputs.word_ids(batch_index=i) 395 | previous_word_idx = None 396 | label_ids = [] 397 | for word_idx in word_ids: 398 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically 399 | # ignored in the loss function. 400 | if word_idx is None: 401 | label_ids.append(-100) 402 | # We set the label for the first token of each word. 403 | elif word_idx != previous_word_idx: 404 | label_ids.append(label_to_id[label[word_idx]]) 405 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 406 | # the label_all_tokens flag. 407 | else: 408 | if args.label_all_tokens: 409 | label_ids.append(b_to_i_label[label_to_id[label[word_idx]]]) 410 | else: 411 | label_ids.append(-100) 412 | previous_word_idx = word_idx 413 | 414 | labels.append(label_ids) 415 | tokenized_inputs["labels"] = labels 416 | return tokenized_inputs 417 | 418 | with accelerator.main_process_first(): 419 | processed_raw_datasets = raw_datasets.map( 420 | tokenize_and_align_labels, 421 | batched=True, 422 | remove_columns=raw_datasets["train"].column_names, 423 | desc="Running tokenizer on dataset", 424 | ) 425 | 426 | train_dataset = processed_raw_datasets["train"] 427 | eval_dataset = processed_raw_datasets["validation"] 428 | 429 | # Log a few random samples from the training set: 430 | for index in random.sample(range(len(train_dataset)), 3): 431 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 432 | 433 | # DataLoaders creation: 434 | if args.pad_to_max_length: 435 | # If padding was already done ot max length, we use the default data collator that will just convert everything 436 | # to tensors. 437 | data_collator = default_data_collator 438 | else: 439 | # Otherwise, `DataCollatorForTokenClassification` will apply dynamic padding for us (by padding to the maximum length of 440 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple 441 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). 442 | data_collator = DataCollatorForTokenClassification( 443 | tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None) 444 | ) 445 | 446 | train_dataloader = DataLoader( 447 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 448 | ) 449 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) 450 | 451 | # Optimizer 452 | # Split weights in two groups, one with weight decay and the other not. 453 | no_decay = ["bias", "LayerNorm.weight"] 454 | optimizer_grouped_parameters = [ 455 | { 456 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 457 | "weight_decay": args.weight_decay, 458 | }, 459 | { 460 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 461 | "weight_decay": 0.0, 462 | }, 463 | ] 464 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 465 | 466 | # Use the device given by the `accelerator` object. 467 | device = accelerator.device 468 | model.to(device) 469 | 470 | # Scheduler and math around the number of training steps. 471 | overrode_max_train_steps = False 472 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 473 | if args.max_train_steps is None: 474 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 475 | overrode_max_train_steps = True 476 | 477 | lr_scheduler = get_scheduler( 478 | name=args.lr_scheduler_type, 479 | optimizer=optimizer, 480 | num_warmup_steps=args.num_warmup_steps, 481 | num_training_steps=args.max_train_steps, 482 | ) 483 | 484 | # Prepare everything with our `accelerator`. 485 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 486 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 487 | ) 488 | 489 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 490 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 491 | if overrode_max_train_steps: 492 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 493 | # Afterwards we recalculate our number of training epochs 494 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 495 | 496 | # Figure out how many steps we should save the Accelerator states 497 | checkpointing_steps = args.checkpointing_steps 498 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 499 | checkpointing_steps = int(checkpointing_steps) 500 | 501 | # We need to initialize the trackers we use, and also store our configuration. 502 | # The trackers initializes automatically on the main process. 503 | if args.with_tracking: 504 | experiment_config = vars(args) 505 | # TensorBoard cannot log Enums, need the raw value 506 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 507 | accelerator.init_trackers("ner_no_trainer", experiment_config) 508 | 509 | # Metrics 510 | metric = evaluate.load("seqeval") 511 | 512 | def get_labels(predictions, references): 513 | # Transform predictions and references tensos to numpy arrays 514 | if device.type == "cpu": 515 | y_pred = predictions.detach().clone().numpy() 516 | y_true = references.detach().clone().numpy() 517 | else: 518 | y_pred = predictions.detach().cpu().clone().numpy() 519 | y_true = references.detach().cpu().clone().numpy() 520 | 521 | # Remove ignored index (special tokens) 522 | true_predictions = [ 523 | [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100] 524 | for pred, gold_label in zip(y_pred, y_true) 525 | ] 526 | true_labels = [ 527 | [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100] 528 | for pred, gold_label in zip(y_pred, y_true) 529 | ] 530 | return true_predictions, true_labels 531 | 532 | def compute_metrics(): 533 | results = metric.compute() 534 | if args.return_entity_level_metrics: 535 | # Unpack nested dictionaries 536 | final_results = {} 537 | for key, value in results.items(): 538 | if isinstance(value, dict): 539 | for n, v in value.items(): 540 | final_results[f"{key}_{n}"] = v 541 | else: 542 | final_results[key] = value 543 | return final_results 544 | else: 545 | return { 546 | "precision": results["overall_precision"], 547 | "recall": results["overall_recall"], 548 | "f1": results["overall_f1"], 549 | "accuracy": results["overall_accuracy"], 550 | } 551 | 552 | # Train! 553 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 554 | 555 | logger.info("***** Running training *****") 556 | logger.info(f" Num examples = {len(train_dataset)}") 557 | logger.info(f" Num Epochs = {args.num_train_epochs}") 558 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 559 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 560 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 561 | logger.info(f" Total optimization steps = {args.max_train_steps}") 562 | # Only show the progress bar once on each machine. 563 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 564 | completed_steps = 0 565 | starting_epoch = 0 566 | # Potentially load in the weights and states from a previous save 567 | if args.resume_from_checkpoint: 568 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 569 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 570 | accelerator.load_state(args.resume_from_checkpoint) 571 | path = os.path.basename(args.resume_from_checkpoint) 572 | else: 573 | # Get the most recent checkpoint 574 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 575 | dirs.sort(key=os.path.getctime) 576 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 577 | # Extract `epoch_{i}` or `step_{i}` 578 | training_difference = os.path.splitext(path)[0] 579 | 580 | if "epoch" in training_difference: 581 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 582 | resume_step = None 583 | else: 584 | resume_step = int(training_difference.replace("step_", "")) 585 | starting_epoch = resume_step // len(train_dataloader) 586 | resume_step -= starting_epoch * len(train_dataloader) 587 | 588 | for epoch in range(starting_epoch, args.num_train_epochs): 589 | model.train() 590 | if args.with_tracking: 591 | total_loss = 0 592 | for step, batch in enumerate(train_dataloader): 593 | # We need to skip steps until we reach the resumed step 594 | if args.resume_from_checkpoint and epoch == starting_epoch: 595 | if resume_step is not None and step < resume_step: 596 | completed_steps += 1 597 | continue 598 | outputs = model(**batch) 599 | loss = outputs.loss 600 | # We keep track of the loss at each epoch 601 | if args.with_tracking: 602 | total_loss += loss.detach().float() 603 | loss = loss / args.gradient_accumulation_steps 604 | accelerator.backward(loss) 605 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 606 | optimizer.step() 607 | lr_scheduler.step() 608 | optimizer.zero_grad() 609 | progress_bar.update(1) 610 | completed_steps += 1 611 | 612 | if isinstance(checkpointing_steps, int): 613 | if completed_steps % checkpointing_steps == 0: 614 | output_dir = f"step_{completed_steps }" 615 | if args.output_dir is not None: 616 | output_dir = os.path.join(args.output_dir, output_dir) 617 | accelerator.save_state(output_dir) 618 | 619 | if completed_steps >= args.max_train_steps: 620 | break 621 | 622 | model.eval() 623 | samples_seen = 0 624 | 625 | outputs4save = [] 626 | for step, batch in enumerate(eval_dataloader): 627 | with torch.no_grad(): 628 | outputs = model(**batch) 629 | predictions = outputs.logits.argmax(dim=-1) 630 | labels = batch["labels"] 631 | if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered 632 | predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100) 633 | labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100) 634 | predictions_gathered, labels_gathered = accelerator.gather((predictions, labels)) 635 | # If we are in a multiprocess environment, the last batch has duplicates 636 | if accelerator.num_processes > 1: 637 | if step == len(eval_dataloader) - 1: 638 | predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen] 639 | labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen] 640 | else: 641 | samples_seen += labels_gathered.shape[0] 642 | preds, refs = get_labels(predictions_gathered, labels_gathered) 643 | metric.add_batch( 644 | predictions=preds, 645 | references=refs, 646 | ) # predictions and preferences are expected to be a nested list of labels, not label_ids 647 | 648 | if epoch == (args.num_train_epochs - 1): 649 | outputs4save.append([preds, refs]) 650 | 651 | eval_metric = compute_metrics() 652 | accelerator.print(f"epoch {epoch}:", eval_metric) 653 | if args.with_tracking: 654 | accelerator.log( 655 | { 656 | "seqeval": eval_metric, 657 | "train_loss": total_loss.item() / len(train_dataloader), 658 | "epoch": epoch, 659 | "step": completed_steps, 660 | }, 661 | step=completed_steps, 662 | ) 663 | 664 | if args.push_to_hub and epoch < args.num_train_epochs - 1: 665 | accelerator.wait_for_everyone() 666 | unwrapped_model = accelerator.unwrap_model(model) 667 | unwrapped_model.save_pretrained( 668 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 669 | ) 670 | if accelerator.is_main_process: 671 | tokenizer.save_pretrained(args.output_dir) 672 | repo.push_to_hub( 673 | commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True 674 | ) 675 | 676 | if args.checkpointing_steps == "epoch": 677 | output_dir = f"epoch_{epoch}" 678 | if args.output_dir is not None: 679 | output_dir = os.path.join(args.output_dir, output_dir) 680 | accelerator.save_state(output_dir) 681 | 682 | if args.with_tracking: 683 | accelerator.end_training() 684 | 685 | if args.output_dir is not None: 686 | accelerator.wait_for_everyone() 687 | unwrapped_model = accelerator.unwrap_model(model) 688 | unwrapped_model.save_pretrained( 689 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 690 | ) 691 | if accelerator.is_main_process: 692 | tokenizer.save_pretrained(args.output_dir) 693 | if args.push_to_hub: 694 | repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) 695 | 696 | 697 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 698 | json.dump( 699 | {"eval_accuracy": eval_metric["accuracy"]}, f 700 | ) 701 | 702 | import pandas as pd 703 | # Save predictions 704 | out = pd.DataFrame(outputs4save) 705 | out.to_csv(os.path.join(args.output_dir, "predictions.csv"), header=None, index=None) 706 | 707 | if __name__ == "__main__": 708 | main() 709 | -------------------------------------------------------------------------------- /train_BERTOS.sh: -------------------------------------------------------------------------------- 1 | python train_BERTOS.py \ 2 | --config_name ./random_config/ \ 3 | --dataset_name materials_icsd_cn.py \ 4 | --max_length 100 \ 5 | --per_device_train_batch_size 256 \ 6 | --learning_rate 1e-3 \ 7 | --num_train_epochs 500 \ 8 | --output_dir ./model_icsdcn 9 | -------------------------------------------------------------------------------- /trained_models/ICSD.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/trained_models/ICSD.zip -------------------------------------------------------------------------------- /trained_models/ICSD_CN.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/trained_models/ICSD_CN.zip -------------------------------------------------------------------------------- /trained_models/ICSD_CN_oxide.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/trained_models/ICSD_CN_oxide.zip -------------------------------------------------------------------------------- /trained_models/ICSD_oxide.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/trained_models/ICSD_oxide.zip -------------------------------------------------------------------------------- /trained_models/README.md: -------------------------------------------------------------------------------- 1 | Download pretrained models for oxidation state prediction from figshare.com at 2 | https://figshare.com/articles/online_resource/BERTOS_model/21554823 3 | and then double click to unzip them. 4 | 5 | --------------------------------------------------------------------------------