├── .gitignore
├── APS 2023 GNN tutorial.pdf
├── LICENSE
├── README.md
├── alignn
├── JARVIS-APS.pdf
├── Training_ALIGNN_model_example.ipynb
└── __init__.py
└── intro_lecture
├── apsmarch23_gnn_tutorial.ipynb
└── intro_to_gnns_apsmarch23.pdf
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/APS 2023 GNN tutorial.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamratcliff/GNN-tutorial-APS-March-2023/d7a8b65083bb80595cf14d9404080b3f3e50b286/APS 2023 GNN tutorial.pdf
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GNN-tutorial-APS-March-2023
2 | These are the slides associated with the GNN tutorial at the APS March Meeting
3 | Info: https://march.aps.org/events/tutorial-graph-neural-networks
4 | Join GDS on Slack at https://app.slack.com/client/TJR0E256K/CJGQP2CBT
5 | March 5, 1:30 – 5:30 p.m. PST
6 | - JAX-MD: https://github.com/jax-md/jax-md/tree/main/notebooks/tutorial
7 | - JARVIS-ALIGNN: https://github.com/JARVIS-Materials-Design/jarvis-tools-notebooks
8 |
9 |
10 |
11 | # Schedule
12 | - 1:30-1:35 Welcome (William)
13 | - 1:35-2:30 Intro to GNN (Savannah)
14 | - 2:45-3:00 Break
15 | - 3:00-4:00 JAX-MD (Dogus)
16 | - 4:00-4:05 Break
17 | - 4:05-5:05 Line Graph (Kamal)
18 | - 5:05-5:10 Break
19 | - 5:10-5:40 Other libraries and summaries (Dogus)
20 |
--------------------------------------------------------------------------------
/alignn/JARVIS-APS.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamratcliff/GNN-tutorial-APS-March-2023/d7a8b65083bb80595cf14d9404080b3f3e50b286/alignn/JARVIS-APS.pdf
--------------------------------------------------------------------------------
/alignn/Training_ALIGNN_model_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Training_ALIGNN_model_example.ipynb",
7 | "provenance": [],
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3",
12 | "name": "python3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | ""
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "OUZGR6D82ij-"
33 | },
34 | "source": [
35 | "# Table of contents\n",
36 | "\n",
37 | "1. Installing [ALIGNN](https://github.com/usnistgov/alignn)\n",
38 | "2. Example training for regression on 50 materials,\n",
39 | "3. Using pre-trained models to make fast predictions\n",
40 | "4. Using ALIGNN-FF model to predict the unrelaxed energy (fast), optimized strcture and energy, and EV curve\n",
41 | "5. Train ALIGNN-FF on a new dataset\n",
42 | "6. Training [JARVIS-DFT](https://jarvis.nist.gov/jarvisdft) 2D exfoliation energy model \n",
43 | "7. Training [QM9](http://quantum-machine.org/datasets/) U0 model"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "metadata": {
49 | "id": "WFrl_N-S1Bxk",
50 | "outputId": "350ce494-1992-4164-df8b-357925d1f408",
51 | "colab": {
52 | "base_uri": "https://localhost:8080/"
53 | }
54 | },
55 | "source": [
56 | "!pip install alignn"
57 | ],
58 | "execution_count": null,
59 | "outputs": [
60 | {
61 | "output_type": "stream",
62 | "name": "stdout",
63 | "text": [
64 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
65 | "Collecting alignn\n",
66 | " Downloading alignn-2023.1.10-py2.py3-none-any.whl (15.1 MB)\n",
67 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.1/15.1 MB\u001b[0m \u001b[31m36.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
68 | "\u001b[?25hRequirement already satisfied: scipy>=1.6.1 in /usr/local/lib/python3.8/dist-packages (from alignn) (1.7.3)\n",
69 | "Collecting pydocstyle>=6.0.0\n",
70 | " Downloading pydocstyle-6.3.0-py3-none-any.whl (38 kB)\n",
71 | "Collecting dgl>=0.6.0\n",
72 | " Downloading dgl-1.0.0-cp38-cp38-manylinux1_x86_64.whl (5.4 MB)\n",
73 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m71.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
74 | "\u001b[?25hRequirement already satisfied: pandas>=1.2.3 in /usr/local/lib/python3.8/dist-packages (from alignn) (1.3.5)\n",
75 | "Requirement already satisfied: tqdm>=4.60.0 in /usr/local/lib/python3.8/dist-packages (from alignn) (4.64.1)\n",
76 | "Collecting pyparsing<3,>=2.2.1\n",
77 | " Downloading pyparsing-2.4.7-py2.py3-none-any.whl (67 kB)\n",
78 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.8/67.8 KB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
79 | "\u001b[?25hCollecting matplotlib>=3.4.1\n",
80 | " Downloading matplotlib-3.6.3-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (9.4 MB)\n",
81 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.4/9.4 MB\u001b[0m \u001b[31m71.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
82 | "\u001b[?25hCollecting flake8>=3.9.1\n",
83 | " Downloading flake8-6.0.0-py2.py3-none-any.whl (57 kB)\n",
84 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.8/57.8 KB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
85 | "\u001b[?25hRequirement already satisfied: pydantic>=1.8.1 in /usr/local/lib/python3.8/dist-packages (from alignn) (1.10.4)\n",
86 | "Requirement already satisfied: scikit-learn>=0.22.2 in /usr/local/lib/python3.8/dist-packages (from alignn) (1.0.2)\n",
87 | "Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.8/dist-packages (from alignn) (1.21.6)\n",
88 | "Collecting torch==1.12.0\n",
89 | " Downloading torch-1.12.0-cp38-cp38-manylinux1_x86_64.whl (776.3 MB)\n",
90 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m776.3/776.3 MB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
91 | "\u001b[?25hCollecting ase\n",
92 | " Downloading ase-3.22.1-py3-none-any.whl (2.2 MB)\n",
93 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m73.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
94 | "\u001b[?25hCollecting pytorch-ignite==0.5.0.dev20221024\n",
95 | " Downloading pytorch_ignite-0.5.0.dev20221024-py3-none-any.whl (263 kB)\n",
96 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m263.6/263.6 KB\u001b[0m \u001b[31m20.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
97 | "\u001b[?25hCollecting pycodestyle>=2.7.0\n",
98 | " Downloading pycodestyle-2.10.0-py2.py3-none-any.whl (41 kB)\n",
99 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.3/41.3 KB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
100 | "\u001b[?25hCollecting jarvis-tools>=2021.07.19\n",
101 | " Downloading jarvis_tools-2023.1.8-py2.py3-none-any.whl (973 kB)\n",
102 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m973.3/973.3 KB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
103 | "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from pytorch-ignite==0.5.0.dev20221024->alignn) (23.0)\n",
104 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch==1.12.0->alignn) (4.4.0)\n",
105 | "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from dgl>=0.6.0->alignn) (2.25.1)\n",
106 | "Collecting psutil>=5.8.0\n",
107 | " Downloading psutil-5.9.4-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (280 kB)\n",
108 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m280.2/280.2 KB\u001b[0m \u001b[31m23.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
109 | "\u001b[?25hRequirement already satisfied: networkx>=2.1 in /usr/local/lib/python3.8/dist-packages (from dgl>=0.6.0->alignn) (3.0)\n",
110 | "Collecting mccabe<0.8.0,>=0.7.0\n",
111 | " Downloading mccabe-0.7.0-py2.py3-none-any.whl (7.3 kB)\n",
112 | "Collecting pyflakes<3.1.0,>=3.0.0\n",
113 | " Downloading pyflakes-3.0.1-py2.py3-none-any.whl (62 kB)\n",
114 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.8/62.8 KB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
115 | "\u001b[?25hRequirement already satisfied: joblib>=0.14.1 in /usr/local/lib/python3.8/dist-packages (from jarvis-tools>=2021.07.19->alignn) (1.2.0)\n",
116 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.8/dist-packages (from jarvis-tools>=2021.07.19->alignn) (0.12.0)\n",
117 | "Collecting spglib>=1.14.1\n",
118 | " Downloading spglib-2.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (515 kB)\n",
119 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m515.1/515.1 KB\u001b[0m \u001b[31m34.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
120 | "\u001b[?25hCollecting xmltodict>=0.11.0\n",
121 | " Downloading xmltodict-0.13.0-py2.py3-none-any.whl (10.0 kB)\n",
122 | "Collecting fonttools>=4.22.0\n",
123 | " Downloading fonttools-4.38.0-py3-none-any.whl (965 kB)\n",
124 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m965.4/965.4 KB\u001b[0m \u001b[31m47.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
125 | "\u001b[?25hCollecting contourpy>=1.0.1\n",
126 | " Downloading contourpy-1.0.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (300 kB)\n",
127 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m300.0/300.0 KB\u001b[0m \u001b[31m28.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
128 | "\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib>=3.4.1->alignn) (7.1.2)\n",
129 | "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib>=3.4.1->alignn) (1.4.4)\n",
130 | "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.8/dist-packages (from matplotlib>=3.4.1->alignn) (2.8.2)\n",
131 | "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib>=3.4.1->alignn) (0.11.0)\n",
132 | "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas>=1.2.3->alignn) (2022.7.1)\n",
133 | "Requirement already satisfied: snowballstemmer>=2.2.0 in /usr/local/lib/python3.8/dist-packages (from pydocstyle>=6.0.0->alignn) (2.2.0)\n",
134 | "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn>=0.22.2->alignn) (3.1.0)\n",
135 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7->matplotlib>=3.4.1->alignn) (1.15.0)\n",
136 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->dgl>=0.6.0->alignn) (1.24.3)\n",
137 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->dgl>=0.6.0->alignn) (2.10)\n",
138 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->dgl>=0.6.0->alignn) (2022.12.7)\n",
139 | "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->dgl>=0.6.0->alignn) (4.0.0)\n",
140 | "Installing collected packages: xmltodict, torch, spglib, pyparsing, pyflakes, pydocstyle, pycodestyle, psutil, mccabe, fonttools, contourpy, pytorch-ignite, matplotlib, flake8, dgl, jarvis-tools, ase, alignn\n",
141 | " Attempting uninstall: torch\n",
142 | " Found existing installation: torch 1.13.1+cu116\n",
143 | " Uninstalling torch-1.13.1+cu116:\n",
144 | " Successfully uninstalled torch-1.13.1+cu116\n",
145 | " Attempting uninstall: pyparsing\n",
146 | " Found existing installation: pyparsing 3.0.9\n",
147 | " Uninstalling pyparsing-3.0.9:\n",
148 | " Successfully uninstalled pyparsing-3.0.9\n",
149 | " Attempting uninstall: psutil\n",
150 | " Found existing installation: psutil 5.4.8\n",
151 | " Uninstalling psutil-5.4.8:\n",
152 | " Successfully uninstalled psutil-5.4.8\n",
153 | " Attempting uninstall: matplotlib\n",
154 | " Found existing installation: matplotlib 3.2.2\n",
155 | " Uninstalling matplotlib-3.2.2:\n",
156 | " Successfully uninstalled matplotlib-3.2.2\n",
157 | "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
158 | "torchvision 0.14.1+cu116 requires torch==1.13.1, but you have torch 1.12.0 which is incompatible.\n",
159 | "torchtext 0.14.1 requires torch==1.13.1, but you have torch 1.12.0 which is incompatible.\n",
160 | "torchaudio 0.13.1+cu116 requires torch==1.13.1, but you have torch 1.12.0 which is incompatible.\u001b[0m\u001b[31m\n",
161 | "\u001b[0mSuccessfully installed alignn-2023.1.10 ase-3.22.1 contourpy-1.0.7 dgl-1.0.0 flake8-6.0.0 fonttools-4.38.0 jarvis-tools-2023.1.8 matplotlib-3.6.3 mccabe-0.7.0 psutil-5.9.4 pycodestyle-2.10.0 pydocstyle-6.3.0 pyflakes-3.0.1 pyparsing-2.4.7 pytorch-ignite-0.5.0.dev20221024 spglib-2.0.2 torch-1.12.0 xmltodict-0.13.0\n"
162 | ]
163 | }
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "metadata": {
169 | "colab": {
170 | "base_uri": "https://localhost:8080/"
171 | },
172 | "id": "JyyE-cHL2iOn",
173 | "outputId": "a0439df3-23f0-448d-894b-27e731f576e0"
174 | },
175 | "source": [
176 | "import os\n",
177 | "!pwd\n",
178 | "os.chdir('/content')\n",
179 | "# Clone ALIGNN repo to get example folder\n",
180 | "if not os.path.exists('alignn'):\n",
181 | " !git clone https://github.com/usnistgov/alignn.git\n",
182 | "\n",
183 | "os.chdir('alignn')\n",
184 | "# Install using setup.py in case pip didn't work\n",
185 | "# !python setup.py develop\n",
186 | "\n",
187 | "#!pip install dgl-cu111 # Colab has cuda 11.1"
188 | ],
189 | "execution_count": null,
190 | "outputs": [
191 | {
192 | "output_type": "stream",
193 | "name": "stdout",
194 | "text": [
195 | "/content\n",
196 | "Cloning into 'alignn'...\n",
197 | "remote: Enumerating objects: 3330, done.\u001b[K\n",
198 | "remote: Counting objects: 100% (924/924), done.\u001b[K\n",
199 | "remote: Compressing objects: 100% (256/256), done.\u001b[K\n",
200 | "remote: Total 3330 (delta 711), reused 780 (delta 643), pack-reused 2406\u001b[K\n",
201 | "Receiving objects: 100% (3330/3330), 32.70 MiB | 7.98 MiB/s, done.\n",
202 | "Resolving deltas: 100% (1887/1887), done.\n"
203 | ]
204 | }
205 | ]
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "metadata": {
210 | "id": "wsJg4A_s2umV"
211 | },
212 | "source": [
213 | "Example folder with id_prop.csv and 'POSCAR files.'"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "metadata": {
219 | "id": "cy1tmx3V2uC7",
220 | "colab": {
221 | "base_uri": "https://localhost:8080/"
222 | },
223 | "outputId": "14a4a5f0-34e8-446b-8856-0acd82680b3f"
224 | },
225 | "source": [
226 | "!ls \"alignn/examples/sample_data\""
227 | ],
228 | "execution_count": null,
229 | "outputs": [
230 | {
231 | "output_type": "stream",
232 | "name": "stdout",
233 | "text": [
234 | "config_example.json\t POSCAR-JVASP-64045.vasp POSCAR-JVASP-86097.vasp\n",
235 | "id_prop.csv\t\t POSCAR-JVASP-64240.vasp POSCAR-JVASP-86205.vasp\n",
236 | "POSCAR-JVASP-107772.vasp POSCAR-JVASP-64377.vasp POSCAR-JVASP-86436.vasp\n",
237 | "POSCAR-JVASP-10.vasp\t POSCAR-JVASP-64584.vasp POSCAR-JVASP-86726.vasp\n",
238 | "POSCAR-JVASP-13526.vasp POSCAR-JVASP-64664.vasp POSCAR-JVASP-86968.vasp\n",
239 | "POSCAR-JVASP-1372.vasp\t POSCAR-JVASP-64719.vasp POSCAR-JVASP-89025.vasp\n",
240 | "POSCAR-JVASP-14014.vasp POSCAR-JVASP-64906.vasp POSCAR-JVASP-89265.vasp\n",
241 | "POSCAR-JVASP-14441.vasp POSCAR-JVASP-65062.vasp POSCAR-JVASP-90228.vasp\n",
242 | "POSCAR-JVASP-14873.vasp POSCAR-JVASP-65101.vasp POSCAR-JVASP-90532.vasp\n",
243 | "POSCAR-JVASP-15345.vasp POSCAR-JVASP-655.vasp POSCAR-JVASP-90856.vasp\n",
244 | "POSCAR-JVASP-1996.vasp\t POSCAR-JVASP-676.vasp POSCAR-JVASP-97378.vasp\n",
245 | "POSCAR-JVASP-21210.vasp POSCAR-JVASP-76308.vasp POSCAR-JVASP-97499.vasp\n",
246 | "POSCAR-JVASP-22556.vasp POSCAR-JVASP-76309.vasp POSCAR-JVASP-97570.vasp\n",
247 | "POSCAR-JVASP-27901.vasp POSCAR-JVASP-76312.vasp POSCAR-JVASP-97677.vasp\n",
248 | "POSCAR-JVASP-28397.vasp POSCAR-JVASP-76313.vasp POSCAR-JVASP-97799.vasp\n",
249 | "POSCAR-JVASP-28565.vasp POSCAR-JVASP-76318.vasp POSCAR-JVASP-97915.vasp\n",
250 | "POSCAR-JVASP-28634.vasp POSCAR-JVASP-76515.vasp POSCAR-JVASP-97984.vasp\n",
251 | "POSCAR-JVASP-28704.vasp POSCAR-JVASP-76516.vasp POSCAR-JVASP-98167.vasp\n",
252 | "POSCAR-JVASP-42300.vasp POSCAR-JVASP-76525.vasp POSCAR-JVASP-98224.vasp\n",
253 | "POSCAR-JVASP-48166.vasp POSCAR-JVASP-76528.vasp POSCAR-JVASP-98225.vasp\n",
254 | "POSCAR-JVASP-50332.vasp POSCAR-JVASP-76536.vasp POSCAR-JVASP-98284.vasp\n",
255 | "POSCAR-JVASP-60596.vasp POSCAR-JVASP-76548.vasp POSCAR-JVASP-98550.vasp\n",
256 | "POSCAR-JVASP-60702.vasp POSCAR-JVASP-76549.vasp scripts\n",
257 | "POSCAR-JVASP-63912.vasp POSCAR-JVASP-76562.vasp\n",
258 | "POSCAR-JVASP-64003.vasp POSCAR-JVASP-76567.vasp\n"
259 | ]
260 | }
261 | ]
262 | },
263 | {
264 | "cell_type": "markdown",
265 | "metadata": {
266 | "id": "jUNiKBBV211E"
267 | },
268 | "source": [
269 | "# 50 materials and their bandgap data generated with the script [generate_sample_data_reg.py](https://github.com/usnistgov/alignn/blob/main/alignn/examples/sample_data/scripts/generate_sample_data_reg.py)"
270 | ]
271 | },
272 | {
273 | "cell_type": "markdown",
274 | "metadata": {
275 | "id": "FbzuGCA332yS"
276 | },
277 | "source": [
278 | "# Train a model for 3 epochs and batch size of 2. Other parameters are provided in `config_example.json` file. For an involved training, use higher batch size such as 16 and epochs such as 300."
279 | ]
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "metadata": {
284 | "id": "HNHla4FDKRre"
285 | },
286 | "source": [
287 | "Command line train_folder.py is used below."
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "metadata": {
293 | "id": "l5JkSMwx2cfy",
294 | "colab": {
295 | "base_uri": "https://localhost:8080/"
296 | },
297 | "outputId": "c2b83959-0b14-4d0b-918f-dab590cca53d"
298 | },
299 | "source": [
300 | "import time\n",
301 | "t1=time.time()\n",
302 | "!train_folder.py --root_dir \"alignn/examples/sample_data\" --epochs 3 --batch_size 2 --config \"alignn/examples/sample_data/config_example.json\" --output_dir=temp\n",
303 | "t2=time.time()\n",
304 | "print ('Time in s',t2-t1)"
305 | ],
306 | "execution_count": null,
307 | "outputs": [
308 | {
309 | "output_type": "stream",
310 | "name": "stdout",
311 | "text": [
312 | "DGL backend not selected or invalid. Assuming PyTorch for now.\n",
313 | "Setting the default backend to \"pytorch\". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable. Valid options are: pytorch, mxnet, tensorflow (all lowercase)\n",
314 | "MAX val: 6.149\n",
315 | "MIN val: 0.0\n",
316 | "MAD: 1.0520696\n",
317 | "Baseline MAE: 0.7102749999999998\n",
318 | "data range 6.149 0.0\n",
319 | " 0% 0/40 [00:00, ?it/s]/usr/local/lib/python3.8/dist-packages/alignn/graphs.py:237: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)\n",
320 | " g.ndata[\"lattice_mat\"] = torch.tensor(\n",
321 | "100% 40/40 [00:01<00:00, 36.30it/s]\n",
322 | "df atoms ... target\n",
323 | "0 {'lattice_mat': [[-0.0, 4.517300851474054, 4.5... ... 0.000\n",
324 | "1 {'lattice_mat': [[7.709535704177289, 2.46207e-... ... 0.000\n",
325 | "2 {'lattice_mat': [[4.191262576674699, 0.0, -0.0... ... 0.016\n",
326 | "3 {'lattice_mat': [[-0.0, 5.040771484524319, 5.0... ... 0.000\n",
327 | "4 {'lattice_mat': [[1.6712283e-08, -2.5080296697... ... 6.149\n",
328 | "5 {'lattice_mat': [[3.93712543178282, 0.0, 2.273... ... 3.851\n",
329 | "6 {'lattice_mat': [[4.927781968323723, -0.0, 0.0... ... 0.000\n",
330 | "7 {'lattice_mat': [[5.157077730332642, 0.0020004... ... 4.030\n",
331 | "8 {'lattice_mat': [[9.067075684180468, -0.0, 0.0... ... 1.197\n",
332 | "9 {'lattice_mat': [[3.790914410660539, -0.0, 0.0... ... 0.000\n",
333 | "10 {'lattice_mat': [[3.2250494729190726, 2.216578... ... 0.689\n",
334 | "11 {'lattice_mat': [[0.0, 5.129874508851702, 5.12... ... 0.000\n",
335 | "12 {'lattice_mat': [[7.843871888963013, 0.0, 0.0]... ... 0.924\n",
336 | "13 {'lattice_mat': [[3.3542337275744103, 0.0, 0.0... ... 0.051\n",
337 | "14 {'lattice_mat': [[4.084155317570781, -1.066825... ... 0.000\n",
338 | "15 {'lattice_mat': [[4.839493559425439, 9.7116505... ... 0.000\n",
339 | "16 {'lattice_mat': [[1.6777483798834445, -2.90594... ... 0.000\n",
340 | "17 {'lattice_mat': [[4.509029640475962, 0.0564034... ... 4.907\n",
341 | "18 {'lattice_mat': [[4.089078911208881, 0.0, 0.0]... ... 0.000\n",
342 | "19 {'lattice_mat': [[0.0, 4.893247728183244, 4.89... ... 0.000\n",
343 | "20 {'lattice_mat': [[5.194393535053021, 0.0345773... ... 0.000\n",
344 | "21 {'lattice_mat': [[3.5666343258756448, 0.0, 0.0... ... 0.000\n",
345 | "22 {'lattice_mat': [[-0.0127275386492899, 4.47534... ... 0.482\n",
346 | "23 {'lattice_mat': [[6.603532697435508, 0.0, -0.0... ... 4.072\n",
347 | "24 {'lattice_mat': [[10.725911963093319, 1.159968... ... 0.000\n",
348 | "25 {'lattice_mat': [[3.292134155794691, 0.0, 0.0]... ... 0.502\n",
349 | "26 {'lattice_mat': [[10.37325585559557, -2.271858... ... 1.569\n",
350 | "27 {'lattice_mat': [[-0.0, 5.037541505850243, 5.0... ... 0.000\n",
351 | "28 {'lattice_mat': [[5.140164879556414, 0.3718366... ... 0.000\n",
352 | "29 {'lattice_mat': [[9.407270982425844, 0.0171637... ... 2.472\n",
353 | "30 {'lattice_mat': [[3.566933224304235, 0.0, -0.0... ... 0.000\n",
354 | "31 {'lattice_mat': [[0.0, 4.936437902689708, 4.93... ... 0.000\n",
355 | "32 {'lattice_mat': [[4.927229198330356, -0.0, -0.... ... 2.122\n",
356 | "33 {'lattice_mat': [[4.376835486482439, 0.0086562... ... 0.000\n",
357 | "34 {'lattice_mat': [[0.0, 4.901572410735, 4.90157... ... 0.000\n",
358 | "35 {'lattice_mat': [[4.284492173131309, 1.636192e... ... 0.000\n",
359 | "36 {'lattice_mat': [[5.587070827330502, -0.006443... ... 1.517\n",
360 | "37 {'lattice_mat': [[6.9098665629767275, 0.128626... ... 2.341\n",
361 | "38 {'lattice_mat': [[0.0, 5.104615296684174, 5.10... ... 0.000\n",
362 | "39 {'lattice_mat': [[6.850665464204784, -0.0, 0.0... ... 0.560\n",
363 | "\n",
364 | "[40 rows x 3 columns]\n",
365 | "warning: could not load CGCNN features for 103\n",
366 | "Setting it to max atomic number available here, 103\n",
367 | "warning: could not load CGCNN features for 101\n",
368 | "Setting it to max atomic number available here, 103\n",
369 | "warning: could not load CGCNN features for 102\n",
370 | "Setting it to max atomic number available here, 103\n",
371 | "building line graphs\n",
372 | "100% 40/40 [00:00<00:00, 816.46it/s]\n",
373 | "data range 1.681 0.0\n",
374 | "100% 5/5 [00:00<00:00, 15.76it/s]\n",
375 | "df atoms ... target\n",
376 | "0 {'lattice_mat': [[5.464512229851642, 0.0, -2.0... ... 0.239\n",
377 | "1 {'lattice_mat': [[3.8114364321417686, 0.0, 0.0... ... 0.000\n",
378 | "2 {'lattice_mat': [[3.5058938597621094, -3.08124... ... 1.681\n",
379 | "3 {'lattice_mat': [[-1.833590720595598, 1.833590... ... 0.000\n",
380 | "4 {'lattice_mat': [[0.0, 5.1858714074842, 5.1858... ... 0.000\n",
381 | "\n",
382 | "[5 rows x 3 columns]\n",
383 | "building line graphs\n",
384 | "100% 5/5 [00:00<00:00, 938.66it/s]\n",
385 | "data range 0.658 0.0\n",
386 | "100% 5/5 [00:00<00:00, 31.51it/s]\n",
387 | "df atoms ... target\n",
388 | "0 {'lattice_mat': [[-0.0, 4.326757913323647, 4.3... ... 0.000\n",
389 | "1 {'lattice_mat': [[0.0, -3.9587610833154616, 0.... ... 0.658\n",
390 | "2 {'lattice_mat': [[4.157436115454804, -0.0, 0.0... ... 0.000\n",
391 | "3 {'lattice_mat': [[-2.2512310528422197, 1.49649... ... 0.000\n",
392 | "4 {'lattice_mat': [[7.2963518353359165, 0.0, 0.0... ... 0.472\n",
393 | "\n",
394 | "[5 rows x 3 columns]\n",
395 | "building line graphs\n",
396 | "100% 5/5 [00:00<00:00, 602.98it/s]\n",
397 | "n_train: 40\n",
398 | "n_val: 5\n",
399 | "n_test: 5\n",
400 | "version='112bbedebdaecf59fb18e11c929080fb2f358246' dataset='user_data' target='target' atom_features='cgcnn' neighbor_strategy='k-nearest' id_tag='jid' random_seed=123 classification_threshold=None n_val=None n_test=None n_train=None train_ratio=0.8 val_ratio=0.1 test_ratio=0.1 target_multiplication_factor=None epochs=3 batch_size=2 weight_decay=1e-05 learning_rate=0.001 filename='sample' warmup_steps=2000 criterion='mse' optimizer='adamw' scheduler='onecycle' pin_memory=False save_dataloader=False write_checkpoint=True write_predictions=True store_outputs=True progress=True log_tensorboard=False standard_scalar_and_pca=False use_canonize=True num_workers=0 cutoff=8.0 max_neighbors=12 keep_data_order=False normalize_graph_level_loss=False distributed=False n_early_stopping=None output_dir='temp' model=ALIGNNConfig(name='alignn', alignn_layers=4, gcn_layers=4, atom_input_features=92, edge_input_features=80, triplet_input_features=40, embedding_features=64, hidden_features=256, output_features=1, link='identity', zero_inflated=False, classification=False, num_classes=2)\n",
401 | "config:\n",
402 | "{'atom_features': 'cgcnn',\n",
403 | " 'batch_size': 2,\n",
404 | " 'classification_threshold': None,\n",
405 | " 'criterion': 'mse',\n",
406 | " 'cutoff': 8.0,\n",
407 | " 'dataset': 'user_data',\n",
408 | " 'distributed': False,\n",
409 | " 'epochs': 3,\n",
410 | " 'filename': 'sample',\n",
411 | " 'id_tag': 'jid',\n",
412 | " 'keep_data_order': False,\n",
413 | " 'learning_rate': 0.001,\n",
414 | " 'log_tensorboard': False,\n",
415 | " 'max_neighbors': 12,\n",
416 | " 'model': {'alignn_layers': 4,\n",
417 | " 'atom_input_features': 92,\n",
418 | " 'classification': False,\n",
419 | " 'edge_input_features': 80,\n",
420 | " 'embedding_features': 64,\n",
421 | " 'gcn_layers': 4,\n",
422 | " 'hidden_features': 256,\n",
423 | " 'link': 'identity',\n",
424 | " 'name': 'alignn',\n",
425 | " 'num_classes': 2,\n",
426 | " 'output_features': 1,\n",
427 | " 'triplet_input_features': 40,\n",
428 | " 'zero_inflated': False},\n",
429 | " 'n_early_stopping': None,\n",
430 | " 'n_test': None,\n",
431 | " 'n_train': None,\n",
432 | " 'n_val': None,\n",
433 | " 'neighbor_strategy': 'k-nearest',\n",
434 | " 'normalize_graph_level_loss': False,\n",
435 | " 'num_workers': 0,\n",
436 | " 'optimizer': 'adamw',\n",
437 | " 'output_dir': 'temp',\n",
438 | " 'pin_memory': False,\n",
439 | " 'progress': True,\n",
440 | " 'random_seed': 123,\n",
441 | " 'save_dataloader': False,\n",
442 | " 'scheduler': 'onecycle',\n",
443 | " 'standard_scalar_and_pca': False,\n",
444 | " 'store_outputs': True,\n",
445 | " 'target': 'target',\n",
446 | " 'target_multiplication_factor': None,\n",
447 | " 'test_ratio': 0.1,\n",
448 | " 'train_ratio': 0.8,\n",
449 | " 'use_canonize': True,\n",
450 | " 'val_ratio': 0.1,\n",
451 | " 'version': '112bbedebdaecf59fb18e11c929080fb2f358246',\n",
452 | " 'warmup_steps': 2000,\n",
453 | " 'weight_decay': 1e-05,\n",
454 | " 'write_checkpoint': True,\n",
455 | " 'write_predictions': True}\n",
456 | "Val_MAE: 1.6030\n",
457 | "Train_MAE: 1.5383\n",
458 | "Val_MAE: 0.7162\n",
459 | "Train_MAE: 1.3723\n",
460 | "Val_MAE: 0.3797\n",
461 | "Train_MAE: 1.4241\n",
462 | "Test MAE: 0.61751669049263\n",
463 | "Time taken (s): 62.5025429725647\n",
464 | "Time in s 71.04203939437866\n"
465 | ]
466 | }
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "source": [
472 | "!ls"
473 | ],
474 | "metadata": {
475 | "id": "tE8JPqIWQ10F",
476 | "outputId": "0f82052e-f57b-49f4-bc2d-52d6e217fe30",
477 | "colab": {
478 | "base_uri": "https://localhost:8080/"
479 | }
480 | },
481 | "execution_count": null,
482 | "outputs": [
483 | {
484 | "output_type": "stream",
485 | "name": "stdout",
486 | "text": [
487 | "alignn\tLICENSE.rst pyproject.toml README.md\tsetup.py temp\n"
488 | ]
489 | }
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "source": [
495 | "The model produces *.pt files which are the trained models."
496 | ],
497 | "metadata": {
498 | "id": "WnlQxz2eRSoL"
499 | }
500 | },
501 | {
502 | "cell_type": "code",
503 | "source": [
504 | "!ls temp"
505 | ],
506 | "metadata": {
507 | "id": "ARFUTpZjQ9JN",
508 | "outputId": "63713f61-6c53-4503-8f1e-8930eb365bbd",
509 | "colab": {
510 | "base_uri": "https://localhost:8080/"
511 | }
512 | },
513 | "execution_count": null,
514 | "outputs": [
515 | {
516 | "output_type": "stream",
517 | "name": "stdout",
518 | "text": [
519 | "checkpoint_2.pt\t\t mad\n",
520 | "checkpoint_3.pt\t\t prediction_results_test_set.csv\n",
521 | "config.json\t\t prediction_results_train_set.csv\n",
522 | "history_train.json\t test_data_data_range\n",
523 | "history_val.json\t train_data_data_range\n",
524 | "ids_train_val_test.json val_data_data_range\n"
525 | ]
526 | }
527 | ]
528 | },
529 | {
530 | "cell_type": "markdown",
531 | "source": [
532 | "We can load a trained model above as the following:"
533 | ],
534 | "metadata": {
535 | "id": "5dWY2SN3SAWm"
536 | }
537 | },
538 | {
539 | "cell_type": "code",
540 | "source": [
541 | "from alignn.models.alignn import ALIGNN, ALIGNNConfig\n",
542 | "import torch\n",
543 | "output_features = 1\n",
544 | "filename = 'temp/checkpoint_3.pt'\n",
545 | "device = \"cpu\"\n",
546 | "if torch.cuda.is_available():\n",
547 | " device = torch.device(\"cuda\")\n",
548 | "model = ALIGNN(ALIGNNConfig(name=\"alignn\", output_features=output_features))\n",
549 | "model.load_state_dict(torch.load(filename, map_location=device)[\"model\"])\n",
550 | "model.eval()"
551 | ],
552 | "metadata": {
553 | "id": "KKwVQwvCRfkD",
554 | "outputId": "522482fc-cb41-4351-e970-8f624a47450b",
555 | "colab": {
556 | "base_uri": "https://localhost:8080/"
557 | }
558 | },
559 | "execution_count": null,
560 | "outputs": [
561 | {
562 | "output_type": "execute_result",
563 | "data": {
564 | "text/plain": [
565 | "ALIGNN(\n",
566 | " (atom_embedding): MLPLayer(\n",
567 | " (layer): Sequential(\n",
568 | " (0): Linear(in_features=92, out_features=256, bias=True)\n",
569 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
570 | " (2): SiLU()\n",
571 | " )\n",
572 | " )\n",
573 | " (edge_embedding): Sequential(\n",
574 | " (0): RBFExpansion()\n",
575 | " (1): MLPLayer(\n",
576 | " (layer): Sequential(\n",
577 | " (0): Linear(in_features=80, out_features=64, bias=True)\n",
578 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
579 | " (2): SiLU()\n",
580 | " )\n",
581 | " )\n",
582 | " (2): MLPLayer(\n",
583 | " (layer): Sequential(\n",
584 | " (0): Linear(in_features=64, out_features=256, bias=True)\n",
585 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
586 | " (2): SiLU()\n",
587 | " )\n",
588 | " )\n",
589 | " )\n",
590 | " (angle_embedding): Sequential(\n",
591 | " (0): RBFExpansion()\n",
592 | " (1): MLPLayer(\n",
593 | " (layer): Sequential(\n",
594 | " (0): Linear(in_features=40, out_features=64, bias=True)\n",
595 | " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
596 | " (2): SiLU()\n",
597 | " )\n",
598 | " )\n",
599 | " (2): MLPLayer(\n",
600 | " (layer): Sequential(\n",
601 | " (0): Linear(in_features=64, out_features=256, bias=True)\n",
602 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
603 | " (2): SiLU()\n",
604 | " )\n",
605 | " )\n",
606 | " )\n",
607 | " (alignn_layers): ModuleList(\n",
608 | " (0): ALIGNNConv(\n",
609 | " (node_update): EdgeGatedGraphConv(\n",
610 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
611 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
612 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
613 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
614 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
615 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
616 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
617 | " )\n",
618 | " (edge_update): EdgeGatedGraphConv(\n",
619 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
620 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
621 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
622 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
623 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
624 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
625 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
626 | " )\n",
627 | " )\n",
628 | " (1): ALIGNNConv(\n",
629 | " (node_update): EdgeGatedGraphConv(\n",
630 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
631 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
632 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
633 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
634 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
635 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
636 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
637 | " )\n",
638 | " (edge_update): EdgeGatedGraphConv(\n",
639 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
640 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
641 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
642 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
643 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
644 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
645 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
646 | " )\n",
647 | " )\n",
648 | " (2): ALIGNNConv(\n",
649 | " (node_update): EdgeGatedGraphConv(\n",
650 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
651 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
652 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
653 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
654 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
655 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
656 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
657 | " )\n",
658 | " (edge_update): EdgeGatedGraphConv(\n",
659 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
660 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
661 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
662 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
663 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
664 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
665 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
666 | " )\n",
667 | " )\n",
668 | " (3): ALIGNNConv(\n",
669 | " (node_update): EdgeGatedGraphConv(\n",
670 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
671 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
672 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
673 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
674 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
675 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
676 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
677 | " )\n",
678 | " (edge_update): EdgeGatedGraphConv(\n",
679 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
680 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
681 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
682 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
683 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
684 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
685 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
686 | " )\n",
687 | " )\n",
688 | " )\n",
689 | " (gcn_layers): ModuleList(\n",
690 | " (0): EdgeGatedGraphConv(\n",
691 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
692 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
693 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
694 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
695 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
696 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
697 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
698 | " )\n",
699 | " (1): EdgeGatedGraphConv(\n",
700 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
701 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
702 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
703 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
704 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
705 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
706 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
707 | " )\n",
708 | " (2): EdgeGatedGraphConv(\n",
709 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
710 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
711 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
712 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
713 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
714 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
715 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
716 | " )\n",
717 | " (3): EdgeGatedGraphConv(\n",
718 | " (src_gate): Linear(in_features=256, out_features=256, bias=True)\n",
719 | " (dst_gate): Linear(in_features=256, out_features=256, bias=True)\n",
720 | " (edge_gate): Linear(in_features=256, out_features=256, bias=True)\n",
721 | " (bn_edges): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
722 | " (src_update): Linear(in_features=256, out_features=256, bias=True)\n",
723 | " (dst_update): Linear(in_features=256, out_features=256, bias=True)\n",
724 | " (bn_nodes): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
725 | " )\n",
726 | " )\n",
727 | " (readout): AvgPooling()\n",
728 | " (fc): Linear(in_features=256, out_features=1, bias=True)\n",
729 | ")"
730 | ]
731 | },
732 | "metadata": {},
733 | "execution_count": 7
734 | }
735 | ]
736 | },
737 | {
738 | "cell_type": "markdown",
739 | "source": [
740 | "Now, we can build graph for a given structure and make a prediction as follows:"
741 | ],
742 | "metadata": {
743 | "id": "N7cNP1YqShuO"
744 | }
745 | },
746 | {
747 | "cell_type": "code",
748 | "source": [
749 | "from jarvis.core.atoms import Atoms\n",
750 | "from alignn.graphs import Graph\n",
751 | "cutoff = 8.0\n",
752 | "max_neighbors = 12\n",
753 | "atoms = Atoms.from_poscar('alignn/examples/sample_data/POSCAR-JVASP-10.vasp')\n",
754 | "g, lg = Graph.atom_dgl_multigraph(\n",
755 | " atoms, cutoff=float(cutoff), max_neighbors=max_neighbors,\n",
756 | ")\n",
757 | "out_data = (\n",
758 | " model([g.to(device), lg.to(device)])\n",
759 | " .detach()\n",
760 | " .cpu()\n",
761 | " .numpy()\n",
762 | " .flatten()\n",
763 | " .tolist()\n",
764 | ")\n",
765 | "print ('output', out_data[0])"
766 | ],
767 | "metadata": {
768 | "id": "kfr_EGHRS_aU",
769 | "outputId": "07686c7f-d39b-4e04-d710-5e94559e7f84",
770 | "colab": {
771 | "base_uri": "https://localhost:8080/"
772 | }
773 | },
774 | "execution_count": null,
775 | "outputs": [
776 | {
777 | "output_type": "stream",
778 | "name": "stdout",
779 | "text": [
780 | "output 2.280003309249878\n"
781 | ]
782 | },
783 | {
784 | "output_type": "stream",
785 | "name": "stderr",
786 | "text": [
787 | "/content/alignn/alignn/graphs.py:237: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)\n",
788 | " g.ndata[\"lattice_mat\"] = torch.tensor(\n"
789 | ]
790 | }
791 | ]
792 | },
793 | {
794 | "cell_type": "markdown",
795 | "source": [
796 | "We have already trained multiple models on several large datasets which can be used with the pretrained.py executable."
797 | ],
798 | "metadata": {
799 | "id": "QbzTNJ-WTgfF"
800 | }
801 | },
802 | {
803 | "cell_type": "markdown",
804 | "metadata": {
805 | "id": "VOWVk7MV1hQ3"
806 | },
807 | "source": [
808 | "Use pretrained models such as models trained on JARVIS-DFT, QM9, Materials project, hMOF etc. databases. The models are downloaded from figshare. See the list here: https://github.com/usnistgov/alignn/blob/main/alignn/pretrained.py#L28"
809 | ]
810 | },
811 | {
812 | "cell_type": "code",
813 | "metadata": {
814 | "id": "16HHZ7TD3uRb",
815 | "colab": {
816 | "base_uri": "https://localhost:8080/"
817 | },
818 | "outputId": "81af2ed6-78ed-49d5-ca00-041ccf89881c"
819 | },
820 | "source": [
821 | "!pretrained.py -h"
822 | ],
823 | "execution_count": null,
824 | "outputs": [
825 | {
826 | "output_type": "stream",
827 | "name": "stdout",
828 | "text": [
829 | "usage: pretrained.py\n",
830 | " [-h]\n",
831 | " [--model_name MODEL_NAME]\n",
832 | " [--file_format FILE_FORMAT]\n",
833 | " [--file_path FILE_PATH]\n",
834 | " [--cutoff CUTOFF]\n",
835 | " [--max_neighbors MAX_NEIGHBORS]\n",
836 | "\n",
837 | "Atomistic\n",
838 | "Line Graph\n",
839 | "Neural\n",
840 | "Network\n",
841 | "Pretrained\n",
842 | "Models\n",
843 | "\n",
844 | "optional arguments:\n",
845 | " -h, --help\n",
846 | " show this\n",
847 | " help\n",
848 | " message and\n",
849 | " exit\n",
850 | " --model_name MODEL_NAME\n",
851 | " Choose a\n",
852 | " model from\n",
853 | " these 40 mo\n",
854 | " dels:jv_for\n",
855 | " mation_ener\n",
856 | " gy_peratom_\n",
857 | " alignn, jv_\n",
858 | " optb88vdw_t\n",
859 | " otal_energy\n",
860 | " _alignn, jv\n",
861 | " _optb88vdw_\n",
862 | " bandgap_ali\n",
863 | " gnn, jv_mbj\n",
864 | " _bandgap_al\n",
865 | " ignn, jv_sp\n",
866 | " illage_alig\n",
867 | " nn, jv_slme\n",
868 | " _alignn, jv\n",
869 | " _bulk_modul\n",
870 | " us_kv_align\n",
871 | " n, jv_shear\n",
872 | " _modulus_gv\n",
873 | " _alignn,\n",
874 | " jv_n-Seebec\n",
875 | " k_alignn,\n",
876 | " jv_n-powerf\n",
877 | " act_alignn,\n",
878 | " jv_magmom_o\n",
879 | " szicar_alig\n",
880 | " nn, jv_kpoi\n",
881 | " nt_length_u\n",
882 | " nit_alignn,\n",
883 | " jv_avg_elec\n",
884 | " _mass_align\n",
885 | " n, jv_avg_h\n",
886 | " ole_mass_al\n",
887 | " ignn, jv_ep\n",
888 | " sx_alignn, \n",
889 | " jv_mepsx_al\n",
890 | " ignn, jv_ma\n",
891 | " x_efg_align\n",
892 | " n, jv_ehull\n",
893 | " _alignn, jv\n",
894 | " _dfpt_piezo\n",
895 | " _max_dielec\n",
896 | " tric_alignn\n",
897 | " , jv_dfpt_p\n",
898 | " iezo_max_di\n",
899 | " j_alignn, j\n",
900 | " v_exfoliati\n",
901 | " on_energy_a\n",
902 | " lignn, jv_s\n",
903 | " upercon_tc_\n",
904 | " alignn, mp_\n",
905 | " e_form_alig\n",
906 | " nnn, mp_gap\n",
907 | " pbe_alignnn\n",
908 | " , qm9_U0_al\n",
909 | " ignn, qm9_U\n",
910 | " _alignn, qm\n",
911 | " 9_alpha_ali\n",
912 | " gnn, qm9_ga\n",
913 | " p_alignn, q\n",
914 | " m9_G_alignn\n",
915 | " , qm9_HOMO_\n",
916 | " alignn, qm9\n",
917 | " _LUMO_align\n",
918 | " n, qm9_ZPVE\n",
919 | " _alignn, hm\n",
920 | " of_co2_absp\n",
921 | " _alignnn, h\n",
922 | " mof_max_co2\n",
923 | " _adsp_align\n",
924 | " nn, hmof_su\n",
925 | " rface_area_\n",
926 | " m2g_alignnn\n",
927 | " , hmof_surf\n",
928 | " ace_area_m2\n",
929 | " cm3_alignnn\n",
930 | " , hmof_pld_\n",
931 | " alignnn, hm\n",
932 | " of_lcd_alig\n",
933 | " nnn, hmof_v\n",
934 | " oid_fractio\n",
935 | " n_alignnn, \n",
936 | " jv_pdos_ali\n",
937 | " gnn\n",
938 | " --file_format FILE_FORMAT\n",
939 | " poscar/cif/\n",
940 | " xyz/pdb\n",
941 | " file\n",
942 | " format.\n",
943 | " --file_path FILE_PATH\n",
944 | " Path to\n",
945 | " file.\n",
946 | " --cutoff CUTOFF\n",
947 | " Distance\n",
948 | " cut-off for\n",
949 | " graph const\n",
950 | " uction,\n",
951 | " usually 8\n",
952 | " for solids\n",
953 | " and 5 for\n",
954 | " molecules.\n",
955 | " --max_neighbors MAX_NEIGHBORS\n",
956 | " Maximum\n",
957 | " number of\n",
958 | " nearest\n",
959 | " neighbors\n",
960 | " in the\n",
961 | " periodic\n",
962 | " atomistic\n",
963 | " graph const\n",
964 | " ruction.\n"
965 | ]
966 | }
967 | ]
968 | },
969 | {
970 | "cell_type": "code",
971 | "metadata": {
972 | "id": "_bIT4hL71wmA",
973 | "colab": {
974 | "base_uri": "https://localhost:8080/"
975 | },
976 | "outputId": "649ad620-08d0-4e6a-de66-0f43573751ac"
977 | },
978 | "source": [
979 | "!pretrained.py --model_name jv_formation_energy_peratom_alignn --file_format poscar --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp"
980 | ],
981 | "execution_count": null,
982 | "outputs": [
983 | {
984 | "output_type": "stream",
985 | "name": "stdout",
986 | "text": [
987 | "100% 47.5M/47.5M [00:03<00:00, 15.8MiB/s]\n",
988 | "Using chk file jv_formation_energy_peratom_alignn/checkpoint_300.pt from ['jv_formation_energy_peratom_alignn/checkpoint_300.pt']\n",
989 | "Path /usr/local/bin/jv_formation_energy_peratom_alignn.zip\n",
990 | "Predicted value: jv_formation_energy_peratom_alignn alignn/examples/sample_data/POSCAR-JVASP-10.vasp [-0.70339435338974]\n"
991 | ]
992 | }
993 | ]
994 | },
995 | {
996 | "cell_type": "markdown",
997 | "source": [
998 | "Using ALIGNN-FF pretrained model to get unrelaxed energy and relaxed structure"
999 | ],
1000 | "metadata": {
1001 | "id": "yMhz-GpH2hfJ"
1002 | }
1003 | },
1004 | {
1005 | "cell_type": "code",
1006 | "source": [
1007 | "!run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task=\"unrelaxed_energy\""
1008 | ],
1009 | "metadata": {
1010 | "id": "6X67E61h2kge",
1011 | "outputId": "da825bbb-38c0-493f-c267-c0ed15f40361",
1012 | "colab": {
1013 | "base_uri": "https://localhost:8080/"
1014 | }
1015 | },
1016 | "execution_count": null,
1017 | "outputs": [
1018 | {
1019 | "output_type": "stream",
1020 | "name": "stdout",
1021 | "text": [
1022 | "model_path /usr/local/lib/python3.8/dist-packages/alignn/ff\n",
1023 | "/usr/local/lib/python3.8/dist-packages/alignn/graphs.py:237: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)\n",
1024 | " g.ndata[\"lattice_mat\"] = torch.tensor(\n",
1025 | "Energy(eV) (-11.776981830596924, array([[ 0.0000000e+00, 3.7252903e-09, 9.3132257e-10],\n",
1026 | " [ 0.0000000e+00, 2.1071173e-07, -1.7778256e-03],\n",
1027 | " [ 0.0000000e+00, -2.1059532e-07, 1.7777842e-03]], dtype=float32))\n"
1028 | ]
1029 | }
1030 | ]
1031 | },
1032 | {
1033 | "cell_type": "code",
1034 | "source": [
1035 | "!run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task=\"optimize\""
1036 | ],
1037 | "metadata": {
1038 | "id": "sGiXomqD21eE",
1039 | "outputId": "8f1214ee-9f6e-4192-d40b-d3c9271935da",
1040 | "colab": {
1041 | "base_uri": "https://localhost:8080/"
1042 | }
1043 | },
1044 | "execution_count": null,
1045 | "outputs": [
1046 | {
1047 | "output_type": "stream",
1048 | "name": "stdout",
1049 | "text": [
1050 | "model_path /usr/local/lib/python3.8/dist-packages/alignn/ff\n",
1051 | "OPTIMIZATION\n",
1052 | "/usr/local/lib/python3.8/dist-packages/alignn/graphs.py:237: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)\n",
1053 | " g.ndata[\"lattice_mat\"] = torch.tensor(\n",
1054 | "a= 1.678 Ang b= 2.906 Ang c= 6.221 Ang Volume= 60.658 amu/a3 PE=-11.77698 eV KE= 0.00000 eV T= 0.000 K \n",
1055 | "a= 1.676 Ang b= 2.903 Ang c= 6.223 Ang Volume= 60.556 amu/a3 PE=-11.77832 eV KE= 0.00000 eV T= 0.000 K \n",
1056 | "a= 1.673 Ang b= 2.897 Ang c= 6.228 Ang Volume= 60.355 amu/a3 PE=-11.78092 eV KE= 0.00000 eV T= 0.000 K \n",
1057 | "a= 1.668 Ang b= 2.888 Ang c= 6.235 Ang Volume= 60.062 amu/a3 PE=-11.78462 eV KE= 0.00000 eV T= 0.000 K \n",
1058 | "a= 1.661 Ang b= 2.877 Ang c= 6.245 Ang Volume= 59.692 amu/a3 PE=-11.78918 eV KE= 0.00000 eV T= 0.000 K \n",
1059 | "a= 1.653 Ang b= 2.863 Ang c= 6.259 Ang Volume= 59.261 amu/a3 PE=-11.79420 eV KE= 0.00000 eV T= 0.000 K \n",
1060 | "a= 1.644 Ang b= 2.848 Ang c= 6.277 Ang Volume= 58.798 amu/a3 PE=-11.79909 eV KE= 0.00000 eV T= 0.000 K \n",
1061 | "a= 1.635 Ang b= 2.832 Ang c= 6.301 Ang Volume= 58.340 amu/a3 PE=-11.80319 eV KE= 0.00000 eV T= 0.000 K \n",
1062 | "a= 1.624 Ang b= 2.814 Ang c= 6.335 Ang Volume= 57.907 amu/a3 PE=-11.80618 eV KE= 0.00000 eV T= 0.000 K \n",
1063 | "a= 1.614 Ang b= 2.796 Ang c= 6.382 Ang Volume= 57.614 amu/a3 PE=-11.80765 eV KE= 0.00000 eV T= 0.000 K \n",
1064 | "a= 1.606 Ang b= 2.782 Ang c= 6.443 Ang Volume= 57.575 amu/a3 PE=-11.80819 eV KE= 0.00000 eV T= 0.000 K \n",
1065 | "a= 1.606 Ang b= 2.782 Ang c= 6.445 Ang Volume= 57.617 amu/a3 PE=-11.80841 eV KE= 0.00000 eV T= 0.000 K \n",
1066 | "a= 1.607 Ang b= 2.784 Ang c= 6.448 Ang Volume= 57.698 amu/a3 PE=-11.80883 eV KE= 0.00000 eV T= 0.000 K \n",
1067 | "a= 1.608 Ang b= 2.786 Ang c= 6.453 Ang Volume= 57.818 amu/a3 PE=-11.80939 eV KE= 0.00000 eV T= 0.000 K \n",
1068 | "initial struct:\n",
1069 | "VSe2\n",
1070 | "1.0\n",
1071 | "1.6777483798834445 -2.9059452409270157 -1.1e-15\n",
1072 | "1.6777483798834438 2.9059452409270126 -7e-16\n",
1073 | "-6.5e-15 -8e-16 6.220805465667012\n",
1074 | "V Se\n",
1075 | "1 2\n",
1076 | "Cartesian\n",
1077 | "0.0 0.0 0.0\n",
1078 | "1.67775 -0.9686519372999812 4.6529213966213625\n",
1079 | "1.67775 0.9686519372999813 1.5678886033786343\n",
1080 | "\n",
1081 | "final struct:\n",
1082 | "VSe2\n",
1083 | "1.0\n",
1084 | "1.608302852396432 -2.7856636350865687 2.6551522928085922e-06\n",
1085 | "1.6083028607260723 2.785663639895683 -2.6472172384636327e-06\n",
1086 | "1.4710913027772843e-08 -5.675435459384826e-06 6.452587626731301\n",
1087 | "V Se\n",
1088 | "1 2\n",
1089 | "Cartesian\n",
1090 | "1.6086832039991881e-06 -3.3976405802442263e-07 1.7983984100003775e-06\n",
1091 | "1.608304062444692 -0.9285619719519679 4.8224468467073915\n",
1092 | "1.6083040553107424 0.9285564334598546 1.6301449559646826\n",
1093 | "\n"
1094 | ]
1095 | }
1096 | ]
1097 | },
1098 | {
1099 | "cell_type": "code",
1100 | "source": [
1101 | "!run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task=\"ev_curve\""
1102 | ],
1103 | "metadata": {
1104 | "id": "lSkc-nuZ3jWw",
1105 | "outputId": "dd03d6ba-30e7-46ec-c233-60f4914054bd",
1106 | "colab": {
1107 | "base_uri": "https://localhost:8080/"
1108 | }
1109 | },
1110 | "execution_count": null,
1111 | "outputs": [
1112 | {
1113 | "output_type": "stream",
1114 | "name": "stdout",
1115 | "text": [
1116 | "model_path /usr/local/lib/python3.8/dist-packages/alignn/ff\n",
1117 | "/usr/local/lib/python3.8/dist-packages/alignn/graphs.py:237: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)\n",
1118 | " g.ndata[\"lattice_mat\"] = torch.tensor(\n",
1119 | "E [-11.59964418 -11.74018407 -11.77280045 -11.78283548 -11.78363013\n",
1120 | " -11.77698183 -11.76358366 -11.74409294 -11.71868992 -11.68687391]\n",
1121 | "V [52.00698610745743, 53.66666028373519, 55.3612736919685, 57.09119028249023, 58.85677400563329, 60.6583888117305, 62.496398651114745, 64.3711674741189, 66.28305923107584, 68.23243787231843]\n"
1122 | ]
1123 | }
1124 | ]
1125 | },
1126 | {
1127 | "cell_type": "markdown",
1128 | "source": [
1129 | "Train ALIGNN-FF model"
1130 | ],
1131 | "metadata": {
1132 | "id": "1EJ8-7dk3rva"
1133 | }
1134 | },
1135 | {
1136 | "cell_type": "code",
1137 | "source": [
1138 | "!train_folder_ff.py --root_dir \"alignn/examples/sample_data_ff\" --config \"alignn/examples/sample_data_ff/config_example_atomwise.json\" --output_dir=temp"
1139 | ],
1140 | "metadata": {
1141 | "id": "-_gLKbTf3uk7",
1142 | "outputId": "54a84d6f-7213-4629-94a4-88b3cf25be7f",
1143 | "colab": {
1144 | "base_uri": "https://localhost:8080/"
1145 | }
1146 | },
1147 | "execution_count": null,
1148 | "outputs": [
1149 | {
1150 | "output_type": "stream",
1151 | "name": "stdout",
1152 | "text": [
1153 | "len dataset 50\n",
1154 | "MAX val: -24.52653862\n",
1155 | "MIN val: -42.04135008\n",
1156 | "MAD: 7.884625411000001\n",
1157 | "Baseline MAE: 7.128754169400001\n",
1158 | "data range -24.52653862 -42.04135008\n",
1159 | "\r 0% 0/40 [00:00, ?it/s]/usr/local/lib/python3.8/dist-packages/alignn/graphs.py:237: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)\n",
1160 | " g.ndata[\"lattice_mat\"] = torch.tensor(\n",
1161 | "100% 40/40 [00:00<00:00, 45.41it/s]\n",
1162 | "df target ... jid\n",
1163 | "0 -24.558198 ... 15609\n",
1164 | "1 -42.030165 ... 15608\n",
1165 | "2 -42.040710 ... 15608\n",
1166 | "3 -24.544487 ... 15609\n",
1167 | "4 -24.549292 ... 15609\n",
1168 | "5 -42.040186 ... 15608\n",
1169 | "6 -42.016896 ... 15608\n",
1170 | "7 -42.041135 ... 15608\n",
1171 | "8 -42.030172 ... 15608\n",
1172 | "9 -29.311899 ... 15607\n",
1173 | "10 -42.041350 ... 15608\n",
1174 | "11 -24.549425 ... 15609\n",
1175 | "12 -42.038933 ... 15608\n",
1176 | "13 -42.040889 ... 15608\n",
1177 | "14 -42.039553 ... 15608\n",
1178 | "15 -42.029257 ... 15608\n",
1179 | "16 -29.312428 ... 15607\n",
1180 | "17 -42.040889 ... 15608\n",
1181 | "18 -29.312862 ... 15607\n",
1182 | "19 -24.544738 ... 15609\n",
1183 | "20 -42.019225 ... 15608\n",
1184 | "21 -42.039996 ... 15608\n",
1185 | "22 -42.039996 ... 15608\n",
1186 | "23 -24.549290 ... 15609\n",
1187 | "24 -42.002229 ... 15608\n",
1188 | "25 -42.030166 ... 15608\n",
1189 | "26 -24.549425 ... 15609\n",
1190 | "27 -24.552079 ... 15609\n",
1191 | "28 -42.016896 ... 15608\n",
1192 | "29 -42.040184 ... 15608\n",
1193 | "30 -29.313096 ... 15607\n",
1194 | "31 -24.527151 ... 15609\n",
1195 | "32 -42.032726 ... 15608\n",
1196 | "33 -24.552075 ... 15609\n",
1197 | "34 -24.527154 ... 15609\n",
1198 | "35 -29.312800 ... 15607\n",
1199 | "36 -42.002227 ... 15608\n",
1200 | "37 -24.558202 ... 15609\n",
1201 | "38 -24.526539 ... 15609\n",
1202 | "39 -42.030172 ... 15608\n",
1203 | "\n",
1204 | "[40 rows x 6 columns]\n",
1205 | "warning: could not load CGCNN features for 103\n",
1206 | "Setting it to max atomic number available here, 103\n",
1207 | "warning: could not load CGCNN features for 101\n",
1208 | "Setting it to max atomic number available here, 103\n",
1209 | "warning: could not load CGCNN features for 102\n",
1210 | "Setting it to max atomic number available here, 103\n",
1211 | "building line graphs\n",
1212 | "100% 40/40 [00:00<00:00, 869.68it/s]\n",
1213 | "data range -24.54474058 -42.04071003\n",
1214 | "100% 5/5 [00:00<00:00, 45.42it/s]\n",
1215 | "df target ... jid\n",
1216 | "0 -24.544741 ... 15609\n",
1217 | "1 -24.558797 ... 15609\n",
1218 | "2 -42.040710 ... 15608\n",
1219 | "3 -29.312159 ... 15607\n",
1220 | "4 -29.311952 ... 15607\n",
1221 | "\n",
1222 | "[5 rows x 6 columns]\n",
1223 | "building line graphs\n",
1224 | "100% 5/5 [00:00<00:00, 579.55it/s]\n",
1225 | "data range -24.55819782 -42.03955494\n",
1226 | "100% 5/5 [00:00<00:00, 51.57it/s]\n",
1227 | "df target ... jid\n",
1228 | "0 -24.558198 ... 15609\n",
1229 | "1 -42.039555 ... 15608\n",
1230 | "2 -29.312791 ... 15607\n",
1231 | "3 -42.029256 ... 15608\n",
1232 | "4 -29.312429 ... 15607\n",
1233 | "\n",
1234 | "[5 rows x 6 columns]\n",
1235 | "building line graphs\n",
1236 | "100% 5/5 [00:00<00:00, 573.42it/s]\n",
1237 | "n_train: 40\n",
1238 | "n_val: 5\n",
1239 | "n_test: 5\n",
1240 | "version='112bbedebdaecf59fb18e11c929080fb2f358246' dataset='user_data' target='target' atom_features='cgcnn' neighbor_strategy='k-nearest' id_tag='jid' random_seed=123 classification_threshold=None n_val=None n_test=None n_train=None train_ratio=0.8 val_ratio=0.1 test_ratio=0.1 target_multiplication_factor=None epochs=3 batch_size=2 weight_decay=1e-05 learning_rate=0.001 filename='sample' warmup_steps=2000 criterion='l1' optimizer='adamw' scheduler='onecycle' pin_memory=False save_dataloader=False write_checkpoint=True write_predictions=True store_outputs=False progress=True log_tensorboard=False standard_scalar_and_pca=False use_canonize=False num_workers=0 cutoff=8.0 max_neighbors=12 keep_data_order=False normalize_graph_level_loss=False distributed=False n_early_stopping=None output_dir='temp' model=ALIGNNAtomWiseConfig(name='alignn_atomwise', alignn_layers=4, gcn_layers=4, atom_input_features=92, edge_input_features=80, triplet_input_features=40, embedding_features=64, hidden_features=256, output_features=1, grad_multiplier=-1, calculate_gradient=True, atomwise_output_features=3, graphwise_weight=0.85, gradwise_weight=0.05, stresswise_weight=0.05, atomwise_weight=0.05, link='identity', zero_inflated=False, classification=False)\n",
1241 | "config:\n",
1242 | "{'atom_features': 'cgcnn',\n",
1243 | " 'batch_size': 2,\n",
1244 | " 'classification_threshold': None,\n",
1245 | " 'criterion': 'l1',\n",
1246 | " 'cutoff': 8.0,\n",
1247 | " 'dataset': 'user_data',\n",
1248 | " 'distributed': False,\n",
1249 | " 'epochs': 3,\n",
1250 | " 'filename': 'sample',\n",
1251 | " 'id_tag': 'jid',\n",
1252 | " 'keep_data_order': False,\n",
1253 | " 'learning_rate': 0.001,\n",
1254 | " 'log_tensorboard': False,\n",
1255 | " 'max_neighbors': 12,\n",
1256 | " 'model': {'alignn_layers': 4,\n",
1257 | " 'atom_input_features': 92,\n",
1258 | " 'atomwise_output_features': 3,\n",
1259 | " 'atomwise_weight': 0.05,\n",
1260 | " 'calculate_gradient': True,\n",
1261 | " 'classification': False,\n",
1262 | " 'edge_input_features': 80,\n",
1263 | " 'embedding_features': 64,\n",
1264 | " 'gcn_layers': 4,\n",
1265 | " 'grad_multiplier': -1,\n",
1266 | " 'gradwise_weight': 0.05,\n",
1267 | " 'graphwise_weight': 0.85,\n",
1268 | " 'hidden_features': 256,\n",
1269 | " 'link': 'identity',\n",
1270 | " 'name': 'alignn_atomwise',\n",
1271 | " 'output_features': 1,\n",
1272 | " 'stresswise_weight': 0.05,\n",
1273 | " 'triplet_input_features': 40,\n",
1274 | " 'zero_inflated': False},\n",
1275 | " 'n_early_stopping': None,\n",
1276 | " 'n_test': None,\n",
1277 | " 'n_train': None,\n",
1278 | " 'n_val': None,\n",
1279 | " 'neighbor_strategy': 'k-nearest',\n",
1280 | " 'normalize_graph_level_loss': False,\n",
1281 | " 'num_workers': 0,\n",
1282 | " 'optimizer': 'adamw',\n",
1283 | " 'output_dir': 'temp',\n",
1284 | " 'pin_memory': False,\n",
1285 | " 'progress': True,\n",
1286 | " 'random_seed': 123,\n",
1287 | " 'save_dataloader': False,\n",
1288 | " 'scheduler': 'onecycle',\n",
1289 | " 'standard_scalar_and_pca': False,\n",
1290 | " 'store_outputs': False,\n",
1291 | " 'target': 'target',\n",
1292 | " 'target_multiplication_factor': None,\n",
1293 | " 'test_ratio': 0.1,\n",
1294 | " 'train_ratio': 0.8,\n",
1295 | " 'use_canonize': False,\n",
1296 | " 'val_ratio': 0.1,\n",
1297 | " 'version': '112bbedebdaecf59fb18e11c929080fb2f358246',\n",
1298 | " 'warmup_steps': 2000,\n",
1299 | " 'weight_decay': 1e-05,\n",
1300 | " 'write_checkpoint': True,\n",
1301 | " 'write_predictions': True}\n",
1302 | "/usr/local/lib/python3.8/dist-packages/torch/optim/lr_scheduler.py:131: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
1303 | " warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n",
1304 | "TrainLoss Epoch 0 total 130.12498193979263 out 7.480552700161934 atom 0.4205492946298318 grad 0.01733575374173365 stress 2.5176925698766857\n",
1305 | "Saving data for epoch: 0\n",
1306 | "ValLoss Epoch 0 total 0.9229794442653656 out 0.36248016357421875 atom 0.20344786350178765 grad 0.009460022839114873 stress 2.8547235106405164\n",
1307 | "TrainLoss Epoch 1 total 19.854396045207977 out 0.9548561573028564 atom 0.20680722037683308 grad 0.016850458171060415 stress 3.398183008581307\n",
1308 | "ValLoss Epoch 1 total 3.4113743901252747 out 1.8229036331176758 atom 0.25263089152722384 grad 0.008425975667720801 stress 2.863325897164085\n",
1309 | "TrainLoss Epoch 2 total 22.987701281905174 out 1.1715710639953614 atom 0.16559625868690392 grad 0.0169404602865175 stress 2.8884559912202747\n",
1310 | "ValLoss Epoch 2 total 1.0481610596179962 out 0.4372735023498535 atom 0.1825272994406987 grad 0.008752041482572772 stress 2.8566820792004335\n",
1311 | "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/loss.py:96: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
1312 | " return F.l1_loss(input, target, reduction=self.reduction)\n",
1313 | "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/loss.py:96: UserWarning: Using a target size (torch.Size([10, 3, 3])) that is different to the input size (torch.Size([3])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
1314 | " return F.l1_loss(input, target, reduction=self.reduction)\n",
1315 | "TestLoss 2 2.405210852622986\n"
1316 | ]
1317 | }
1318 | ]
1319 | },
1320 | {
1321 | "cell_type": "markdown",
1322 | "source": [
1323 | "The generated model is saved as best_model.pt"
1324 | ],
1325 | "metadata": {
1326 | "id": "M-6U-D-V40lQ"
1327 | }
1328 | },
1329 | {
1330 | "cell_type": "code",
1331 | "source": [
1332 | "!ls -altr temp"
1333 | ],
1334 | "metadata": {
1335 | "id": "qvgrKa9q4oBu",
1336 | "outputId": "8219734e-3a81-4a95-e85c-58cab593480b",
1337 | "colab": {
1338 | "base_uri": "https://localhost:8080/"
1339 | }
1340 | },
1341 | "execution_count": null,
1342 | "outputs": [
1343 | {
1344 | "output_type": "stream",
1345 | "name": "stdout",
1346 | "text": [
1347 | "total 110996\n",
1348 | "-rw-r--r-- 1 root root 48648565 Feb 7 17:07 checkpoint_2.pt\n",
1349 | "-rw-r--r-- 1 root root 48648565 Feb 7 17:07 checkpoint_3.pt\n",
1350 | "-rw-r--r-- 1 root root 244 Feb 7 17:07 prediction_results_test_set.csv\n",
1351 | "-rw-r--r-- 1 root root 222 Feb 7 17:07 prediction_results_train_set.csv\n",
1352 | "drwxr-xr-x 6 root root 4096 Feb 7 17:13 ..\n",
1353 | "-rw-r--r-- 1 root root 489 Feb 7 17:14 ids_train_val_test.json\n",
1354 | "-rw-r--r-- 1 root root 68 Feb 7 17:14 mad\n",
1355 | "-rw-r--r-- 1 root root 34 Feb 7 17:14 train_data_data_range\n",
1356 | "-rw-r--r-- 1 root root 34 Feb 7 17:14 val_data_data_range\n",
1357 | "-rw-r--r-- 1 root root 34 Feb 7 17:14 test_data_data_range\n",
1358 | "-rw-r--r-- 1 root root 1785 Feb 7 17:14 config.json\n",
1359 | "-rw-r--r-- 1 root root 16177149 Feb 7 17:15 best_model.pt\n",
1360 | "-rw-r--r-- 1 root root 98076 Feb 7 17:15 Train_results.json\n",
1361 | "-rw-r--r-- 1 root root 9479 Feb 7 17:15 Val_results.json\n",
1362 | "-rw-r--r-- 1 root root 249 Feb 7 17:17 history_train.json\n",
1363 | "-rw-r--r-- 1 root root 254 Feb 7 17:17 history_val.json\n",
1364 | "drwxr-xr-x 2 root root 4096 Feb 7 17:18 .\n",
1365 | "-rw-r--r-- 1 root root 14736 Feb 7 17:18 Test_results.json\n"
1366 | ]
1367 | }
1368 | ]
1369 | },
1370 | {
1371 | "cell_type": "markdown",
1372 | "source": [
1373 | "#Train a model for JARVIS-DFT 2D Exfoliation energy"
1374 | ],
1375 | "metadata": {
1376 | "id": "kC8yi5z7Sh3G"
1377 | }
1378 | },
1379 | {
1380 | "cell_type": "markdown",
1381 | "source": [
1382 | "There are quite a few datasets available here:https://jarvis-tools.readthedocs.io/en/master/databases.html\n",
1383 | "In the following example, we will use the JARVIS-DFT 2D dataset"
1384 | ],
1385 | "metadata": {
1386 | "id": "01bLUnEF4anq"
1387 | }
1388 | },
1389 | {
1390 | "cell_type": "markdown",
1391 | "source": [
1392 | "Get data in id_prop.csv format"
1393 | ],
1394 | "metadata": {
1395 | "id": "6e0uz_WaUmbE"
1396 | }
1397 | },
1398 | {
1399 | "cell_type": "code",
1400 | "source": [
1401 | "from jarvis.db.figshare import data as jdata\n",
1402 | "from jarvis.core.atoms import Atoms\n",
1403 | "import os\n",
1404 | "\n",
1405 | "cwd = os.getcwd() #current working directory\n",
1406 | "temp_dir_name = \"DataDir_ExfoEnergy\" \n",
1407 | "os.makedirs(temp_dir_name)\n",
1408 | "os.chdir(temp_dir_name)\n",
1409 | "\n",
1410 | "dft_3d = jdata(\"dft_3d\")\n",
1411 | "prop = \"exfoliation_energy\" #\"optb88vdw_bandgap\"\n",
1412 | "f = open(\"id_prop.csv\", \"w\")\n",
1413 | "# count = 0\n",
1414 | "for i in dft_3d:\n",
1415 | " atoms = Atoms.from_dict(i[\"atoms\"])\n",
1416 | " jid = i[\"jid\"]\n",
1417 | " poscar_name = \"POSCAR-\" + jid + \".vasp\"\n",
1418 | " target = i[prop]\n",
1419 | " if target != \"na\":\n",
1420 | " atoms.write_poscar(poscar_name)\n",
1421 | " f.write(\"%s,%6f\\n\" % (poscar_name, target))\n",
1422 | " # count += 1\n",
1423 | " # if count == max_samples:\n",
1424 | " # break\n",
1425 | "f.close()\n",
1426 | "\n",
1427 | "os.chdir(cwd)"
1428 | ],
1429 | "metadata": {
1430 | "colab": {
1431 | "base_uri": "https://localhost:8080/"
1432 | },
1433 | "id": "apVVRqVgSdbG",
1434 | "outputId": "8064210b-0fcb-4902-a006-c937af745711"
1435 | },
1436 | "execution_count": null,
1437 | "outputs": [
1438 | {
1439 | "output_type": "stream",
1440 | "name": "stdout",
1441 | "text": [
1442 | "Obtaining 3D dataset 76k ...\n",
1443 | "Reference:https://www.nature.com/articles/s41524-020-00440-1\n",
1444 | "Other versions:https://doi.org/10.6084/m9.figshare.6815699\n"
1445 | ]
1446 | },
1447 | {
1448 | "output_type": "stream",
1449 | "name": "stderr",
1450 | "text": [
1451 | "100%|██████████| 40.8M/40.8M [00:02<00:00, 15.8MiB/s]\n"
1452 | ]
1453 | },
1454 | {
1455 | "output_type": "stream",
1456 | "name": "stdout",
1457 | "text": [
1458 | "Loading the zipfile...\n",
1459 | "Loading completed.\n"
1460 | ]
1461 | }
1462 | ]
1463 | },
1464 | {
1465 | "cell_type": "code",
1466 | "metadata": {
1467 | "id": "v62Vzv2_2M2s",
1468 | "colab": {
1469 | "base_uri": "https://localhost:8080/"
1470 | },
1471 | "outputId": "d658ae8a-9514-4f7c-96c4-b317142459e3"
1472 | },
1473 | "source": [
1474 | "!ls -altr DataDir_ExfoEnergy/*.vasp | wc -l\n"
1475 | ],
1476 | "execution_count": null,
1477 | "outputs": [
1478 | {
1479 | "output_type": "stream",
1480 | "name": "stdout",
1481 | "text": [
1482 | "813\n"
1483 | ]
1484 | }
1485 | ]
1486 | },
1487 | {
1488 | "cell_type": "code",
1489 | "source": [
1490 | " !wc -l DataDir_ExfoEnergy/id_prop.csv "
1491 | ],
1492 | "metadata": {
1493 | "colab": {
1494 | "base_uri": "https://localhost:8080/"
1495 | },
1496 | "id": "SQW9wmpsToBR",
1497 | "outputId": "d409c5a2-33c9-4c02-b644-2ebe31153315"
1498 | },
1499 | "execution_count": null,
1500 | "outputs": [
1501 | {
1502 | "output_type": "stream",
1503 | "name": "stdout",
1504 | "text": [
1505 | "813 DataDir_ExfoEnergy/id_prop.csv\n"
1506 | ]
1507 | }
1508 | ]
1509 | },
1510 | {
1511 | "cell_type": "code",
1512 | "source": [
1513 | "import time\n",
1514 | "t1=time.time()\n",
1515 | "!train_folder.py --root_dir \"DataDir_ExfoEnergy\" --epochs 1 --batch_size 64 --config \"alignn/examples/sample_data/config_example.json\" --output_dir=\"ExfoEnOut\"\n",
1516 | "t2=time.time()\n",
1517 | "print ('Time in s',t2-t1)"
1518 | ],
1519 | "metadata": {
1520 | "colab": {
1521 | "base_uri": "https://localhost:8080/"
1522 | },
1523 | "id": "E8aEhUTHT-AV",
1524 | "outputId": "874753d9-fc17-4283-b961-d22ff555fe25"
1525 | },
1526 | "execution_count": null,
1527 | "outputs": [
1528 | {
1529 | "output_type": "stream",
1530 | "name": "stdout",
1531 | "text": [
1532 | "MAX val: 948.93\n",
1533 | "MIN val: 0.03\n",
1534 | "MAD: 62.629814227293544\n",
1535 | "Baseline MAE: 61.033631528964854\n",
1536 | "data range 948.93 0.03\n",
1537 | "100% 650/650 [00:13<00:00, 48.88it/s]\n",
1538 | "warning: could not load CGCNN features for 103\n",
1539 | "Setting it to max atomic number available here, 103\n",
1540 | "warning: could not load CGCNN features for 101\n",
1541 | "Setting it to max atomic number available here, 103\n",
1542 | "warning: could not load CGCNN features for 102\n",
1543 | "Setting it to max atomic number available here, 103\n",
1544 | "building line graphs\n",
1545 | "100% 650/650 [00:00<00:00, 1636.83it/s]\n",
1546 | "data range 388.51 18.3\n",
1547 | "100% 81/81 [00:01<00:00, 47.49it/s]\n",
1548 | "building line graphs\n",
1549 | "100% 81/81 [00:00<00:00, 977.37it/s]\n",
1550 | "data range 903.94 0.95\n",
1551 | "100% 81/81 [00:01<00:00, 47.50it/s]\n",
1552 | "building line graphs\n",
1553 | "100% 81/81 [00:00<00:00, 1042.25it/s]\n",
1554 | "n_train: 650\n",
1555 | "n_val: 81\n",
1556 | "n_test: 81\n",
1557 | "version='112bbedebdaecf59fb18e11c929080fb2f358246' dataset='user_data' target='target' atom_features='cgcnn' neighbor_strategy='k-nearest' id_tag='jid' random_seed=123 classification_threshold=None n_val=None n_test=None n_train=None train_ratio=0.8 val_ratio=0.1 test_ratio=0.1 target_multiplication_factor=None epochs=1 batch_size=64 weight_decay=1e-05 learning_rate=0.001 filename='sample' warmup_steps=2000 criterion='mse' optimizer='adamw' scheduler='onecycle' pin_memory=False save_dataloader=False write_checkpoint=True write_predictions=True store_outputs=True progress=True log_tensorboard=False standard_scalar_and_pca=False use_canonize=True num_workers=0 cutoff=8.0 max_neighbors=12 keep_data_order=False distributed=False n_early_stopping=None output_dir='ExfoEnOut' model=ALIGNNConfig(name='alignn', alignn_layers=4, gcn_layers=4, atom_input_features=92, edge_input_features=80, triplet_input_features=40, embedding_features=64, hidden_features=256, output_features=1, link='identity', zero_inflated=False, classification=False)\n",
1558 | "config:\n",
1559 | "{'atom_features': 'cgcnn',\n",
1560 | " 'batch_size': 64,\n",
1561 | " 'classification_threshold': None,\n",
1562 | " 'criterion': 'mse',\n",
1563 | " 'cutoff': 8.0,\n",
1564 | " 'dataset': 'user_data',\n",
1565 | " 'distributed': False,\n",
1566 | " 'epochs': 1,\n",
1567 | " 'filename': 'sample',\n",
1568 | " 'id_tag': 'jid',\n",
1569 | " 'keep_data_order': False,\n",
1570 | " 'learning_rate': 0.001,\n",
1571 | " 'log_tensorboard': False,\n",
1572 | " 'max_neighbors': 12,\n",
1573 | " 'model': {'alignn_layers': 4,\n",
1574 | " 'atom_input_features': 92,\n",
1575 | " 'classification': False,\n",
1576 | " 'edge_input_features': 80,\n",
1577 | " 'embedding_features': 64,\n",
1578 | " 'gcn_layers': 4,\n",
1579 | " 'hidden_features': 256,\n",
1580 | " 'link': 'identity',\n",
1581 | " 'name': 'alignn',\n",
1582 | " 'output_features': 1,\n",
1583 | " 'triplet_input_features': 40,\n",
1584 | " 'zero_inflated': False},\n",
1585 | " 'n_early_stopping': None,\n",
1586 | " 'n_test': None,\n",
1587 | " 'n_train': None,\n",
1588 | " 'n_val': None,\n",
1589 | " 'neighbor_strategy': 'k-nearest',\n",
1590 | " 'num_workers': 0,\n",
1591 | " 'optimizer': 'adamw',\n",
1592 | " 'output_dir': 'ExfoEnOut',\n",
1593 | " 'pin_memory': False,\n",
1594 | " 'progress': True,\n",
1595 | " 'random_seed': 123,\n",
1596 | " 'save_dataloader': False,\n",
1597 | " 'scheduler': 'onecycle',\n",
1598 | " 'standard_scalar_and_pca': False,\n",
1599 | " 'store_outputs': True,\n",
1600 | " 'target': 'target',\n",
1601 | " 'target_multiplication_factor': None,\n",
1602 | " 'test_ratio': 0.1,\n",
1603 | " 'train_ratio': 0.8,\n",
1604 | " 'use_canonize': True,\n",
1605 | " 'val_ratio': 0.1,\n",
1606 | " 'version': '112bbedebdaecf59fb18e11c929080fb2f358246',\n",
1607 | " 'warmup_steps': 2000,\n",
1608 | " 'weight_decay': 1e-05,\n",
1609 | " 'write_checkpoint': True,\n",
1610 | " 'write_predictions': True}\n",
1611 | "/usr/local/lib/python3.7/dist-packages/torch/amp/autocast_mode.py:198: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
1612 | " warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n",
1613 | "Val_MAE: 109.1306\n",
1614 | "Train_MAE: 120.3734\n",
1615 | "Test MAE: 114.4529100369524\n",
1616 | "Time taken (s): 148.4665551185608\n",
1617 | "Time in s 170.2420151233673\n"
1618 | ]
1619 | }
1620 | ]
1621 | },
1622 | {
1623 | "cell_type": "code",
1624 | "source": [
1625 | "!ls ExfoEnOut\t "
1626 | ],
1627 | "metadata": {
1628 | "colab": {
1629 | "base_uri": "https://localhost:8080/"
1630 | },
1631 | "id": "luExy9FkWJ81",
1632 | "outputId": "47c9a412-2ba4-41b6-c980-ccf1d110baa5"
1633 | },
1634 | "execution_count": null,
1635 | "outputs": [
1636 | {
1637 | "output_type": "stream",
1638 | "name": "stdout",
1639 | "text": [
1640 | "checkpoint_1.pt ids_train_val_test.json\t test_data_data_range\n",
1641 | "config.json\t mad\t\t\t\t train_data_data_range\n",
1642 | "history_train.json prediction_results_test_set.csv val_data_data_range\n",
1643 | "history_val.json prediction_results_train_set.csv\n"
1644 | ]
1645 | }
1646 | ]
1647 | },
1648 | {
1649 | "cell_type": "markdown",
1650 | "source": [
1651 | "Here checkpoints are the model parameter files that can be loaded in torch library to make predictions such as [this example](https://github.com/usnistgov/alignn/blob/main/alignn/scripts/predict.py)."
1652 | ],
1653 | "metadata": {
1654 | "id": "zzrjEIMKpdFS"
1655 | }
1656 | },
1657 | {
1658 | "cell_type": "code",
1659 | "source": [
1660 | "import matplotlib.pyplot as plt\n",
1661 | "%matplotlib inline\n",
1662 | "import pandas as pd\n",
1663 | "df = pd.read_csv('/content/alignn/ExfoEnOut/prediction_results_test_set.csv')"
1664 | ],
1665 | "metadata": {
1666 | "id": "EhzUYeD1oYpD"
1667 | },
1668 | "execution_count": null,
1669 | "outputs": []
1670 | },
1671 | {
1672 | "cell_type": "markdown",
1673 | "source": [
1674 | "These are predictions on 10 % held dataset that the model has never seen"
1675 | ],
1676 | "metadata": {
1677 | "id": "Tz-AsQnNo9-2"
1678 | }
1679 | },
1680 | {
1681 | "cell_type": "code",
1682 | "source": [
1683 | "df"
1684 | ],
1685 | "metadata": {
1686 | "colab": {
1687 | "base_uri": "https://localhost:8080/",
1688 | "height": 423
1689 | },
1690 | "id": "hGGnsXseo7Qz",
1691 | "outputId": "143d30cf-ee2f-4a5b-9095-3c060a47d082"
1692 | },
1693 | "execution_count": null,
1694 | "outputs": [
1695 | {
1696 | "output_type": "execute_result",
1697 | "data": {
1698 | "text/plain": [
1699 | " id target prediction\n",
1700 | "0 POSCAR-JVASP-12918.vasp 27.170000 -6.187871\n",
1701 | "1 POSCAR-JVASP-2035.vasp 82.290001 -7.342460\n",
1702 | "2 POSCAR-JVASP-13942.vasp 87.809998 -7.392377\n",
1703 | "3 POSCAR-JVASP-278.vasp 144.320007 -7.004450\n",
1704 | "4 POSCAR-JVASP-10173.vasp 33.700001 -6.653560\n",
1705 | ".. ... ... ...\n",
1706 | "76 POSCAR-JVASP-4364.vasp 54.290001 -8.038940\n",
1707 | "77 POSCAR-JVASP-29480.vasp 78.639999 -8.734681\n",
1708 | "78 POSCAR-JVASP-28375.vasp 55.480000 -7.636129\n",
1709 | "79 POSCAR-JVASP-590.vasp 88.519997 -7.158464\n",
1710 | "80 POSCAR-JVASP-4741.vasp 226.220001 -5.179449\n",
1711 | "\n",
1712 | "[81 rows x 3 columns]"
1713 | ],
1714 | "text/html": [
1715 | "\n",
1716 | "
| \n", 1736 | " | id | \n", 1737 | "target | \n", 1738 | "prediction | \n", 1739 | "
|---|---|---|---|
| 0 | \n", 1744 | "POSCAR-JVASP-12918.vasp | \n", 1745 | "27.170000 | \n", 1746 | "-6.187871 | \n", 1747 | "
| 1 | \n", 1750 | "POSCAR-JVASP-2035.vasp | \n", 1751 | "82.290001 | \n", 1752 | "-7.342460 | \n", 1753 | "
| 2 | \n", 1756 | "POSCAR-JVASP-13942.vasp | \n", 1757 | "87.809998 | \n", 1758 | "-7.392377 | \n", 1759 | "
| 3 | \n", 1762 | "POSCAR-JVASP-278.vasp | \n", 1763 | "144.320007 | \n", 1764 | "-7.004450 | \n", 1765 | "
| 4 | \n", 1768 | "POSCAR-JVASP-10173.vasp | \n", 1769 | "33.700001 | \n", 1770 | "-6.653560 | \n", 1771 | "
| ... | \n", 1774 | "... | \n", 1775 | "... | \n", 1776 | "... | \n", 1777 | "
| 76 | \n", 1780 | "POSCAR-JVASP-4364.vasp | \n", 1781 | "54.290001 | \n", 1782 | "-8.038940 | \n", 1783 | "
| 77 | \n", 1786 | "POSCAR-JVASP-29480.vasp | \n", 1787 | "78.639999 | \n", 1788 | "-8.734681 | \n", 1789 | "
| 78 | \n", 1792 | "POSCAR-JVASP-28375.vasp | \n", 1793 | "55.480000 | \n", 1794 | "-7.636129 | \n", 1795 | "
| 79 | \n", 1798 | "POSCAR-JVASP-590.vasp | \n", 1799 | "88.519997 | \n", 1800 | "-7.158464 | \n", 1801 | "
| 80 | \n", 1804 | "POSCAR-JVASP-4741.vasp | \n", 1805 | "226.220001 | \n", 1806 | "-5.179449 | \n", 1807 | "
81 rows × 3 columns
\n", 1811 | "