├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── Notice.txt ├── README.md ├── benchmark ├── benchmark.pt ├── in.data └── in.lammps ├── configs ├── alignment_config │ └── config.json ├── ensemble_config │ └── config.json └── train_config │ └── config.json ├── models ├── __init__.py ├── bamboo_base.py ├── bamboo_get.py └── modules │ ├── __init__.py │ └── dftd3 │ ├── __init__.py │ ├── dftd3.pt │ └── dftd3.py ├── pair ├── build.sh ├── init_compile.sh └── src │ ├── pair_bamboo.cpp │ ├── pair_bamboo.h │ ├── pair_bamboo_kokkos.cpp │ └── pair_bamboo_kokkos.h ├── train ├── alignment.py ├── ensemble.py └── train.py └── utils ├── __init__.py ├── batchify.py ├── constant.py ├── funcs.py ├── load_traj.py ├── log_helper.py ├── path.py └── rejit.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | We are thrilled that you are interested in contributing to BAMBOO! This document provides guidelines and instructions on how to contribute to the project in a way that is efficient and aligns with our goals. 4 | 5 | ## Our Goals 6 | 7 | * **Improve Model Quality:** We aim to enhance the predictive accuracy and reduce the variance of our models. 8 | 9 | * **Increase Computational Efficiency:** Any contributions that can optimize our processes and decrease computational time are highly welcome. 10 | 11 | ## How to Contribute 12 | 13 | Contributing to BAMBOO is easy. Here's how you can do it: 14 | 15 | ### Reporting Bugs or Issues 16 | 17 | 1. **Check the Issue Tracker:** Before you submit a new issue, please check if it is already reported. 18 | 2. **Create a New Issue:** If your issue is new, click on the *Issues* tab and then the *New Issue* button to create a new issue. 19 | 3. **Describe the Issue:** Provide a detailed description of the issue, including steps to reproduce it, expected outcomes, and actual outcomes. 20 | 21 | ### Proposing Changes 22 | 23 | If you have a proposal to improve the project, follow these steps: 24 | 25 | 1. **Fork the Repository:** Create your own fork of the project and clone it locally. 26 | 2. **Create a New Branch:** Always create a new branch for your changes. 27 | 3. **Develop:** Implement your changes in your branch. Make sure your changes adhere to our goals of improving model quality or computational efficiency. 28 | 4. **Write Tests:** We highly encourage writing tests to verify that your changes meet the intended effects and do not break existing functionality. 29 | 5. **Document Your Changes:** Update the README.md if your changes require it, and provide comments and documentation in the code as necessary. 30 | 6. **Submit a Pull Request:** Once you are ready, push your changes to your fork and then submit a pull request to the main repository. 31 | 32 | ### Code Review Process 33 | 34 | Once your pull request is submitted, it will be reviewed by the maintainers. Here’s what you can expect: 35 | 36 | - **Review:** One of the maintainers will review your code to ensure it aligns with our goals and meets our standards for quality and efficiency. 37 | - **Feedback:** You may receive feedback and requests for changes. This is a normal part of the review process, and we encourage contributors to be open to constructive discussion. 38 | - **Acceptance:** If your contribution is accepted, it will be merged into the project. 39 | 40 | ## Guidelines 41 | 42 | - **Code Style:** Please adhere to the coding style that is prevalent in the project. 43 | - **Commit Messages:** Use clear and meaningful commit messages. 44 | 45 | ## Getting Help 46 | 47 | If you need help or have questions about contributing, feel free to contact us. 48 | 49 | Thank you for your interest in contributing to BAMBOO. We look forward to your contributions and are excited to see how together we can make BAMBOO better! -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. -------------------------------------------------------------------------------- /Notice.txt: -------------------------------------------------------------------------------- 1 | Copyright 2022-2024 Bytedance Ltd. and/or its affiliates -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **B**yteDance **A**I **M**olecular Simulation **BOO**ster (BAMBOO) 2 | 3 | Welcome to the repository of BAMBOO! This repository hosts the source code for creating a machine learning-based force field (MLFF) for molecular dynamics (MD) simulations of lithium battery electrolytes. Whether you're interested in simulating lithium battery electrolytes or other types of liquids, BAMBOO provides a robust and versatile solution. 4 | 5 | Thank you for considering BAMBOO for your research needs. We are thrilled to be a part of your scientific journey and are eager to see how our project contributes to your outstanding results. 6 | 7 | ## 2025.05 Release 8 | In this release, we provide the following updates about data, checkpoints, and implementation of dispersion corrections: 9 | 10 | ### Dataset 11 | The complete training and validation DFT dataset is provided at: https://huggingface.co/datasets/mzl/bamboo 12 | 13 | ### Implementation of dispersion correction 14 | In the previous release, dispersion correction was constructed as element-dependent but independent of geometry. In this new release, we provide a new implementation of dispersion correction that depends on the coordination number (CN) of the atoms \[[2](#ref2), [3](#ref3)\]. This new implementation follows the original spirit of the DFT-D3(CSO) dispersion correction, with the value $a_4$ slightly adjusted for better zero-shot density prediction (see comment in `models/modules/dftd3/dftd3.py`). 15 | 16 | ### Checkpoints 17 | In this release, we also provide the model checkpoint with the old implementation of dispersion correction that can reproduce the results presented in the paper [[1]](#ref1), as well as the model checkpoint with the newly implemented dispersion correction, at `benchmark/paper_new_disp.pt`. These two checkpoints have similar performance in terms of prediction of density, viscosity, and ionic conductivity. Some benchmarks of the checkpoint with the newly implemented dispersion correction are listed below: 18 | 19 | | System | Predicted Density (g/ml) | Exp. Density (g/ml) | Predicted Viscosity (cP) | Exp. Viscosity (cP) | Predicted Conductivity by Mistry (mS/cm) | Predicted Conductivity by NE (mS/cm) | Exp. Conductivity(mS/cm) | 20 | | --- | --- | --- | --- | --- | --- | --- | --- | 21 | | DEC | 0.971 +- 0.003 | 0.97 | 0.769 +- 0.020 | 0.749 | 22 | | DMC | 1.053 +- 0.003 | 1.06 | 0.573 +- 0.011 | 0.585 | 23 | | EA | 0.904 +- 0.004 | 0.9 | 0.536 +- 0.013 | 0.43 | 24 | | EC | 1.328 +- 0.003 | 1.32 | 1.492 +- 0.021 | 1.93 | 25 | | FEC | 1.499 +- 0.002 | 1.477 | 2.295 +- 0.031 | 2.24-4.1 | 26 | | PC | 1.201 +- 0.003 | 1.2 | 1.755 +- 0.034 | 2.53 | 27 | | Novec7000 | 1.399 +- 0.007 | 1.4 | 0.480 +- 0.008 | 0.45 | 28 | | DMC_EC\|60_40\|LiPF6\|0.9 | 1.238 +- 0.003 | 1.239 | 2.942 +- 0.138 | 2.778 | 10.593 +- 0.512 | 13.560 +- 0.744 | 12.793 | 29 | | DMC_EC_EMC\|45_50_5\|LiPF6\|1.1 | 1.273 +- 0.004 | 1.273 | 4.207 +- 0.190 | 4.566 | 8.061 +- 0.392 | 10.596 +- 0.571 | 12.175 | 30 | | DMC_EC\|70_30\|LiPF6\|1.50 | 1.260 +- 0.003 | 1.268 | 4.916 +- 0.281 | 4.472 | 7.761 +- 0.282 | 10.795 +- 0.355 | 12.144 | 31 | | DMC\|LiFSI\|2.22 | 1.276 +- 0.003 | 1.27 | 5.406 +- 0.545 | 3.9 | 8.642 +- 0.489 | 13.338 +- 0.494 | 12.2 | 32 | | EC\|LiFSI\|0.49 | 1.375 +- 0.004 | 1.38 | 3.462 +- 0.129 | 4.1 | 5.913 +- 0.193 | 6.759 +- 0.241 | 8.7 | 33 | | EC\|LiFSI\|1.14 | 1.419 +- 0.003 | 1.43 | 7.072 +- 0.333 | 8.1 | 4.724 +- 0.279 | 6.070 +- 0.383 | 9.7 | 34 | | EMC | 1.018 +- 0.002 | 1 | 0.720 +- 0.009 | 0.65 | 35 | | VC | 1.326 +- 0.004 | 1.355 | 1.013 +- 0.010 | 1.78 | 36 | | ACT | 0.808 +- 0.001 | 0.79 | 0.469 +- 0.011 | 0.31 | 37 | | DMC\|LiFSI\|3.70 | 1.361 +- 0.003 | 1.36 | 15.443 +- 2.009 | 12.9 | 3.504 +- 0.311 | 5.524 +- 0.337 | 8.1 | 38 | | DMC\|LiFSI\|1.11 | 1.167 +- 0.001 | 1.18 | 1.903 +- 0.049 | 1.5 | 15.359 +- 0.504 | 20.113 +- 0.787 | 9.9 | 39 | | EC\|LiFSI\|3.78 | 1.555 +- 0.002 | 1.57 | 38.690 +- 7.094 | 0 | 1.434 +- 0.187 | 2.064 +- 0.164 | 2.3 | 40 | | EC\|LiFSI\|2.27 | 1.484 +- 0.001 | 1.5 | 18.041 +- 2.852 | 33.1 | 2.340 +- 0.354 | 3.138 +- 0.499 | 5.6 | 41 | | DMC_EC\|51_49\|LiFSI\|3.74 | 1.456 +- 0.002 | 1.46 | 27.497 +- 3.017 | 36.9 | 1.899 +- 0.153 | 2.772 +- 0.173 | 4.5 | 42 | | DMC_EC\|51_49\|LiFSI\|2.25 | 1.372 +- 0.002 | 1.38 | 10.376 +- 0.685 | 9.8 | 4.124 +- 0.214 | 5.889 +- 0.285 | 9.8 | 43 | | DMC_EC\|51_49\|LiFSI\|1.12 | 1.285 +- 0.002 | 1.3 | 3.777 +- 0.096 | 3.4 | 9.296 +- 0.114 | 12.336 +- 0.155 | 14 | 44 | | DMC_EC\|50_50\|LiPF6\|0.5 | 1.225 +- 0.002 | 1.235 | 2.120 +- 0.080 | 2.219 | 9.506 +- 0.453 | 11.011 +- 0.572 | 12.1420533 | 45 | | DMC_EC\|70_30\|LiPF6\|1.00 | 1.219 +- 0.003 | 1.228 | 2.877 +- 0.122 | 2.734 | 11.260 +- 0.506 | 14.655 +- 0.693 | 12.4265334 | 46 | | DMC_EC\|70_30\|LiPF6\|1.30 | 1.244 +- 0.002 | 1.252 | 4.111 +- 0.202 | 3.767 | 8.858 +- 0.378 | 12.098 +- 0.543 | 12.117719 | 47 | | MA | 1.027 +- 0.002 | 0.93 | 0.552 +- 0.018 | 0.36 | 48 | 49 | ### References 50 | 51 | \[1\] Gong, Sheng, et al. "A predictive machine learning force-field framework for liquid electrolyte development." Nature Machine Intelligence (2026): 1-10. \ 52 | \[2\] Grimme, Stefan, et al. "A consistent and accurate ab initio parametrization of density functional dispersion correction (DFT-D) for the 94 elements H-Pu." The Journal of chemical physics 132.15 (2010). \ 53 | \[3\] Schröder, Heiner, Anne Creon, and Tobias Schwabe. "Reformulation of the D3 (Becke–Johnson) dispersion correction without resorting to higher than C 6 dispersion coefficients." Journal of chemical theory and computation 11.7 (2015): 3163-3170. 54 | 55 | 56 | ## Getting Started 57 | 58 | This section will guide you on how to obtain and set up BAMBOO on your local machine for development and testing purposes. 59 | 60 | ### Prerequisites 61 | 62 | To get started with BAMBOO, please ensure that you meet the following requirements: 63 | 64 | - LAMMPS: stable_2Aug2023_update3 (Tested branch.) 65 | - CUDA: 12+ 66 | - Pytorch: 2.0+ 67 | 68 | Once you have satisfied the above prerequisites, you are ready to proceed to the installation steps. 69 | 70 | ### Installing 71 | 72 | To get started, clone the BAMBOO repository to your local machine using the following command: 73 | 74 | ```bash 75 | git clone https://github.com/bytedance/bamboo.git 76 | ``` 77 | 78 | With this step, you get BAMBOO on your local system, ready for use. 79 | 80 | To initialize the environment and retrieve the LAMMPS source code, follow these steps: 81 | 82 | ```bash 83 | cd pair 84 | bash ./init_compile.sh 85 | cd lammps 86 | bash ./build.sh 87 | ``` 88 | 89 | > The build.sh script is pre-configured for the NVIDIA GeForce RTX 4090 GPU. If you are using a different GPU, you may need to adjust the ARCH variable within the script to match your specific hardware. Refer to the NVIDIA CUDA Toolkit documentation for details on selecting the correct architecture flags. 90 | 91 | > The Libtorch version is currently specified in the init_compile.sh script. If you require a different version of Libtorch, you will need to update this script accordingly. 92 | 93 | 94 | ## User Manual 95 | 96 | To demonstrate the capabilities and usage of BAMBOO, we have included a small but self-contained dataset featuring key components used in electrolyte for lithium batteries. This dataset includes: 97 | 98 | - **Dimethyl carbonate (DMC)** 99 | - **Ethylene carbonate (EC)** 100 | - **Lithium ions (Li+)** 101 | - **Hexafluorophosphate ions (PF6-)** 102 | 103 | To get the dataset, you need: 104 | 105 | 1. Visit the following links to download the datasets: [Demo data](https://huggingface.co/datasets/mzl/bamboo) 106 | 107 | 2. After downloading, copy the `train_data.pt` and `val_data.pt` into the `data` directory of the project. Once the datasets are properly placed, you can proceed with the following examples. 108 | 109 | As we focus on simulating an electrolyte composed of DMC, EC, and LiPF6, we also provide: 110 | 111 | - **Initial conformation file**: `in.data` in folder `benchmark`, which contains the starting structure for MD simulations. 112 | - **Input file for LAMMPS**: `in.lammps` in folder `benchmark`, which is prepared to start simulations using LAMMPS. 113 | 114 | These resources are designed to help users quickly set up BAMBOO and run simulations based on MLFF to explore the behavior of lithium battery electrolytes. 115 | 116 | ### Train a MLFF Model 117 | 118 | Follow these steps to train a MLFF using BAMBOO: 119 | 120 | 1. **Navigate to the project directory** 121 | 122 | Replace `` with the actual path where you have installed BAMBOO, then execute the following command to move into that directory: 123 | 124 | ```bash 125 | cd 126 | ``` 127 | 2. **Train a model** 128 | 129 | Start the training process by running: 130 | 131 | ```bash 132 | python3 -m train.train --config configs/train_config/config.json 133 | ``` 134 | 135 | This command uses a configuration file located at `configs/train_config/config.json`, where the paremeters can be changed as you need. After training, a new folder named after the `job_name` variable in your configuration file will be created inside the `/train` directory. This folder will contain the training logs and checkpoint models saved as `.pt` files. 136 | 137 | ### Run a MD Simulation using a BAMBOO MLFF Model 138 | 139 | To perform a MD simulation using a BAMBOO model, follow these steps: 140 | 141 | 1. **Create a folder for MD simulation and prepare the necessary files** 142 | 143 | Navigate to your BAMBOO directory and make a new folder for MD simulations. Copy the `in.data` and `in.lammps` files from `/data` into this directory: 144 | 145 | ```bash 146 | cd 147 | mkdir simulation && cd simulation 148 | cp ../benchmark/* . 149 | ``` 150 | 2. **Configure the simulation settings** 151 | 152 | Modify the `benchmark.pt` in `in.lammps` file to point to the path of `.pt` file for the simulation. 153 | 154 | 3. **Run a MD simulation** 155 | 156 | Execute a MD simulation by LAMMPS: 157 | 158 | ```bash 159 | lmp -k on g 1 -sf kk -in in.lammps -log log.lammps > out.log 2>&1 160 | ``` 161 | 162 | The `in.lammps` file can be configured for your simulation needs. The `.pt` file from any MLFF generated from training, ensembling, or alignment, can be used to run the MD simulations. 163 | 164 | ### Generate Frames for Ensemble and Alignment 165 | 166 | To run ensemble and alignment processes, frames from MD trajectories are required. Here's a guide to generating these frames: 167 | 168 | 1. **Navigate to the project directory** 169 | 170 | Execute the following command to move into that directory: 171 | 172 | ```bash 173 | cd 174 | ``` 175 | 2. **Extract the frames from MD trajectories** 176 | 177 | Here is an example command to extract frames from MD trajectories: 178 | 179 | ```bash 180 | python3 -m utils.load_traj --job_folder --output_folder --mixture_name 181 | ``` 182 | 183 | The mixture-name will be used in the alignment to instruct which system is aligned. 184 | 185 | ### Ensemble a model 186 | 187 | Averaging multiple replicate MLFF models into an ensembled one can help reduce variance and improve the accuracy of predictions. Follow these steps to ensemble several models trained from your dataset: 188 | 189 | 1. **Navigate to the project directory** 190 | 191 | Execute the following command to move into that directory: 192 | 193 | ```bash 194 | cd 195 | ``` 196 | 2. **Modify the config file** 197 | 198 | To ensemble your models, you need to modify the `config.json` file appropriately. This file should clearly define the paths to the models you intend to ensemble, the model based on which the changes of paremeters will be made, and the directories containing the MD frames used for ensembling. Here, we give an example of `config.json`. 199 | 200 | ```json 201 | { 202 | "job_name": "ensemble_bamboo_community", 203 | "training_data_path": "/data/train_data.pt", 204 | "validation_data_path": "/data/val_data.pt", 205 | "batch_size": 512, 206 | "models": ["/.pt", "/.pt", "/.pt"], 207 | "frame_directories": [""], 208 | "ensemble_model": "/.pt", 209 | "validation_split_ratio": 0.1, 210 | "lr": 1e-6, 211 | "epochs": 50, 212 | "scheduler_gamma": 0.99, 213 | "validation_interval": 10, 214 | "energy_ratio": 0.3, 215 | "force_ratio": 1.0, 216 | "virial_ratio": 0.1, 217 | "bulk_energy_ratio": 0.01, 218 | "bulk_force_ratio": 3.0, 219 | "bulk_virial_ratio": 0.01, 220 | "max_frames_per_mixture": 960, 221 | "frame_validation_interval": 3 222 | } 223 | ``` 224 | In this file, the `models` is a list containing all the paths of models you intend to ensemble. The `frame_direcories` is a list containing all the paths of MD frames used. The `ensemble_model` is the path of the based-model, whose parameters will change. 225 | 3. **Ensemble the models** 226 | 227 | Start the ensemble process by running: 228 | 229 | ```bash 230 | python3 -m train.ensemble --config configs/ensemble_config/config.json 231 | ``` 232 | 233 | After ensembling, a new folder named after the `job_name` variable in your configuration file will be created inside the `/ensemble` directory. This folder will contain the training logs and checkpoint models saved as `.pt` files. 234 | 235 | **Note**: To create an ensemble model, you need at least three different models. 236 | 237 | ### Alignment 238 | 239 | BAMBOO offers functionality to finetune the model's predictions by adjusting parameters such as pressure, which is referred to as the alignment process. For example, if you need to change the model's predicted pressure by dP = -2000 Pa, follow these specific steps: 240 | 241 | 1. **Navigate to the project directory** 242 | 243 | Execute the following command to move into that directory: 244 | 245 | ```bash 246 | cd 247 | ``` 248 | 2. **Modify the config file** 249 | 250 | To finetune your models by the alignment, you need to modify the `config.json` file appropriately. This file should clearly define the paths to the model you intend to finetune, and the directories containing the MD frames used for alignment. Here, we give an example of `config.json`. 251 | 252 | ```json 253 | { 254 | "job_name": "alignment_bamboo_community", 255 | "training_data_path": "/data/train_data.pt", 256 | "validation_data_path": "/data/val_data.pt", 257 | "model": "/.pt", 258 | "frame_directories": [""], 259 | "mixture_names": [""], 260 | "delta_pressure": [-2000], 261 | "energy_ratio": 0.3, 262 | "force_ratio": 1.0, 263 | "virial_ratio": 0.1, 264 | "dipole_ratio": 3.0, 265 | "bulk_energy_ratio": 1e2, 266 | "bulk_force_ratio": 1e6, 267 | "bulk_virial_ratio": 3e3, 268 | "batch_size": 512, 269 | "epochs": 30, 270 | "frame_val_interval": 3, 271 | "max_frame_per_mixture": 30, 272 | "lr": 1e-12, 273 | "scheduler_gamma": 0.99 274 | } 275 | ``` 276 | The `mixture_names` is a list that includes the names of the mixtures corresponding to the frames, which is set during generating frames. The `delta_pressure` is a list that contains the values of dP for each mixture. 277 | 3. **Finetune the model by the alignment process** 278 | 279 | Start the alignment process by running: 280 | 281 | ```bash 282 | python3 -m train.alignment --config configs/alignment_config/config.json 283 | ``` 284 | 285 | After alignment, a new folder named after the `job_name` variable in your configuration file will be created inside the `/alignment` directory. This folder will contain the training logs and checkpoint models saved as `.pt` files. 286 | 287 | ## Benchmark Model 288 | We have provided the model we trained that used for the data reported in our paper, which is located in the `benchmark` folder and named `benchmark.pt`. If you wish to reproduce the results mentioned in the paper, you can use this model. 289 | 290 | ## Contributing 291 | 292 | We welcome contributions to BAMBOO! If you have suggestions or improvements, please refers to `CONTRIBUTING.md` 293 | 294 | ## Citing BAMBOO 295 | 296 | If you use BAMBOO in your research, please cite: 297 | 298 | ```bibtex 299 | @article{gong_predictive_2025, 300 | title = {A predictive machine learning force-field framework for liquid electrolyte development}, 301 | volume = {7}, 302 | issn = {2522-5839}, 303 | url = {https://doi.org/10.1038/s42256-025-01009-7}, 304 | doi = {10.1038/s42256-025-01009-7}, 305 | number = {4}, 306 | journal = {Nature Machine Intelligence}, 307 | author = {Gong, Sheng and Zhang, Yumin and Mu, Zhenliang and Pu, Zhichen and Wang, Hongyi and Han, Xu and Yu, Zhiao and Chen, Mengyi and Zheng, Tianze and Wang, Zhi and Chen, Lifei and Yang, Zhenze and Wu, Xiaojie and Shi, Shaochen and Gao, Weihao and Yan, Wen and Xiang, Liang}, 308 | month = apr, 309 | year = {2025}, 310 | pages = {543--552}, 311 | } 312 | ``` 313 | 314 | ## License 315 | 316 | This project is licensed under the [GNU General Public License, Version 2](https://www.gnu.org/licenses/old-licenses/gpl-2.0.html). 317 | -------------------------------------------------------------------------------- /benchmark/benchmark.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/bamboo/a42ed5c42e442e484b2e1aabb29cc8b3cd3b65e1/benchmark/benchmark.pt -------------------------------------------------------------------------------- /benchmark/in.lammps: -------------------------------------------------------------------------------- 1 | units real 2 | atom_style full 3 | atom_modify map yes 4 | newton off 5 | read_data in.data 6 | 7 | pair_style bamboo 5.0 5.0 10.0 1 8 | pair_coeff benchmark.pt H LI C O F P 9 | 10 | kspace_style pppm 1.0e-6 11 | kspace_modify mesh 64 64 64 12 | 13 | neighbor 3 bin 14 | neigh_modify delay 0 every 1 check yes 15 | timestep 1 16 | 17 | thermo 1000 18 | thermo_style custom step temp press vol density pe ke etotal evdwl ecoul spcpu 19 | 20 | variable pxy equal pxy 21 | variable pxz equal pxz 22 | variable pyz equal pyz 23 | fix pressure all ave/time 1 1 1 v_pxy v_pxz v_pyz file dump_pressure.out 24 | 25 | dump 1 all custom 1000 dump_npt.lammpstrj id type xu yu zu x y z ix iy iz vx vy vz fx fy fz q 26 | 27 | velocity all create 300.0 4928459 28 | velocity all zero linear 29 | 30 | minimize 0.0 0.0 1000 100000 31 | 32 | # Fix NPT 33 | fix 1 all npt temp 300.0 300.0 100 iso 0 0 1000 34 | 35 | run 1000000 36 | write_data npt.data 37 | unfix 1 38 | undump 1 39 | 40 | dump 2 all custom 1000 dump_nvt.lammpstrj id type xu yu zu x y z ix iy iz vx vy vz fx fy fz q 41 | 42 | # Fix NVT 43 | fix 2 all nvt temp 300.0 300.0 10 44 | 45 | run 5000000 46 | write_data nvt.data 47 | -------------------------------------------------------------------------------- /configs/alignment_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "job_name": "alignment_bamboo_community", 3 | "training_data_path": "/data/train_data.pt", 4 | "validation_data_path": "/data/val_data.pt", 5 | "model": "/.pt", 6 | "frame_directories": [""], 7 | "mixture_names": [""], 8 | "delta_pressure": [-2000], 9 | "energy_ratio": 0.3, 10 | "force_ratio": 1.0, 11 | "virial_ratio": 0.1, 12 | "dipole_ratio": 3.0, 13 | "bulk_energy_ratio": 1e2, 14 | "bulk_force_ratio": 1e6, 15 | "bulk_virial_ratio": 3e3, 16 | "batch_size": 512, 17 | "epochs": 30, 18 | "frame_val_interval": 3, 19 | "max_frame_per_mixture": 30, 20 | "lr": 1e-12, 21 | "scheduler_gamma": 0.99 22 | } 23 | -------------------------------------------------------------------------------- /configs/ensemble_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "job_name": "ensemble_bamboo_community", 3 | "training_data_path": "/data/train_data.pt", 4 | "validation_data_path": "/data/val_data.pt", 5 | "batch_size": 512, 6 | "models": ["/.pt", "/.pt", "/.pt"], 7 | "frame_directories": [""], 8 | "ensemble_model": "/.pt", 9 | "validation_split_ratio": 0.1, 10 | "lr": 1e-6, 11 | "epochs": 50, 12 | "scheduler_gamma": 0.99, 13 | "validation_interval": 10, 14 | "energy_ratio": 0.3, 15 | "force_ratio": 1.0, 16 | "virial_ratio": 0.1, 17 | "bulk_energy_ratio": 0.01, 18 | "bulk_force_ratio": 3.0, 19 | "bulk_virial_ratio": 0.01, 20 | "max_frames_per_mixture": 960, 21 | "frame_validation_interval": 3 22 | } 23 | -------------------------------------------------------------------------------- /configs/train_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "job_name": "bamboo_community_train", 3 | "data_training": "train_data.pt", 4 | "data_validation": "val_data.pt", 5 | "random_seed": 42, 6 | "train_batch_size": 128, 7 | "val_batch_size": 128, 8 | "num_epoch": 1000, 9 | "lr": 0.01, 10 | "weight_decay": 0.001, 11 | "scheduler_gamma": 0.99, 12 | "loss_charge_ratio": 10.0, 13 | "loss_dipole_ratio": 10.0, 14 | "loss_energy_ratio": 0.01, 15 | "loss_forces_ratio": 0.3, 16 | "loss_virial_ratio": 0.01, 17 | "charge_ub": 2.0, 18 | "qeq_force_regularizer": 300.0, 19 | "num_layers": 3, 20 | "num_rbf": 32, 21 | "emb_dim": 64, 22 | "num_heads": 16, 23 | "rcut": 5.0, 24 | "coul_damping_beta": 18.7, 25 | "coul_damping_r0": 2.2, 26 | "disp_cutoff": 10.0, 27 | "charge_mlp_layers": 2, 28 | "energy_mlp_layers": 2 29 | } 30 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | -------------------------------------------------------------------------------- /models/bamboo_base.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | from typing import Dict, List, Optional 19 | 20 | import torch 21 | import torch.nn as nn 22 | from torch_runstats.scatter import scatter 23 | 24 | from utils.constant import debye_ea, ele_factor, ewald_a, ewald_f, ewald_p, nelems 25 | from utils.funcs import CosineCutoff, ExpNormalSmearing 26 | from models.modules.dftd3 import DFTD3CSO 27 | 28 | 29 | class BambooBase(torch.nn.Module): 30 | def __init__(self, device, 31 | nn_params = { 32 | 'dim': 64, 33 | 'num_rbf': 32, 34 | 'rcut': 5.0, 35 | 'charge_ub': 2.0, 36 | 'act_fn': nn.SiLU(), 37 | 'charge_mlp_layers': 2, 38 | 'energy_mlp_layers': 2 39 | }, 40 | coul_disp_params = { 41 | 'coul_damping_beta': 18.7, 42 | 'coul_damping_r0': 2.2, 43 | 'disp_cutoff': 10.0 44 | }): 45 | super(BambooBase, self).__init__() 46 | self.device = device 47 | self.nelems = nelems 48 | self.coul_disp_params = coul_disp_params 49 | 50 | # constants for ewald computation of coulomb forces 51 | self.ewald_f = ewald_f 52 | self.ewald_p = ewald_p 53 | self.ewald_a = ewald_a 54 | self.ele_factor = ele_factor 55 | self.debye_ea = debye_ea 56 | 57 | self.dispersion = DFTD3CSO(disp_cutoff=coul_disp_params['disp_cutoff']).to(device) 58 | 59 | self.dim = nn_params['dim'] 60 | self.num_rbf = nn_params['num_rbf'] 61 | self.rcut = nn_params['rcut'] 62 | self.charge_ub = nn_params['charge_ub'] 63 | self.atom_embtab = nn.Embedding(self.nelems, self.dim) 64 | self.dis_rbf = ExpNormalSmearing(0.0, self.rcut, self.num_rbf, device=self.device) 65 | self.dis_rbf.reset_parameters() 66 | self.cutoff = CosineCutoff(0.0, self.rcut) 67 | 68 | self.charge_mlp_layers = nn_params['charge_mlp_layers'] 69 | self.energy_mlp_layers = nn_params['energy_mlp_layers'] 70 | 71 | def get_mlp_layers(layers: int, dim: int): 72 | mlp_layers = [] 73 | for i in range(layers): 74 | if i == 0: 75 | mlp_layers.append(nn.Linear(dim, dim//2)) 76 | else: 77 | mlp_layers.append(nn.Linear(dim//2, dim//2)) 78 | mlp_layers.append(nn_params['act_fn']) 79 | mlp_layers.append(nn.Linear(dim//2, 1)) 80 | return mlp_layers 81 | 82 | self.charge_mlp = nn.Sequential(*get_mlp_layers(self.charge_mlp_layers, self.dim)) 83 | self.energy_mlp = nn.Sequential(*get_mlp_layers(self.energy_mlp_layers, self.dim)) 84 | self.pred_electronegativity_mlp = nn.Sequential(*get_mlp_layers(self.charge_mlp_layers, self.dim)) 85 | self.pred_electronegativity_hardness_mlp = nn.Sequential(*get_mlp_layers(self.charge_mlp_layers, self.dim)) 86 | 87 | self.coul_softplus = nn.Softplus(beta = coul_disp_params['coul_damping_beta']) 88 | self.nmol = 1 89 | 90 | self.to(self.device) 91 | 92 | def get_coulomb(self, 93 | row: torch.Tensor, 94 | col: torch.Tensor, 95 | dij: torch.Tensor, 96 | pred_charge: torch.Tensor, 97 | g_ewald: Optional[torch.Tensor] = None, 98 | ) -> List[torch.Tensor]: 99 | ''' 100 | Compute Coulomb energy and pairwise Coulomb force from predicted charge 101 | ''' 102 | rij = torch.sqrt(torch.sum(torch.square(dij), dim=-1)) 103 | prefactor_coul = self.ele_factor * pred_charge[row] * pred_charge[col] / rij 104 | beta, r0 = self.coul_disp_params['coul_damping_beta'], self.coul_disp_params['coul_damping_r0'] 105 | damp_coul = torch.sigmoid(beta / r0 * (rij - r0)) 106 | 107 | # damping of Coulomb energy and force 108 | softplus_coul = self.coul_softplus((rij - r0) / r0) 109 | ecoul = prefactor_coul * rij / r0 / (1 + softplus_coul) 110 | fcoul = prefactor_coul * damp_coul * (rij / r0 / (1 + softplus_coul))**2 111 | 112 | # compute erfc correction in inference, not available in trianing 113 | if g_ewald is not None: 114 | grij = g_ewald * rij 115 | expm2 = torch.exp(-grij * grij) 116 | t = 1.0 / (1.0 + self.ewald_p * grij) 117 | erfc = t * (self.ewald_a[0] + t * (self.ewald_a[1] + t * (self.ewald_a[2] + t * (self.ewald_a[3] + t * self.ewald_a[4])))) * expm2 118 | ecoul += prefactor_coul * (erfc - 1.0) 119 | fcoul += prefactor_coul * (erfc + self.ewald_f * grij * expm2 - 1.0) 120 | 121 | coul_fij = dij * (fcoul / rij / rij).unsqueeze(-1) 122 | return ecoul, coul_fij 123 | 124 | def graph_nn(self, 125 | node_feat: torch.Tensor, 126 | edge_index: torch.Tensor, 127 | coord_diff: torch.Tensor, 128 | radial: torch.Tensor, 129 | weights_rbf: torch.Tensor) -> torch.Tensor: 130 | ''' 131 | Graph neural network to update node features. 132 | Implemented in models/bamboo_get.py 133 | ''' 134 | raise NotImplementedError('graph_nn is not implemented') 135 | 136 | def energy_nn(self, inputs: Dict[str, torch.Tensor]) -> List[torch.Tensor]: 137 | node_feat = self.atom_embtab(inputs['atom_types']) 138 | coord_diff = inputs['edge_cell_shift'] 139 | radial = torch.sqrt(torch.sum(coord_diff**2, 1)) 140 | coord_diff = coord_diff / radial.unsqueeze(-1) 141 | weights_rbf = self.dis_rbf(radial) 142 | radial = self.cutoff(radial) 143 | 144 | # compute electronegativity and hardness from atom embeddings 145 | pred_electronegativity = self.pred_electronegativity_mlp(node_feat).squeeze(-1) 146 | pred_electronegativity_hardness = self.pred_electronegativity_hardness_mlp(node_feat).squeeze(-1) 147 | 148 | # GNN message passing 149 | node_feat = self.graph_nn(node_feat, inputs['edge_index'], coord_diff, radial, weights_rbf) 150 | 151 | # predict charge from atom embeddings and normalize the charges 152 | charge = self.charge_mlp(node_feat).squeeze(-1) 153 | charge = self.charge_ub * torch.tanh(charge / self.charge_ub) # an upper bound of atomic partial charge 154 | sum_charge = scatter(charge, inputs['mol_ids'], dim=0, dim_size=self.nmol) 155 | natoms = scatter(torch.ones_like(inputs['mol_ids'], dtype=torch.float32), inputs['mol_ids'], dim=0, dim_size=self.nmol) 156 | diff_charge = (inputs['total_charge'] - sum_charge)/natoms 157 | pred_charge = charge + torch.gather(diff_charge, 0, inputs['mol_ids']) # make sure summation of charges is preserved 158 | 159 | # compute electronegativity energy 160 | electronegativity_energy = pred_electronegativity**2 * pred_charge + \ 161 | pred_electronegativity_hardness**2 * pred_charge * pred_charge #using physical electronegative value "en_value" as starting point 162 | electronegativity_energy = scatter(electronegativity_energy, inputs['mol_ids'], dim=0, dim_size=self.nmol) 163 | 164 | # predict NN energy 165 | energy = self.energy_mlp(node_feat).squeeze(-1) 166 | nn_energy = scatter(energy, inputs['mol_ids'], dim=0, dim_size=self.nmol) 167 | 168 | return nn_energy, pred_charge, electronegativity_energy 169 | 170 | def get_loss(self, inputs: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: 171 | ''' 172 | Get MSE and MAE in training and validation. 173 | ''' 174 | pred = self.predict(inputs) 175 | 176 | # compute mean square errors with removing batch average in energy 177 | mse = dict() 178 | pred_energy_ave, label_energy_ave = torch.mean(pred['energy']), torch.mean(inputs['energy']) 179 | mse['energy'] = torch.mean(torch.square(pred['energy'] - inputs['energy'] - pred_energy_ave + label_energy_ave)) 180 | mse['forces'] = torch.mean(torch.square(pred['forces'] - inputs['forces'])) 181 | mse['virial'] = torch.mean(torch.square(pred['virial'] - inputs['virial'])) 182 | mse['charge'] = torch.mean(torch.square(pred['charge'] - inputs['charge'])) 183 | mse['dipole'] = torch.mean(torch.square(pred['dipole'] - inputs['dipole'])) 184 | 185 | # compute mean sabsolute errors with removing batch average in energy 186 | mae = dict() 187 | mae['energy'] = torch.mean(torch.abs(pred['energy'] - inputs['energy'] - pred_energy_ave + label_energy_ave)) 188 | mae['forces'] = torch.mean(torch.abs(pred['forces'] - inputs['forces'])) 189 | mae['virial'] = torch.mean(torch.abs(pred['virial'] - inputs['virial'])) 190 | mae['charge'] = torch.mean(torch.abs(pred['charge'] - inputs['charge'])) 191 | mae['dipole'] = torch.mean(torch.abs(pred['dipole'] - inputs['dipole'])) 192 | 193 | # compute charge equilibrium force penalty 194 | penalty = {} 195 | penalty['qeq_force'] = torch.mean(torch.square(pred['qeq_force'])) 196 | 197 | return mse, mae, penalty 198 | 199 | @torch.jit.export 200 | def predict(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 201 | ''' 202 | Used in training and validation 203 | Inputs units and shape: 204 | ------------------ 205 | edge_index: [2, Num_edges]. 206 | edge_cell_shift: [Num_edges, 3]. Unit: Angstrom 207 | all_edge_index: [2, Num_all_edges]. 208 | all_edge_cell_shift: [Num_all_edges, 3]. Unit: Angstrom 209 | atom_types: [Num_atoms]. torch.long 210 | total_charge: [Num_molecules] 211 | mol_ids: [Num_atoms]. 212 | 213 | Outputs units and shape: 214 | --------------------- 215 | energy: [Num_molecules]. Unit: Kcal/mol 216 | forces: [Num_atoms, 3]. Unit: Kcal/mol/Angstrom 217 | virial: [Num_molecules, 3, 3]. Unit: a.u. 218 | charge: [Num_atoms]. Unit: a.u. 219 | dipole: [Num_molecules, 3]. Unit: Debye 220 | qeq_forces: [Num_atoms, 3]. Unit: Kcal/mol/Angstrom. Residual in charge equilibrium which should be regularized to zero 221 | ''' 222 | 223 | # prepare data 224 | natoms = len(inputs['atom_types']) 225 | self.nmol = int(torch.max(inputs['mol_ids']).item()) + 1 226 | row, col = inputs['edge_index'][0], inputs['edge_index'][1] 227 | inputs['edge_cell_shift'].requires_grad_(True) 228 | 229 | # NN inference 230 | nn_energy, pred_charge, electronegativity_energy = self.energy_nn(inputs) 231 | 232 | # comute NN atom forces and virial 233 | grad_outputs : Optional[List[Optional[torch.Tensor]]] = [ torch.ones_like(nn_energy) ] 234 | nn_fij = torch.autograd.grad([nn_energy], [inputs['edge_cell_shift']], grad_outputs=grad_outputs, create_graph=True, allow_unused=True)[0] 235 | if nn_fij is None: # used for torch.jit.script 236 | nn_fij_cast = torch.zeros(size=inputs['edge_cell_shift'].size(), device=self.device) 237 | else: 238 | nn_fij_cast = -1.0 * nn_fij 239 | nn_forces = scatter(nn_fij_cast, row, dim=0, dim_size=natoms) - scatter(nn_fij_cast, col, dim=0, dim_size=natoms) 240 | nn_virial = nn_fij_cast.unsqueeze(-2) * inputs['edge_cell_shift'].unsqueeze(-1) 241 | nn_virial = scatter(scatter(nn_virial, row, dim=0, dim_size=natoms), inputs['mol_ids'], dim=0, dim_size=self.nmol) 242 | 243 | # compute Coulomb energy, forces and virial for all pairs 244 | row_all, col_all = inputs['all_edge_index'][0], inputs['all_edge_index'][1] 245 | ecoul, coul_fij = self.get_coulomb(row_all, col_all, inputs['all_edge_cell_shift'], pred_charge) 246 | coul_energy = 0.5 * scatter(scatter(ecoul, row_all, dim=0, dim_size=natoms), inputs['mol_ids'], dim=0, dim_size=self.nmol) 247 | coul_forces = scatter(coul_fij, row_all, dim=0, dim_size=natoms) 248 | coul_virial = 0.5 * scatter(scatter(coul_fij.unsqueeze(-2) * inputs['all_edge_cell_shift'].unsqueeze(-1), row_all, dim=0, dim_size=natoms), inputs['mol_ids'], dim=0, dim_size=self.nmol) 249 | 250 | # compute residual in charge equilibrium formula, which should be regularized to zero 251 | grad_outputs : Optional[List[Optional[torch.Tensor]]] = [ torch.ones_like(nn_energy) ] 252 | charge_energy = coul_energy + electronegativity_energy 253 | qeq_fij = torch.autograd.grad([charge_energy], [inputs['edge_cell_shift']], grad_outputs=grad_outputs, create_graph=True, allow_unused=True)[0] 254 | if qeq_fij is None: # used for torch.jit.script 255 | qeq_fij_cast = torch.zeros(size=inputs['edge_cell_shift'].size(), device=self.device) 256 | else: 257 | qeq_fij_cast = -1.0 * qeq_fij 258 | qeq_force = scatter(qeq_fij_cast, row, dim=0, dim_size=natoms) - scatter(qeq_fij_cast, col, dim=0, dim_size=natoms) 259 | 260 | # prepare output dictionary 261 | pred = dict() 262 | pred['energy'] = nn_energy + coul_energy + electronegativity_energy 263 | pred['forces'] = nn_forces + coul_forces 264 | pred['virial'] = nn_virial + coul_virial 265 | pred['charge'] = pred_charge 266 | pred['dipole'] = scatter(inputs['pos'] * pred_charge.unsqueeze(-1), inputs['mol_ids'], dim=0, dim_size=self.nmol) / self.debye_ea 267 | pred['qeq_force'] = qeq_force 268 | return pred 269 | 270 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 271 | ''' 272 | Used in LAMMPS inference 273 | 274 | Inputs: always float64. 275 | ---------------- 276 | edge_index: [2, Num_edges]. 277 | edge_cell_shift: [Num_edges, 3]. Unit: Angstrom 278 | coul_edge_index: [2, Num_coul_edges]. 279 | coul_edge_cell_shift: [Num_coul_edges, 3]. Unit: Angstrom 280 | disp_edge_index: [2, Num_disp_edges]. 281 | disp_edge_cell_shift: [Num_disp_edges, 3]. Unit: Angstrom 282 | atom_types: [Num_atoms]. torch.long 283 | g_ewald: [1]. g_ewald parameter in LAMMPS 284 | 285 | Outputs: always float64 286 | ----------------- 287 | pred_energy: [1]. Unit: Kcal/mol 288 | pred_forces: [Num_atoms, 3]. Unit: Kcal/mol/Angstrom 289 | pred_virial: [3, 3]. Unit: Kcal/mol 290 | coul_energy: [1]. Unit: Kcal/mol 291 | pred_charge: [Num_atoms]. Unit: a.u. 292 | ''' 293 | 294 | # Prepare input data 295 | for k in inputs.keys(): 296 | if torch.is_floating_point(inputs[k]): 297 | inputs[k] = inputs[k].to(torch.float32) 298 | 299 | natoms = len(inputs['atom_types']) 300 | self.nmol = 1 301 | inputs['total_charge'] = torch.zeros(1, dtype=torch.float32, device=self.device) 302 | inputs['mol_ids'] = torch.zeros(natoms, dtype=torch.long, device=self.device) 303 | row, col = inputs['edge_index'][0], inputs['edge_index'][1] 304 | inputs['edge_cell_shift'].requires_grad_(True) # Ne 305 | 306 | # NN inference 307 | nn_energy, pred_charge, electronegativity_energy = self.energy_nn(inputs) # 1, Na 308 | 309 | # comute NN atom forces and virial 310 | grad_outputs : Optional[List[Optional[torch.Tensor]]] = [ torch.ones_like(nn_energy) ] 311 | nn_fij = torch.autograd.grad([nn_energy], [inputs['edge_cell_shift']], grad_outputs=grad_outputs, create_graph=True, allow_unused=True)[0] 312 | if nn_fij is None: # used for torch.jit.script 313 | nn_fij_cast = torch.zeros(size=inputs['edge_cell_shift'].size(), device=self.device) 314 | else: 315 | nn_fij_cast = -1.0 * nn_fij 316 | nn_forces = scatter(nn_fij_cast, row, dim=0, dim_size=natoms) - scatter(nn_fij_cast, col, dim=0, dim_size=natoms) 317 | nn_virial = torch.sum(nn_fij_cast.unsqueeze(-2) * inputs['edge_cell_shift'].unsqueeze(-1), dim=0) 318 | 319 | # Coulomb energy, force and virial within cutoff 320 | row_coul, col_coul = inputs['coul_edge_index'][0], inputs['coul_edge_index'][1] 321 | ecoul, coul_fij = self.get_coulomb(row_coul, col_coul, inputs['coul_edge_cell_shift'], pred_charge, g_ewald=inputs['g_ewald']) 322 | coul_energy = 0.5 * torch.sum(ecoul) 323 | coul_forces = scatter(coul_fij, row_coul, dim=0, dim_size=natoms) 324 | coul_virial = 0.5 * torch.sum(coul_fij.unsqueeze(-2) * inputs['coul_edge_cell_shift'].unsqueeze(-1), dim=0) 325 | 326 | # dispersion energy, force and virial within cutoff 327 | disp_energy, disp_forces, disp_virial = self.dispersion(inputs['atom_types'], inputs['disp_edge_cell_shift'], inputs['disp_edge_index']) 328 | 329 | # prepare output dictionary and convert back to float64 330 | outputs = dict() 331 | outputs['pred_energy'] = nn_energy + coul_energy + disp_energy + electronegativity_energy 332 | outputs['pred_forces'] = nn_forces + coul_forces + disp_forces 333 | outputs['pred_virial'] = nn_virial + coul_virial + disp_virial 334 | outputs['pred_coul_energy'] = coul_energy 335 | outputs['pred_charge'] = pred_charge 336 | 337 | if 'edge_outer_mask' in inputs.keys(): 338 | outputs['nn_virial_outer'] = torch.sum(torch.sum(nn_fij_cast * inputs['edge_cell_shift'], dim=-1) * inputs['edge_outer_mask']) 339 | 340 | for k, v in outputs.items(): 341 | outputs[k] = v.to(torch.float64) 342 | return outputs -------------------------------------------------------------------------------- /models/bamboo_get.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | from typing import List 19 | 20 | import torch 21 | import torch.nn as nn 22 | from torch_runstats.scatter import scatter 23 | 24 | from models.bamboo_base import BambooBase 25 | 26 | 27 | class LinearAttnFirst(nn.Module): 28 | """ 29 | Graph Equivariant Transformer First Layer 30 | No node_vec in the input compared to middle layers 31 | """ 32 | def __init__(self, 33 | dim = 64, 34 | num_heads = 16, 35 | act_fn = nn.GELU()): 36 | super(LinearAttnFirst, self).__init__() 37 | self.qkv_proj = nn.Linear(dim, dim * 3) 38 | self.output_proj = nn.Linear(dim, dim) 39 | 40 | self.layer_norm = nn.LayerNorm(dim) 41 | self.dim = dim 42 | self.dim_per_head = dim // num_heads 43 | self.num_heads = num_heads 44 | self.attn_act = act_fn 45 | 46 | def qkv_attn(self, 47 | node_feat: torch.Tensor, 48 | row: torch.Tensor, 49 | col: torch.Tensor, 50 | ) -> List[torch.Tensor]: 51 | node_feat = self.layer_norm(node_feat) 52 | qkv = self.qkv_proj(node_feat) 53 | qkv = qkv.reshape(qkv.shape[:-1]+(self.num_heads, self.dim_per_head * 3)) 54 | 55 | q, k, v = qkv[...,:self.dim_per_head], qkv[...,self.dim_per_head:2*self.dim_per_head], qkv[...,2*self.dim_per_head:] 56 | q_row, k_col, v_col = q[row], k[col], v[col] 57 | return q_row, k_col, v_col 58 | 59 | def forward(self, 60 | node_feat: torch.Tensor, 61 | edge_feat: torch.Tensor, 62 | edge_vec: torch.Tensor, 63 | row: torch.Tensor, 64 | col: torch.Tensor, 65 | radial: torch.Tensor, 66 | natoms: int, 67 | ) -> List[torch.Tensor]: 68 | # attention layer 69 | q_row, k_col, v_col = self.qkv_attn(node_feat, row, col) 70 | attn = self.attn_act(torch.sum(q_row * k_col, dim=-1)) * radial.unsqueeze(-1) 71 | 72 | # update scalar messages 73 | m_feat = v_col * edge_feat * attn.unsqueeze(-1) 74 | m_feat = scatter(m_feat, row, dim=0, dim_size=natoms) 75 | m_feat = m_feat.reshape(m_feat.shape[:-2]+(self.dim,)) 76 | 77 | # update vector messages 78 | m_vec = v_col.unsqueeze(-3) * edge_vec 79 | m_vec = scatter(m_vec, row, dim=0, dim_size=natoms) 80 | delta_node_vec = m_vec.reshape(m_vec.shape[:-2]+(self.dim,)) 81 | 82 | # update scalar node features 83 | delta_node_feat = self.output_proj(m_feat) 84 | return delta_node_feat, delta_node_vec 85 | 86 | 87 | class LinearAttn(nn.Module): 88 | """ 89 | Graph Equivariant Transformer Layer 90 | """ 91 | def __init__(self, 92 | dim = 64, 93 | num_heads = 16, 94 | act_fn = nn.GELU()): 95 | super(LinearAttn, self).__init__() 96 | self.qkv_proj = nn.Linear(dim, dim * 3) 97 | self.output_proj = nn.Linear(dim, dim * 3) 98 | self.vec_proj = nn.Linear(dim, dim * 3, bias=False) 99 | 100 | self.layer_norm = nn.LayerNorm(dim) 101 | self.dim = dim 102 | self.dim_per_head = dim // num_heads 103 | self.num_heads = num_heads 104 | self.attn_act = act_fn 105 | 106 | def qkv_attn(self, 107 | node_feat: torch.Tensor, 108 | row: torch.Tensor, 109 | col: torch.Tensor, 110 | ) -> List[torch.Tensor]: 111 | node_feat = self.layer_norm(node_feat) 112 | qkv = self.qkv_proj(node_feat) 113 | qkv = qkv.reshape(qkv.shape[:-1]+(self.num_heads, self.dim_per_head * 3)) 114 | 115 | q, k, v = qkv[...,:self.dim_per_head], qkv[...,self.dim_per_head:2*self.dim_per_head], qkv[...,2*self.dim_per_head:] 116 | q_row, k_col, v_col = q[row], k[col], v[col] 117 | return q_row, k_col, v_col 118 | 119 | def forward(self, 120 | node_feat: torch.Tensor, 121 | edge_feat: torch.Tensor, 122 | node_vec: torch.Tensor, 123 | edge_vec: torch.Tensor, 124 | row: torch.Tensor, 125 | col: torch.Tensor, 126 | radial: torch.Tensor, 127 | natoms: int, 128 | ) -> List[torch.Tensor]: 129 | # attention layer 130 | q_row, k_col, v_col = self.qkv_attn(node_feat, row, col) 131 | attn = self.attn_act(torch.sum(q_row * k_col, dim=-1)) * radial.unsqueeze(-1) 132 | 133 | # preprocess node vectors; user inner product to produce scalar feature to ensure equivariance 134 | input_vec = self.vec_proj(node_vec) 135 | input_1, input_2, input_3 = input_vec[...,:self.dim], input_vec[...,self.dim:2*self.dim], input_vec[...,2*self.dim:] 136 | input_dot = (input_1 * input_2).sum(dim=-2) 137 | 138 | # update scalar messages 139 | m_feat = v_col * edge_feat * attn.unsqueeze(-1) 140 | m_feat = scatter(m_feat, row, dim=0, dim_size=natoms) 141 | m_feat = m_feat.reshape(m_feat.shape[:-2]+(self.dim,)) 142 | 143 | # update vector messages 144 | m_vec = v_col.unsqueeze(-3) * edge_vec 145 | m_vec = scatter(m_vec, row, dim=0, dim_size=natoms) 146 | m_vec = m_vec.reshape(m_vec.shape[:-2]+(self.dim,)) 147 | 148 | # update scalar node features 149 | output_feat = self.output_proj(m_feat) 150 | output_1, output_2, output_3 = output_feat[...,:self.dim], output_feat[...,self.dim:2*self.dim], output_feat[...,2*self.dim:] 151 | delta_node_feat = input_dot * output_2 + output_3 152 | 153 | # update node vectors 154 | delta_node_vec = input_3 * output_1.unsqueeze(-2) + m_vec 155 | 156 | return delta_node_feat, delta_node_vec 157 | 158 | 159 | class LinearAttnLast(nn.Module): 160 | """ 161 | Graph Equivariant Transformer Last Layer 162 | No node_vec output compared to middle layers 163 | """ 164 | def __init__(self, 165 | dim = 64, 166 | num_heads = 16, 167 | act_fn = nn.GELU()): 168 | super(LinearAttnLast, self).__init__() 169 | self.qkv_proj = nn.Linear(dim, dim * 3) 170 | self.output_proj = nn.Linear(dim, dim * 2) 171 | self.vec_proj = nn.Linear(dim, dim * 2, bias=False) 172 | 173 | self.layer_norm = nn.LayerNorm(dim) 174 | self.dim = dim 175 | self.dim_per_head = dim // num_heads 176 | self.num_heads = num_heads 177 | self.attn_act = act_fn 178 | 179 | def qkv_attn(self, 180 | node_feat: torch.Tensor, 181 | row: torch.Tensor, 182 | col: torch.Tensor, 183 | ) -> List[torch.Tensor]: 184 | node_feat = self.layer_norm(node_feat) 185 | qkv = self.qkv_proj(node_feat) 186 | qkv = qkv.reshape(qkv.shape[:-1]+(self.num_heads, self.dim_per_head * 3)) 187 | 188 | q, k, v = qkv[...,:self.dim_per_head], qkv[...,self.dim_per_head:2*self.dim_per_head], qkv[...,2*self.dim_per_head:] 189 | q_row, k_col, v_col = q[row], k[col], v[col] 190 | return q_row, k_col, v_col 191 | 192 | def forward(self, 193 | node_feat: torch.Tensor, 194 | edge_feat: torch.Tensor, 195 | node_vec: torch.Tensor, 196 | row: torch.Tensor, 197 | col: torch.Tensor, 198 | radial: torch.Tensor, 199 | natoms: int, 200 | ) -> torch.Tensor: 201 | # attention layer 202 | q_row, k_col, v_col = self.qkv_attn(node_feat, row, col) 203 | attn = self.attn_act(torch.sum(q_row * k_col, dim=-1)) * radial.unsqueeze(-1) 204 | 205 | # preprocess node vectors; user inner product to produce scalar feature to ensure equivariance 206 | input_vec = self.vec_proj(node_vec) 207 | input_1, input_2= input_vec[...,:self.dim], input_vec[...,self.dim:] 208 | input_dot = (input_1 * input_2).sum(dim=-2) 209 | 210 | # update scalar messages 211 | m_feat = v_col * edge_feat * attn.unsqueeze(-1) 212 | m_feat = scatter(m_feat, row, dim=0, dim_size=natoms) 213 | m_feat = m_feat.reshape(m_feat.shape[:-2]+(self.dim,)) 214 | 215 | # update scalar node features 216 | output_feat = self.output_proj(m_feat) 217 | output_2, output_3 = output_feat[...,:self.dim], output_feat[...,self.dim:] 218 | delta_node_feat = input_dot * output_2 + output_3 219 | 220 | return delta_node_feat 221 | 222 | 223 | class BambooGET(BambooBase): 224 | def __init__(self, device, coul_disp_params, nn_params, 225 | gnn_params = { 226 | 'n_layers': 3, 227 | 'num_heads': 16, 228 | 'act_fn': nn.GELU(), 229 | }): 230 | super(BambooGET, self).__init__(device=device, nn_params=nn_params, coul_disp_params=coul_disp_params) 231 | self.n_layers = gnn_params['n_layers'] 232 | self.num_heads = gnn_params['num_heads'] 233 | self.dim_per_head = self.dim // self.num_heads 234 | self.act_fn = gnn_params['act_fn'] 235 | self.rbf_proj = nn.Sequential( 236 | nn.Linear(self.num_rbf, self.dim, bias=False), 237 | self.act_fn 238 | ) 239 | self.first_attn = LinearAttnFirst(self.dim, self.num_heads, self.act_fn) 240 | self.attns = nn.ModuleList([ 241 | LinearAttn(self.dim, self.num_heads, self.act_fn) for _ in range(self.n_layers-2) 242 | ]) 243 | self.last_attn = LinearAttnLast(self.dim, self.num_heads, self.act_fn) 244 | self.apply(self._init_weights) 245 | self.to(self.device) 246 | 247 | def _init_weights(self, module): 248 | if isinstance(module, nn.Linear): 249 | nn.init.xavier_uniform_(module.weight) 250 | if module.bias is not None: 251 | module.bias.data.zero_() 252 | 253 | def graph_nn(self, 254 | node_feat: torch.Tensor, 255 | edge_index: torch.Tensor, 256 | coord_diff: torch.Tensor, 257 | radial: torch.Tensor, 258 | weights_rbf: torch.Tensor, 259 | ) -> torch.Tensor: 260 | # compute initial edge feature and edge vector 261 | edge_feat = self.rbf_proj(weights_rbf) 262 | edge_feat = edge_feat.reshape(edge_feat.shape[:-1]+(self.num_heads, self.dim_per_head)) 263 | edge_vec = edge_feat.unsqueeze(-3) * coord_diff.unsqueeze(-1).unsqueeze(-1) 264 | row, col = edge_index[0], edge_index[1] 265 | natoms = node_feat.shape[0] 266 | 267 | # first GET layer 268 | delta_node_feat, delta_node_vec = self.first_attn(node_feat, edge_feat, edge_vec, row, col, radial, natoms) 269 | node_feat = node_feat + delta_node_feat 270 | node_vec = delta_node_vec 271 | 272 | # middle GET layerss 273 | for attn in self.attns: 274 | delta_node_feat, delta_node_vec = attn(node_feat, edge_feat, node_vec, edge_vec, row, col, radial, natoms) 275 | node_feat = node_feat + delta_node_feat 276 | node_vec = node_vec + delta_node_vec 277 | 278 | # last GEt layer 279 | delta_node_feat = self.last_attn(node_feat, edge_feat, node_vec, row, col, radial, natoms) 280 | node_feat = node_feat + delta_node_feat 281 | return node_feat -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | -------------------------------------------------------------------------------- /models/modules/dftd3/__init__.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | from .dftd3 import DFTD3CSO 19 | -------------------------------------------------------------------------------- /models/modules/dftd3/dftd3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/bamboo/a42ed5c42e442e484b2e1aabb29cc8b3cd3b65e1/models/modules/dftd3/dftd3.pt -------------------------------------------------------------------------------- /models/modules/dftd3/dftd3.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import os 19 | from typing import Optional 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from utils import constant 25 | from torch_runstats.scatter import scatter 26 | 27 | INITIAL_PARAMS_PATH = os.path.join(os.path.dirname(__file__), 'dftd3.pt') 28 | 29 | 30 | class DFTD3Params(nn.Module): 31 | def __init__(self, dtype=torch.double): 32 | super().__init__() 33 | self.default_params = torch.load(INITIAL_PARAMS_PATH, map_location='cpu') 34 | 35 | self.ntype = 90 36 | self.nweight = 5 37 | 38 | self.sqrt_z_r4r2 = nn.Parameter(torch.empty(self.ntype, dtype=dtype), requires_grad=False) 39 | self.cov_d3 = nn.Parameter(torch.empty(self.ntype, dtype=dtype), requires_grad=False) 40 | self.cn = nn.Parameter(torch.empty(self.ntype, self.nweight, dtype=dtype), requires_grad=False) 41 | self.c6 = nn.Parameter(torch.empty(self.ntype, self.ntype, self.nweight, self.nweight, dtype=dtype), requires_grad=False) 42 | 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | with torch.no_grad(): 47 | self.sqrt_z_r4r2.copy_(self.default_params['sqrt_z_r4_over_r2'].to(self.sqrt_z_r4r2)) 48 | self.cov_d3.copy_(self.default_params['cov_d3'].to(self.cov_d3)) 49 | self.cn.copy_(self.default_params['cn'].to(self.cn)) 50 | self.c6.copy_(self.default_params['c6'].to(self.c6)) 51 | 52 | 53 | class DFTD3Base(nn.Module): 54 | def __init__(self, dftd3_params=None, disp_cutoff=50.0): 55 | super().__init__() 56 | self.disp_cutoff = disp_cutoff / constant.bohr_angstrom # disp_cutoff is specified in angstrom 57 | self.params = dftd3_params or DFTD3Params() 58 | 59 | def reset_parameters(self): 60 | self.params.reset_parameters() 61 | 62 | def cn_d3(self, atom_types, row, col, rij): 63 | KCN = 16.0 64 | rcov = self.params.cov_d3[atom_types] 65 | rc = rcov[row] + rcov[col] 66 | cf = 1.0 / (1.0 + torch.exp(-KCN * (rc / rij - 1.0))) 67 | cn = scatter(cf, row, dim=0, dim_size=atom_types.shape[0]) + scatter(cf, col, dim=0, dim_size=atom_types.shape[0]) 68 | d_cn = -KCN * cf * (1.0 - cf) * rc / (rij ** 2) 69 | return cn, d_cn 70 | 71 | def weight_references(self, atom_types, cn): 72 | refcn = self.params.cn[atom_types] 73 | dcn = refcn - cn.unsqueeze(-1) 74 | 75 | dcn = dcn.double() 76 | 77 | factor = 4.0 78 | weights = torch.exp(-factor * dcn.pow(2)) 79 | 80 | norm = weights.sum(dim=-1, keepdim=True) 81 | normalized_weights = weights / norm 82 | 83 | d_weights = 2.0 * factor * dcn * weights 84 | norm_d = d_weights.sum(dim=-1, keepdim=True) 85 | d_1_over_norm = - norm_d / norm ** 2 86 | d_normalized_weights = d_weights / norm + weights * d_1_over_norm 87 | 88 | normalized_weights = normalized_weights.type(cn.dtype) 89 | d_normalized_weights = d_normalized_weights.type(cn.dtype) 90 | assert normalized_weights.isnan().sum() == 0 91 | assert d_normalized_weights.isnan().sum() == 0 92 | 93 | one = torch.tensor(1.0, device=cn.device, dtype=cn.dtype) 94 | zero = torch.tensor(0.0, device=cn.device, dtype=cn.dtype) 95 | maxcn = torch.max(refcn, dim=-1, keepdim=True)[0] 96 | is_exceptional = (norm == 0) | (normalized_weights > 1e50) 97 | normalized_weights = torch.where(is_exceptional, torch.where(refcn == maxcn, one, zero), normalized_weights) 98 | d_normalized_weights = torch.where(is_exceptional, zero, d_normalized_weights) 99 | 100 | return normalized_weights, d_normalized_weights 101 | 102 | def compute_c6(self, atom_types, weights, dweights, row, col): 103 | rc6 = self.params.c6[atom_types[col], atom_types[row]].contiguous() 104 | 105 | rc6_mask = rc6 != 0.0 106 | rc6 = rc6 * rc6_mask 107 | 108 | rc6_wc = (rc6 * weights[col].unsqueeze(-1)).sum(dim=-2) 109 | rc6_wr = (rc6 * weights[row].unsqueeze(-2)).sum(dim=-1) 110 | 111 | c6ij = (weights[row] * rc6_wc).sum(dim=-1) 112 | 113 | d_c6ij_dcni = (dweights[row] * rc6_wc).sum(dim=-1) 114 | d_c6ij_dcnj = (dweights[col] * rc6_wr).sum(dim=-1) 115 | return c6ij, d_c6ij_dcni, d_c6ij_dcnj 116 | 117 | def forward( 118 | self, 119 | atom_types: torch.Tensor, 120 | dij: torch.Tensor, 121 | edge_index: torch.Tensor, 122 | ): 123 | # note that cutoff is already applied in LAMMPS 124 | edge_mask = edge_index[1] > edge_index[0] 125 | edge_index = edge_index[:, edge_mask] 126 | dij = dij[edge_mask] 127 | 128 | # kernel 129 | row, col = edge_index.unbind(0) 130 | dij = dij / constant.bohr_angstrom 131 | rij = dij.norm(dim=-1, p=2) 132 | 133 | cn, d_cn_d_rij = self.cn_d3(atom_types, row, col, rij) 134 | weights, d_weights_d_cn = self.weight_references(atom_types, cn) 135 | c6ij, d_c6ij_d_cni, d_c6ij_d_cnj = self.compute_c6(atom_types, weights, d_weights_d_cn, row, col) 136 | 137 | num_r4r2 = self.params.sqrt_z_r4r2[atom_types] 138 | kij = num_r4r2[row] * num_r4r2[col] * 3 139 | r0ij = torch.sqrt(kij) 140 | 141 | # call the implementation of damping function in subclass 142 | edisp, gdisp = self.compute_dispersion(rij, r0ij, kij) 143 | 144 | # results w/o contributions from cn 145 | energy = scatter(- edisp * c6ij, row, dim=0, dim_size=atom_types.shape[0]) 146 | dE_drij = - c6ij * gdisp 147 | 148 | # add contributions from cn 149 | dE_dcn = ( 150 | scatter(-d_c6ij_d_cni * edisp, row, dim=0, dim_size=atom_types.shape[0]) 151 | + scatter(-d_c6ij_d_cnj * edisp, col, dim=0, dim_size=atom_types.shape[0]) 152 | ) 153 | d_E_d_cn_d_cn_d_rij = d_cn_d_rij * dE_dcn[edge_index].sum(dim=0) 154 | dE_drij = dE_drij + d_E_d_cn_d_cn_d_rij 155 | 156 | # calculate forces and virial 157 | dE_ddij = (dE_drij / rij).unsqueeze(-1) * dij 158 | gradients = scatter(dE_ddij, row, dim=0, dim_size=atom_types.shape[0]) + scatter(-dE_ddij, col, dim=0, dim_size=atom_types.shape[0]) 159 | forces = - gradients 160 | virial = - (dE_ddij.unsqueeze(-2) * dij.unsqueeze(-1)) 161 | 162 | # aggregate to mol level 163 | energy = energy.sum() 164 | virial = virial.sum(dim=0) 165 | 166 | # convert the units back to kcal/mol and angstrom 167 | energy = energy * constant.hartree_kcal_mol 168 | forces = forces * constant.hartree_kcal_mol / constant.bohr_angstrom 169 | virial = virial * constant.hartree_kcal_mol 170 | 171 | return energy, forces, virial 172 | 173 | 174 | class DFTD3CSO(DFTD3Base): 175 | def __init__(self, s6=1.0, a1=0.86, *args, **kwargs): 176 | super().__init__(*args, **kwargs) 177 | self.s6 = s6 178 | self.a1 = a1 179 | 180 | def compute_dispersion(self, rij, r0ij, kij): 181 | # note that the a4=6.25 is changed to 7.75 here 182 | e = torch.exp(rij - 2.5 * r0ij) 183 | e1 = 1 + e 184 | t = 1 / (rij ** 6 + 7.75 ** 6) 185 | m = self.s6 + self.a1 / e1 186 | dt = - 6 * rij ** 5 * (t ** 2) 187 | dm = - self.a1 * e / (e1**2) 188 | edisp = m * t 189 | fdisp = dm * t + m * dt 190 | return edisp, fdisp 191 | -------------------------------------------------------------------------------- /pair/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Exit immediately if a command exits with a non-zero status, and print each command. 3 | set -ex 4 | 5 | # Define the working directory based on the script's location. 6 | WORK_DIR="$(dirname "$(readlink -f "$0")")" 7 | cd "${WORK_DIR}" 8 | 9 | # CUDA settings 10 | CUDA_PATH="/usr/local/cuda" 11 | export PATH="$PATH:${CUDA_PATH}/bin" 12 | 13 | # Check CUDA compiler version 14 | nvcc --version 15 | 16 | # Define directories 17 | BUILD_DIR="${WORK_DIR}/build" 18 | OUT_DIR="${WORK_DIR}/output" 19 | CMAKE_DIR="${WORK_DIR}/cmake" 20 | 21 | # Clean up previous build and output directories, then recreate them. 22 | rm -rf "${OUT_DIR}" && mkdir "${OUT_DIR}" 23 | rm -rf "${BUILD_DIR}" && mkdir "${BUILD_DIR}" && cd "${BUILD_DIR}" 24 | 25 | # CMake configuration 26 | CMAKE_PRESETS_PATH="../cmake/presets" 27 | CMAKE_BASIC_PRESET="${CMAKE_PRESETS_PATH}/basic.cmake" 28 | CMAKE_KOKKOS_CUDA_PRESET="${CMAKE_PRESETS_PATH}/kokkos-cuda.cmake" 29 | 30 | cmake "${CMAKE_DIR}" \ 31 | -C "${CMAKE_BASIC_PRESET}" \ 32 | -C "${CMAKE_KOKKOS_CUDA_PRESET}" \ 33 | -DCMAKE_BUILD_TYPE=Release \ 34 | -DCMAKE_CXX_STANDARD=17 \ 35 | -DCMAKE_CUDA_ARCHITECTURES=89 \ 36 | -DCMAKE_LIBRARY_PATH="${CUDA_PATH}/lib64/" \ 37 | -DMKL_INCLUDE_DIR="/usr/include" \ 38 | -DBUILD_TESTING=OFF \ 39 | -DCUDAToolkit_ROOT="${CUDA_PATH}" \ 40 | -DKokkos_ARCH_PASCAL60=OFF \ 41 | -DKokkos_ARCH_ADA89=ON \ 42 | -DKokkos_CUDA_DIR="${CUDA_PATH}" \ 43 | -DKokkos_ENABLE_OPENMP=ON \ 44 | -DPKG_GPU=yes \ 45 | -DGPU_API=cuda \ 46 | -DGPU_ARCH=sm_89 \ 47 | -DTorch_DIR="/opt/libtorch/share/cmake/Torch" \ 48 | -DBIN2C="${CUDA_PATH}/bin/bin2c" \ 49 | -DFFT=FFTW3 \ 50 | -DPKG_KOKKOS=ON \ 51 | -DPKG_KSPACE=ON 52 | 53 | # Compile with all available cores 54 | make -j 55 | 56 | # Move the compiled binary to the output directory 57 | mv ./lmp "${OUT_DIR}/lmp" 58 | 59 | echo "Compile finished." 60 | -------------------------------------------------------------------------------- /pair/init_compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xe 4 | 5 | # Step 0: Install required packages. 6 | python3 -m pip install -U --no-cache-dir pip setuptools cython cmake torch_runstats numpy pandas 7 | apt-get update -y 8 | apt-get install -y zip gfortran libgtest-dev libopenblas-dev libfftw3-dev libfftw3-double3 libfftw3-single3 libfftw3-3 libfftw3-bin 9 | 10 | # Define working directory 11 | WORK_DIR=$(dirname "$(readlink -f "$0")") 12 | cd ${WORK_DIR} 13 | 14 | # Determine PyTorch and CUDA versions 15 | PYTORCH_VERSION="2.1.0" 16 | CUDA_VERSION="12.1" 17 | 18 | # Define libtorch download URL 19 | LIBTORCH_URL="https://download.pytorch.org/libtorch/cu${CUDA_VERSION}/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2Bcu${CUDA_VERSION}.zip" 20 | echo "Libtorch url: ${LIBTORCH_URL}" 21 | 22 | # Download and extract libtorch if not already installed 23 | TORCH_CMAKE='/opt/libtorch/share/cmake/Torch/TorchConfig.cmake' 24 | if [ ! -f "${TORCH_CMAKE}" ]; then 25 | wget -O libtorch.zip "${LIBTORCH_URL}" && unzip libtorch.zip -d /opt && rm libtorch.zip 26 | echo "libtorch downloaded and extracted to /opt/libtorch." 27 | else 28 | echo "libtorch is already installed in /opt/libtorch." 29 | fi 30 | 31 | # Clone LAMMPS repository if CMakeLists.txt is not found 32 | CMAKE_PATH="./lammps/cmake/CMakeLists.txt" 33 | if [ ! -f "${CMAKE_PATH}" ]; then 34 | git clone --depth 1 https://github.com/lammps/lammps.git -b stable_2Aug2023_update3 35 | echo "LAMMPS cloned." 36 | else 37 | echo "LAMMPS CMakeLists.txt found. Skipping clone." 38 | fi 39 | 40 | if [ ! -f "${CMAKE_PATH}" ]; then 41 | echo "Clone failed." 42 | exit 1 43 | fi 44 | 45 | # Update build.sh in lammps directory 46 | cp -u ./build.sh ./lammps/ 47 | 48 | # Copy custom pair files into LAMMPS src directory if not already present 49 | for file in pair_bamboo.cpp pair_bamboo.h; do 50 | if [ ! -f "./lammps/src/${file}" ]; then 51 | cp "./src/${file}" ./lammps/src/ 52 | fi 53 | done 54 | 55 | # Copy custom KOKKOS pair files into LAMMPS KOKKOS directory if not already present 56 | for file in pair_bamboo_kokkos.cpp pair_bamboo_kokkos.h; do 57 | if [ ! -f "./lammps/src/KOKKOS/${file}" ]; then 58 | cp "./src/${file}" ./lammps/src/KOKKOS/ 59 | fi 60 | done 61 | 62 | # Append Torch configuration to CMakeLists.txt if not already done 63 | if ! grep -q "find_package(Torch REQUIRED)" "$CMAKE_PATH"; then 64 | cat >> "$CMAKE_PATH" << EOF 65 | 66 | # Find the Torch package 67 | find_package(Torch REQUIRED) 68 | 69 | # Add the Torch CXX flags to the compilation options 70 | set(CMAKE_CXX_FLAGS "\${CMAKE_CXX_FLAGS} \${TORCH_CXX_FLAGS}") 71 | 72 | # Link the target against the Torch libraries 73 | target_link_libraries(lammps PUBLIC "\${TORCH_LIBRARIES}") 74 | EOF 75 | echo "Torch configuration appended." 76 | else 77 | echo "Torch configuration already present." 78 | fi 79 | -------------------------------------------------------------------------------- /pair/src/pair_bamboo.cpp: -------------------------------------------------------------------------------- 1 | /* ------------------------------------------------------------------------- 2 | ----- BAMBOO: Bytedance AI Molecular Booster ----- 3 | Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 4 | 5 | This program is free software; you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation; either version 2 of the License, or 8 | (at your option) any later version. 9 | 10 | This program is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with this program; if not, write to the Free Software 17 | Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 18 | ------------------------------------------------------------------------- */ 19 | 20 | #include 21 | #include "atom.h" 22 | #include "comm.h" 23 | #include "domain.h" 24 | #include "error.h" 25 | #include "force.h" 26 | #include "memory.h" 27 | #include "kspace.h" 28 | #include "neigh_list.h" 29 | #include "neigh_request.h" 30 | #include "neighbor.h" 31 | #include "potential_file_reader.h" 32 | #include "tokenizer.h" 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #include 42 | #include 43 | #include 44 | #include 45 | #include 46 | #include 47 | 48 | 49 | using namespace LAMMPS_NS; 50 | 51 | PairBAMBOO::PairBAMBOO(LAMMPS *lmp) : Pair(lmp) { 52 | restartinfo = 0; 53 | manybody_flag = 1; 54 | evflag = 1; 55 | msmflag = 1; 56 | ewaldflag = 1; 57 | pppmflag = 1; 58 | 59 | int device_count = torch::cuda::device_count(); 60 | if (device_count == 0) { 61 | error->all(FLERR,"pair_bamboo: no GPUs available"); 62 | } 63 | 64 | int cuda_device_id = -1; 65 | if (comm->nprocs > 1) { 66 | if (comm->nprocs > device_count) { 67 | error->all(FLERR,"pair_bamboo: mismatch between number of ranks and number of available GPUs"); 68 | } 69 | MPI_Comm shmcomm; 70 | MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, 71 | MPI_INFO_NULL, &shmcomm); 72 | int shmrank; 73 | MPI_Comm_rank(shmcomm, &shmrank); 74 | cuda_device_id = shmrank; 75 | } 76 | 77 | device = c10::Device(torch::kCUDA, cuda_device_id); 78 | 79 | std::cout << "BAMBOO is using device " << device << "\n"; 80 | 81 | if(const char* env_p = std::getenv("BAMBOO_DEBUG")){ 82 | if (strcmp(env_p, "1") == 0) 83 | std::cout << "Found env BAMBOO_DEBUG=1, Pair BAMBOO is in DEBUG mode.\n"; 84 | debug_mode = 1; 85 | } 86 | 87 | if(const char* env_p = std::getenv("BAMBOO_TIMER")){ 88 | if (strcmp(env_p, "1") == 0) 89 | std::cout << "Found env BAMBOO_TIMER=1, Pair BAMBOO show timer.\n"; 90 | bamboo_timer = 1; 91 | t_start = std::chrono::high_resolution_clock::now(); 92 | } 93 | } 94 | 95 | PairBAMBOO::~PairBAMBOO(){ 96 | if (allocated) { 97 | memory->destroy(setflag); 98 | memory->destroy(cutsq); 99 | } 100 | } 101 | 102 | void PairBAMBOO::init_style(){ 103 | // Error check for parameters. 104 | if (atom->tag_enable == 0){ 105 | error->all(FLERR,"Pair style BAMBOO requires atom IDs"); 106 | } 107 | 108 | if (force->newton_pair == 1) { 109 | error->all(FLERR,"Pair style BAMBOO requires newton pair off"); 110 | } 111 | 112 | // need a full neighbor list 113 | int irequest = neighbor->request(this,instance_me); 114 | auto req = neighbor->requests[irequest]; 115 | req->enable_full(); 116 | 117 | // Safely access and store the Ewald summation accuracy parameter, if applicable 118 | if (force->kspace && std::isnormal(force->kspace->g_ewald)) { 119 | g_ewald = force->kspace->g_ewald; 120 | } else { 121 | g_ewald = 0.0; // Default to 0.0 if kspace is not used or g_ewald is not valid 122 | fmt::print("pair_bamboo: KSAPCE is not set or is not valid. Defaulting to 0.0 for g_ewald.\n"); 123 | } 124 | } 125 | 126 | void *PairBAMBOO::extract(const char *str, int &dim) 127 | { 128 | dim = 0; 129 | return (void *) &cutoff_coul; 130 | } 131 | 132 | double PairBAMBOO::init_one(int i, int j) 133 | { 134 | return cutoff_max; 135 | } 136 | 137 | void PairBAMBOO::allocate() 138 | { 139 | allocated = 1; 140 | int n = atom->ntypes; 141 | 142 | memory->create(setflag, n+1, n+1, "pair:setflag"); 143 | memory->create(cutsq, n+1, n+1, "pair:cutsq"); 144 | } 145 | 146 | void PairBAMBOO::settings(int narg, char **arg) { 147 | constexpr int MinArgs = 3; 148 | constexpr int MaxArgs = 4; 149 | constexpr int DirectEdgeMode = 1; 150 | constexpr int UndirectedEdgeMode = 0; 151 | 152 | // Validate the number of arguments 153 | if (narg < MinArgs || narg > MaxArgs) { 154 | error->all(FLERR, "Illegal pair_style command"); 155 | } 156 | 157 | // Parse cutoff values 158 | cutoff_net = utils::numeric(FLERR, arg[0], false, lmp); 159 | cutoff_coul = utils::numeric(FLERR, arg[1], false, lmp); 160 | cutoff_disp = utils::numeric(FLERR, arg[2], false, lmp); 161 | 162 | // Determine edge mode, defaulting to undirected if not specified 163 | edge_mode = (narg == MaxArgs) ? utils::inumeric(FLERR, arg[3], false, lmp) : DirectEdgeMode; 164 | 165 | switch (edge_mode) { 166 | case DirectEdgeMode: 167 | fmt::print("Using direct edge mode\n"); 168 | break; 169 | case UndirectedEdgeMode: 170 | fmt::print("Using undirected edge mode\n"); 171 | break; 172 | default: 173 | error->all(FLERR, "Illegal edge mode"); 174 | break; 175 | } 176 | 177 | // Compute the maximum cutoff 178 | cutoff_max = std::max({cutoff_net, cutoff_coul, cutoff_disp}); 179 | 180 | // Display cutoff values 181 | fmt::print("cutoff_net: {}, cutoff_coul: {}, cutoff_disp: {}\n", 182 | cutoff_net, cutoff_coul, cutoff_disp); 183 | 184 | // Compute squared cutoff values for later use 185 | cutoff_coul_sq = cutoff_coul * cutoff_coul; 186 | cutoff_disp_sq = cutoff_disp * cutoff_disp; 187 | cutoff_net_sq = cutoff_net * cutoff_net; 188 | cutoff_max_sq = cutoff_max * cutoff_max; 189 | 190 | } 191 | 192 | void PairBAMBOO::coeff(int narg, char **arg) { 193 | 194 | // Allocate memory if not already done 195 | if (!allocated) allocate(); 196 | 197 | const int ntypes = atom->ntypes; 198 | 199 | // Ensure there is exactly one argument for each type plus the model file name 200 | if (narg != ntypes + 1) { 201 | error->all(FLERR, "Incorrect args for pair coefficients"); 202 | } 203 | 204 | // Clear previous settings 205 | for (int i = 1; i <= ntypes; i++) { 206 | for (int j = i; j <= ntypes; j++) { 207 | setflag[i][j] = 0; 208 | } 209 | } 210 | 211 | std::vector elements(ntypes); 212 | for(int i = 0; i < ntypes; i++){ 213 | elements[i] = arg[i+1]; 214 | } 215 | 216 | // to construct a type mapper from LAMMPS type to Bamboo atom_types 217 | std::unordered_map symbol_to_index = { 218 | {"H", 1}, {"He", 2}, {"LI", 3}, {"Li", 3}, {"Be", 4}, {"B", 5}, {"C", 6}, {"N", 7}, 219 | {"O", 8}, {"F", 9}, {"Ne", 10}, {"Na", 11}, {"Mg", 12}, {"Al", 13}, {"Si", 14}, 220 | {"P", 15}, {"S", 16}, {"Cl", 17}, {"Ar", 18}, {"K", 19}, {"Ca", 20}, {"Sc", 21}, 221 | {"Ti", 22}, {"V", 23}, {"Cr", 24}, {"Mn", 25}, {"Fe", 26}, {"Co", 27}, {"Ni", 28}, 222 | {"Cu", 29}, {"Zn", 30}, {"Ga", 31}, {"Ge", 32}, {"As", 33}, {"Se", 34}, {"Br", 35}, 223 | {"Kr", 36}, {"Rb", 37}, {"Sr", 38}, {"Y", 39}, {"Zr", 40}, {"Nb", 41}, {"Mo", 42}, 224 | {"Tc", 43}, {"Ru", 44}, {"Rh", 45}, {"Pd", 46}, {"Ag", 47}, {"Cd", 48}, {"In", 49}, 225 | {"Sn", 50}, {"Sb", 51}, {"Te", 52}, {"I", 53}, {"Xe", 54}, {"Cs", 55}, {"Ba", 56}, 226 | {"La", 57}, {"Ce", 58}, {"Pr", 59}, {"Nd", 60}, {"Pm", 61}, {"Sm", 62}, {"Eu", 63}, 227 | {"Gd", 64}, {"Tb", 65}, {"Dy", 66}, {"Ho", 67}, {"Er", 68}, {"Tm", 69}, {"Yb", 70}, 228 | {"Lu", 71}, {"Hf", 72}, {"Ta", 73}, {"W", 74}, {"Re", 75}, {"Os", 76}, {"Ir", 77}, 229 | {"Pt", 78}, {"Au", 79}, {"Hg", 80}, {"Tl", 81}, {"Pb", 82}, {"Bi", 83}, {"Po", 84}, 230 | {"At", 85}, {"Rn", 86} 231 | }; 232 | std::cout << "Construct type mapper:" << "\n"; 233 | 234 | // Initiate type mapper 235 | for (int i = 0; i< ntypes; i++){ 236 | std::cout << "i: " << i << " symbol: " << elements[i] << " index: " << symbol_to_index[elements[i]] << "\n"; 237 | type_mapper.push_back(symbol_to_index[elements[i]]); 238 | } 239 | 240 | std::cout << "Loading model from " << arg[0] << "\n"; 241 | 242 | try { 243 | model = torch::jit::load(arg[0]); 244 | } catch (const c10::Error& e) { 245 | error->all(FLERR, "Failed to load the model"); 246 | } 247 | model.eval(); 248 | 249 | // disable fusion. 250 | torch::jit::setGraphExecutorOptimize(false); 251 | 252 | torch::set_default_dtype(caffe2::TypeMeta::Make()); 253 | 254 | // Set whether to allow TF32 255 | bool allow_tf32 = false; 256 | at::globalContext().setAllowTF32CuBLAS(allow_tf32); 257 | at::globalContext().setAllowTF32CuDNN(allow_tf32); 258 | 259 | // set setflag i,j for type pairs where both are mapped to elements 260 | for (int i = 1; i <= ntypes; i++) { 261 | for (int j = i; j <= ntypes; j++) { 262 | if ((type_mapper[i-1] >= 0) && (type_mapper[j-1] >= 0)) { 263 | setflag[i][j] = 1; 264 | } 265 | } 266 | } 267 | 268 | } 269 | 270 | 271 | void PairBAMBOO::click_timer(const char *str){ 272 | // Return immediately if the bamboo_timer is not enabled 273 | if (!bamboo_timer) { 274 | return; 275 | } 276 | 277 | // Capture the current time as the end point 278 | t_end = std::chrono::high_resolution_clock::now(); 279 | 280 | // Calculate the duration in milliseconds since the last checkpoint 281 | auto durationInMillis = std::chrono::duration_cast(t_end - t_start).count() / 1000.0; 282 | 283 | // Log the timer tag, counter, and the duration 284 | fmt::print("[tag: {}], [steps: {}], [Time used: {} ms]\n", str, timer_counter, durationInMillis); 285 | 286 | // Reset the start time for the next duration measurement 287 | t_start = t_end; 288 | } 289 | 290 | 291 | void PairBAMBOO::compute(int eflag, int vflag){ 292 | // Raise error, this method should not be called. 293 | error->all(FLERR, "Pair style BAMBOO NON-KOKKOS mode is not supported."); 294 | } 295 | -------------------------------------------------------------------------------- /pair/src/pair_bamboo.h: -------------------------------------------------------------------------------- 1 | /* ------------------------------------------------------------------------- 2 | ----- BAMBOO: Bytedance AI Molecular Booster ----- 3 | Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 4 | 5 | This program is free software; you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation; either version 2 of the License, or 8 | (at your option) any later version. 9 | 10 | This program is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with this program; if not, write to the Free Software 17 | Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 18 | ------------------------------------------------------------------------- */ 19 | 20 | #ifdef PAIR_CLASS 21 | 22 | PairStyle(bamboo,PairBAMBOO) 23 | 24 | #else 25 | 26 | #ifndef LMP_PAIR_BAMBOO_H 27 | #define LMP_PAIR_BAMBOO_H 28 | 29 | #include "pair.h" 30 | #include 31 | #include 32 | 33 | namespace LAMMPS_NS { 34 | 35 | class PairBAMBOO : public Pair { 36 | public: 37 | PairBAMBOO(class LAMMPS *); 38 | virtual ~PairBAMBOO(); 39 | virtual void compute(int, int); 40 | void settings(int, char **); 41 | virtual void coeff(int, char **); 42 | virtual double init_one(int, int); 43 | virtual void init_style(); 44 | void allocate(); 45 | void *extract(const char *str, int &dim); 46 | void click_timer(const char *str); 47 | double cutoff_disp, cutoff_coul, cutoff_net, cutoff_max; 48 | double cutoff_disp_sq, cutoff_coul_sq, cutoff_net_sq, cutoff_max_sq; 49 | torch::jit::Module model; 50 | torch::Device device = torch::kCPU; 51 | 52 | protected: 53 | std::vector type_mapper; 54 | int debug_mode = 0; 55 | int bamboo_timer = 0; 56 | int timer_counter = 0; 57 | int edge_mode = 0; // 0: undirected edge, 1 directed edge. 58 | double g_ewald; 59 | std::chrono::high_resolution_clock::time_point t_start, t_end; // Start time for the timer 60 | }; 61 | 62 | } 63 | #endif 64 | #endif 65 | 66 | -------------------------------------------------------------------------------- /pair/src/pair_bamboo_kokkos.cpp: -------------------------------------------------------------------------------- 1 | /* ------------------------------------------------------------------------- 2 | ----- BAMBOO: Bytedance AI Molecular Booster ----- 3 | Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 4 | 5 | This program is free software; you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation; either version 2 of the License, or 8 | (at your option) any later version. 9 | 10 | This program is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with this program; if not, write to the Free Software 17 | Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 18 | ------------------------------------------------------------------------- */ 19 | 20 | 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | 27 | using namespace LAMMPS_NS; 28 | 29 | #ifdef LMP_KOKKOS_GPU 30 | int vector_length = 32; 31 | #define TEAM_SIZE 4 32 | #define SINGLE_BOND_TEAM_SIZE 16 33 | #else 34 | int vector_length = 8; 35 | #define TEAM_SIZE Kokkos::AUTO() 36 | #define SINGLE_BOND_TEAM_SIZE Kokkos::AUTO() 37 | #endif 38 | 39 | // Buffer for edge size memory. 40 | #define kEdgeRatio 1.1f 41 | 42 | 43 | template 44 | PairBAMBOOKokkos::PairBAMBOOKokkos(LAMMPS *lmp) : PairBAMBOO(lmp) 45 | { 46 | respa_enable = 0; 47 | 48 | atomKK = (AtomKokkos *) atom; 49 | domainKK = (DomainKokkos *) domain; 50 | execution_space = ExecutionSpaceFromDevice::space; 51 | datamask_read = X_MASK | F_MASK | TAG_MASK | TYPE_MASK | ENERGY_MASK | VIRIAL_MASK; 52 | datamask_modify = F_MASK | ENERGY_MASK | VIRIAL_MASK; 53 | } 54 | 55 | 56 | template 57 | PairBAMBOOKokkos::~PairBAMBOOKokkos() 58 | { 59 | if (!copymode) { 60 | memoryKK->destroy_kokkos(k_eatom,eatom); 61 | memoryKK->destroy_kokkos(k_vatom,vatom); 62 | eatom = NULL; 63 | vatom = NULL; 64 | } 65 | } 66 | 67 | template 68 | void PairBAMBOOKokkos::compute(int eflag_in, int vflag_in) 69 | { 70 | eflag = eflag_in; 71 | vflag = vflag_in; 72 | ev_init(eflag,vflag, 0); 73 | click_timer("Compute start"); 74 | 75 | // reallocate per-atom arrays if necessary 76 | 77 | if (eflag_atom) { 78 | memoryKK->destroy_kokkos(k_eatom,eatom); 79 | memoryKK->create_kokkos(k_eatom,eatom,maxeatom,"pair:eatom"); 80 | d_eatom = k_eatom.view(); 81 | } 82 | if (vflag_atom) { 83 | memoryKK->destroy_kokkos(k_vatom,vatom); 84 | memoryKK->create_kokkos(k_vatom,vatom,maxvatom,"pair:vatom"); 85 | d_vatom = k_vatom.view(); 86 | } 87 | 88 | atomKK->sync(execution_space,datamask_read); 89 | if (eflag || vflag) atomKK->modified(execution_space,datamask_modify); 90 | else atomKK->modified(execution_space,F_MASK); 91 | 92 | x = atomKK->k_x.view(); 93 | f = atomKK->k_f.view(); 94 | q = atomKK->k_q.view(); 95 | tag = atomKK->k_tag.view(); 96 | type = atomKK->k_type.view(); 97 | nlocal = atom->nlocal; 98 | 99 | nall = atom->nlocal + atom->nghost; 100 | const int inum = list->inum; 101 | 102 | NeighListKokkos* k_list = static_cast*>(list); 103 | d_ilist = k_list->d_ilist; 104 | d_numneigh = k_list->d_numneigh; 105 | d_neighbors = k_list->d_neighbors; 106 | 107 | copymode = 1; 108 | 109 | // build short neighbor list 110 | const int max_neighs = d_neighbors.extent(1); 111 | // Some extra buffer to decrease resize frequence. 112 | const int max_neighs_memory = static_cast(kEdgeRatio * d_neighbors.extent(1)); 113 | 114 | if(d_numneigh_coul.extent(0) < inum){ 115 | // edges for Coulomb 116 | reallocateView(d_numneigh_coul, "BAMBOO::numneighs_coul", inum); 117 | reallocateView(d_cumsum_numneigh_coul, "BAMBOO::cumsum_numneighs_coul", inum); 118 | 119 | // edges for dispersion 120 | reallocateView(d_numneigh_disp, "BAMBOO::numneighs_disp", inum); 121 | reallocateView(d_cumsum_numneigh_disp, "BAMBOO::cumsum_numneighs_disp", inum); 122 | 123 | // edges for neural network 124 | reallocateView(d_numneigh_net, "BAMBOO::numneighs_net", inum); 125 | reallocateView(d_cumsum_numneigh_net, "BAMBOO::cumsum_numneighs_net", inum); 126 | } 127 | 128 | if(d_neighbors_coul.extent(0) < inum || d_neighbors_coul.extent(1) < max_neighs_memory){ 129 | reallocateView(d_neighbors_coul, "BAMBOO::neighbors_coul", inum, max_neighs_memory); 130 | reallocateView(d_neighbors_disp, "BAMBOO::neighbors_disp", inum, max_neighs_memory); 131 | reallocateView(d_neighbors_net, "BAMBOO::neighbors_net", inum, max_neighs_memory); 132 | reallocateView(d_edge_shift, "BAMBOO::edge_shift", inum, max_neighs_memory, 3); 133 | } 134 | 135 | click_timer("Init views"); 136 | 137 | // compute short neighbor list 138 | auto d_numneigh_coul = this->d_numneigh_coul; 139 | auto d_neighbors_coul = this->d_neighbors_coul; 140 | auto d_cumsum_numneigh_coul = this->d_cumsum_numneigh_coul; 141 | auto d_numneigh_disp = this->d_numneigh_disp; 142 | auto d_neighbors_disp = this->d_neighbors_disp; 143 | auto d_cumsum_numneigh_disp = this->d_cumsum_numneigh_disp; 144 | auto d_numneigh_net = this->d_numneigh_net; 145 | auto d_neighbors_net = this->d_neighbors_net; 146 | auto d_cumsum_numneigh_net = this->d_cumsum_numneigh_net; 147 | 148 | auto d_edge_shift = this->d_edge_shift; 149 | 150 | double cutoff_max_sq = this->cutoff_max_sq; 151 | double cutoff_coul_sq = this->cutoff_coul_sq; 152 | double cutoff_disp_sq = this->cutoff_disp_sq; 153 | double cutoff_net_sq = this->cutoff_net_sq; 154 | auto x = this->x; 155 | auto d_type = this->type; 156 | auto d_ilist = this->d_ilist; 157 | auto d_numneigh = this->d_numneigh; 158 | auto d_neighbors = this->d_neighbors; 159 | auto f = this->f; 160 | auto d_eatom = this->d_eatom; 161 | auto d_type_mapper = this->d_type_mapper; 162 | auto tag = this->tag; 163 | auto q = this->q; 164 | auto edge_mode = this->edge_mode; 165 | 166 | click_timer("Pre-Loop NeignborList"); 167 | Kokkos::parallel_for("BAMBOO: Loop NeignborList", Kokkos::RangePolicy(0,inum), KOKKOS_LAMBDA(const int ii){ 168 | const int i = d_ilist(ii); 169 | const X_FLOAT xtmp = x(i,0); 170 | const X_FLOAT ytmp = x(i,1); 171 | const X_FLOAT ztmp = x(i,2); 172 | 173 | const int jnum = d_numneigh(i); 174 | int coul_index_count = 0; 175 | int disp_index_count = 0; 176 | int net_index_count = 0; 177 | for (int jj = 0; jj < jnum; jj++) { 178 | int j = d_neighbors(i,jj); 179 | j &= NEIGHMASK; 180 | 181 | const X_FLOAT delx = xtmp - x(j,0); 182 | const X_FLOAT dely = ytmp - x(j,1); 183 | const X_FLOAT delz = ztmp - x(j,2); 184 | const F_FLOAT rsq = delx*delx + dely*dely + delz*delz; 185 | // default: disp is the max cutoff 186 | if (rsq < cutoff_max_sq && (edge_mode || tag(i) < tag(j))) { 187 | if (rsq < cutoff_disp_sq){ 188 | d_neighbors_disp(ii, disp_index_count) = jj; 189 | disp_index_count++; 190 | } 191 | if (rsq < cutoff_coul_sq){ 192 | d_neighbors_coul(ii, coul_index_count) = jj; 193 | coul_index_count++; 194 | } 195 | if (rsq < cutoff_net_sq){ 196 | d_neighbors_net(ii, net_index_count) = jj; 197 | net_index_count++; 198 | } 199 | d_edge_shift(ii, jj, 0) = delx; 200 | d_edge_shift(ii, jj, 1) = dely; 201 | d_edge_shift(ii, jj, 2) = delz; 202 | } 203 | } 204 | d_numneigh_net(ii) = net_index_count; 205 | d_numneigh_coul(ii) = coul_index_count; 206 | d_numneigh_disp(ii) = disp_index_count; 207 | }); 208 | 209 | click_timer("Loop NeignborList"); 210 | 211 | if(debug_mode){ 212 | std::cout << "index: " << "\n"; 213 | std::cout << "d_cumsum_numneigh_coul: " << d_cumsum_numneigh_coul.extent(0) << "\n"; 214 | std::cout << "d_numneigh_coul: " << d_numneigh_coul.extent(0) << "\n"; 215 | std::cout << "d_cumsum_numneigh_disp: " << d_cumsum_numneigh_disp.extent(0) << "\n"; 216 | std::cout << "d_numneigh_disp: " << d_numneigh_disp.extent(0) << "\n"; 217 | std::cout << "d_cumsum_numneigh_net: " << d_cumsum_numneigh_net.extent(0) << "\n"; 218 | std::cout << "d_numneigh_net: " << d_numneigh_net.extent(0) << "\n"; 219 | } 220 | 221 | Kokkos::deep_copy(d_cumsum_numneigh_coul, d_numneigh_coul); 222 | Kokkos::deep_copy(d_cumsum_numneigh_disp, d_numneigh_disp); 223 | Kokkos::deep_copy(d_cumsum_numneigh_net, d_numneigh_net); 224 | 225 | Kokkos::parallel_scan("BAMBOO: cumsum coul_neighs", Kokkos::RangePolicy(0,inum), KOKKOS_LAMBDA(const int ii, int& update, const bool is_final){ 226 | const int curr_val = d_cumsum_numneigh_coul(ii); 227 | update += curr_val; 228 | if(is_final) d_cumsum_numneigh_coul(ii) = update; 229 | }); 230 | Kokkos::parallel_scan("BAMBOO: cumsum disp_neighs", Kokkos::RangePolicy(0,inum), KOKKOS_LAMBDA(const int ii, int& update, const bool is_final){ 231 | const int curr_val = d_cumsum_numneigh_disp(ii); 232 | update += curr_val; 233 | if(is_final) d_cumsum_numneigh_disp(ii) = update; 234 | }); 235 | Kokkos::parallel_scan("BAMBOO: cumsum net_neighs", Kokkos::RangePolicy(0,inum), KOKKOS_LAMBDA(const int ii, int& update, const bool is_final){ 236 | const int curr_val = d_cumsum_numneigh_net(ii); 237 | update += curr_val; 238 | if(is_final) d_cumsum_numneigh_net(ii) = update; 239 | }); 240 | 241 | click_timer("Scan NeignborList"); 242 | 243 | int n_coul_edges = 0; 244 | int n_disp_edges = 0; 245 | int n_net_edges = 0; 246 | Kokkos::View n_coul_edges_view("BAMBOO: n_coul_edges_view",1); 247 | Kokkos::View n_disp_edges_view("BAMBOO: n_disp_edges_view",1); 248 | Kokkos::View n_net_edges_view("BAMBOO: n_net_edges_view",1); 249 | Kokkos::deep_copy(n_coul_edges_view, Kokkos::subview(d_cumsum_numneigh_coul, Kokkos::make_pair(inum-1, inum))); 250 | Kokkos::deep_copy(n_disp_edges_view, Kokkos::subview(d_cumsum_numneigh_disp, Kokkos::make_pair(inum-1, inum))); 251 | Kokkos::deep_copy(n_net_edges_view, Kokkos::subview(d_cumsum_numneigh_net, Kokkos::make_pair(inum-1, inum))); 252 | n_coul_edges = n_coul_edges_view(0); 253 | n_disp_edges = n_disp_edges_view(0); 254 | n_net_edges = n_net_edges_view(0); 255 | 256 | click_timer("Edge cumsum"); 257 | 258 | if(d_coul_edges.extent(1) < n_coul_edges){ 259 | reallocateView(d_coul_edges, "BAMBOO: coul_edges", 2, n_coul_edges); 260 | reallocateView(d_edge_shift_coul, "BAMBOO: coul_edge_shift", n_coul_edges, 3); 261 | } 262 | if(d_disp_edges.extent(1) < n_disp_edges){ 263 | reallocateView(d_disp_edges, "BAMBOO: disp_edges", 2, n_disp_edges); 264 | reallocateView(d_edge_shift_disp, "BAMBOO: disp_edge_shift", n_disp_edges, 3); 265 | 266 | } 267 | if(d_net_edges.extent(1) < n_net_edges){ 268 | reallocateView(d_net_edges, "BAMBOO: net_edges", 2, n_net_edges); 269 | reallocateView(d_edge_shift_net, "BAMBOO: net_edge_shift", n_net_edges, 3); 270 | } 271 | if(d_atom_types.extent(0) < inum){ 272 | reallocateView(d_atom_types, "BAMBOO: atom_types", inum); 273 | reallocateView(d_atom_pos, "BAMBOO: atom_pos", inum, 3); 274 | } 275 | 276 | click_timer("Pre-Loop edge"); 277 | 278 | auto d_coul_edges = this->d_coul_edges; 279 | auto d_net_edges = this->d_net_edges; 280 | auto d_disp_edges = this->d_disp_edges; 281 | auto d_atom_types = this->d_atom_types; 282 | auto d_atom_pos = this->d_atom_pos; 283 | auto d_edge_shift_coul = this->d_edge_shift_coul; 284 | auto d_edge_shift_disp = this->d_edge_shift_disp; 285 | auto d_edge_shift_net = this->d_edge_shift_net; 286 | 287 | Kokkos::parallel_for("BAMBOO: atom type and pos", Kokkos::RangePolicy(0, inum), KOKKOS_LAMBDA(const int i){ 288 | const int itag = tag(i) - 1; 289 | // 1 based to 0 based index. 290 | d_atom_types(itag) = d_type_mapper(d_type(i) - 1); 291 | d_atom_pos(itag,0) = x(i,0); 292 | d_atom_pos(itag,1) = x(i,1); 293 | d_atom_pos(itag,2) = x(i,2); 294 | }); 295 | click_timer("Position"); 296 | 297 | Kokkos::parallel_for("BAMBOO: create coul edges", Kokkos::TeamPolicy(inum, Kokkos::AUTO()), KOKKOS_LAMBDA(const MemberType team_member){ 298 | const int ii = team_member.league_rank(); 299 | const int i = d_ilist(ii); 300 | const int startedge = ii==0 ? 0 : d_cumsum_numneigh_coul(ii-1); 301 | Kokkos::parallel_for(Kokkos::TeamVectorRange(team_member, d_numneigh_coul(ii)), [&] (const int jj){ 302 | const int jj_origin = d_neighbors_coul(ii,jj); 303 | const int j = d_neighbors(i, jj_origin); 304 | d_coul_edges(0, startedge + jj) = tag(i) - 1; 305 | d_coul_edges(1, startedge + jj) = tag(j) - 1; 306 | d_edge_shift_coul(startedge + jj, 0) = d_edge_shift(ii, jj_origin, 0); 307 | d_edge_shift_coul(startedge + jj, 1) = d_edge_shift(ii, jj_origin, 1); 308 | d_edge_shift_coul(startedge + jj, 2) = d_edge_shift(ii, jj_origin, 2); 309 | }); 310 | }); 311 | 312 | click_timer("Edge coul"); 313 | 314 | Kokkos::parallel_for("BAMBOO: create disp edges", Kokkos::TeamPolicy(inum, Kokkos::AUTO()), KOKKOS_LAMBDA(const MemberType team_member){ 315 | const int ii = team_member.league_rank(); 316 | const int i = d_ilist(ii); 317 | const int startedge = ii==0 ? 0 : d_cumsum_numneigh_disp(ii-1); 318 | Kokkos::parallel_for(Kokkos::TeamVectorRange(team_member, d_numneigh_disp(ii)), [&] (const int jj){ 319 | const int jj_origin = d_neighbors_disp(ii,jj); 320 | const int j = d_neighbors(i, jj_origin); 321 | d_disp_edges(0, startedge + jj) = tag(i) - 1; 322 | d_disp_edges(1, startedge + jj) = tag(j) - 1; 323 | d_edge_shift_disp(startedge + jj, 0) = d_edge_shift(ii, jj_origin, 0); 324 | d_edge_shift_disp(startedge + jj, 1) = d_edge_shift(ii, jj_origin, 1); 325 | d_edge_shift_disp(startedge + jj, 2) = d_edge_shift(ii, jj_origin, 2); 326 | }); 327 | }); 328 | 329 | click_timer("Edge disp"); 330 | 331 | Kokkos::parallel_for("BAMBOO: create net edges", Kokkos::TeamPolicy(inum, Kokkos::AUTO()), KOKKOS_LAMBDA(const MemberType team_member){ 332 | const int ii = team_member.league_rank(); 333 | const int i = d_ilist(ii); 334 | const int startedge = ii==0 ? 0 : d_cumsum_numneigh_net(ii-1); 335 | Kokkos::parallel_for(Kokkos::TeamVectorRange(team_member, d_numneigh_net(ii)), [&] (const int jj){ 336 | const int jj_origin = d_neighbors_net(ii,jj); 337 | const int j = d_neighbors(i, jj_origin); 338 | d_net_edges(0, startedge + jj) = tag(i) - 1; 339 | d_net_edges(1, startedge + jj) = tag(j) - 1; 340 | d_edge_shift_net(startedge + jj, 0) = d_edge_shift(ii, jj_origin, 0); 341 | d_edge_shift_net(startedge + jj, 1) = d_edge_shift(ii, jj_origin, 1); 342 | d_edge_shift_net(startedge + jj, 2) = d_edge_shift(ii, jj_origin, 2); 343 | }); 344 | }); 345 | 346 | click_timer("Edge net"); 347 | 348 | DoubleView2D d_domain_box("domain_box", 3, 3); 349 | auto h_domain_box = Kokkos::create_mirror_view(d_domain_box); 350 | h_domain_box(0, 0) = domain->boxhi[0] - domain->boxlo[0]; 351 | 352 | h_domain_box(1, 0) = domain->xy; 353 | h_domain_box(1, 1) = domain->boxhi[1] - domain->boxlo[1]; 354 | 355 | h_domain_box(2, 0) = domain->xz; 356 | h_domain_box(2, 1) = domain->yz; 357 | h_domain_box(2, 2) = domain->boxhi[2] - domain->boxlo[2]; 358 | Kokkos::deep_copy(d_domain_box, h_domain_box); 359 | click_timer("Domain box"); 360 | 361 | torch::Tensor edge_shift_tensor = torch::from_blob(d_edge_shift.data(), {inum, max_neighs, 3}, torch::TensorOptions().dtype(torch::kFloat64).device(device)); 362 | torch::Tensor neighbors_net_tensor = torch::from_blob(d_neighbors_net.data(), {inum, max_neighs}, torch::TensorOptions().dtype(torch::kInt32).device(device)); 363 | torch::Tensor neighbors_coul_tensor = torch::from_blob(d_neighbors_coul.data(), {inum, max_neighs}, torch::TensorOptions().dtype(torch::kInt32).device(device)); 364 | torch::Tensor neighbors_disp_tensor = torch::from_blob(d_neighbors_disp.data(), {inum, max_neighs}, torch::TensorOptions().dtype(torch::kInt32).device(device)); 365 | 366 | torch::Tensor coul_edge_index = torch::from_blob(d_coul_edges.data(), {2,n_coul_edges}, {(long) d_coul_edges.extent(1),1}, torch::TensorOptions().dtype(torch::kInt64).device(device)); 367 | torch::Tensor net_edges_index = torch::from_blob(d_net_edges.data(), {2,n_net_edges }, {(long) d_net_edges.extent(1), 1}, torch::TensorOptions().dtype(torch::kInt64).device(device)); 368 | torch::Tensor disp_edge_index = torch::from_blob(d_disp_edges.data(), {2,n_disp_edges}, {(long) d_disp_edges.extent(1),1}, torch::TensorOptions().dtype(torch::kInt64).device(device)); 369 | 370 | torch::Tensor coul_edge_cell_shift = torch::from_blob(d_edge_shift_coul.data(), {n_coul_edges, 3}, {3 ,1}, torch::TensorOptions().dtype(torch::kFloat64).device(device)); 371 | torch::Tensor net_edge_cell_shift = torch::from_blob(d_edge_shift_net.data(), {n_net_edges, 3}, {3 ,1}, torch::TensorOptions().dtype(torch::kFloat64).device(device)); 372 | torch::Tensor disp_edge_cell_shift = torch::from_blob(d_edge_shift_disp.data(), {n_disp_edges, 3}, {3 ,1}, torch::TensorOptions().dtype(torch::kFloat64).device(device)); 373 | 374 | torch::Tensor ij2type_tensor = torch::from_blob(d_atom_types.data(), {inum}, torch::TensorOptions().dtype(torch::kInt64).device(device)); 375 | torch::Tensor pos_tensor = torch::from_blob(d_atom_pos.data(), {inum,3}, {3,1}, torch::TensorOptions().device(device)); 376 | torch::Tensor cell_tensor = torch::from_blob(d_domain_box.data(), {3,3}, {3, 1}, torch::TensorOptions().dtype(torch::kFloat64).device(device)); 377 | torch::Tensor ewald_tensor = torch::tensor({g_ewald}, torch::TensorOptions().dtype(torch::kFloat64).device(device)); 378 | 379 | if(debug_mode){ 380 | std::cout << "coul_edge_index: " << coul_edge_index.sizes() << "\n"; 381 | std::cout << "net_edges_index: " << net_edges_index.sizes() << "\n"; 382 | std::cout << "disp_edge_index: " << disp_edge_index.sizes() << "\n"; 383 | std::cout << "coul_edge_cell_shift: " << coul_edge_cell_shift.sizes() << "\n"; 384 | std::cout << "net_edge_cell_shift: " << net_edge_cell_shift.sizes() << "\n"; 385 | std::cout << "disp_edge_cell_shift: " << disp_edge_cell_shift.sizes() << "\n"; 386 | 387 | torch::save({ 388 | coul_edge_index, 389 | net_edges_index, 390 | disp_edge_index, 391 | coul_edge_cell_shift, 392 | net_edge_cell_shift, 393 | disp_edge_cell_shift, 394 | ij2type_tensor, pos_tensor, 395 | cell_tensor, 396 | ewald_tensor, 397 | edge_shift_tensor, 398 | neighbors_net_tensor, 399 | neighbors_coul_tensor, 400 | neighbors_disp_tensor 401 | }, "tensor.pt"); 402 | 403 | std::cout << "Save tensor to tensor.pt" << "\n"; 404 | } 405 | 406 | 407 | c10::Dict input; 408 | input.insert("pos", pos_tensor); 409 | input.insert("edge_index", net_edges_index); 410 | input.insert("edge_cell_shift", net_edge_cell_shift); 411 | input.insert("coul_edge_index", coul_edge_index); 412 | input.insert("coul_edge_cell_shift", coul_edge_cell_shift); 413 | input.insert("disp_edge_index", disp_edge_index); 414 | input.insert("disp_edge_cell_shift", disp_edge_cell_shift); 415 | input.insert("cell", cell_tensor); 416 | input.insert("atom_types", ij2type_tensor); 417 | input.insert("g_ewald", ewald_tensor); 418 | std::vector input_vector(1, input); 419 | click_timer("Pre-inference"); 420 | 421 | auto output = model.forward(input_vector).toGenericDict(); 422 | torch::Tensor pred_virial_tensor = output.at("pred_virial").toTensor().cpu(); 423 | click_timer("Inference"); 424 | 425 | torch::Tensor pred_coul_energy_tensor = output.at("pred_coul_energy").toTensor().cpu(); 426 | torch::Tensor pred_energy_tensor = output.at("pred_energy").toTensor().cpu(); 427 | torch::Tensor predict_charge_tensor = output.at("pred_charge").toTensor(); 428 | torch::Tensor predict_forces_tensor = output.at("pred_forces").toTensor(); 429 | 430 | UnmanagedDoubleView2D d_forces(predict_forces_tensor.data_ptr(), inum, 3); 431 | UnmanagedDoubleView1D d_charge(predict_charge_tensor.data_ptr(), inum); 432 | 433 | click_timer("Begin postprocess."); 434 | eng_vdwl = pred_energy_tensor.data_ptr()[0] - pred_coul_energy_tensor.data_ptr()[0]; 435 | eng_coul = pred_coul_energy_tensor.data_ptr()[0]; 436 | 437 | click_timer("Update fpair."); 438 | Kokkos::parallel_for("BAMBOO: update fpair", Kokkos::RangePolicy(0, inum), KOKKOS_LAMBDA(const int i){ 439 | // 1 based to 0 based index. 440 | const int itag = tag(i) - 1; 441 | f(i,0) = d_forces(itag, 0); 442 | f(i,1) = d_forces(itag, 1); 443 | f(i,2) = d_forces(itag, 2); 444 | q(i) = d_charge(itag); 445 | }); 446 | 447 | click_timer("Update nn_virial."); 448 | auto predict_virial = pred_virial_tensor.accessor(); 449 | virial[0] += predict_virial[0][0]; 450 | virial[1] += predict_virial[1][1]; 451 | virial[2] += predict_virial[2][2]; 452 | virial[3] += 0.5 * (predict_virial[0][1] + predict_virial[1][0]); 453 | virial[4] += 0.5 * (predict_virial[0][2] + predict_virial[2][0]); 454 | virial[5] += 0.5 * (predict_virial[1][2] + predict_virial[2][1]); 455 | 456 | click_timer("Update charge."); 457 | auto h_q = Kokkos::create_mirror_view(q); 458 | Kokkos::deep_copy(h_q, q); 459 | double *q_atom = atom->q; 460 | for(int i=0; i 483 | void PairBAMBOOKokkos::coeff(int narg, char **arg) 484 | { 485 | PairBAMBOO::coeff(narg, arg); 486 | 487 | d_type_mapper = IntView1D("BAMBOO: type_mapper", type_mapper.size()); 488 | auto h_type_mapper = Kokkos::create_mirror_view(d_type_mapper); 489 | for(int i = 0; i < type_mapper.size(); i++){ 490 | h_type_mapper(i) = type_mapper[i]; 491 | } 492 | Kokkos::deep_copy(d_type_mapper, h_type_mapper); 493 | } 494 | 495 | 496 | template 497 | void PairBAMBOOKokkos::init_style() 498 | { 499 | PairBAMBOO::init_style(); 500 | 501 | // irequest = neigh request made by parent class 502 | neighflag = lmp->kokkos->neighflag; 503 | int irequest = neighbor->nrequest - 1; 504 | 505 | auto req = neighbor->requests[irequest]; 506 | 507 | req->set_kokkos_host(std::is_same::value && 508 | !std::is_same::value); 509 | 510 | req->set_kokkos_device(std::is_same::value); 511 | req->enable_full(); 512 | } 513 | 514 | namespace LAMMPS_NS { 515 | template class PairBAMBOOKokkos; 516 | #ifdef LMP_KOKKOS_GPU 517 | template class PairBAMBOOKokkos; 518 | #endif 519 | } 520 | 521 | -------------------------------------------------------------------------------- /pair/src/pair_bamboo_kokkos.h: -------------------------------------------------------------------------------- 1 | /* ------------------------------------------------------------------------- 2 | ----- BAMBOO: Bytedance AI Molecular Booster ----- 3 | Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 4 | 5 | This program is free software; you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation; either version 2 of the License, or 8 | (at your option) any later version. 9 | 10 | This program is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with this program; if not, write to the Free Software 17 | Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 18 | ------------------------------------------------------------------------- */ 19 | 20 | 21 | #ifdef PAIR_CLASS 22 | 23 | PairStyle(bamboo/kk,PairBAMBOOKokkos); 24 | PairStyle(bamboo/kk/device,PairBAMBOOKokkos); 25 | PairStyle(bamboo/kk/host,PairBAMBOOKokkos); 26 | 27 | #else 28 | 29 | #ifndef LMP_PAIR_BAMBOO_KOKKOS_H 30 | #define LMP_PAIR_BAMBOO_KOKKOS_H 31 | 32 | 33 | #include "pair_bamboo.h" 34 | #include 35 | #include 36 | #include 37 | 38 | #include 39 | #include "kokkos.h" 40 | #include "atom_kokkos.h" 41 | #include "domain_kokkos.h" 42 | #include "neigh_request.h" 43 | #include "force.h" 44 | #include "comm.h" 45 | #include "memory_kokkos.h" 46 | #include "neighbor.h" 47 | #include "neigh_list_kokkos.h" 48 | #include "error.h" 49 | #include "atom_masks.h" 50 | 51 | namespace LAMMPS_NS { 52 | 53 | template 54 | class PairBAMBOOKokkos : public PairBAMBOO { 55 | public: 56 | using MemberType = typename Kokkos::TeamPolicy::member_type; 57 | typedef ArrayTypes AT; 58 | 59 | PairBAMBOOKokkos(class LAMMPS *); 60 | virtual ~PairBAMBOOKokkos(); 61 | // Override compute method from PairBAMBOO 62 | virtual void compute(int eflag, int vflag) override; 63 | 64 | // Override coeff method from PairBAMBOO 65 | virtual void coeff(int narg, char **arg) override; 66 | 67 | // Override init_style method from PairBAMBOO 68 | virtual void init_style() override; 69 | 70 | typename AT::t_efloat_1d d_eatom; 71 | typename AT::t_virial_array d_vatom; 72 | 73 | template 74 | static void reallocateView(ViewType& view, const std::string& name, const size_t dim1) { 75 | view = ViewType(); // Reset the view 76 | view = ViewType(Kokkos::ViewAllocateWithoutInitializing(name), dim1); 77 | } 78 | 79 | template 80 | static void reallocateView(ViewType& view, const std::string& name, const size_t dim1, const size_t dim2) { 81 | view = ViewType(); // Reset the view 82 | view = ViewType(Kokkos::ViewAllocateWithoutInitializing(name), dim1, dim2); 83 | } 84 | 85 | template 86 | static void reallocateView(ViewType& view, const std::string& name, const size_t dim1, const size_t dim2, const size_t dim3) { 87 | view = ViewType(); // Reset the view 88 | view = ViewType(Kokkos::ViewAllocateWithoutInitializing(name), dim1, dim2, dim3); 89 | } 90 | 91 | protected: 92 | 93 | 94 | class DomainKokkos *domainKK; 95 | 96 | using IntView1D = Kokkos::View; 97 | using IntView2D = Kokkos::View; 98 | using LongView1D = Kokkos::View; 99 | using LongView2D = Kokkos::View; 100 | using View1D = Kokkos::View; 101 | using View2D = Kokkos::View; 102 | using FloatView2D = Kokkos::View; 103 | using DoubleView2D = Kokkos::View; 104 | using DoubleView3D = Kokkos::View; 105 | using UnmanagedFloatView1D = Kokkos::View; 106 | using UnmanagedFloatView2D = Kokkos::View; 107 | using UnmanagedDoubleView1D = Kokkos::View; 108 | using UnmanagedDoubleView2D = Kokkos::View; 109 | 110 | typename AT::t_x_array_randomread x; 111 | typename AT::t_f_array f; 112 | typename AT::t_tagint_1d tag; 113 | typename AT::t_float_1d q; 114 | typename AT::t_int_1d_randomread type; 115 | typename AT::t_neighbors_2d d_neighbors; 116 | typename AT::t_int_1d_randomread d_ilist; 117 | typename AT::t_int_1d_randomread d_numneigh; 118 | 119 | DAT::tdual_efloat_1d k_eatom; 120 | DAT::tdual_virial_array k_vatom; 121 | 122 | View1D d_ewald; 123 | IntView1D d_type_mapper; 124 | LongView1D d_atom_types; 125 | LongView2D d_disp_edges, d_coul_edges, d_net_edges; 126 | DoubleView2D d_atom_pos; 127 | 128 | IntView1D d_numneigh_net, d_numneigh_coul, d_numneigh_disp; 129 | IntView1D d_cumsum_numneigh_net, d_cumsum_numneigh_coul, d_cumsum_numneigh_disp; 130 | IntView2D d_neighbors_net, d_neighbors_coul, d_neighbors_disp; 131 | DoubleView3D d_edge_shift; 132 | DoubleView2D d_edge_shift_net, d_edge_shift_coul, d_edge_shift_disp; 133 | 134 | int neighflag, newton_pair; 135 | int nlocal, nall, eflag, vflag; 136 | }; 137 | 138 | } 139 | 140 | #endif 141 | #endif 142 | 143 | -------------------------------------------------------------------------------- /train/alignment.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import argparse 19 | import json 20 | import os 21 | import random 22 | from typing import Any, Dict, List, Tuple 23 | 24 | import numpy as np 25 | import torch 26 | 27 | from models.bamboo_get import BambooGET 28 | from utils.batchify import batchify 29 | from utils.log_helper import create_logger 30 | from utils.path import ALIGNMENT_PATH, DATA_PATH 31 | from utils.rejit import convert 32 | from utils.constant import nktv2p 33 | 34 | 35 | def get_parser(): 36 | # Create the parser 37 | parser = argparse.ArgumentParser(description="Arguments for bamboo model alignment.") 38 | 39 | # Required arguments 40 | parser.add_argument('--config', default='', type=str, help="Path to a configuration file in JSON format.") 41 | parser.add_argument('--job_name', default='default', type=str) 42 | 43 | # Training and validation data paths 44 | parser.add_argument("--training_data_path", type=str, default="train_data.pt", help="Path to the training data file.") 45 | parser.add_argument("--validation_data_path", type=str, default="val_data.pt", help="Path to the validation data file.") 46 | 47 | # Data sources and model configuration 48 | parser.add_argument("--model", type=str, default=None, help="Specify the model's path for use in alignment.") 49 | parser.add_argument("--frame_directories", nargs="*", type=str, default=[], help="List of directories that contain frame data for processing.") 50 | parser.add_argument("--mixture_names", nargs="*", type=str, default=[], help="Names of mixtures to be considered during alignment.") 51 | parser.add_argument("--delta_pressure", nargs="*", type=float, default=[], help="Delta pressures for the respective mixtures, listed in the same order as mixture names.") 52 | 53 | # Data property weighting 54 | parser.add_argument("--energy_ratio", type=float, default=0.3, help="Weight of energy predictions in the loss function.") 55 | parser.add_argument("--force_ratio", type=float, default=1.0, help="Weight of force predictions in the loss function.") 56 | parser.add_argument("--virial_ratio", type=float, default=0.1, help="Weight of virial predictions in the loss function.") 57 | parser.add_argument("--dipole_ratio", type=float, default=3.0, help="Weight of dipole predictions in the loss function.") 58 | parser.add_argument("--bulk_energy_ratio", type=float, default=1e2, help="Weight of bulk energy predictions in the loss function.") 59 | parser.add_argument("--bulk_force_ratio", type=float, default=1e6, help="Weight of bulk force predictions in the loss function.") 60 | parser.add_argument("--bulk_virial_ratio", type=float, default=3e3, help="Weight of bulk virial predictions in the loss function.") 61 | 62 | # Training parameters 63 | parser.add_argument("--batch_size", type=int, default=512, help="Number of samples processed together in one pass.") 64 | parser.add_argument("--epochs", type=int, default=30, help="Total number of training cycles through the entire dataset.") 65 | parser.add_argument("--frame_val_interval", type=int, default=3, help="Interval for validating the model with the validation dataset.") 66 | parser.add_argument("--max_frame_per_mixture", type=int, default=30, help="Maximum number of frames allowed for each mixture.") 67 | parser.add_argument("--lr", type=float, default=1e-12, help="Initial learning rate for the optimization algorithm.") 68 | parser.add_argument("--scheduler_gamma", type=float, default=0.99, help="Decay rate for adjusting the learning rate across epochs.") 69 | 70 | args = parser.parse_args() 71 | 72 | # Load configuration from a JSON file if specified and the file exists 73 | if os.path.isfile(args.config): 74 | with open(args.config, 'r') as config_file: 75 | config_from_file = json.load(config_file) 76 | 77 | # Update the command line arguments with values from the JSON configuration 78 | for key, value in config_from_file.items(): 79 | # Skip updating args with None values from the configuration file 80 | if value is not None: 81 | setattr(args, key, value) 82 | 83 | return args 84 | 85 | 86 | class DensityAlignment: 87 | def __init__(self, args) -> None: 88 | 89 | self.args = args 90 | self.work_dir: str = os.path.join(ALIGNMENT_PATH, self.args.job_name) 91 | self.checkpoint_output = os.path.join(self.work_dir, "checkpoints") 92 | self.log_output = os.path.join(self.work_dir, "logs") 93 | os.makedirs(self.checkpoint_output, exist_ok=True) 94 | os.makedirs(self.log_output, exist_ok=True) 95 | 96 | self.log_file = os.path.join(self.log_output, "alignment.log") 97 | self.logger = create_logger("ALIGNMENT", self.log_file) 98 | 99 | # Training and validation data paths 100 | self.training_data_path = os.path.join(DATA_PATH, args.training_data_path) 101 | self.validation_data_path = os.path.join(DATA_PATH, args.validation_data_path) 102 | 103 | # Placeholder for cluster data 104 | self._train_cluster_data = None 105 | self._val_cluster_data = None 106 | 107 | self.cluster_loss_ratio: Dict[str, float] = { 108 | "energy": args.energy_ratio, 109 | "forces": args.force_ratio, 110 | "virial": args.virial_ratio, 111 | "dipole": args.dipole_ratio, 112 | } 113 | 114 | self.bulk_loss_ratios: Dict[str, float] = { 115 | "pred_energy": args.bulk_energy_ratio, 116 | "pred_forces": args.bulk_force_ratio, 117 | "pred_virial": args.bulk_virial_ratio, 118 | } 119 | 120 | # Initialize delta_pressure dictionary directly from zipped mixture names and delta pressures 121 | self.delta_pressure = dict(zip(args.mixture_names, args.delta_pressure)) 122 | 123 | # Log the current state of delta pressure 124 | self.logger.info(f"Delta Pressure: {self.delta_pressure}") 125 | 126 | # Determine if we should skip alignment based on the condition that all delta pressures are close to zero 127 | self.skip_alignment = np.allclose(list(self.delta_pressure.values()), 0, rtol=1e-1) 128 | 129 | # Log the decision on whether to skip finetuning 130 | if self.skip_alignment: 131 | self.logger.info("All delta pressures are nearly zero, skipping alignment.") 132 | 133 | self.model = args.model 134 | if not os.path.isfile(self.model): 135 | raise FileNotFoundError(f"Model file {self.model} not found.") 136 | self.train_model = convert(self.model) 137 | 138 | for arg in vars(args): 139 | val = getattr(args, arg) 140 | if isinstance(val, list): 141 | val = '\n\t\t\t' + '\n\t\t\t'.join(map(str, val)) 142 | self.logger.info(f"{arg} = {val}") 143 | else: 144 | self.logger.info(f"{arg} = {val}") 145 | 146 | self.frame_directories = args.frame_directories 147 | 148 | self.lr: float = args.lr 149 | self.scheduler_gamma: float = args.scheduler_gamma 150 | self.epochs: int = args.epochs 151 | self.batch_size: int = args.batch_size 152 | self.frame_val_interval: int = args.frame_val_interval 153 | self.max_frame_per_mixture: int = args.max_frame_per_mixture 154 | 155 | self.device = torch.device('cuda') 156 | self.result = {} 157 | 158 | @property 159 | def train_cluster_data(self) -> Dict[str, torch.Tensor]: 160 | if self._train_cluster_data is None: 161 | self._train_cluster_data = torch.load(self.training_data_path, map_location="cpu") 162 | return self._train_cluster_data 163 | 164 | @property 165 | def val_cluster_data(self) -> Dict[str, torch.Tensor]: 166 | if self._val_cluster_data is None: 167 | self._val_cluster_data = torch.load(self.validation_data_path, map_location=self.device) 168 | return self._val_cluster_data 169 | 170 | def load_data(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: 171 | # *.pts -> key: val 172 | # name: str -> mixture_name: str 173 | frames = [] 174 | train_frames = [] 175 | val_frames = [] 176 | 177 | mixture_name_counter = {k: 0 for k in self.delta_pressure} 178 | 179 | # Recursively scan the frames folder to locate all data files. 180 | for frame_directory in self.frame_directories: 181 | if not os.path.isdir(frame_directory): 182 | raise NotADirectoryError(f"Frame directory {frame_directory} not found.") 183 | for root, _, files in os.walk(frame_directory): 184 | for file in files: 185 | if file.endswith(".pt"): 186 | frames.append(os.path.join(root, file)) 187 | 188 | for frame_path in frames: 189 | frame_data = torch.load(frame_path, map_location=self.device) 190 | mixture_name = frame_data["mixture_name"] 191 | 192 | if mixture_name not in self.delta_pressure: 193 | self.logger.warning(f"Skipping {frame_path}: {mixture_name} because it is not in delta_pressure.") 194 | continue 195 | 196 | if mixture_name_counter[mixture_name] >= self.max_frame_per_mixture: 197 | self.logger.warning(f"Skipping {frame_path}: {mixture_name} because it exceeds exceed max_frame_per_mixture.") 198 | continue 199 | 200 | # Process frame_data and prepare result template 201 | pred: Dict[str, torch.Tensor] = self.train_model.forward(frame_data["inputs"]) 202 | # detach the tensors to save memory. 203 | for k, v in pred.items(): 204 | pred[k] = v.detach() 205 | pred[k].requires_grad = False 206 | 207 | result_tmp = { 208 | "delta_pressure": self.delta_pressure[mixture_name], 209 | "frame_path": frame_path, 210 | "nn_virial_outer": pred["nn_virial_outer"], 211 | "mixture_name": mixture_name, 212 | } 213 | 214 | # Decide on whether to add to validation or training frames 215 | if mixture_name_counter[mixture_name] % self.frame_val_interval == 0: 216 | val_frames.append(result_tmp) 217 | else: 218 | for k in self.bulk_loss_ratios: 219 | result_tmp[k] = pred[k] 220 | train_frames.append(result_tmp) 221 | mixture_name_counter[mixture_name] += 1 222 | 223 | # Explicit cleanup to assist garbage collection and reduce memory footprint 224 | del frame_data, pred 225 | 226 | # Log the count of train and validation frames 227 | num_train_frames = len(train_frames) 228 | num_val_frames = len(val_frames) 229 | self.logger.info(f"Loaded {num_train_frames} train frames and {num_val_frames} val frames.") 230 | 231 | # Log counts for each mixture name 232 | mixture_counts_log = "\n".join([f"mixture name: {name}, count: {count}" for name, count in mixture_name_counter.items()]) 233 | self.logger.info(f"Mixture counts:\n{mixture_counts_log}") 234 | 235 | return train_frames, val_frames 236 | 237 | def bulk_validation(self, model: BambooGET, val_frames: List[Dict[str, Any]]) -> Dict[str, float]: 238 | val_dp_outer = {mixture_name: [] for mixture_name in self.delta_pressure.keys()} 239 | 240 | # Process each validation frame 241 | for curr_frame in val_frames: 242 | frame_data = torch.load(curr_frame["frame_path"], map_location=self.device) 243 | inputs = frame_data["inputs"] 244 | pred = model.forward(inputs) 245 | nn_virial_outer_diff = pred['nn_virial_outer'] - curr_frame["nn_virial_outer"] 246 | volume = inputs['cell'][0][0] * inputs['cell'][1][1] * inputs['cell'][2][2] 247 | pred_outer_press = nktv2p * nn_virial_outer_diff / (3 * volume) - curr_frame["delta_pressure"] 248 | val_dp_outer[curr_frame["mixture_name"]].append(pred_outer_press.item()) 249 | del frame_data, pred, inputs 250 | 251 | # Compute means for each mixture and overall statistics 252 | val_dp_outer_mean: Dict[str, float] = {k: np.mean(v) for k, v in val_dp_outer.items()} 253 | all_means = list(val_dp_outer_mean.values()) 254 | 255 | dp_avg = np.mean(all_means) 256 | dp_std = np.std(all_means) 257 | val_dp_outer_mean.update({"AVG": dp_avg, "STD": dp_std}) 258 | return val_dp_outer_mean 259 | 260 | def cluster_validation(self, model: BambooGET, cluster: Dict[str, torch.Tensor]) -> Dict[str, float]: 261 | keys = ['energy', 'forces', 'virial', 'dipole'] 262 | val_rmse = {k: [] for k in keys} 263 | val_data_size = len(cluster['total_charge']) 264 | total_step = val_data_size // self.batch_size 265 | 266 | for step in range(total_step): 267 | start_idx = step * self.batch_size 268 | end_idx = (step + 1) * self.batch_size 269 | batch_data = batchify(cluster, start_idx, end_idx, device=self.device) 270 | mse, _, _ = model.get_loss(batch_data) 271 | 272 | for k in keys: 273 | val_rmse[k].append(mse[k].item() * self.batch_size) 274 | 275 | total_val_rmse = {f"cluster_{k}_rmse": np.sqrt(sum(val_rmse[k]) / val_data_size) for k in keys} 276 | return total_val_rmse 277 | 278 | def construct_log(self, info, name=None, baseline=None): 279 | log_parts = [] 280 | for k, v in info.items(): 281 | entry = f"{k}: {v:.2f}" 282 | if baseline is not None: 283 | diff = v - baseline.get(k, 0) # Safely get baseline value, defaulting to 0 284 | diff_sign = "+" if diff >= 0 else "-" 285 | entry += f" ({diff_sign} {abs(diff):.2f})" 286 | log_parts.append(entry) 287 | 288 | log = ", ".join(log_parts) 289 | if name: 290 | log = f"{name}: {log}" 291 | 292 | self.logger.info(log) 293 | 294 | def run(self): 295 | if self.skip_alignment: 296 | self.conclude() 297 | return 298 | 299 | params = list(self.train_model.energy_mlp.parameters()) 300 | optimizer = torch.optim.SGD(params, lr=self.lr) 301 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.scheduler_gamma) 302 | 303 | train_frames, val_frames = self.load_data() 304 | self.logger.info(f"Train frames: {len(train_frames)}") 305 | self.logger.info(f"Val frames: {len(val_frames)}") 306 | 307 | if not train_frames: 308 | raise ValueError("No train frames found.") 309 | if not val_frames: 310 | raise ValueError("No val frames found.") 311 | 312 | 313 | total_cluster_data = len(self.train_cluster_data['total_charge']) 314 | cluster_batch_num = total_cluster_data // self.batch_size 315 | cluster_random_index = list(range(cluster_batch_num)) 316 | random.shuffle(cluster_random_index) 317 | n_cluster_index = 0 318 | 319 | base_val_cluster_rmse = self.cluster_validation(model=self.train_model, cluster=self.val_cluster_data) 320 | val_dp_outer = self.bulk_validation(model=self.train_model, val_frames=val_frames) 321 | 322 | self.logger.info("Before alignment:") 323 | self.construct_log(base_val_cluster_rmse, name="[CLUSTER]") 324 | self.construct_log(val_dp_outer, name="[BULK]") 325 | self.logger.info(f"Alignment starts. frames: {len(train_frames)}") 326 | 327 | for epoch in range(self.epochs): 328 | train_order = list(range(len(train_frames))) 329 | random.shuffle(train_order) 330 | for idx in train_order: 331 | optimizer.zero_grad() 332 | 333 | curr_frame = train_frames[idx] 334 | 335 | frame_data = torch.load(curr_frame["frame_path"], map_location=self.device) 336 | inputs = frame_data["inputs"] 337 | 338 | natoms = len(inputs["atom_types"]) 339 | natoms_ratio = {"pred_energy": 1.0 / natoms, "pred_forces": 1.0, "pred_virial": 1.0 / natoms} 340 | 341 | pred = self.train_model.forward(frame_data["inputs"]) 342 | nn_virial_outer_diff = pred['nn_virial_outer'] - curr_frame["nn_virial_outer"] 343 | volume = inputs['cell'][0][0] * inputs['cell'][1][1] * inputs['cell'][2][2] 344 | loss: torch.Tensor = (nktv2p * nn_virial_outer_diff / (3*volume) - curr_frame["delta_pressure"])**2 345 | for k, v in self.bulk_loss_ratios.items(): 346 | loss += self.bulk_loss_ratios[k] * torch.mean((pred[k] - curr_frame[k])**2) * natoms_ratio[k] 347 | loss.backward() 348 | optimizer.step() 349 | del frame_data, pred, inputs 350 | 351 | # Train on cluster data. 352 | cluster_loss = torch.tensor(0.0, device=self.device) 353 | #training on cluster data 354 | optimizer.zero_grad() 355 | 356 | n_cluster_index = (n_cluster_index + 1) % cluster_batch_num 357 | start = cluster_random_index[n_cluster_index] * self.batch_size 358 | end = start + self.batch_size 359 | 360 | batch_data = batchify(self.train_cluster_data, start, end, device=self.device) 361 | mse, _, _ = self.train_model.get_loss(batch_data) 362 | for k, v in self.cluster_loss_ratio.items(): 363 | cluster_loss += v * mse[k] 364 | del batch_data 365 | 366 | cluster_loss.backward() 367 | optimizer.step() 368 | 369 | val_cluster_rmse = self.cluster_validation(model=self.train_model, cluster=self.val_cluster_data) 370 | val_dp_outer = self.bulk_validation(model=self.train_model, val_frames=val_frames) 371 | 372 | self.construct_log(val_cluster_rmse, name="[CLUSTER]", baseline=base_val_cluster_rmse) 373 | self.construct_log(val_dp_outer, name=f"[EPOCH: {epoch}]") 374 | 375 | scheduler.step() 376 | self.result = val_dp_outer 377 | 378 | self.conclude() 379 | 380 | def conclude(self): 381 | # Save the model. 382 | module = torch.jit.script(self.train_model) 383 | module_file = os.path.join(self.checkpoint_output, "alignment.pt") 384 | module.save(module_file) # type: ignore 385 | self.result["model"] = module_file 386 | 387 | # save result info. 388 | result_file = os.path.join(self.work_dir, "result.json") 389 | with open(result_file, "w") as f: 390 | json.dump(self.result, f, indent=4) 391 | 392 | 393 | def main(): 394 | args = get_parser() 395 | 396 | density_alignment = DensityAlignment(args) 397 | density_alignment.run() 398 | 399 | 400 | if __name__ == "__main__": 401 | # For local test. 402 | main() 403 | -------------------------------------------------------------------------------- /train/ensemble.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import argparse 19 | import json 20 | import os 21 | import random 22 | from datetime import datetime 23 | from typing import Dict, List, Optional, Union 24 | 25 | import numpy as np 26 | import pandas as pd 27 | import torch 28 | 29 | from models.bamboo_get import BambooGET 30 | from utils.batchify import batchify 31 | from utils.log_helper import create_logger 32 | from utils.path import DATA_PATH, ENSEMBLE_PATH 33 | from utils.rejit import convert 34 | 35 | 36 | def get_parser(): 37 | # Create the parser 38 | parser = argparse.ArgumentParser(description="Arguments for bamboo model ensembling.") 39 | 40 | # Required arguments 41 | parser.add_argument('--config', default='', type=str, help="Path to a configuration file in JSON format.") 42 | parser.add_argument('--job_name', default='default', type=str) 43 | 44 | # Training and validation data paths 45 | parser.add_argument("--training_data_path", type=str, default="train_data.pt", help="Path to the training data file.") 46 | parser.add_argument("--validation_data_path", type=str, default="val_data.pt", help="Path to the validation data file.") 47 | parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training and validation.") 48 | 49 | # Data sources and model configuration 50 | parser.add_argument("--models", nargs="*", type=str, default=[], help="Paths to models for uncertainty calculation.") 51 | parser.add_argument("--frame_directories", nargs="*", type=str, default=[], help="Directories containing frame data.") 52 | parser.add_argument("--ensemble_model", type=str, default=None, help="Path to the model used for ensemble predictions.") 53 | 54 | # Training parameters 55 | parser.add_argument("--validation_split_ratio", type=float, default=0.1, help="Fraction of data to use for validation.") 56 | parser.add_argument("--lr", type=float, default=1e-6, help="Initial learning rate.") 57 | parser.add_argument("--epochs", type=int, default=50, help="Total number of training epochs.") 58 | parser.add_argument("--scheduler_gamma", type=float, default=0.99, help="Learning rate decay factor per epoch.") 59 | parser.add_argument("--validation_interval", type=int, default=10, help="Interval (in epochs) between validations.") 60 | 61 | # Data property weighting 62 | parser.add_argument("--energy_ratio", type=float, default=0.3, help="Weight of energy predictions in the loss function.") 63 | parser.add_argument("--force_ratio", type=float, default=1.0, help="Weight of force predictions in the loss function.") 64 | parser.add_argument("--virial_ratio", type=float, default=0.1, help="Weight of virial predictions in the loss function.") 65 | parser.add_argument("--bulk_energy_ratio", type=float, default=0.01, help="Weight of bulk energy predictions in the loss function.") 66 | parser.add_argument("--bulk_force_ratio", type=float, default=3.0, help="Weight of bulk force predictions in the loss function.") 67 | parser.add_argument("--bulk_virial_ratio", type=float, default=0.01, help="Weight of bulk virial predictions in the loss function.") 68 | 69 | # Additional training settings 70 | parser.add_argument("--max_frames_per_mixture", type=int, default=960, help="Maximum number of frames per mixture.") 71 | parser.add_argument("--frame_validation_interval", type=int, default=3, help="Interval for frame-level validation checks.") 72 | 73 | args = parser.parse_args() 74 | 75 | # Load configuration from a JSON file if specified and the file exists 76 | if os.path.isfile(args.config): 77 | with open(args.config, 'r') as config_file: 78 | config_from_file = json.load(config_file) 79 | 80 | # Update the command line arguments with values from the JSON configuration 81 | for key, value in config_from_file.items(): 82 | # Skip updating args with None values from the configuration file 83 | if value is not None: 84 | setattr(args, key, value) 85 | 86 | return args 87 | 88 | 89 | class DistillationEnsemble: 90 | def __init__(self, args) -> None: 91 | self.args = args 92 | 93 | # Validate required arguments 94 | if not self.args.frame_directories: 95 | raise ValueError("Frame folders must be provided.") 96 | 97 | if not self.args.models: 98 | raise ValueError("Models must be provided.") 99 | 100 | if not self.args.ensemble_model: 101 | raise ValueError("Uncertainty jobs must be provided.") 102 | 103 | self.work_dir = os.path.join(ENSEMBLE_PATH, self.args.job_name) 104 | 105 | self.frames_output = os.path.join(self.work_dir, "frame") 106 | self.checkpoint_output = os.path.join(self.work_dir, "checkpoints") 107 | self.log_output = os.path.join(self.work_dir, "logs") 108 | 109 | make_dirs = [self.frames_output, self.checkpoint_output, self.log_output] 110 | for dir_tmp in make_dirs: 111 | os.makedirs(dir_tmp, exist_ok=True) 112 | 113 | log_file = os.path.join(self.log_output, f"ensemble_{datetime.now().strftime('%m%d%H%M')}.log") 114 | 115 | self.logger = create_logger(name="ENSEMBLE", log_file=log_file) 116 | self.logger.info(f"Initializing.") 117 | 118 | for arg in vars(args): 119 | val = getattr(args, arg) 120 | if isinstance(val, list): 121 | val = '\n\t\t\t' + '\n\t\t\t'.join(map(str, val)) 122 | self.logger.info(f"{arg} = {val}") 123 | else: 124 | self.logger.info(f"{arg} = {val}") 125 | 126 | # Init device 127 | if torch.cuda.is_available(): 128 | self.device = torch.device(f"cuda") 129 | self.logger.info(f'device = cuda') 130 | else: 131 | raise RuntimeError("Cannot find CUDA device.") 132 | 133 | # Training and validation data paths 134 | self.training_data_path = os.path.join(DATA_PATH, args.training_data_path) 135 | self.validation_data_path = os.path.join(DATA_PATH, args.validation_data_path) 136 | 137 | # Placeholder for cluster data 138 | self._train_cluster_data = None 139 | self._val_cluster_data = None 140 | 141 | self.loss_ratios = { 142 | 'energy': args.energy_ratio, 143 | 'forces': args.force_ratio, 144 | 'virial': args.virial_ratio, 145 | } 146 | 147 | self.bulk_energy_ratio = args.bulk_energy_ratio 148 | self.bulk_force_ratio = args.bulk_force_ratio 149 | self.bulk_virial_ratio = args.bulk_virial_ratio 150 | 151 | self.batch_size = args.batch_size # Share batch size for both val and train. 152 | 153 | self.lr = args.lr 154 | self.scheduler_gamma = args.scheduler_gamma 155 | 156 | self.validation_split_ratio = args.validation_split_ratio 157 | self.validation_interval = args.validation_interval 158 | 159 | self.epochs = args.epochs 160 | self.max_frames_per_mixture = args.max_frames_per_mixture 161 | self.frame_validation_interval = args.frame_validation_interval 162 | 163 | # Assign the list of frame directories from command line arguments 164 | self.frame_directories = args.frame_directories 165 | 166 | # Verify each specified frame directory exists 167 | for frame_dir in self.frame_directories: 168 | if not os.path.isdir(frame_dir): 169 | # Raise an error if a specified directory does not exist 170 | raise NotADirectoryError(f"Frame directory {frame_dir} not found.") 171 | 172 | 173 | self.logger.info("Initiating the loading of all models.") 174 | 175 | self.script_models = {} 176 | 177 | # Load all models into memory for efficiency, assuming a manageable total number. 178 | for model in self.args.models: 179 | if model not in self.script_models: 180 | self.script_models[model] = torch.jit.load(model, map_location=self.device) 181 | 182 | if self.args.ensemble_model is None: 183 | self.ensemble_model = self.args.models[0] 184 | self.logger.info(f"Ensemble model not specified, using the first model: {self.ensemble_model}") 185 | else: 186 | self.ensemble_model = self.args.ensemble_model 187 | if self.ensemble_model not in self.script_models: 188 | raise ValueError(f"Ensemble model {self.ensemble_model} not found in models.") 189 | 190 | self.logger.info(f"Number of models: {len(self.script_models)}") 191 | 192 | self._cached_frames = {} 193 | self.uncertainty_train = [] 194 | self.uncertainty_val = [] 195 | self.split_data_flag = False 196 | 197 | @property 198 | def train_cluster_data(self) -> Dict[str, torch.Tensor]: 199 | if self._train_cluster_data is None: 200 | self._train_cluster_data = torch.load(self.training_data_path, map_location="cpu") 201 | return self._train_cluster_data 202 | 203 | @property 204 | def val_cluster_data(self) -> Dict[str, torch.Tensor]: 205 | if self._val_cluster_data is None: 206 | self._val_cluster_data = torch.load(self.validation_data_path, map_location="cpu") 207 | return self._val_cluster_data 208 | 209 | def uncertainty(self): 210 | self.logger.info("Start uncertainty quantification") 211 | 212 | configs = {} 213 | 214 | data_paths = [] 215 | 216 | uncertainty_count = {} 217 | 218 | def add_frame_data(mixture_name: str, output_file: str): 219 | if mixture_name not in uncertainty_count: 220 | uncertainty_count[mixture_name] = 0 221 | uncertainty_count[mixture_name] += 1 222 | 223 | if uncertainty_count[mixture_name] % self.frame_validation_interval: 224 | self.uncertainty_train.append(output_file) 225 | else: 226 | self.uncertainty_val.append(output_file) 227 | 228 | config_file = os.path.join(self.frames_output, "config.json") 229 | if os.path.isfile(config_file): 230 | self.logger.info("Skipping uncertainty quantification, config file exists.") 231 | 232 | with open(config_file, "r") as f: 233 | configs = json.load(f) 234 | 235 | for data_path, config in configs.items(): 236 | mixture_name = config["mixture_name"] 237 | output_file = config["data"] 238 | add_frame_data(mixture_name, output_file) 239 | 240 | return 241 | 242 | # Recursively scan the frames folder to locate all data files. 243 | for frame_directory in self.frame_directories: 244 | if not os.path.isdir(frame_directory): 245 | raise NotADirectoryError(f"Frame directory {frame_directory} not found.") 246 | for root, _, files in os.walk(frame_directory): 247 | for file in files: 248 | if file.endswith(".pt"): 249 | data_paths.append(os.path.join(root, file)) 250 | random.shuffle(data_paths) 251 | 252 | self.logger.info(f"Number of data: {len(data_paths)}") 253 | 254 | self.logger.info("Starting inference for uncertainty quantification.") 255 | 256 | for data_path in data_paths: # Assuming data_paths is defined and passed correctly. 257 | # Load data directly onto the specified device. 258 | single_data: Dict[str, torch.Tensor] = torch.load(data_path, map_location=self.device) 259 | mixture_name = single_data["mixture_name"] 260 | 261 | # Skip processing if max frames per mixture limit is reached. 262 | if uncertainty_count.get(mixture_name, 0) > self.max_frames_per_mixture: 263 | self.logger.info(f"Skipping {mixture_name}: max frames per mixture reached.") 264 | continue 265 | 266 | ensemble_pred = { 267 | 'energy': [], 268 | 'forces': [], 269 | 'virial': [], 270 | } 271 | 272 | # Collect predictions from all modules. 273 | for model in self.script_models.values(): 274 | pred = model.forward(single_data["inputs"]) 275 | 276 | ensemble_pred['energy'].append(torch.flatten(pred['pred_energy'].detach())) 277 | ensemble_pred['forces'].append(torch.flatten(pred['pred_forces'].detach())) 278 | ensemble_pred['virial'].append(torch.flatten(pred['pred_virial'].detach())) 279 | del pred 280 | ensemble_results = {} 281 | 282 | # Compute ensemble statistics for each property (energy, forces, virial). 283 | for property_name, predictions in ensemble_pred.items(): 284 | # Stack all predictions for the current property along a new dimension. 285 | predictions_tensor = torch.stack(predictions, dim=0) 286 | 287 | # Calculate the mean and standard deviation along the stacked dimension. 288 | mean_list = torch.mean(predictions_tensor, dim=0).detach().cpu().tolist() 289 | std_list = torch.std(predictions_tensor, dim=0).detach().cpu().tolist() 290 | 291 | # Special handling for force and virial to group results in triplets. 292 | if property_name in ["force", "virial"]: 293 | mean_list = [mean_list[i:i+3] for i in range(0, len(mean_list), 3)] 294 | std_list = [std_list[i:i+3] for i in range(0, len(std_list), 3)] 295 | 296 | # Store the computed mean and standard deviation. 297 | ensemble_results[property_name] = { 298 | "mean": mean_list, 299 | "std": std_list 300 | } 301 | 302 | ensemble_results.update(single_data) 303 | 304 | # Save file {frames}/a/b/c.pt to {output}/c_{index}.pt 305 | file_base_name = os.path.splitext(os.path.basename(data_path))[0] 306 | index = 0 307 | 308 | def get_output_file(index: int) -> str: 309 | return os.path.join(self.frames_output, f"{file_base_name}_{index}.pt") 310 | 311 | output_file = get_output_file(index) 312 | while os.path.isfile(output_file): 313 | index += 1 314 | output_file = get_output_file(index) 315 | 316 | configs[data_path] = {"mixture_name": mixture_name, "data": output_file} 317 | self.logger.info(f"Source: {data_path}, output: {output_file}") 318 | torch.save(ensemble_results, output_file) 319 | add_frame_data(mixture_name, output_file) 320 | 321 | del single_data 322 | # Save configs 323 | self.logger.info(f"Save config to {config_file}.") 324 | with open(config_file, "w") as config_fp: 325 | json.dump(configs, config_fp, indent=4) 326 | self.logger.info("Uncertainty quantification finished") 327 | 328 | def cluster_validation(self, model: BambooGET, cluster: Dict[str, torch.Tensor]) -> Dict[str, float]: 329 | keys = ['energy', 'forces', 'virial', 'dipole'] 330 | val_rmse = {k: [] for k in keys} 331 | val_data_size = len(cluster['total_charge']) 332 | total_step = val_data_size // self.batch_size 333 | for step in range(total_step): 334 | batch_data = batchify(cluster, step*self.batch_size, (step+1)*self.batch_size, device=self.device) 335 | mse, _, _ = model.get_loss(batch_data) 336 | for k in keys: 337 | val_rmse[k].append(mse[k].item() * self.batch_size) 338 | total_val_rmse = {} 339 | total_val_rmse["cluster_energy_rmse"] = np.sqrt(sum(val_rmse["energy"]) / self.batch_size / total_step) 340 | total_val_rmse["cluster_force_rmse"] = np.sqrt(sum(val_rmse["forces"]) / self.batch_size / total_step) 341 | total_val_rmse["cluster_virial_rmse"] = np.sqrt(sum(val_rmse["virial"]) / self.batch_size / total_step) 342 | total_val_rmse["cluster_dipole_rmse"] = np.sqrt(sum(val_rmse["dipole"]) / self.batch_size / total_step) 343 | return total_val_rmse 344 | 345 | def bulk_validation(self, model: BambooGET, files: List[str]) -> Dict[str, float]: 346 | val_forces_rmse = [] 347 | val_energy_rmse = [] 348 | val_virial_rmse = [] 349 | 350 | for file in files: 351 | single_data_pt = self.load_frame(file) 352 | inputs = {k: v.to(self.device) for k, v in single_data_pt["inputs"].items()} 353 | 354 | pred = model.forward(inputs) 355 | mse_forces = torch.mean(torch.square(pred['pred_forces'].flatten() - torch.tensor(single_data_pt['forces']['mean'], device=self.device).flatten())) 356 | mse_energy = torch.mean(torch.square(pred['pred_energy'].flatten() - torch.tensor(single_data_pt['energy']['mean'], device=self.device).flatten())) 357 | mse_virial = torch.mean(torch.square(pred['pred_virial'].flatten() - torch.tensor(single_data_pt['virial']['mean'], device=self.device).flatten())) 358 | 359 | val_forces_rmse.append(mse_forces.item()) 360 | val_energy_rmse.append(mse_energy.item()) 361 | val_virial_rmse.append(mse_virial.item()) 362 | del single_data_pt, inputs, pred 363 | 364 | result = {} 365 | result['force_rmse'] = np.sqrt(np.mean(val_forces_rmse)) 366 | result['energy_rmse'] = np.sqrt(np.mean(val_energy_rmse)) 367 | result['virial_rmse'] = np.sqrt(np.mean(val_virial_rmse)) 368 | return result 369 | 370 | def should_evaluate_model(self, epoch: int) -> bool: 371 | """Determine if the model should be evaluated based on the epoch.""" 372 | is_validation_epoch = (epoch % self.validation_interval == 0) 373 | is_last_epoch = (epoch == self.epochs - 1) 374 | return is_validation_epoch or is_last_epoch 375 | 376 | def run(self): 377 | if not self.uncertainty_train: 378 | raise ValueError("No uncertainty_train frames available.") 379 | 380 | # Save the ensembled model to {save_dir}/ensembled.pt 381 | self.logger.info(f"Start finetuning for model: {self.ensemble_model}") 382 | checkpoint_path = os.path.join(self.checkpoint_output, f"ensemble.pt") 383 | if os.path.isfile(checkpoint_path): 384 | self.logger.info(f"checkpoint already exists: {checkpoint_path}") 385 | return 386 | 387 | script_model = self.script_models[self.ensemble_model] 388 | model = convert(script_model, device=self.device) 389 | model.train() 390 | 391 | training_curve = { 392 | 'epoch':[], 393 | 'force_rmse':[], 394 | 'energy_rmse':[], 395 | 'virial_rmse':[], 396 | 'cluster_force_rmse':[], 397 | 'cluster_energy_rmse':[], 398 | 'cluster_virial_rmse':[], 399 | 'cluster_dipole_rmse':[] 400 | } 401 | 402 | # Call logger to log the training_curve 403 | def log_train_curve() -> None: 404 | log_string = "" 405 | for k, v in training_curve.items(): 406 | if isinstance(v[-1], float): 407 | log_string += f"{k}: {v[-1]:.4f} " 408 | else: 409 | log_string += f"{k}: {v[-1]} " 410 | self.logger.info(log_string) 411 | 412 | def add_train_curve(epoch: int, bulk: Dict[str, float], cluster: Dict[str, float]) -> None: 413 | training_curve['epoch'].append(epoch) 414 | training_curve['force_rmse'].append(bulk['force_rmse']) 415 | training_curve['energy_rmse'].append(bulk['energy_rmse']) 416 | training_curve['virial_rmse'].append(bulk['virial_rmse']) 417 | training_curve['cluster_force_rmse'].append(cluster['cluster_force_rmse']) 418 | training_curve['cluster_energy_rmse'].append(cluster['cluster_energy_rmse']) 419 | training_curve['cluster_virial_rmse'].append(cluster['cluster_virial_rmse']) 420 | training_curve['cluster_dipole_rmse'].append(cluster['cluster_dipole_rmse']) 421 | 422 | # Initialize the training process 423 | energy_mlp_params = list(model.energy_mlp.parameters()) 424 | optimizer = torch.optim.Adam(energy_mlp_params, lr=self.lr) 425 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.scheduler_gamma) 426 | 427 | # Evaluate the initial model. 428 | bulk_val_rmse = self.bulk_validation(model, self.uncertainty_val) 429 | cluster_val_rmse = self.cluster_validation(model, self.val_cluster_data) 430 | 431 | add_train_curve(-1, bulk_val_rmse, cluster_val_rmse) 432 | log_train_curve() 433 | 434 | # Start ensemble training 435 | total_cluster_data = len(self.train_cluster_data['total_charge']) 436 | cluster_batch_num = total_cluster_data // self.batch_size 437 | cluster_random_index = list(range(cluster_batch_num)) 438 | random.shuffle(cluster_random_index) 439 | n_cluster_index = 0 440 | 441 | for epoch in range(self.epochs): 442 | train_order = list(range(len(self.uncertainty_train))) 443 | random.shuffle(train_order) 444 | for idx in train_order: 445 | file = self.uncertainty_train[idx] 446 | single_data_pt = torch.load(file, map_location=self.device) 447 | inputs = single_data_pt["inputs"] 448 | 449 | natoms = len(inputs['atom_types']) 450 | #training on bulk traj data 451 | optimizer.zero_grad() 452 | pred = model.forward(inputs) 453 | mse_forces = torch.mean(torch.square(pred['pred_forces'].flatten() - torch.tensor(single_data_pt['forces']['mean'], device=self.device).flatten())) 454 | mse_energy = torch.mean(torch.square(pred['pred_energy'].flatten() - torch.tensor(single_data_pt['energy']['mean'], device=self.device).flatten())) 455 | mse_virial = torch.mean(torch.square(pred['pred_virial'].flatten() - torch.tensor(single_data_pt['virial']['mean'], device=self.device).flatten())) 456 | 457 | bulk_loss: torch.Tensor = mse_forces * self.bulk_force_ratio \ 458 | + mse_energy / natoms * self.bulk_energy_ratio \ 459 | + mse_virial / natoms * self.bulk_virial_ratio 460 | 461 | bulk_loss.backward() 462 | optimizer.step() 463 | optimizer.zero_grad() 464 | del inputs, pred, single_data_pt 465 | 466 | # Training on cluster data 467 | cluster_loss = torch.tensor(0.0, device=self.device) 468 | n_cluster_index = (n_cluster_index + 1) % cluster_batch_num 469 | start = cluster_random_index[n_cluster_index] * self.batch_size 470 | end = start + self.batch_size 471 | batch_data = batchify(self.train_cluster_data, start, end, device=self.device) 472 | mse, _, _ = model.get_loss(batch_data) 473 | for k in self.loss_ratios.keys(): 474 | cluster_loss += self.loss_ratios[k] * mse[k] 475 | del batch_data 476 | cluster_loss.backward() 477 | optimizer.step() 478 | 479 | scheduler.step() 480 | if self.should_evaluate_model(epoch): 481 | bulk_val_rmse = self.bulk_validation(model, self.uncertainty_val) 482 | cluster_val_rmse = self.cluster_validation(model, self.val_cluster_data) 483 | 484 | add_train_curve(epoch, bulk_val_rmse, cluster_val_rmse) 485 | log_train_curve() 486 | 487 | # Save ensembled model. 488 | script_model = torch.jit.script(model) 489 | torch.jit.save(script_model, checkpoint_path) 490 | self.logger.info(f"Ensembled model saved to {checkpoint_path}") 491 | 492 | # Save training curve 493 | curve_path = os.path.join(self.log_output, "training_curve.csv") 494 | df = pd.DataFrame(training_curve) 495 | df.to_csv(curve_path, index=False) 496 | 497 | def load_frame(self, file: str) -> Dict[str, torch.Tensor]: 498 | # Load a frame from the specified file. If the frame is not already cached, 499 | # it loads it into the cache. 500 | if file not in self._cached_frames: 501 | self._cached_frames[file] = torch.load(file, map_location='cpu') 502 | 503 | return self._cached_frames[file] 504 | 505 | 506 | def main(): 507 | args = get_parser() 508 | distiller = DistillationEnsemble(args) 509 | distiller.uncertainty() 510 | distiller.run() 511 | 512 | 513 | if __name__ == "__main__": 514 | main() 515 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import argparse 19 | import json 20 | import logging 21 | import math 22 | import os 23 | from datetime import datetime 24 | from random import shuffle 25 | from typing import Dict, Optional 26 | 27 | import numpy as np 28 | import torch 29 | import torch.nn as nn 30 | 31 | from models.bamboo_get import BambooGET 32 | from utils.batchify import batchify 33 | from utils.log_helper import create_logger 34 | from utils.path import DATA_PATH, TRAIN_PATH 35 | 36 | 37 | def get_parser(): 38 | parser = argparse.ArgumentParser(description='Arguments for bamboo model training') 39 | 40 | # general arguments 41 | parser.add_argument('--config', default='', type=str) 42 | parser.add_argument('--job_name', default='default', type=str) 43 | parser.add_argument('--data_training', default='train_data.pt') 44 | parser.add_argument('--data_validation', default='val_data.pt') 45 | parser.add_argument('--random_seed', default=42, type=int) 46 | 47 | # training arguments 48 | parser.add_argument('--train_batch_size', default=128, type=int) 49 | parser.add_argument('--val_batch_size', default=128, type=int) 50 | parser.add_argument('--num_epoch', default=750, type=int) 51 | parser.add_argument('--lr', default=1e-2, type=float) 52 | parser.add_argument('--weight_decay', default=1e-3, type=float) 53 | parser.add_argument('--scheduler_gamma', default=0.99, type=float) 54 | parser.add_argument('--loss_charge_ratio', default=10.0, type=float) 55 | parser.add_argument('--loss_dipole_ratio', default=10.0, type=float) 56 | parser.add_argument('--loss_energy_ratio', default=0.01, type=float) 57 | parser.add_argument('--loss_forces_ratio', default=0.3, type=float) 58 | parser.add_argument('--loss_virial_ratio', default=0.01, type=float) 59 | parser.add_argument('--charge_ub', default=2.0, type=float) 60 | parser.add_argument('--qeq_force_regularizer', default=300.0, type=float) 61 | 62 | # model arguments 63 | parser.add_argument('--num_layers', type=int, default=3) 64 | parser.add_argument('--num_rbf', type=int, default=32) 65 | parser.add_argument('--emb_dim', type=int, default=64) 66 | parser.add_argument('--num_heads', type=int, default=16) 67 | parser.add_argument('--rcut', type=float, default=5.0) 68 | parser.add_argument('--coul_damping_beta', type=float, default=18.7) 69 | parser.add_argument('--coul_damping_r0', type=float, default=2.2) 70 | parser.add_argument('--disp_cutoff', type=float, default=10.0) 71 | parser.add_argument('--energy_mlp_layers', type=int, default=2) 72 | parser.add_argument('--charge_mlp_layers', type=int, default=2) 73 | 74 | args = parser.parse_args() 75 | 76 | # if config file is provided, read args from config file 77 | if os.path.isfile(args.config): 78 | with open(args.config, 'r') as f: 79 | config_args = json.load(f) 80 | for key, value in config_args.items(): 81 | if value is None: 82 | continue 83 | setattr(args, key, value) 84 | return args 85 | 86 | class BambooTrainer(): 87 | """ 88 | Basic trainer for bamboo 89 | """ 90 | def __init__(self, args): 91 | self.args = args 92 | 93 | # Init log and checkpoint directory 94 | job_path = os.path.join(TRAIN_PATH, self.args.job_name).lower() 95 | if not os.path.exists(job_path): 96 | os.makedirs(job_path, exist_ok=True) 97 | train_log_path = os.path.join(job_path, 'train_logs') 98 | if not os.path.exists(train_log_path): 99 | os.makedirs(train_log_path, exist_ok=True) 100 | 101 | log_file = os.path.join(train_log_path, f"train_{datetime.now().strftime('%m%d%H%M')}.log") 102 | self.logger = create_logger(name="TRAIN", log_file=log_file) 103 | self.logger.info(f"Initializing.") 104 | for k_args, v_args in vars(args).items(): 105 | self.logger.info(f'{k_args} = {v_args}') 106 | 107 | ckpt_path = os.path.join(job_path, 'checkpoints') 108 | if not os.path.exists(ckpt_path): 109 | os.makedirs(ckpt_path, exist_ok=True) 110 | self.ckpt_path = ckpt_path 111 | 112 | # Init device 113 | if torch.cuda.is_available(): 114 | self.device = torch.device(f"cuda") 115 | self.logger.info(f'device = cuda') 116 | else: 117 | raise RuntimeError("Cannot find CUDA device.") 118 | 119 | # Init random seed 120 | torch.manual_seed(args.random_seed) 121 | np.random.seed(args.random_seed) 122 | 123 | # Init loss ratios 124 | self.loss_ratios = dict() 125 | self.loss_ratios['energy'] = self.args.loss_energy_ratio 126 | self.loss_ratios['forces'] = self.args.loss_forces_ratio 127 | self.loss_ratios['virial'] = self.args.loss_virial_ratio 128 | self.loss_ratios['charge'] = self.args.loss_charge_ratio 129 | self.loss_ratios['dipole'] = self.args.loss_dipole_ratio 130 | self.loss_unit = { 131 | 'energy': 'kcal/mol', 132 | 'forces': 'kcal/mol/Ang', 133 | 'virial': 'kcal/mol', 134 | 'charge': 'a.u.', 135 | 'dipole': 'Debye', 136 | } 137 | self.qeq_force_regularizer = self.args.qeq_force_regularizer 138 | 139 | # Init dataset 140 | self.train_data = torch.load(os.path.join(DATA_PATH, self.args.data_training), map_location='cpu') 141 | self.val_data = torch.load(os.path.join(DATA_PATH, self.args.data_validation), map_location='cpu') 142 | 143 | # Init model 144 | nn_params = { 145 | 'dim': self.args.emb_dim, 146 | 'num_rbf': self.args.num_rbf, 147 | 'rcut': self.args.rcut, 148 | 'charge_ub': self.args.charge_ub, 149 | 'act_fn': nn.SiLU(), 150 | 'charge_mlp_layers': self.args.charge_mlp_layers, 151 | 'energy_mlp_layers': self.args.energy_mlp_layers, 152 | } 153 | gnn_params = { 154 | 'n_layers': self.args.num_layers, 155 | 'num_heads': self.args.num_heads, 156 | 'act_fn': nn.SiLU(), 157 | } 158 | coul_disp_params = { 159 | 'coul_damping_beta': self.args.coul_damping_beta, 160 | 'coul_damping_r0': self.args.coul_damping_r0, 161 | 'disp_cutoff': self.args.disp_cutoff, 162 | } 163 | self.model = BambooGET(device = self.device, 164 | coul_disp_params = coul_disp_params, 165 | nn_params = nn_params, 166 | gnn_params = gnn_params) 167 | 168 | # Init optimizer and scheduler 169 | self.optimizer = torch.optim.Adamax(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 170 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.args.scheduler_gamma) 171 | 172 | def print_losses(self, losses: Dict[str, float], prefix: str): 173 | message = prefix 174 | for k in self.loss_ratios.keys(): 175 | message += f' {k} {losses[k]:.4f} {self.loss_unit[k]},' 176 | self.logger.info(message) 177 | 178 | def prepare_data(self, data: Dict[str, torch.Tensor], batch_size: int): 179 | # Simple data split 180 | data_size = len(data['total_charge']) 181 | steps = (data_size-1) // batch_size + 1 182 | start = [i * batch_size for i in range(steps)] 183 | end = [(i+1) * batch_size for i in range(steps-1)] + [data_size] 184 | return steps, start, end, data_size 185 | 186 | def validate_one_epoch(self, epoch: int) -> Dict[str, float]: 187 | self.logger.info(f"[Val Start] Epoch {epoch+1}.") 188 | val_rmse = {k: [] for k in self.loss_ratios.keys()} 189 | val_mae = {k: [] for k in self.loss_ratios.keys()} 190 | steps, start, end, data_size = self.prepare_data(self.val_data, self.args.val_batch_size) 191 | 192 | for step in range(steps): 193 | batch_data = batchify(self.val_data, start[step], end[step], device=self.device) 194 | data_length = len(batch_data["total_charge"]) 195 | mse, mae, _ = self.model.get_loss(batch_data) 196 | for k in mse.keys(): 197 | val_rmse[k].append(mse[k].item() * data_length) 198 | val_mae[k].append(mae[k].item() * data_length) 199 | 200 | for k in val_rmse.keys(): 201 | val_rmse[k] = sum(val_rmse[k]) / data_size 202 | val_mae[k] = sum(val_mae[k]) / data_size 203 | val_rmse['weighted'] = sum([self.loss_ratios[k] * val_rmse[k] for k in self.loss_ratios.keys()]) 204 | for k in val_rmse: 205 | val_rmse[k] = math.sqrt(val_rmse[k]) 206 | 207 | self.logger.info(f"[Val End] Epoch {epoch+1}, Total: {data_size} clusters, {steps} batches.") 208 | self.print_losses(val_rmse, prefix=f"[Val RMSE] Epoch {epoch+1}, ") 209 | self.print_losses(val_mae, prefix=f"[Val MAE] Epoch {epoch+1}, ") 210 | return val_rmse 211 | 212 | def train_one_epoch(self, epoch: int) -> Dict[str, float]: 213 | def closure(): 214 | self.optimizer.zero_grad() 215 | mse, mae, penalty = self.model.get_loss(batch_data) 216 | data_length = len(batch_data['total_charge']) 217 | qeq_force = penalty['qeq_force'] 218 | loss = 0. 219 | for k in mse.keys(): 220 | train_rmse[k].append(mse[k].item() * data_length) 221 | train_mae[k].append(mae[k].item() * data_length) 222 | loss += self.loss_ratios[k] * mse[k] 223 | loss += qeq_force * self.qeq_force_regularizer 224 | loss.backward() 225 | return loss 226 | 227 | self.logger.info(f"[Train Start] Epoch {epoch+1}.") 228 | train_rmse = {k: [] for k in self.loss_ratios.keys()} 229 | train_mae = {k: [] for k in self.loss_ratios.keys()} 230 | steps, start, end, data_size = self.prepare_data(self.train_data, self.args.train_batch_size) 231 | steps_shuffle = list(range(steps)) 232 | shuffle(steps_shuffle) 233 | 234 | for step in steps_shuffle: 235 | batch_data = batchify(self.train_data, start[step], end[step], device=self.device) 236 | self.optimizer.step(closure) 237 | 238 | for k in train_rmse.keys(): 239 | train_rmse[k] = sum(train_rmse[k]) / data_size 240 | train_mae[k] = sum(train_mae[k]) / data_size 241 | train_rmse['weighted'] = sum([self.loss_ratios[k] * train_rmse[k] for k in self.loss_ratios.keys()]) 242 | for k in train_rmse: 243 | train_rmse[k] = math.sqrt(train_rmse[k]) 244 | 245 | self.logger.info(f"[Train End] Epoch {epoch+1}, Total: {data_size} clusters, {steps} batches.") 246 | self.print_losses(train_rmse, prefix=f"[Train RMSE] Epoch {epoch+1}, ") 247 | self.print_losses(train_mae, prefix=f"[Train MAE] Epoch {epoch+1}, ") 248 | 249 | def train(self, epochs: int): 250 | val_rmse = self.validate_one_epoch(epoch = -1) 251 | best_rmse = val_rmse 252 | 253 | for epoch in range(epochs): 254 | self.train_one_epoch(epoch = epoch) 255 | self.scheduler.step() 256 | val_rmse = self.validate_one_epoch(epoch = epoch) 257 | if val_rmse['weighted'] < best_rmse['weighted']: 258 | self.logger.info(f"Found best weighted RMSE {best_rmse['weighted']:.4f} at epoch {epoch+1}") 259 | 260 | # Save ckpts every epoch 261 | ckpt_filename = os.path.join(self.ckpt_path, f"epoch_{epoch+1}_loss_{int(1000*val_rmse['weighted'])}.pt") 262 | module = torch.jit.script(self.model) 263 | module.eval() 264 | module.save(ckpt_filename) 265 | self.logger.info(f'Epoch {epoch+1}, checkpoint saved at {ckpt_filename}') 266 | 267 | finish_message = f"Training finished." 268 | for k in self.loss_ratios.keys(): 269 | finish_message += f" Best {k} RMSE {best_rmse[k]:4f}. " 270 | finish_message += f" Best weighted RMSE {best_rmse['weighted']:.4f}." 271 | self.logger.info(finish_message) 272 | 273 | def main(): 274 | args = get_parser() 275 | bamboo_trainer = BambooTrainer(args) 276 | bamboo_trainer.train(args.num_epoch) 277 | 278 | if __name__ == "__main__": 279 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | from . import constant, funcs 19 | -------------------------------------------------------------------------------- /utils/batchify.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | 19 | from typing import Dict 20 | 21 | import torch 22 | 23 | 24 | def batchify(data: dict, start: int, end: int, device: torch.device) -> Dict[str, torch.Tensor]: 25 | batch_data = dict() 26 | 27 | # molecule-level keys 28 | mol_keys = ['total_charge', 'energy', 'virial', 'dipole'] 29 | for k in list(set(mol_keys) & set(data.keys())): 30 | batch_data[k] = data[k][start:end] 31 | 32 | # atom-level keys 33 | atom_keys = ['pos', 'atom_types', 'forces', 'charge'] 34 | for k in list(set(atom_keys) & set(data.keys())): 35 | batch_data[k] = data[k][data['cumsum_atom'][start]: data['cumsum_atom'][end]] 36 | batch_data['mol_ids'] = data['mol_ids'][data['cumsum_atom'][start]: data['cumsum_atom'][end]] - start 37 | 38 | # edge-level keys 39 | batch_data['edge_index'] = (data['edge_index'][data['cumsum_edge'][start]: data['cumsum_edge'][end]] \ 40 | - data['cumsum_atom'][start].unsqueeze(-1)).transpose(-2, -1).type(torch.long) 41 | if 'edge_cell_shift' in data: 42 | batch_data['edge_cell_shift'] = data['edge_cell_shift'][data['cumsum_edge'][start]: data['cumsum_edge'][end]] 43 | else: 44 | row, col = batch_data['edge_index'][0], batch_data['edge_index'][1] 45 | batch_data['edge_cell_shift'] = batch_data['pos'][row] - batch_data['pos'][col] 46 | batch_data['all_edge_index'] = (data['all_edge_index'][data['cumsum_all_edge'][start]: data['cumsum_all_edge'][end]] \ 47 | - data['cumsum_atom'][start].unsqueeze(-1)).transpose(-2, -1).type(torch.long) 48 | if 'all_edge_cell_shift' in data: 49 | batch_data['all_edge_cell_shift'] = data['all_edge_cell_shift'][data['cumsum_all_edge'][start]: data['cumsum_all_edge'][end]] 50 | else: 51 | row_all, col_all = batch_data['all_edge_index'][0], batch_data['all_edge_index'][1] 52 | batch_data['all_edge_cell_shift'] = batch_data['pos'][row_all] - batch_data['pos'][col_all] 53 | 54 | for k in batch_data.keys(): 55 | batch_data[k] = batch_data[k].to(device) 56 | if torch.is_floating_point(batch_data[k]): 57 | batch_data[k] = batch_data[k].to(torch.get_default_dtype()) 58 | elif batch_data[k].dtype == torch.int16: 59 | batch_data[k] = batch_data[k].to(torch.int64) 60 | return batch_data -------------------------------------------------------------------------------- /utils/constant.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | # constants for approximating erfc function 19 | ewald_f = 1.12837917 20 | ewald_p = 0.3275911 21 | ewald_a = [0.254829592, -0.284496736, 1.421413741, -1.453152027, 1.061405429] 22 | 23 | # constants in SI unit 24 | angstrom = 1.0e-10 25 | bohr = 5.29177e-11 26 | electron_charge = 1.60217663e-19 27 | kcal_mol = 6.9477e-21 28 | debye = 3.33564e-30 29 | coulomb_constant = 8.9875517923e+9 30 | atm_pressure = 1.01325e+5 31 | hartree = 4.3597447222071e-18 32 | kcal_mol = 6.9477e-21 33 | 34 | # unit conversion used in Bamboo 35 | # debye_ea: 0.20819427381112157 36 | debye_ea = debye / (electron_charge * angstrom) 37 | hartree_kcal_mol = hartree / kcal_mol 38 | bohr_angstrom = bohr / angstrom 39 | 40 | # ele_factor: 332.06349451357806 41 | ele_factor = coulomb_constant * electron_charge * electron_charge / kcal_mol / angstrom 42 | 43 | # nktv2p: 68568.46780162843 44 | nktv2p = kcal_mol / angstrom / angstrom / angstrom / atm_pressure 45 | 46 | 47 | nelems = 87 # placeholder, H to Rn 48 | 49 | # Hardcode Li and LI. 50 | atom_mapper = {'H': 1, 'He': 2, 'LI': 3, 51 | 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Ne': 10, 52 | 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P' : 15, 'S': 16, 'Cl': 17, 'Ar': 18, 53 | 'K': 19, 'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 54 | 'Mn': 25, 'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29, 'Zn': 30, 55 | 'Ga': 31, 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, 56 | 'Rb': 37, 'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43, 'Ru': 44, 57 | 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49, 'Sn': 50, 'Sb': 51, 'Te':52, 58 | 'I': 53, 'Xe': 54, 'Cs': 55, 'Ba': 56, 'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 59 | 'Pm': 61, 'Sm': 62, 'Eu': 63, 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67, 'Er': 68, 60 | 'Tm': 69, 'Yb': 70, 'Lu': 71, 'Hf': 72, 'Ta': 73, 'W': 74, 'Re': 75, 'Os': 76, 61 | 'Ir': 77, 'Pt': 78, 'Au': 79, 'Hg': 80, 'Tl': 81, 'Pb': 82, 'Bi': 83, 'Po': 84, 'At': 85, 'Rn': 86} 62 | -------------------------------------------------------------------------------- /utils/funcs.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import math 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | 24 | class CosineCutoff(nn.Module): 25 | def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0): 26 | super(CosineCutoff, self).__init__() 27 | self.cutoff_lower = cutoff_lower 28 | self.cutoff_upper = cutoff_upper 29 | 30 | def forward(self, distances): 31 | if self.cutoff_lower > 0: 32 | cutoffs = 0.5 * ( 33 | torch.cos( 34 | math.pi 35 | * ( 36 | 2 37 | * (distances - self.cutoff_lower) 38 | / (self.cutoff_upper - self.cutoff_lower) 39 | + 1.0 40 | ) 41 | ) 42 | + 1.0 43 | ) 44 | # remove contributions below the cutoff radius 45 | cutoffs = cutoffs * (distances < self.cutoff_upper).float() 46 | cutoffs = cutoffs * (distances > self.cutoff_lower).float() 47 | return cutoffs 48 | else: 49 | cutoffs = 0.5 * \ 50 | (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0) 51 | # remove contributions beyond the cutoff radius 52 | cutoffs = cutoffs * (distances < self.cutoff_upper).float() 53 | return cutoffs 54 | 55 | class ExpNormalSmearing(nn.Module): 56 | def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=32, trainable=True, device=torch.device("cpu")): 57 | super(ExpNormalSmearing, self).__init__() 58 | self.cutoff_lower = cutoff_lower 59 | self.cutoff_upper = cutoff_upper 60 | self.num_rbf = num_rbf 61 | self.trainable = trainable 62 | self.device = device 63 | 64 | self.cutoff_fn = CosineCutoff(0, cutoff_upper) 65 | self.alpha = 5.0 / (cutoff_upper - cutoff_lower) 66 | 67 | means, betas = self._initial_params() 68 | if trainable: 69 | self.register_parameter("means", nn.Parameter(means)) 70 | self.register_parameter("betas", nn.Parameter(betas)) 71 | else: 72 | self.register_buffer("means", means) 73 | self.register_buffer("betas", betas) 74 | 75 | def _initial_params(self): 76 | # initialize means and betas according to the default values in PhysNet 77 | # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 78 | start_value = torch.exp( 79 | torch.scalar_tensor(-self.cutoff_upper + self.cutoff_lower) 80 | ).to(self.device) 81 | means = torch.linspace(start_value, 1, self.num_rbf).to(self.device) 82 | betas = torch.tensor( 83 | [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf 84 | ).to(self.device) 85 | return means, betas 86 | 87 | def reset_parameters(self): 88 | means, betas = self._initial_params() 89 | self.means.data.copy_(means) 90 | self.betas.data.copy_(betas) 91 | 92 | def forward(self, dist): 93 | dist = dist.unsqueeze(-1) 94 | return self.cutoff_fn(dist) * torch.exp( 95 | -self.betas 96 | * (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2 97 | ) -------------------------------------------------------------------------------- /utils/load_traj.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import argparse 19 | import logging 20 | import os 21 | from typing import List, Optional, Tuple 22 | 23 | import numpy as np 24 | import torch 25 | from scipy.spatial import KDTree 26 | 27 | from utils.constant import atom_mapper 28 | 29 | 30 | class Domain: 31 | def __init__(self) -> None: 32 | self.xlo = None 33 | self.xhi = None 34 | self.ylo = None 35 | self.yhi = None 36 | self.zlo = None 37 | self.zhi = None 38 | 39 | @property 40 | def box(self) -> List[List[float]]: 41 | if None in [self.xlo, self.xhi, self.ylo, self.yhi, self.zlo, self.zhi]: 42 | print("Domain is not initialized") 43 | return [[self.xlo, self.xhi], [self.ylo, self.yhi], [self.zlo, self.zhi]] # type: ignore 44 | 45 | 46 | class Frame: 47 | def __init__(self, filter_config: Optional[dict] = None) -> None: 48 | """ 49 | Initialize a Frame object with optional filter configuration. 50 | 51 | :param filter_config: A dictionary containing filter configuration options. 52 | """ 53 | # Initialize instance attributes with default values or provided arguments 54 | self.step: Optional[int] = None 55 | self.header: Optional[list] = None 56 | self.filter_config: dict = filter_config if filter_config is not None else {} 57 | self.domain: Domain = Domain() 58 | 59 | self.data = None 60 | 61 | def _convert_dtype(self, x: str): 62 | # Define type conversion in a separate method for clarity 63 | int_types = ["id", "type", "ix", "iy", "iz"] 64 | float_types = ["xu", "yu", "zu", "xs", "ys", "zs", "x", "y", "z", "vx", "vy", "vz", "fx", "fy", "fz", "q", "charge"] 65 | 66 | if x in int_types: 67 | return (x, np.int32) 68 | elif x in float_types: 69 | return (x, np.float32) 70 | else: 71 | raise ValueError(f"Unknown type {x}") 72 | 73 | def _filter_data(self): 74 | # Implement data filtering in a separate method for clarity 75 | if "field" in self.filter_config: 76 | self.data = self.data[self.filter_config["field"]] 77 | if "type" in self.filter_config: 78 | self.data = self.data[np.isin(self.data["type"], self.filter_config["type"])] 79 | 80 | def parse(self, data: list): 81 | self.step = int(data[1]) 82 | self.natoms = int(data[3]) 83 | 84 | self.domain.xlo, self.domain.xhi = map(float, data[5].split()) 85 | self.domain.ylo, self.domain.yhi = map(float, data[6].split()) 86 | self.domain.zlo, self.domain.zhi = map(float, data[7].split()) 87 | 88 | # Filter by frame_interval if specified 89 | if self.filter_config.get("frame_interval", 0) and self.step % self.filter_config["frame_interval"] != 0: 90 | return 91 | 92 | # Process header and set dynamic properties 93 | self.header = data[8].strip("\n").split()[2:] 94 | for prop_name in self.header: 95 | def make_property(name): 96 | return property(lambda self: self.data[name]) 97 | setattr(self.__class__, prop_name, make_property(prop_name)) 98 | 99 | my_dtype = [self._convert_dtype(x) for x in self.header] 100 | self.data = np.genfromtxt(data, skip_header=9, dtype=my_dtype) 101 | self.data = np.sort(self.data, order="id") 102 | 103 | self._filter_data() 104 | 105 | class LammpsDump: 106 | def __init__(self, file_path: str, filter_config: Optional[dict] = None): 107 | """ 108 | Initialize a LammpsDump object to parse LAMMPS dump files. 109 | 110 | :param file_path: Path to the LAMMPS dump file. 111 | :param filter_config: Optional dictionary for filtering data during parsing. 112 | """ 113 | self.file_path = file_path 114 | self.filter_config = filter_config if filter_config is not None else {} 115 | self.frames = [] 116 | self.domain = Domain() 117 | 118 | with open(file_path, 'r') as file: 119 | self.data = file.readlines() 120 | 121 | self.length = len(self.data) 122 | if self.length < 10: 123 | raise ValueError("File too short to contain valid data.") 124 | 125 | self.natoms = int(self._line(3)) 126 | self.nframes = self.length // (self.natoms + 9) 127 | self.domain.xlo, self.domain.xhi = map(float, self._line(5).split()) 128 | self.domain.ylo, self.domain.yhi = map(float, self._line(6).split()) 129 | self.domain.zlo, self.domain.zhi = map(float, self._line(7).split()) 130 | 131 | def _line(self, n: int): 132 | return self.data[n].strip("\n") 133 | 134 | def series_parse(self): 135 | """ 136 | Parse all frames in the series. 137 | """ 138 | self.frames.clear() 139 | if not self.natoms or not self.nframes: 140 | return 141 | 142 | for i in range(self.nframes): 143 | frame_data = self.data[i * (self.natoms + 9): (i + 1) * (self.natoms + 9)] 144 | frame = Frame(self.filter_config) 145 | frame.parse(frame_data) 146 | if frame.data is not None: 147 | self.frames.append(frame) 148 | 149 | 150 | def get_edge(pos: torch.Tensor, kd_tree: KDTree, cutoff: float, boxsize_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 151 | pairs = kd_tree.query_pairs(r=cutoff, output_type='ndarray') 152 | edge_index = torch.tensor(pairs.T, dtype=torch.long) 153 | # Concat inverse direction edge index. 154 | edge_index = torch.cat([edge_index, edge_index.flip([0])], dim=1) 155 | row, col = edge_index[0], edge_index[1] 156 | edge_cell_shift = pos[row] - pos[col] 157 | edge_cell_shift = torch.remainder(edge_cell_shift + 0.5*boxsize_tensor, boxsize_tensor) - 0.5 * boxsize_tensor 158 | return edge_index, edge_cell_shift 159 | 160 | 161 | def parse_in_data_file(in_data_file: str) -> torch.Tensor: 162 | """Parse molecular data from the input file.""" 163 | molecular = [] 164 | with open(in_data_file) as file: 165 | start_flag = False 166 | for line in file: 167 | if line.startswith('Atoms'): 168 | start_flag = True 169 | continue 170 | if not start_flag or not line.strip(): 171 | continue 172 | molecular.append(int(line.split()[1])) 173 | return torch.tensor(molecular) 174 | 175 | 176 | def parse_in_lammps_file(in_lammps_file: str) -> np.array: 177 | """Extract type list from the LAMMPS input file.""" 178 | type_list = [] 179 | with open(in_lammps_file) as file: 180 | for line in file: 181 | if line.startswith("pair_coeff"): 182 | strs = line.split() 183 | type_list = [atom_mapper[symbol] for symbol in strs[2:]] 184 | break 185 | if not type_list: 186 | logging.warning("No type list found.") 187 | return np.array(type_list) 188 | 189 | 190 | def fetch_pressure_volume(log_lammps: str): 191 | log_info = {} 192 | with open(log_lammps, 'r') as fin: 193 | lines = fin.readlines() 194 | start_flag = False 195 | for line in lines: 196 | if line.startswith(' G vector'): 197 | log_info['g_ewald'] = torch.tensor(float(line.split()[-1])) 198 | if line.startswith('Step Temp Press'): 199 | start_flag = True 200 | continue 201 | if not start_flag: 202 | continue 203 | if line.startswith("Loop time of"): 204 | start_flag = False 205 | continue 206 | 207 | strs = line.split() 208 | log_info[int(strs[0])] = { 209 | "pressure": torch.tensor(float(strs[2])), 210 | "volume": torch.tensor(float(strs[3])) 211 | } 212 | 213 | return log_info 214 | 215 | def prepare_frame_data(frame: Frame, type_array: np.array, molecular: torch.Tensor, 216 | nn_cutoff: float, coul_cutoff: float, disp_cutoff: float): 217 | data = {} 218 | domain_box = np.array(frame.domain.box) 219 | pos = np.stack([frame.data['xu'], frame.data['yu'], frame.data['zu']], axis=1) 220 | pos = pos - domain_box[:, 0] 221 | atom_type = type_array[frame.data['type'] - 1] 222 | box_size = domain_box[:, 1] - domain_box[:, 0] 223 | 224 | # # Convert to tensors 225 | boxsize_tensor = torch.tensor(box_size) 226 | pos = np.remainder(pos, box_size) 227 | kd_tree = KDTree(pos, boxsize=boxsize_tensor) 228 | pos = torch.tensor(pos) 229 | data['pos'] = pos 230 | data['atom_types'] = torch.tensor(atom_type, dtype=torch.long) 231 | 232 | edge_index, edge_cell_shift = get_edge(pos, kd_tree, nn_cutoff, boxsize_tensor=boxsize_tensor) 233 | data['edge_index'], data['edge_cell_shift'] = edge_index, edge_cell_shift 234 | outer = molecular[data["edge_index"][0]] != molecular[data["edge_index"][1]] 235 | 236 | data['molecular'] = torch.LongTensor(molecular) 237 | 238 | data['edge_outer_mask'] = outer.to(torch.float32) 239 | 240 | if nn_cutoff == coul_cutoff: 241 | data['coul_edge_index'], data['coul_edge_cell_shift'] = edge_index, edge_cell_shift 242 | else: 243 | coul_edge_index, coul_edge_cell_shift = get_edge(pos, kd_tree, coul_cutoff, boxsize_tensor=boxsize_tensor) 244 | data['coul_edge_index'], data['coul_edge_cell_shift'] = coul_edge_index, coul_edge_cell_shift 245 | disp_edge_index, disp_edge_cell_shift = get_edge(pos, kd_tree, disp_cutoff, boxsize_tensor=boxsize_tensor) 246 | data['disp_edge_index'], data['disp_edge_cell_shift'] = disp_edge_index, disp_edge_cell_shift 247 | data['cell'] = torch.tensor([[box_size[0], 0., 0.], [0., box_size[1], 0.], [0., 0., box_size[2]]], dtype=torch.float32) 248 | 249 | return data 250 | 251 | 252 | def write_data( 253 | job_folder: str, 254 | output_folder: str, 255 | mixture_name: str, 256 | nn_cutoff: float = 5.0, 257 | coul_cutoff: float = 5.0, 258 | disp_cutoff: float = 10.0, 259 | interval: int = 10, 260 | log_file: str = "log.lammps" 261 | ) -> None: 262 | # interval: unit ps 263 | # cutoff: unit A 264 | 265 | in_data_file = os.path.join(job_folder, 'in.data') 266 | in_lammps_file = os.path.join(job_folder, 'in.lammps') 267 | out_lammps_log = os.path.join(job_folder, log_file) 268 | out_lammps_traj = os.path.join(job_folder, 'dump_nvt.lammpstrj') 269 | 270 | os.makedirs(output_folder, exist_ok=True) 271 | 272 | molecular = parse_in_data_file(in_data_file) 273 | 274 | type_array = parse_in_lammps_file(in_lammps_file) 275 | 276 | log_info = fetch_pressure_volume(out_lammps_log) 277 | 278 | filter_config = {"field": ["id", "type", "xu", "yu", "zu"], "frame_interval": int(1000 * interval)} 279 | lammps_dump = LammpsDump(file_path=out_lammps_traj, filter_config=filter_config) 280 | lammps_dump.series_parse() 281 | 282 | if not lammps_dump.frames: 283 | return 284 | 285 | for frame in lammps_dump.frames: 286 | data = prepare_frame_data(frame, type_array, molecular, nn_cutoff, coul_cutoff, disp_cutoff) 287 | data["g_ewald"] = log_info["g_ewald"] 288 | if frame.step not in log_info: 289 | continue 290 | data.update(log_info[frame.step]) 291 | 292 | output_data = {"inputs": data, "mixture_name": mixture_name} 293 | torch.save(output_data, os.path.join(output_folder, f'frame_{frame.step}.pt')) 294 | 295 | 296 | if __name__ == "__main__": 297 | # args for work folder and output folder 298 | parser = argparse.ArgumentParser() 299 | parser.add_argument('--job_folder', type=str, required=True) 300 | parser.add_argument('--output_folder', type=str, required=True) 301 | parser.add_argument('--mixture_name', type=str, required=True) 302 | parser.add_argument('--nn_cutoff', type=float, default=5.0) 303 | parser.add_argument('--coul_cutoff', type=float, default=5.0) 304 | parser.add_argument('--disp_cutoff', type=float, default=10.0) 305 | parser.add_argument('--interval', type=int, default=10) 306 | parser.add_argument('--log_file', type=str, default="log.lammps") 307 | args = parser.parse_args() 308 | 309 | write_data( 310 | job_folder=args.job_folder, 311 | output_folder=args.output_folder, 312 | mixture_name=args.mixture_name, 313 | nn_cutoff=args.nn_cutoff, 314 | coul_cutoff=args.coul_cutoff, 315 | disp_cutoff=args.disp_cutoff, 316 | interval=args.interval, 317 | log_file=args.log_file 318 | ) -------------------------------------------------------------------------------- /utils/log_helper.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import logging 19 | import os 20 | from typing import Optional 21 | 22 | 23 | def create_logger(name, log_file: Optional[str] = None, screen: bool = True): 24 | logger = logging.getLogger(name) 25 | logger.setLevel(logging.INFO) 26 | logger.propagate = False # Prevents log messages from propagating to parent loggers 27 | 28 | # Clear handles if any. 29 | logger.handlers.clear() 30 | formatter = logging.Formatter("%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s") 31 | 32 | if log_file: 33 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 34 | fh = logging.FileHandler(log_file) 35 | fh.setLevel(logging.INFO) 36 | fh.setFormatter(formatter) 37 | logger.addHandler(fh) 38 | 39 | if screen: 40 | ch = logging.StreamHandler() 41 | ch.setLevel(logging.INFO) 42 | ch.setFormatter(formatter) 43 | logger.addHandler(ch) 44 | 45 | return logger 46 | -------------------------------------------------------------------------------- /utils/path.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import os 19 | 20 | # Determine the directory two levels up from the current file's location 21 | CODE_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 22 | 23 | TRAIN_PATH = os.path.join(CODE_PATH, "train") 24 | 25 | ENSEMBLE_PATH = os.path.join(CODE_PATH, "ensemble") 26 | 27 | ALIGNMENT_PATH = os.path.join(CODE_PATH, "alignment") 28 | 29 | DATA_PATH = os.path.join(CODE_PATH, "data") 30 | 31 | 32 | -------------------------------------------------------------------------------- /utils/rejit.py: -------------------------------------------------------------------------------- 1 | # ----- BAMBOO: Bytedance AI Molecular Booster ----- 2 | # Copyright 2022-2024 Bytedance Ltd. and/or its affiliates 3 | 4 | # This program is free software; you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation; either version 2 of the License, or 7 | # (at your option) any later version. 8 | 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program; if not, write to the Free Software 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17 | 18 | import argparse 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | from models.bamboo_get import BambooGET 24 | 25 | 26 | def convert(checkpoint, device=torch.device('cuda')): 27 | if isinstance(checkpoint, str): 28 | model = torch.jit.load(checkpoint) 29 | elif isinstance(checkpoint, torch.jit.RecursiveScriptModule): 30 | model = checkpoint 31 | else: 32 | raise ValueError("Input must be a string or torch.jit.RecursiveScriptModule.") 33 | 34 | act_fn_map = { 35 | 'ELU': nn.ELU(), 36 | 'CELU': nn.CELU(), 37 | 'GELU': nn.GELU(), 38 | 'SiLU': nn.SiLU(), 39 | 'Mish': nn.Mish(), 40 | 'Softplus': nn.Softplus() 41 | } 42 | 43 | nn_params_act_fn_name: str = list(model.charge_mlp.children())[1].original_name 44 | gnn_params_act_fn_name: str = model.act_fn.original_name 45 | nn_params = { 46 | 'dim': model.dim, 47 | 'num_rbf': model.num_rbf, 48 | 'rcut': model.rcut, 49 | 'charge_ub': model.charge_ub, 50 | 'act_fn': act_fn_map[nn_params_act_fn_name], 51 | 'charge_mlp_layers': model.charge_mlp_layers, 52 | 'energy_mlp_layers': model.energy_mlp_layers, 53 | } 54 | 55 | gnn_params = { 56 | 'n_layers': model.n_layers, 57 | 'num_heads': model.num_heads, 58 | 'act_fn': act_fn_map[gnn_params_act_fn_name] 59 | } 60 | 61 | origin_model = BambooGET( 62 | device=model.device, 63 | coul_disp_params=model.coul_disp_params, 64 | nn_params=nn_params, 65 | gnn_params=gnn_params 66 | ) 67 | 68 | missing_keys, unexpected_keys = origin_model.load_state_dict(model.state_dict(), strict=False) 69 | print(f'missing keys in checkpoint: {missing_keys}') 70 | print(f'unexpected keys in checkpoint: {unexpected_keys}') 71 | origin_model = origin_model.to(device) 72 | return origin_model 73 | 74 | 75 | def main(): 76 | arparser = argparse.ArgumentParser() 77 | arparser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint.') 78 | arparser.add_argument('--destination', type=str, required=True, help='Path to save the converted model.') 79 | arparser.add_argument('--no_cuda', action='store_true', help='Do not use GPU for training.') 80 | args = arparser.parse_args() 81 | 82 | if args.no_cuda: 83 | device = torch.device('cpu') 84 | else: 85 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 86 | 87 | print(f'model loaded from {args.checkpoint}') 88 | 89 | model = convert(args.checkpoint, device) 90 | 91 | model_jit = torch.jit.script(model) 92 | model_jit.save(args.destination) 93 | print(f'model saved at {args.destination}') 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | --------------------------------------------------------------------------------