├── LICENSE
├── README.md
├── checkCN.py
├── dataset
├── ICSD.zip
├── ICSD_CN.zip
├── ICSD_CN_oxide.zip
├── ICSD_oxide.zip
└── README.md
├── formulas.csv
├── getOS.py
├── materials_icsd.py
├── materials_icsd_cn.py
├── materials_icsd_cno.py
├── materials_icsd_o.py
├── performances.png
├── random_config
└── config.json
├── requirements.txt
├── tokenizer
└── vocab.txt
├── train_BERTOS.py
├── train_BERTOS.sh
└── trained_models
├── ICSD.zip
├── ICSD_CN.zip
├── ICSD_CN_oxide.zip
├── ICSD_oxide.zip
└── README.md
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BERTOS
2 | BERTOS: transformer language model for oxidation state prediction
3 |
4 | Citation: Fu, Nihang, Jeffrey Hu, Ying Feng, Gregory Morrison, Hans‐Conrad zur Loye, and Jianjun Hu. "Composition Based Oxidation State Prediction of Materials Using Deep Learning Language Models." Advanced Science (2023): 2301011. [Link](https://onlinelibrary.wiley.com/doi/full/10.1002/advs.202301011)
5 |
6 |
7 | Nihang Fu, Jeffrey Hu, Ying Feng, Jianjun Hu*
8 |
9 | Machine Learning and Evolution Laboratory
10 | Department of computer science and Engineering
11 | University of South Carolina
12 |
13 | [Online Toolbox](http://www.materialsatlas.org/bertos)
14 |
15 | ## Table of Contents
16 | - [Installations](#Installations)
17 |
18 | - [Datasets](#Datasets)
19 |
20 | - [Usage](#Usage)
21 |
22 | - [Pretrained Models](#Pretrained-models)
23 |
24 | - [Performance](#Performance)
25 |
26 | - [Acknowledgement](#Acknowledgement)
27 |
28 | ## Installations
29 |
30 | 0. Set up a virtual environment
31 | ```
32 | conda create -n bertos
33 | conda activate bertos
34 | ```
35 |
36 | 1. PyTorch and transformers for computers with Nvidia GPU.
37 | ```
38 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
39 | conda install -c conda-forge transformers
40 | ```
41 | If you only have CPU on your computer, try this:
42 | ```
43 | pip install transformers[torch]
44 | ```
45 | If you are using Mac M1 chip computer, following [this tutorial](https://jamescalam.medium.com/hugging-face-and-sentence-transformers-on-m1-macs-4b12e40c21ce) or [this one](https://towardsdatascience.com/hugging-face-transformers-on-apple-m1-26f0705874d7) to install pytorch and transformers.
46 |
47 | 2. Other packagess
48 | ```
49 | pip install -r requirements.txt
50 | ```
51 |
52 | ## Datasets
53 | Our training process is carried out on our BERTOS datasets. After extracting the data under `datasets` folder, you will get the following four folders `ICSD`, `ICSD_CN`, `ICSD_CN_oxide`, and `ICSD_oxide`.
54 |
55 | ## Usage
56 | ### A Quick Run
57 | Quickly run the script to train a BERTOS using the OS-ICSD-CN training set and save the model into the `./model_icsdcn` folder.
58 | ```
59 | bash train_BERTOS.sh
60 | ```
61 | ### Training
62 | The command is to train a BERTOS model.
63 | ```
64 | python train_BERTOS.py --config_name $CONFIG_NAME$ --dataset_name $DATASET_LOADER$ --max_length $MAX_LENGTH$ --per_device_train_batch_size $BATCH_ SIZE$ --learning_rate $LEARNING_RATE$ --num_train_epochs $EPOCHS$ --output_dir $MODEL_OUTPUT_DIRECTORY$
65 | ```
66 | We use `ICSD_CN` dataset as an example:
67 | ```
68 | python train_BERTOS.py --config_name ./random_config --dataset_name materials_icsd_cn.py --max_length 100 --per_device_train_batch_size 256 --learning_rate 1e-3 --num_train_epochs 500 --output_dir ./model_icsdcn
69 | ```
70 | If you want to change the dataset, you can use a different dataset file to replace `$DATASET_LOADER$`, like `materials_icsd.py`, `materials_icsdcn.py`, `materials_icsdcno.py`, and `materials_icsdo.py`. And you can also follow the intructions of [huggingface]() to build your own custom dataset.
71 |
72 | ### Predict
73 | Run `getOS.py` file to get predicted oxidation states for an input formula or input formulas.csv file containing multiple formulas.
74 | Using default pretrained model (trained on ICSD_CN):
75 | ```
76 | python getOS.py --i SrTiO3 --model_name_or_path ./trained_models/ICSD_CN
77 | python getOS.py --f formulas.csv --model_name_or_path ./trained_models/ICSD_CN
78 | ```
79 | Or using your model:
80 | ```
81 | python getOS.py --i SrTiO3 --model_name_or_path ./model_directory
82 | python getOS.py --f formulas.csv --model_name_or_path ./model_directory
83 |
84 | ```
85 |
86 | ### Check charge neutrality for hypothetical formulas
87 | Run `checkCN.py` file to check charge neutrality for an input formula or input formulas.csv file containing multiple formulas.
88 | Using default pretrained model (trained on ICSD_CN):
89 | ```
90 | python checkCN.py --i SrTiO3
91 | python checkCN.py --f formulas.csv
92 | ```
93 | Or using your model:
94 | ```
95 | python checkCN.py --i SrTiO3 --model_name_or_path ./model_directory
96 | python checkCN.py --f formulas.csv --model_name_or_path ./model_directory
97 | ```
98 |
99 | ## Pretrained Models
100 | Our trained models can be downloaded from figshare [BERTOS models](https://figshare.com/articles/online_resource/BERTOS_model/21554823), and you can use it as a test or prediction model.
101 |
102 |
103 | ## Performance
104 |
105 | 
106 | Removing `OS`, the datasets under `datasets` folder correspond to the datasets in the figure.
107 |
108 | ## Acknowledgement
109 | We use the transformer model as implemented in Huggingface.
110 | ```
111 | @article{wolf2019huggingface,
112 | title={Huggingface's transformers: State-of-the-art natural language processing},
113 | author={Wolf, Thomas and Debut, Lysandre and Sanh, Victor and Chaumond, Julien and Delangue, Clement and Moi, Anthony and Cistac, Pierric and Rault, Tim and Louf, R{\'e}mi and Funtowicz, Morgan and others},
114 | journal={arXiv preprint arXiv:1910.03771},
115 | year={2019}
116 | }
117 | ```
118 |
119 | ## Cite our work
120 | ```
121 | Fu, Nihang, Jeffrey Hu, Ying Feng, Gregory Morrison, Hans‐Conrad zur Loye, and Jianjun Hu. "Composition Based Oxidation State Prediction of Materials Using Deep Learning Language Models." Advanced Science (2023): 2301011. [PDF](https://arxiv.org/pdf/2211.15895)
122 |
123 | ```
124 |
125 | # Contact
126 | If you have any problem using BERTOS, feel free to contact via [funihang@gmail.com](mailto:funihang@gmail.com).
127 |
--------------------------------------------------------------------------------
/checkCN.py:
--------------------------------------------------------------------------------
1 | # for a formula: python getOS.py --i SO2
2 | # for a csv file conatining multiple formulas: python getOS.py --f formulas.csv
3 |
4 | import argparse
5 | import json
6 | import logging
7 | import os
8 | import torch
9 |
10 | import transformers
11 | from transformers import (
12 | AutoConfig,
13 | AutoModelForTokenClassification,
14 | )
15 | from transformers import BertTokenizerFast
16 |
17 | import numpy as np
18 |
19 | from pymatgen.io.cif import CifParser
20 | from pymatgen.core.composition import Composition
21 | from pymatgen.core.structure import Structure
22 | from pymatgen.core.periodic_table import Element
23 |
24 | import torch.nn.functional as F
25 |
26 | import pandas as pd
27 |
28 | def merge_os(osstr):
29 | #Sr(+2:1.00) Ti(+4:1.00) O(-2:1.00) O(-2:1.00) O(-2:1.00)
30 | items = osstr.split(" ")
31 | elementos={}
32 | for x in items:
33 | if x in elementos:
34 | elementos[x]+=1
35 | else:
36 | elementos[x]=1
37 | out=''
38 | for x in elementos:
39 | if elementos[x]==1:
40 | out+=x+" "
41 | else:
42 | e=x.split('(')[0]
43 | out+=f'{e}{elementos[x]}({"".join(x.split("(")[1:])} '
44 | return out.strip()
45 |
46 | #import pymatgen
47 |
48 | def parse_args():
49 | parser = argparse.ArgumentParser(
50 | description="Test trained model."
51 | )
52 | parser.add_argument(
53 | "--i",
54 | type=str,
55 | default=None,
56 | help="Input formula",
57 | )
58 |
59 | parser.add_argument(
60 | "--f",
61 | type=str,
62 | default=None,
63 | help="Input file",
64 | )
65 |
66 | parser.add_argument(
67 | "--max_length",
68 | type=int,
69 | default=50,
70 | help=(
71 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
72 | " sequences shorter will be padded if `--pad_to_max_length` is passed."
73 | ),
74 | )
75 |
76 | parser.add_argument(
77 | "--model_name_or_path",
78 | type=str,
79 | default='./trained_models/ICSD_CN/',
80 | help="Path to pretrained model or model identifier from huggingface.co/models.",
81 | required=False,
82 | )
83 |
84 | parser.add_argument(
85 | "--tokenizer_name",
86 | type=str,
87 | default='./tokenizer',
88 | help="Pretrained tokenizer name or path if not the same as model_name",
89 | )
90 |
91 | parser.add_argument(
92 | "--ignore_mismatched_sizes",
93 | action="store_true",
94 | default=True,
95 | help="ignore_mismatched_sizes set to True by default.",
96 | )
97 |
98 | parser.add_argument(
99 | "--pad_to_max_length",
100 | action="store_true",
101 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
102 | )
103 | args = parser.parse_args()
104 | return args
105 |
106 | def main():
107 | args = parse_args()
108 |
109 | # Load tokenizer
110 | tokenizer_name_or_path = args.tokenizer_name
111 | tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name_or_path, do_lower_case=False)
112 |
113 | padding = "max_length" if args.pad_to_max_length else False
114 |
115 | # Load model config
116 | config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=14)
117 |
118 | # Load model
119 | model = AutoModelForTokenClassification.from_pretrained(
120 | args.model_name_or_path,
121 | from_tf=bool(".ckpt" in args.model_name_or_path),
122 | config=config,
123 | ignore_mismatched_sizes=args.ignore_mismatched_sizes,
124 | )
125 | model.eval()
126 | if (args.i is not None) and (args.f is not None):
127 | print("Please input a formula (using --i) or give the csv file with some formulas (using --f)")
128 | return
129 |
130 | if args.i is not None:
131 | print("Input formula -------> ", args.i)
132 | comp = Composition(args.i)
133 | comp_dict = comp.to_reduced_dict
134 |
135 | input_seq = ""
136 | for ele in comp_dict.keys():
137 | for count in range(int(comp_dict[ele])):
138 | input_seq = input_seq + ele + " "
139 |
140 |
141 | tokenized_inputs = torch.tensor(tokenizer.encode(input_seq, add_special_tokens=True)).unsqueeze(0) # Batch size 1
142 |
143 | outputs = model(tokenized_inputs)
144 | predictions = outputs.logits.argmax(dim=-1)
145 | probs = torch.max(F.softmax(outputs[0], dim=-1), dim=-1)
146 |
147 |
148 | true_pred = predictions[0][1:-1]
149 | true_probs = probs[0][0][1:-1]
150 |
151 |
152 | tmp = input_seq.split()
153 | outstr = ''
154 | count_cn = 0
155 | for i, ele in enumerate(tmp):
156 | outstr += ele
157 | true_os = true_pred[i].item() - 5
158 | count_cn += true_os
159 | if true_os>0:
160 | true_os='+'+str(true_os)
161 | prob = true_probs[i].item()
162 |
163 | outstr = outstr +f'({true_os}:{prob:.2f}) '
164 | outstr=merge_os(outstr)
165 |
166 | print("Predicted Oxidation States:\n ", outstr)
167 |
168 | if count_cn == 0:
169 | print("Charge Neutral? Yes")
170 | else:
171 | print("Charge Neutral? No")
172 |
173 | if args.f is not None:
174 | print("Input file ------->", args.f)
175 | df = pd.read_csv(args.f, header=None)
176 | formulas = df[0]
177 |
178 | all_outs = []
179 | for item in formulas:
180 | comp = Composition(item)
181 | comp_dict = comp.to_reduced_dict
182 |
183 | input_seq = ""
184 | for ele in comp_dict.keys():
185 | for count in range(int(comp_dict[ele])):
186 | input_seq = input_seq + ele + " "
187 |
188 | tokenized_inputs = torch.tensor(tokenizer.encode(input_seq, add_special_tokens=True)).unsqueeze(0) # Batch size 1
189 |
190 |
191 | outputs = model(tokenized_inputs)
192 | predictions = outputs.logits.argmax(dim=-1)
193 | probs = torch.max(F.softmax(outputs[0], dim=-1), dim=-1)
194 |
195 |
196 | true_pred = predictions[0][1:-1]
197 | true_probs = probs[0][0][1:-1]
198 |
199 | tmp = input_seq.split()
200 |
201 | cn_count = 0
202 | outstr = ''
203 | for i, ele in enumerate(tmp):
204 | outstr += ele
205 | true_os = true_pred[i].item() - 5
206 |
207 | cn_count += true_os
208 |
209 | if true_os>0:
210 | true_os='+'+str(true_os)
211 | prob = true_probs[i].item()
212 |
213 | outstr = outstr +f'({true_os}:{prob:.2f}) '
214 | outstr=merge_os(outstr)
215 |
216 | if cn_count == 0:
217 | all_outs.append([item, outstr, "True"])
218 | else:
219 | all_outs.append([item, outstr, "False"])
220 |
221 | out_df = pd.DataFrame(all_outs)
222 | out_df.columns = ["formula", "predicted OS", "charge neutrality"]
223 |
224 | #add _OS to the input filename as output file
225 | outfile='.'.join(args.f.split(".")[0:-1])+"_OS_CN."+args.f.split(".")[-1]
226 |
227 | out_df.to_csv(outfile, index=None)
228 | print("Output file ------>",f"{outfile} <-- check for the predicted oxidation states")
229 |
230 |
231 | if __name__ == "__main__":
232 | main()
233 |
--------------------------------------------------------------------------------
/dataset/ICSD.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/dataset/ICSD.zip
--------------------------------------------------------------------------------
/dataset/ICSD_CN.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/dataset/ICSD_CN.zip
--------------------------------------------------------------------------------
/dataset/ICSD_CN_oxide.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/dataset/ICSD_CN_oxide.zip
--------------------------------------------------------------------------------
/dataset/ICSD_oxide.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/dataset/ICSD_oxide.zip
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 |
2 | just double click the zip files to unzip the datasets.
3 |
--------------------------------------------------------------------------------
/formulas.csv:
--------------------------------------------------------------------------------
1 | SrTiO3
2 | LiMnO3
3 | Te2As
4 | CdP3Sr3
5 |
--------------------------------------------------------------------------------
/getOS.py:
--------------------------------------------------------------------------------
1 | # for a formula: python getOS.py --i SO2
2 | # for a csv file conatining multiple formulas: python getOS.py --f formulas.csv
3 |
4 | import argparse
5 | import json
6 | import logging
7 | import os
8 | import torch
9 |
10 | import transformers
11 | from transformers import (
12 | AutoConfig,
13 | AutoModelForTokenClassification,
14 | )
15 | from transformers import BertTokenizerFast
16 |
17 | import numpy as np
18 |
19 | from pymatgen.io.cif import CifParser
20 | from pymatgen.core.composition import Composition
21 | from pymatgen.core.structure import Structure
22 | from pymatgen.core.periodic_table import Element
23 |
24 | import torch.nn.functional as F
25 |
26 | import pandas as pd
27 |
28 | def merge_os(osstr):
29 | #Sr(+2:1.00) Ti(+4:1.00) O(-2:1.00) O(-2:1.00) O(-2:1.00)
30 | items = osstr.split(" ")
31 | elementos={}
32 | for x in items:
33 | if x in elementos:
34 | elementos[x]+=1
35 | else:
36 | elementos[x]=1
37 | out=''
38 | for x in elementos:
39 | if elementos[x]==1:
40 | out+=x+" "
41 | else:
42 | e=x.split('(')[0]
43 | out+=f'{e}{elementos[x]}({"".join(x.split("(")[1:])} '
44 | return out.strip()
45 |
46 | #import pymatgen
47 |
48 | def parse_args():
49 | parser = argparse.ArgumentParser(
50 | description="Test trained model."
51 | )
52 | parser.add_argument(
53 | "--i",
54 | type=str,
55 | default=None,
56 | help="Input formula",
57 | )
58 |
59 | parser.add_argument(
60 | "--f",
61 | type=str,
62 | default=None,
63 | help="Input file",
64 | )
65 |
66 | parser.add_argument(
67 | "--max_length",
68 | type=int,
69 | default=50,
70 | help=(
71 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
72 | " sequences shorter will be padded if `--pad_to_max_length` is passed."
73 | ),
74 | )
75 |
76 | parser.add_argument(
77 | "--model_name_or_path",
78 | type=str,
79 | default='./trained_models/ICSD_CN/',
80 | help="Path to pretrained model or model identifier from huggingface.co/models.",
81 | required=False,
82 | )
83 |
84 | parser.add_argument(
85 | "--tokenizer_name",
86 | type=str,
87 | default='./tokenizer',
88 | help="Pretrained tokenizer name or path if not the same as model_name",
89 | )
90 |
91 | parser.add_argument(
92 | "--ignore_mismatched_sizes",
93 | action="store_true",
94 | default=True,
95 | help="ignore_mismatched_sizes set to True by default.",
96 | )
97 |
98 | parser.add_argument(
99 | "--pad_to_max_length",
100 | action="store_true",
101 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
102 | )
103 | args = parser.parse_args()
104 | return args
105 |
106 | def main():
107 | args = parse_args()
108 |
109 | # Load tokenizer
110 | tokenizer_name_or_path = args.tokenizer_name
111 | tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name_or_path, do_lower_case=False)
112 |
113 | padding = "max_length" if args.pad_to_max_length else False
114 |
115 | # Load model config
116 | config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=14)
117 |
118 | # Load model
119 | model = AutoModelForTokenClassification.from_pretrained(
120 | args.model_name_or_path,
121 | from_tf=bool(".ckpt" in args.model_name_or_path),
122 | config=config,
123 | ignore_mismatched_sizes=args.ignore_mismatched_sizes,
124 | )
125 |
126 | if (args.i is not None) and (args.f is not None):
127 | print("Please input a formula (using --i) or give the csv file with some formulas (using --f)")
128 | return
129 |
130 | if args.i is not None:
131 | print("Input formula -------> ", args.i)
132 | comp = Composition(args.i)
133 | comp_dict = comp.to_reduced_dict
134 |
135 | input_seq = ""
136 | for ele in comp_dict.keys():
137 | for count in range(int(comp_dict[ele])):
138 | input_seq = input_seq + ele + " "
139 |
140 |
141 | tokenized_inputs = torch.tensor(tokenizer.encode(input_seq, add_special_tokens=True)).unsqueeze(0) # Batch size 1
142 |
143 | outputs = model(tokenized_inputs)
144 | predictions = outputs.logits.argmax(dim=-1)
145 | probs = torch.max(F.softmax(outputs[0], dim=-1), dim=-1)
146 |
147 |
148 | true_pred = predictions[0][1:-1]
149 | true_probs = probs[0][0][1:-1]
150 |
151 |
152 | tmp = input_seq.split()
153 | outstr = ''
154 | for i, ele in enumerate(tmp):
155 | outstr += ele
156 | true_os = true_pred[i].item() - 5
157 | if true_os>0:
158 | true_os='+'+str(true_os)
159 | prob = true_probs[i].item()
160 |
161 | outstr = outstr +f'({true_os}:{prob:.2f}) '
162 | outstr=merge_os(outstr)
163 | print("Predicted Oxidation States:\n ", outstr)
164 |
165 | if args.f is not None:
166 | print("Input file ------->", args.f)
167 | df = pd.read_csv(args.f, header=None)
168 | formulas = df[0]
169 |
170 | all_outs = []
171 | for item in formulas:
172 | comp = Composition(item)
173 | comp_dict = comp.to_reduced_dict
174 |
175 | input_seq = ""
176 | for ele in comp_dict.keys():
177 | for count in range(int(comp_dict[ele])):
178 | input_seq = input_seq + ele + " "
179 |
180 | tokenized_inputs = torch.tensor(tokenizer.encode(input_seq, add_special_tokens=True)).unsqueeze(0) # Batch size 1
181 |
182 |
183 | outputs = model(tokenized_inputs)
184 | predictions = outputs.logits.argmax(dim=-1)
185 | probs = torch.max(F.softmax(outputs[0], dim=-1), dim=-1)
186 |
187 |
188 | true_pred = predictions[0][1:-1]
189 | true_probs = probs[0][0][1:-1]
190 |
191 | tmp = input_seq.split()
192 | outstr = ''
193 | for i, ele in enumerate(tmp):
194 | outstr += ele
195 | true_os = true_pred[i].item() - 5
196 | if true_os>0:
197 | true_os='+'+str(true_os)
198 | prob = true_probs[i].item()
199 |
200 | outstr = outstr +f'({true_os}:{prob:.2f}) '
201 | outstr=merge_os(outstr)
202 |
203 |
204 | all_outs.append(outstr)
205 |
206 | out_df = pd.DataFrame(all_outs)
207 |
208 | #add _OS to the input filename as output file
209 | outfile='.'.join(args.f.split(".")[0:-1])+"_OS."+args.f.split(".")[-1]
210 |
211 | out_df.to_csv(outfile, header=None, index=None)
212 | print("Output file ------>",f"{outfile} <-- check for the predicted oxidation states")
213 |
214 |
215 | if __name__ == "__main__":
216 | main()
217 |
--------------------------------------------------------------------------------
/materials_icsd.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Materials dataset"""
16 |
17 | import os
18 | import datasets
19 |
20 |
21 | logger = datasets.logging.get_logger(__name__)
22 |
23 |
24 | _CITATION = """
25 | """
26 |
27 | _DESCRIPTION = """
28 | """
29 |
30 | _ROOT = "./dataset/ICSD/"
31 | _TRAINING_FILE = "train.txt"
32 | _DEV_FILE = "validation.txt"
33 | _TEST_FILE = "test.txt"
34 |
35 |
36 | class Materials(datasets.GeneratorBasedBuilder):
37 | """Materials dataset"""
38 |
39 | VERSION = datasets.Version("1.0.0")
40 |
41 | BUILDER_CONFIGS = [
42 | datasets.BuilderConfig(name="materials", version=VERSION, description="Materials dataset"),
43 | ]
44 |
45 | def _info(self):
46 | return datasets.DatasetInfo(
47 | description=_DESCRIPTION,
48 | features=datasets.Features(
49 | {
50 | "id": datasets.Value("string"),
51 | "tokens": datasets.Sequence(datasets.Value("string")),
52 | "ner_tags": datasets.Sequence(
53 | datasets.features.ClassLabel(
54 | names=[
55 | "-5",
56 | "-4",
57 | "-3",
58 | "-2",
59 | "-1",
60 | "0",
61 | "1",
62 | "2",
63 | "3",
64 | "4",
65 | "5",
66 | "6",
67 | "7",
68 | "8",
69 | ]
70 | )
71 | ),
72 | }
73 | ),
74 | supervised_keys=None,
75 | homepage="https://github.com/usccolumbia/BERTOS.git",
76 | citation=_CITATION,
77 | )
78 |
79 | def _split_generators(self, dl_manager):
80 | """Returns SplitGenerators."""
81 |
82 | data_files = {
83 | "train": os.path.join(_ROOT, _TRAINING_FILE),
84 | "validation": os.path.join(_ROOT, _DEV_FILE),
85 | "test": os.path.join(_ROOT, _TEST_FILE),
86 | }
87 |
88 | return [
89 | datasets.SplitGenerator(
90 | name=datasets.Split.TRAIN,
91 | gen_kwargs={"filepath": data_files["train"], "split": "train"},
92 | ),
93 | datasets.SplitGenerator(
94 | name=datasets.Split.VALIDATION,
95 | gen_kwargs={"filepath": data_files["validation"], "split": "validation"},
96 | ),
97 | datasets.SplitGenerator(
98 | name=datasets.Split.TEST,
99 | gen_kwargs={"filepath": data_files["test"], "split": "test"},
100 | ),
101 | ]
102 |
103 | def _generate_examples(self, filepath, split):
104 | """Yields examples."""
105 |
106 | with open(filepath, encoding="utf-8") as f:
107 |
108 | guid = 0
109 | tokens = []
110 | ner_tags = []
111 |
112 | for line in f:
113 | if line == "" or line == "\n":
114 | if tokens:
115 | yield guid, {
116 | "id": str(guid),
117 | "tokens": tokens,
118 | "ner_tags": ner_tags,
119 | }
120 | guid += 1
121 | tokens = []
122 | ner_tags = []
123 | else:
124 | splits = line.split(" ")
125 | tokens.append(splits[0])
126 | ner_tags.append(splits[1].rstrip())
127 |
128 | # last example
129 | yield guid, {
130 | "id": str(guid),
131 | "tokens": tokens,
132 | "ner_tags": ner_tags,
133 | }
134 |
--------------------------------------------------------------------------------
/materials_icsd_cn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Materials dataset"""
16 |
17 | import os
18 | import datasets
19 |
20 |
21 | logger = datasets.logging.get_logger(__name__)
22 |
23 |
24 | _CITATION = """
25 | """
26 |
27 | _DESCRIPTION = """
28 | """
29 |
30 | _ROOT = "./dataset/ICSD_CN/"
31 | _TRAINING_FILE = "train.txt"
32 | _DEV_FILE = "validation.txt"
33 | _TEST_FILE = "test.txt"
34 |
35 |
36 | class Materials(datasets.GeneratorBasedBuilder):
37 | """Materials dataset"""
38 |
39 | VERSION = datasets.Version("1.0.0")
40 |
41 | BUILDER_CONFIGS = [
42 | datasets.BuilderConfig(name="materials", version=VERSION, description="Materials dataset"),
43 | ]
44 |
45 | def _info(self):
46 | return datasets.DatasetInfo(
47 | description=_DESCRIPTION,
48 | features=datasets.Features(
49 | {
50 | "id": datasets.Value("string"),
51 | "tokens": datasets.Sequence(datasets.Value("string")),
52 | "ner_tags": datasets.Sequence(
53 | datasets.features.ClassLabel(
54 | names=[
55 | "-5",
56 | "-4",
57 | "-3",
58 | "-2",
59 | "-1",
60 | "0",
61 | "1",
62 | "2",
63 | "3",
64 | "4",
65 | "5",
66 | "6",
67 | "7",
68 | "8",
69 | ]
70 | )
71 | ),
72 | }
73 | ),
74 | supervised_keys=None,
75 | homepage="https://github.com/usccolumbia/BERTOS.git",
76 | citation=_CITATION,
77 | )
78 |
79 | def _split_generators(self, dl_manager):
80 | """Returns SplitGenerators."""
81 |
82 | data_files = {
83 | "train": os.path.join(_ROOT, _TRAINING_FILE),
84 | "validation": os.path.join(_ROOT, _DEV_FILE),
85 | "test": os.path.join(_ROOT, _TEST_FILE),
86 | }
87 |
88 | return [
89 | datasets.SplitGenerator(
90 | name=datasets.Split.TRAIN,
91 | gen_kwargs={"filepath": data_files["train"], "split": "train"},
92 | ),
93 | datasets.SplitGenerator(
94 | name=datasets.Split.VALIDATION,
95 | gen_kwargs={"filepath": data_files["validation"], "split": "validation"},
96 | ),
97 | datasets.SplitGenerator(
98 | name=datasets.Split.TEST,
99 | gen_kwargs={"filepath": data_files["test"], "split": "test"},
100 | ),
101 | ]
102 |
103 | def _generate_examples(self, filepath, split):
104 | """Yields examples."""
105 |
106 | with open(filepath, encoding="utf-8") as f:
107 |
108 | guid = 0
109 | tokens = []
110 | ner_tags = []
111 |
112 | for line in f:
113 | if line == "" or line == "\n":
114 | if tokens:
115 | yield guid, {
116 | "id": str(guid),
117 | "tokens": tokens,
118 | "ner_tags": ner_tags,
119 | }
120 | guid += 1
121 | tokens = []
122 | ner_tags = []
123 | else:
124 | splits = line.split(" ")
125 | tokens.append(splits[0])
126 | ner_tags.append(splits[1].rstrip())
127 |
128 | # last example
129 | yield guid, {
130 | "id": str(guid),
131 | "tokens": tokens,
132 | "ner_tags": ner_tags,
133 | }
134 |
--------------------------------------------------------------------------------
/materials_icsd_cno.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Materials dataset"""
16 |
17 | import os
18 | import datasets
19 |
20 |
21 | logger = datasets.logging.get_logger(__name__)
22 |
23 |
24 | _CITATION = """
25 | """
26 |
27 | _DESCRIPTION = """
28 | """
29 |
30 | _ROOT = "./dataset/ICSD_CN_oxide/"
31 | _TRAINING_FILE = "train.txt"
32 | _DEV_FILE = "validation.txt"
33 | _TEST_FILE = "test.txt"
34 |
35 |
36 | class Materials(datasets.GeneratorBasedBuilder):
37 | """Materials dataset"""
38 |
39 | VERSION = datasets.Version("1.0.0")
40 |
41 | BUILDER_CONFIGS = [
42 | datasets.BuilderConfig(name="materials", version=VERSION, description="Materials dataset"),
43 | ]
44 |
45 | def _info(self):
46 | return datasets.DatasetInfo(
47 | description=_DESCRIPTION,
48 | features=datasets.Features(
49 | {
50 | "id": datasets.Value("string"),
51 | "tokens": datasets.Sequence(datasets.Value("string")),
52 | "ner_tags": datasets.Sequence(
53 | datasets.features.ClassLabel(
54 | names=[
55 | "-5",
56 | "-4",
57 | "-3",
58 | "-2",
59 | "-1",
60 | "0",
61 | "1",
62 | "2",
63 | "3",
64 | "4",
65 | "5",
66 | "6",
67 | "7",
68 | "8",
69 | ]
70 | )
71 | ),
72 | }
73 | ),
74 | supervised_keys=None,
75 | homepage="https://github.com/usccolumbia/BERTOS.git",
76 | citation=_CITATION,
77 | )
78 |
79 | def _split_generators(self, dl_manager):
80 | """Returns SplitGenerators."""
81 |
82 | data_files = {
83 | "train": os.path.join(_ROOT, _TRAINING_FILE),
84 | "validation": os.path.join(_ROOT, _DEV_FILE),
85 | "test": os.path.join(_ROOT, _TEST_FILE),
86 | }
87 |
88 | return [
89 | datasets.SplitGenerator(
90 | name=datasets.Split.TRAIN,
91 | gen_kwargs={"filepath": data_files["train"], "split": "train"},
92 | ),
93 | datasets.SplitGenerator(
94 | name=datasets.Split.VALIDATION,
95 | gen_kwargs={"filepath": data_files["validation"], "split": "validation"},
96 | ),
97 | datasets.SplitGenerator(
98 | name=datasets.Split.TEST,
99 | gen_kwargs={"filepath": data_files["test"], "split": "test"},
100 | ),
101 | ]
102 |
103 | def _generate_examples(self, filepath, split):
104 | """Yields examples."""
105 |
106 | #logger.info("? Generating examples from = %s", filepath)
107 |
108 | with open(filepath, encoding="utf-8") as f:
109 |
110 | guid = 0
111 | tokens = []
112 | ner_tags = []
113 |
114 | for line in f:
115 | if line == "" or line == "\n":
116 | if tokens:
117 | yield guid, {
118 | "id": str(guid),
119 | "tokens": tokens,
120 | "ner_tags": ner_tags,
121 | }
122 | guid += 1
123 | tokens = []
124 | ner_tags = []
125 | else:
126 | splits = line.split(" ")
127 | tokens.append(splits[0])
128 | ner_tags.append(splits[1].rstrip())
129 |
130 | # last example
131 | yield guid, {
132 | "id": str(guid),
133 | "tokens": tokens,
134 | "ner_tags": ner_tags,
135 | }
136 |
--------------------------------------------------------------------------------
/materials_icsd_o.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Materials dataset"""
16 |
17 | import os
18 | import datasets
19 |
20 |
21 | logger = datasets.logging.get_logger(__name__)
22 |
23 |
24 | _CITATION = """
25 | """
26 |
27 | _DESCRIPTION = """
28 | """
29 |
30 | _ROOT = "./dataset/ICSD_oxide/"
31 | _TRAINING_FILE = "train.txt"
32 | _DEV_FILE = "validation.txt"
33 | _TEST_FILE = "test.txt"
34 |
35 |
36 | class Materials(datasets.GeneratorBasedBuilder):
37 | """Materials dataset"""
38 |
39 | VERSION = datasets.Version("1.0.0")
40 |
41 | BUILDER_CONFIGS = [
42 | datasets.BuilderConfig(name="materials", version=VERSION, description="Materials dataset"),
43 | ]
44 |
45 | def _info(self):
46 | return datasets.DatasetInfo(
47 | description=_DESCRIPTION,
48 | features=datasets.Features(
49 | {
50 | "id": datasets.Value("string"),
51 | "tokens": datasets.Sequence(datasets.Value("string")),
52 | "ner_tags": datasets.Sequence(
53 | datasets.features.ClassLabel(
54 | names=[
55 | "-5",
56 | "-4",
57 | "-3",
58 | "-2",
59 | "-1",
60 | "0",
61 | "1",
62 | "2",
63 | "3",
64 | "4",
65 | "5",
66 | "6",
67 | "7",
68 | "8",
69 | ]
70 | )
71 | ),
72 | }
73 | ),
74 | supervised_keys=None,
75 | homepage="https://github.com/usccolumbia/BERTOS.git",
76 | citation=_CITATION,
77 | )
78 |
79 | def _split_generators(self, dl_manager):
80 | """Returns SplitGenerators."""
81 |
82 | data_files = {
83 | "train": os.path.join(_ROOT, _TRAINING_FILE),
84 | "validation": os.path.join(_ROOT, _DEV_FILE),
85 | "test": os.path.join(_ROOT, _TEST_FILE),
86 | }
87 |
88 | return [
89 | datasets.SplitGenerator(
90 | name=datasets.Split.TRAIN,
91 | gen_kwargs={"filepath": data_files["train"], "split": "train"},
92 | ),
93 | datasets.SplitGenerator(
94 | name=datasets.Split.VALIDATION,
95 | gen_kwargs={"filepath": data_files["validation"], "split": "validation"},
96 | ),
97 | datasets.SplitGenerator(
98 | name=datasets.Split.TEST,
99 | gen_kwargs={"filepath": data_files["test"], "split": "test"},
100 | ),
101 | ]
102 |
103 | def _generate_examples(self, filepath, split):
104 | """Yields examples."""
105 |
106 | with open(filepath, encoding="utf-8") as f:
107 |
108 | guid = 0
109 | tokens = []
110 | ner_tags = []
111 |
112 | for line in f:
113 | if line == "" or line == "\n":
114 | if tokens:
115 | yield guid, {
116 | "id": str(guid),
117 | "tokens": tokens,
118 | "ner_tags": ner_tags,
119 | }
120 | guid += 1
121 | tokens = []
122 | ner_tags = []
123 | else:
124 | splits = line.split(" ")
125 | tokens.append(splits[0])
126 | ner_tags.append(splits[1].rstrip())
127 |
128 | # last example
129 | yield guid, {
130 | "id": str(guid),
131 | "tokens": tokens,
132 | "ner_tags": ner_tags,
133 | }
134 |
--------------------------------------------------------------------------------
/performances.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/performances.png
--------------------------------------------------------------------------------
/random_config/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "classifier_dropout": null,
7 | "hidden_act": "gelu",
8 | "hidden_dropout_prob": 0.1,
9 | "hidden_size": 120,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 512,
12 | "layer_norm_eps": 1e-12,
13 | "max_position_embeddings": 100,
14 | "model_type": "bert",
15 | "num_attention_heads": 4,
16 | "num_hidden_layers": 12,
17 | "pad_token_id": 0,
18 | "position_embedding_type": "absolute",
19 | "torch_dtype": "float32",
20 | "transformers_version": "4.23.0.dev0",
21 | "type_vocab_size": 2,
22 | "use_cache": true,
23 | "vocab_size": 123
24 | }
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.3.5
2 | argparse==1.4.0
3 | pymatgen==2022.0.17
4 | datasets==2.5.1
5 | tqdm==4.64.1
6 | accelerate==0.12.0
7 | evaluate==0.2.2
8 | transformers
9 | seqeval
10 | tensorboard
11 |
--------------------------------------------------------------------------------
/tokenizer/vocab.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [UNK]
3 | [CLS]
4 | [SEP]
5 | [MASK]
6 | H
7 | He
8 | Li
9 | Be
10 | B
11 | C
12 | N
13 | O
14 | F
15 | Ne
16 | Na
17 | Mg
18 | Al
19 | Si
20 | P
21 | S
22 | Cl
23 | Ar
24 | K
25 | Ca
26 | Sc
27 | Ti
28 | V
29 | Cr
30 | Mn
31 | Fe
32 | Co
33 | Ni
34 | Cu
35 | Zn
36 | Ga
37 | Ge
38 | As
39 | Se
40 | Br
41 | Kr
42 | Rb
43 | Sr
44 | Y
45 | Zr
46 | Nb
47 | Mo
48 | Tc
49 | Ru
50 | Rh
51 | Pd
52 | Ag
53 | Cd
54 | In
55 | Sn
56 | Sb
57 | Te
58 | I
59 | Xe
60 | Cs
61 | Ba
62 | La
63 | Ce
64 | Pr
65 | Nd
66 | Pm
67 | Sm
68 | Eu
69 | Gd
70 | Tb
71 | Dy
72 | Ho
73 | Er
74 | Tm
75 | Yb
76 | Lu
77 | Hf
78 | Ta
79 | W
80 | Re
81 | Os
82 | Ir
83 | Pt
84 | Au
85 | Hg
86 | Tl
87 | Pb
88 | Bi
89 | Po
90 | At
91 | Rn
92 | Fr
93 | Ra
94 | Ac
95 | Th
96 | Pa
97 | U
98 | Np
99 | Pu
100 | Am
101 | Cm
102 | Bk
103 | Cf
104 | Es
105 | Fm
106 | Md
107 | No
108 | Lr
109 | Rf
110 | Db
111 | Sg
112 | Bh
113 | Hs
114 | Mt
115 | Ds
116 | Rg
117 | Cn
118 | Nh
119 | Fl
120 | Mc
121 | Lv
122 | Ts
123 | Og
--------------------------------------------------------------------------------
/train_BERTOS.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Train BERTOS
16 | """
17 |
18 | import argparse
19 | import json
20 | import logging
21 | import math
22 | import os
23 | import random
24 | from pathlib import Path
25 |
26 | import datasets
27 | import torch
28 | from datasets import ClassLabel, load_dataset
29 | from torch.utils.data import DataLoader
30 | from tqdm.auto import tqdm
31 |
32 | import evaluate
33 | import transformers
34 | from accelerate import Accelerator
35 | from accelerate.logging import get_logger
36 | from accelerate.utils import set_seed
37 | from huggingface_hub import Repository
38 | from transformers import (
39 | AutoConfig,
40 | AutoModelForTokenClassification,
41 | DataCollatorForTokenClassification,
42 | PretrainedConfig,
43 | SchedulerType,
44 | default_data_collator,
45 | get_scheduler,
46 | )
47 | from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
48 | from transformers.utils.versions import require_version
49 |
50 | from transformers import BertTokenizerFast
51 |
52 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
53 | check_min_version("4.23.0.dev0")
54 |
55 | logger = get_logger(__name__)
56 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
57 |
58 |
59 | def parse_args():
60 | parser = argparse.ArgumentParser(
61 | description="Train BERTOS"
62 | )
63 | parser.add_argument(
64 | "--dataset_name",
65 | type=str,
66 | default=None,
67 | help="The name of the dataset to use (via the datasets library).",
68 | )
69 | parser.add_argument(
70 | "--text_column_name",
71 | type=str,
72 | default=None,
73 | help="The column name of text to input in the file (a csv or JSON file).",
74 | )
75 | parser.add_argument(
76 | "--label_column_name",
77 | type=str,
78 | default=None,
79 | help="The column name of label to input in the file (a csv or JSON file).",
80 | )
81 | parser.add_argument(
82 | "--max_length",
83 | type=int,
84 | default=128,
85 | help=(
86 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
87 | " sequences shorter will be padded if `--pad_to_max_length` is passed."
88 | ),
89 | )
90 | parser.add_argument(
91 | "--pad_to_max_length",
92 | action="store_true",
93 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
94 | )
95 | parser.add_argument(
96 | "--model_name_or_path",
97 | type=str,
98 | help="Path to pretrained model or model identifier from huggingface.co/models.",
99 | required=False,
100 | )
101 | parser.add_argument(
102 | "--config_name",
103 | type=str,
104 | default=None,
105 | help="Pretrained config name or path if not the same as model_name",
106 | )
107 | parser.add_argument(
108 | "--tokenizer_name",
109 | type=str,
110 | default='./tokenizer',
111 | help="Pretrained tokenizer name or path if not the same as model_name",
112 | )
113 | parser.add_argument(
114 | "--per_device_train_batch_size",
115 | type=int,
116 | default=8,
117 | help="Batch size (per device) for the training dataloader.",
118 | )
119 | parser.add_argument(
120 | "--per_device_eval_batch_size",
121 | type=int,
122 | default=8,
123 | help="Batch size (per device) for the evaluation dataloader.",
124 | )
125 | parser.add_argument(
126 | "--learning_rate",
127 | type=float,
128 | default=5e-5,
129 | help="Initial learning rate (after the potential warmup period) to use.",
130 | )
131 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
132 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
133 | parser.add_argument(
134 | "--max_train_steps",
135 | type=int,
136 | default=None,
137 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
138 | )
139 | parser.add_argument(
140 | "--gradient_accumulation_steps",
141 | type=int,
142 | default=1,
143 | help="Number of updates steps to accumulate before performing a backward/update pass.",
144 | )
145 | parser.add_argument(
146 | "--lr_scheduler_type",
147 | type=SchedulerType,
148 | default="linear",
149 | help="The scheduler type to use.",
150 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
151 | )
152 | parser.add_argument(
153 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
154 | )
155 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
156 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
157 | parser.add_argument(
158 | "--label_all_tokens",
159 | action="store_true",
160 | help="Setting labels of all special tokens to -100 and thus PyTorch will ignore them.",
161 | )
162 | parser.add_argument(
163 | "--return_entity_level_metrics",
164 | action="store_true",
165 | help="Indication whether entity level metrics are to be returner.",
166 | )
167 | parser.add_argument(
168 | "--task_name",
169 | type=str,
170 | default="ner",
171 | choices=["ner", "pos", "chunk"],
172 | help="The name of the task.",
173 | )
174 | parser.add_argument(
175 | "--debug",
176 | action="store_true",
177 | help="Activate debug mode and run training only with a subset of data.",
178 | )
179 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
180 | parser.add_argument(
181 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
182 | )
183 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
184 | parser.add_argument(
185 | "--checkpointing_steps",
186 | type=str,
187 | default=None,
188 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
189 | )
190 | parser.add_argument(
191 | "--resume_from_checkpoint",
192 | type=str,
193 | default=None,
194 | help="If the training should continue from a checkpoint folder.",
195 | )
196 | parser.add_argument(
197 | "--with_tracking",
198 | action="store_true",
199 | help="Whether to enable experiment trackers for logging.",
200 | )
201 | parser.add_argument(
202 | "--report_to",
203 | type=str,
204 | default="all",
205 | help=(
206 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
207 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
208 | "Only applicable when `--with_tracking` is passed."
209 | ),
210 | )
211 | parser.add_argument(
212 | "--ignore_mismatched_sizes",
213 | action="store_true",
214 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
215 | )
216 | args = parser.parse_args()
217 |
218 | # Sanity checks
219 | if args.push_to_hub:
220 | assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
221 |
222 | return args
223 |
224 |
225 | def main():
226 | args = parse_args()
227 |
228 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
229 | # information sent is the one passed as arguments along with your Python/PyTorch versions.
230 | send_example_telemetry("run_ner_no_trainer", args)
231 |
232 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
233 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
234 | # in the environment
235 | accelerator = (
236 | Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
237 | )
238 | # Make one log on every process with the configuration for debugging.
239 | logging.basicConfig(
240 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
241 | datefmt="%m/%d/%Y %H:%M:%S",
242 | level=logging.INFO,
243 | )
244 | logger.info(accelerator.state, main_process_only=False)
245 | if accelerator.is_local_main_process:
246 | datasets.utils.logging.set_verbosity_warning()
247 | transformers.utils.logging.set_verbosity_info()
248 | else:
249 | datasets.utils.logging.set_verbosity_error()
250 | transformers.utils.logging.set_verbosity_error()
251 |
252 | # If passed along, set the training seed now.
253 | if args.seed is not None:
254 | set_seed(args.seed)
255 |
256 | # Handle the repository creation
257 | if accelerator.is_main_process:
258 | if args.push_to_hub:
259 | if args.hub_model_id is None:
260 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
261 | else:
262 | repo_name = args.hub_model_id
263 | repo = Repository(args.output_dir, clone_from=repo_name)
264 |
265 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
266 | if "step_*" not in gitignore:
267 | gitignore.write("step_*\n")
268 | if "epoch_*" not in gitignore:
269 | gitignore.write("epoch_*\n")
270 | elif args.output_dir is not None:
271 | os.makedirs(args.output_dir, exist_ok=True)
272 | accelerator.wait_for_everyone()
273 |
274 |
275 | ## load dataset
276 | if not args.dataset_name:
277 | raise ValueError(
278 | "Please give dataset file"
279 | )
280 |
281 | raw_datasets = load_dataset(args.dataset_name)
282 |
283 | # Trim a number of training examples
284 | if args.debug:
285 | for split in raw_datasets.keys():
286 | raw_datasets[split] = raw_datasets[split].select(range(100))
287 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
288 | # https://huggingface.co/docs/datasets/loading_datasets.html.
289 |
290 | if raw_datasets["train"] is not None:
291 | column_names = raw_datasets["train"].column_names
292 | features = raw_datasets["train"].features
293 | else:
294 | column_names = raw_datasets["validation"].column_names
295 | features = raw_datasets["validation"].features
296 |
297 | if args.text_column_name is not None:
298 | text_column_name = args.text_column_name
299 | elif "tokens" in column_names:
300 | text_column_name = "tokens"
301 | else:
302 | text_column_name = column_names[0]
303 |
304 | if args.label_column_name is not None:
305 | label_column_name = args.label_column_name
306 | elif f"{args.task_name}_tags" in column_names:
307 | label_column_name = f"{args.task_name}_tags"
308 | else:
309 | label_column_name = column_names[1]
310 |
311 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
312 | # unique labels.
313 | def get_label_list(labels):
314 | unique_labels = set()
315 | for label in labels:
316 | unique_labels = unique_labels | set(label)
317 | label_list = list(unique_labels)
318 | label_list.sort()
319 | return label_list
320 |
321 | # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
322 | # Otherwise, we have to get the list of labels manually.
323 | labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
324 | if labels_are_int:
325 | label_list = features[label_column_name].feature.names
326 | label_to_id = {i: i for i in range(len(label_list))}
327 | else:
328 | label_list = get_label_list(raw_datasets["train"][label_column_name])
329 | label_to_id = {l: i for i, l in enumerate(label_list)}
330 |
331 | num_labels = len(label_list)
332 |
333 | # Load pretrained model and tokenizer
334 | ##prepare config file (BERT)
335 | config = AutoConfig.from_pretrained(args.config_name, num_labels=num_labels)
336 |
337 | ##load tokenizer
338 | tokenizer_name_or_path = args.tokenizer_name
339 | if not tokenizer_name_or_path:
340 | raise ValueError(
341 | "You are instantiating a new tokenizer from scratch. This is not supported by this script."
342 | "You can do it from another script, save it, and load it from here, using --tokenizer_name."
343 | )
344 |
345 | tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name_or_path, do_lower_case=False)
346 |
347 | logger.info("Training new model from scratch")
348 | model = AutoModelForTokenClassification.from_config(config)
349 |
350 | model.resize_token_embeddings(len(tokenizer))
351 |
352 | # Model has labels -> use them.
353 | if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
354 | if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
355 | # Reorganize `label_list` to match the ordering of the model.
356 | if labels_are_int:
357 | label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
358 | label_list = [model.config.id2label[i] for i in range(num_labels)]
359 | else:
360 | label_list = [model.config.id2label[i] for i in range(num_labels)]
361 | label_to_id = {l: i for i, l in enumerate(label_list)}
362 | else:
363 | logger.warning(
364 | "Your model seems to have been trained with labels, but they don't match the dataset: ",
365 | f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:"
366 | f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
367 | )
368 |
369 | # Set the correspondences label/ID inside the model config
370 | model.config.label2id = {l: i for i, l in enumerate(label_list)}
371 | model.config.id2label = {i: l for i, l in enumerate(label_list)}
372 |
373 | # Map that sends B-Xxx label to its I-Xxx counterpart
374 | b_to_i_label = []
375 |
376 | # Preprocessing the datasets.
377 | # First we tokenize all the texts.
378 | padding = "max_length" if args.pad_to_max_length else False
379 |
380 | # Tokenize all texts and align the labels with them.
381 |
382 | def tokenize_and_align_labels(examples):
383 | tokenized_inputs = tokenizer(
384 | examples[text_column_name],
385 | max_length=args.max_length,
386 | padding=padding,
387 | truncation=True,
388 | # We use this argument because the texts in our dataset are lists of words (with a label for each word).
389 | is_split_into_words=True,
390 | )
391 |
392 | labels = []
393 | for i, label in enumerate(examples[label_column_name]):
394 | word_ids = tokenized_inputs.word_ids(batch_index=i)
395 | previous_word_idx = None
396 | label_ids = []
397 | for word_idx in word_ids:
398 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically
399 | # ignored in the loss function.
400 | if word_idx is None:
401 | label_ids.append(-100)
402 | # We set the label for the first token of each word.
403 | elif word_idx != previous_word_idx:
404 | label_ids.append(label_to_id[label[word_idx]])
405 | # For the other tokens in a word, we set the label to either the current label or -100, depending on
406 | # the label_all_tokens flag.
407 | else:
408 | if args.label_all_tokens:
409 | label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
410 | else:
411 | label_ids.append(-100)
412 | previous_word_idx = word_idx
413 |
414 | labels.append(label_ids)
415 | tokenized_inputs["labels"] = labels
416 | return tokenized_inputs
417 |
418 | with accelerator.main_process_first():
419 | processed_raw_datasets = raw_datasets.map(
420 | tokenize_and_align_labels,
421 | batched=True,
422 | remove_columns=raw_datasets["train"].column_names,
423 | desc="Running tokenizer on dataset",
424 | )
425 |
426 | train_dataset = processed_raw_datasets["train"]
427 | eval_dataset = processed_raw_datasets["validation"]
428 |
429 | # Log a few random samples from the training set:
430 | for index in random.sample(range(len(train_dataset)), 3):
431 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
432 |
433 | # DataLoaders creation:
434 | if args.pad_to_max_length:
435 | # If padding was already done ot max length, we use the default data collator that will just convert everything
436 | # to tensors.
437 | data_collator = default_data_collator
438 | else:
439 | # Otherwise, `DataCollatorForTokenClassification` will apply dynamic padding for us (by padding to the maximum length of
440 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
441 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
442 | data_collator = DataCollatorForTokenClassification(
443 | tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)
444 | )
445 |
446 | train_dataloader = DataLoader(
447 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
448 | )
449 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
450 |
451 | # Optimizer
452 | # Split weights in two groups, one with weight decay and the other not.
453 | no_decay = ["bias", "LayerNorm.weight"]
454 | optimizer_grouped_parameters = [
455 | {
456 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
457 | "weight_decay": args.weight_decay,
458 | },
459 | {
460 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
461 | "weight_decay": 0.0,
462 | },
463 | ]
464 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
465 |
466 | # Use the device given by the `accelerator` object.
467 | device = accelerator.device
468 | model.to(device)
469 |
470 | # Scheduler and math around the number of training steps.
471 | overrode_max_train_steps = False
472 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
473 | if args.max_train_steps is None:
474 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
475 | overrode_max_train_steps = True
476 |
477 | lr_scheduler = get_scheduler(
478 | name=args.lr_scheduler_type,
479 | optimizer=optimizer,
480 | num_warmup_steps=args.num_warmup_steps,
481 | num_training_steps=args.max_train_steps,
482 | )
483 |
484 | # Prepare everything with our `accelerator`.
485 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
486 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
487 | )
488 |
489 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
490 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
491 | if overrode_max_train_steps:
492 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
493 | # Afterwards we recalculate our number of training epochs
494 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
495 |
496 | # Figure out how many steps we should save the Accelerator states
497 | checkpointing_steps = args.checkpointing_steps
498 | if checkpointing_steps is not None and checkpointing_steps.isdigit():
499 | checkpointing_steps = int(checkpointing_steps)
500 |
501 | # We need to initialize the trackers we use, and also store our configuration.
502 | # The trackers initializes automatically on the main process.
503 | if args.with_tracking:
504 | experiment_config = vars(args)
505 | # TensorBoard cannot log Enums, need the raw value
506 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
507 | accelerator.init_trackers("ner_no_trainer", experiment_config)
508 |
509 | # Metrics
510 | metric = evaluate.load("seqeval")
511 |
512 | def get_labels(predictions, references):
513 | # Transform predictions and references tensos to numpy arrays
514 | if device.type == "cpu":
515 | y_pred = predictions.detach().clone().numpy()
516 | y_true = references.detach().clone().numpy()
517 | else:
518 | y_pred = predictions.detach().cpu().clone().numpy()
519 | y_true = references.detach().cpu().clone().numpy()
520 |
521 | # Remove ignored index (special tokens)
522 | true_predictions = [
523 | [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
524 | for pred, gold_label in zip(y_pred, y_true)
525 | ]
526 | true_labels = [
527 | [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
528 | for pred, gold_label in zip(y_pred, y_true)
529 | ]
530 | return true_predictions, true_labels
531 |
532 | def compute_metrics():
533 | results = metric.compute()
534 | if args.return_entity_level_metrics:
535 | # Unpack nested dictionaries
536 | final_results = {}
537 | for key, value in results.items():
538 | if isinstance(value, dict):
539 | for n, v in value.items():
540 | final_results[f"{key}_{n}"] = v
541 | else:
542 | final_results[key] = value
543 | return final_results
544 | else:
545 | return {
546 | "precision": results["overall_precision"],
547 | "recall": results["overall_recall"],
548 | "f1": results["overall_f1"],
549 | "accuracy": results["overall_accuracy"],
550 | }
551 |
552 | # Train!
553 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
554 |
555 | logger.info("***** Running training *****")
556 | logger.info(f" Num examples = {len(train_dataset)}")
557 | logger.info(f" Num Epochs = {args.num_train_epochs}")
558 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
559 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
560 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
561 | logger.info(f" Total optimization steps = {args.max_train_steps}")
562 | # Only show the progress bar once on each machine.
563 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
564 | completed_steps = 0
565 | starting_epoch = 0
566 | # Potentially load in the weights and states from a previous save
567 | if args.resume_from_checkpoint:
568 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
569 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
570 | accelerator.load_state(args.resume_from_checkpoint)
571 | path = os.path.basename(args.resume_from_checkpoint)
572 | else:
573 | # Get the most recent checkpoint
574 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
575 | dirs.sort(key=os.path.getctime)
576 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
577 | # Extract `epoch_{i}` or `step_{i}`
578 | training_difference = os.path.splitext(path)[0]
579 |
580 | if "epoch" in training_difference:
581 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
582 | resume_step = None
583 | else:
584 | resume_step = int(training_difference.replace("step_", ""))
585 | starting_epoch = resume_step // len(train_dataloader)
586 | resume_step -= starting_epoch * len(train_dataloader)
587 |
588 | for epoch in range(starting_epoch, args.num_train_epochs):
589 | model.train()
590 | if args.with_tracking:
591 | total_loss = 0
592 | for step, batch in enumerate(train_dataloader):
593 | # We need to skip steps until we reach the resumed step
594 | if args.resume_from_checkpoint and epoch == starting_epoch:
595 | if resume_step is not None and step < resume_step:
596 | completed_steps += 1
597 | continue
598 | outputs = model(**batch)
599 | loss = outputs.loss
600 | # We keep track of the loss at each epoch
601 | if args.with_tracking:
602 | total_loss += loss.detach().float()
603 | loss = loss / args.gradient_accumulation_steps
604 | accelerator.backward(loss)
605 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
606 | optimizer.step()
607 | lr_scheduler.step()
608 | optimizer.zero_grad()
609 | progress_bar.update(1)
610 | completed_steps += 1
611 |
612 | if isinstance(checkpointing_steps, int):
613 | if completed_steps % checkpointing_steps == 0:
614 | output_dir = f"step_{completed_steps }"
615 | if args.output_dir is not None:
616 | output_dir = os.path.join(args.output_dir, output_dir)
617 | accelerator.save_state(output_dir)
618 |
619 | if completed_steps >= args.max_train_steps:
620 | break
621 |
622 | model.eval()
623 | samples_seen = 0
624 |
625 | outputs4save = []
626 | for step, batch in enumerate(eval_dataloader):
627 | with torch.no_grad():
628 | outputs = model(**batch)
629 | predictions = outputs.logits.argmax(dim=-1)
630 | labels = batch["labels"]
631 | if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered
632 | predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
633 | labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
634 | predictions_gathered, labels_gathered = accelerator.gather((predictions, labels))
635 | # If we are in a multiprocess environment, the last batch has duplicates
636 | if accelerator.num_processes > 1:
637 | if step == len(eval_dataloader) - 1:
638 | predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen]
639 | labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen]
640 | else:
641 | samples_seen += labels_gathered.shape[0]
642 | preds, refs = get_labels(predictions_gathered, labels_gathered)
643 | metric.add_batch(
644 | predictions=preds,
645 | references=refs,
646 | ) # predictions and preferences are expected to be a nested list of labels, not label_ids
647 |
648 | if epoch == (args.num_train_epochs - 1):
649 | outputs4save.append([preds, refs])
650 |
651 | eval_metric = compute_metrics()
652 | accelerator.print(f"epoch {epoch}:", eval_metric)
653 | if args.with_tracking:
654 | accelerator.log(
655 | {
656 | "seqeval": eval_metric,
657 | "train_loss": total_loss.item() / len(train_dataloader),
658 | "epoch": epoch,
659 | "step": completed_steps,
660 | },
661 | step=completed_steps,
662 | )
663 |
664 | if args.push_to_hub and epoch < args.num_train_epochs - 1:
665 | accelerator.wait_for_everyone()
666 | unwrapped_model = accelerator.unwrap_model(model)
667 | unwrapped_model.save_pretrained(
668 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
669 | )
670 | if accelerator.is_main_process:
671 | tokenizer.save_pretrained(args.output_dir)
672 | repo.push_to_hub(
673 | commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
674 | )
675 |
676 | if args.checkpointing_steps == "epoch":
677 | output_dir = f"epoch_{epoch}"
678 | if args.output_dir is not None:
679 | output_dir = os.path.join(args.output_dir, output_dir)
680 | accelerator.save_state(output_dir)
681 |
682 | if args.with_tracking:
683 | accelerator.end_training()
684 |
685 | if args.output_dir is not None:
686 | accelerator.wait_for_everyone()
687 | unwrapped_model = accelerator.unwrap_model(model)
688 | unwrapped_model.save_pretrained(
689 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
690 | )
691 | if accelerator.is_main_process:
692 | tokenizer.save_pretrained(args.output_dir)
693 | if args.push_to_hub:
694 | repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
695 |
696 |
697 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
698 | json.dump(
699 | {"eval_accuracy": eval_metric["accuracy"]}, f
700 | )
701 |
702 | import pandas as pd
703 | # Save predictions
704 | out = pd.DataFrame(outputs4save)
705 | out.to_csv(os.path.join(args.output_dir, "predictions.csv"), header=None, index=None)
706 |
707 | if __name__ == "__main__":
708 | main()
709 |
--------------------------------------------------------------------------------
/train_BERTOS.sh:
--------------------------------------------------------------------------------
1 | python train_BERTOS.py \
2 | --config_name ./random_config/ \
3 | --dataset_name materials_icsd_cn.py \
4 | --max_length 100 \
5 | --per_device_train_batch_size 256 \
6 | --learning_rate 1e-3 \
7 | --num_train_epochs 500 \
8 | --output_dir ./model_icsdcn
9 |
--------------------------------------------------------------------------------
/trained_models/ICSD.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/trained_models/ICSD.zip
--------------------------------------------------------------------------------
/trained_models/ICSD_CN.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/trained_models/ICSD_CN.zip
--------------------------------------------------------------------------------
/trained_models/ICSD_CN_oxide.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/trained_models/ICSD_CN_oxide.zip
--------------------------------------------------------------------------------
/trained_models/ICSD_oxide.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/usccolumbia/BERTOS/ab4e4f31b09543a7949f36981f1c87b9ff41bb74/trained_models/ICSD_oxide.zip
--------------------------------------------------------------------------------
/trained_models/README.md:
--------------------------------------------------------------------------------
1 | Download pretrained models for oxidation state prediction from figshare.com at
2 | https://figshare.com/articles/online_resource/BERTOS_model/21554823
3 | and then double click to unzip them.
4 |
5 |
--------------------------------------------------------------------------------