├── tests ├── __init__.py ├── data │ └── matbench_phonons.json.gz ├── wren │ └── test_utils.py ├── test_package.py ├── conftest.py ├── test_core.py ├── test_print_metrics.py ├── test_utils.py └── test_wrenformer.py ├── aviary ├── wren │ ├── __init__.py │ └── utils.py ├── cgcnn │ └── __init__.py ├── roost │ ├── __init__.py │ ├── data.py │ └── model.py ├── wrenformer │ └── __init__.py ├── __init__.py ├── embeddings │ ├── wyckoff │ │ └── README.md │ └── element │ │ └── README.md ├── losses.py ├── scatter.py ├── networks.py ├── data.py └── segments.py ├── .github └── workflows │ ├── link-check-config.json │ ├── link-check.yml │ ├── publish.yml │ └── test.yml ├── citation.cff ├── examples ├── inputs │ ├── raw │ │ ├── 0fd347fb29c23cad30d42852d7aa4ed906102b8a.poscar │ │ ├── b3603cef0816d0e115adeafa4eee65c056dc7a26.poscar │ │ ├── b428ccf0f8cd5488a3af5d72b74e117dc71aa575.poscar │ │ ├── c07b988125b3a0adc329b2449eeaac24c9e7def8.poscar │ │ ├── 0c9cf664171d7e10527033cda1551694141207ed.poscar │ │ ├── 0c5ef0e6b20d6e34ae388d640797f032a15df918.poscar │ │ ├── 006047d5cfe39157f2904d2533bad74b2ddd3170.poscar │ │ ├── 0a2c3a235656a97d76f731b269982f92de9b37fd.poscar │ │ ├── 0abb6669aa25624d8371765c5afa6b80ae20f30b.poscar │ │ ├── 006a8e78e4bb5c1570bf70b47b24c0b855a5b5b5.poscar │ │ ├── 0a2e1f340255f62a617974d778714e5fdd28a007.poscar │ │ ├── 000fe0f2a4730347df688bd0d3c75f4494a28dca.poscar │ │ ├── 00af9b9350147b8c20498eb9cd65b63a3132c741.poscar │ │ ├── 6f7d6a970fde2a4243091aba720fc14ceec2854b.poscar │ │ ├── 0c7147c05a270b9b2744576b0bfb69c8391ee1e4.poscar │ │ ├── 0060794138df3ba93c68007949b0a72d19cb4c62.poscar │ │ ├── 0c2fc1da358bcddfb4baf41b5dfd7285a13b766a.poscar │ │ ├── 0cd142efcf69ba63dc62760ecd72fb4f1e930115.poscar │ │ ├── 004c70f544dca7f8e1c74c6c4ff87a1e355e7cd5.poscar │ │ ├── 0c5e41fdc8fba8e90b48cbe4752dc2c4a2805f21.poscar │ │ ├── 0a74297fb11027ba2a958da5ee31980b8876e1d4.poscar │ │ ├── 0ac7f92ee7225e47dbdc25f6d60a10b116b0813b.poscar │ │ ├── 0c13902f8200732555ad1603e954a726d4d24d85.poscar │ │ ├── 0c9c74b6f48facd507c4e41d9df797a2017634e7.poscar │ │ ├── 0ad7d4c22ba4414fbe161f27f42c112c6cb83b29.poscar │ │ ├── 0c710ee221d7567cc0f1a06fba8231ae8261a181.poscar │ │ ├── 0ac3201405a380e382377a71216b33cde892b161.poscar │ │ ├── 0c27147781fc9c64cbcf45acc1575a0bf8070841.poscar │ │ ├── 0c5264c980964b88ae8a7c2c18d5a61daff481b1.poscar │ │ ├── 0a58ef8a37185d8abcbd6e1bad7eaf7e3d3d6016.poscar │ │ ├── 0c9cbed1e5d755a9666db36f742004c14bf780bd.poscar │ │ ├── 0079b9e3ab0e210ea4afe037a9faee8cd6577922.poscar │ │ ├── 00bbba1d86168727008d4c1542b39d45d93d693c.poscar │ │ ├── 0a3eb2074380f4abdec94d4a6171a099c9ecb30a.poscar │ │ ├── 0a646e174268b493bd67f58fdf52fead02f74a96.poscar │ │ ├── 00ceac1a8da8e554218a84def07e5a6da9f106f3.poscar │ │ ├── 0a5af69c21e6c2d9b004c2f51939790b10b19814.poscar │ │ ├── 0a0ae9c18de23e72bb764e10ef542de693321542.poscar │ │ ├── 0c450e6f9c256e33ef452c98dd20917fb6484317.poscar │ │ ├── 0059fd6dd1d17138c42bb5430fa0bc074053f0de.poscar │ │ ├── 00fa7d75b7b2e1442df0a68071210c8d64b93ebe.poscar │ │ ├── 0cd1f1670176b03cc3d61935d9a6882c9508e33a.poscar │ │ ├── 00237cf70206a5de77e5154ea72633514cdd4445.poscar │ │ ├── 0c882a0621189d18ff1350e3bed4114868663acf.poscar │ │ ├── 0ad8980cb8eb0693647f411732b2165b6a0fce0e.poscar │ │ ├── 009ee0451531276f598976f46848fe7d590262ca.poscar │ │ ├── 0c5ffa8e2b3126eba02901c37b31e8faf3b5377c.poscar │ │ ├── 0a458b5f2eeebeb0e4e55e1d1a17988ec87b7123.poscar │ │ ├── 0ac8a65af75e8223b07dcd98c6958db7af6504cf.poscar │ │ ├── 0a130b596d1ce7bfe317d9e546bbeacc703fef65.poscar │ │ ├── 0a4a721fb8f0996350edd4fb224d89ce213f7f93.poscar │ │ ├── 0c429f98b79c9522a11d51fbfae484004f899422.poscar │ │ ├── 001a938f944fa33e899b81e3d0a909876587e7b0.poscar │ │ ├── 0039f3c461ce4a2cbf4c48b09bfc3c6b13a5c769.poscar │ │ ├── 0abac62147275aa8199465764751eaa67c94d16c.poscar │ │ ├── 000866a75fbc0a305808c11c58342bb973c3cad5.poscar │ │ ├── 00a8d0f6f55191596770605314ff347995a3540d.poscar │ │ ├── 0ab8ebf80d94a1ba826857a9e81cea8c7124d86e.poscar │ │ ├── 0019407b910a29730999dab65491b4e50624c219.poscar │ │ ├── 0c0b4fd84c8f4e075436d7ec34bfb3fad5e610be.poscar │ │ ├── 00a3259f00f8a9aa503caae2045970d2f7b22e13.poscar │ │ ├── 0a2c4708d3c853ea3c4b8630027afe9b788ee2ce.poscar │ │ ├── 00cca7ad273875ba80f1d437af797b6a01cc6ebd.poscar │ │ ├── 0a946dc02260d0d4b36ef2087e60b415a5fff800.poscar │ │ ├── 0ab8cc7e5cc3d48fcc786f9060da5a8882374f44.poscar │ │ ├── 00bada6787af6463447a1460da0c031d51494b60.poscar │ │ ├── 0adeaf7a9c16772b66815cce19917ef0460affe8.poscar │ │ ├── 00ecadb66a8baab030675a817d2ee15f0002a4dc.poscar │ │ ├── 00021364446a637881257fd9ee912a422a6b1753.poscar │ │ ├── 006197ce31efdef23c82f0f6cec08ac0b0febdce.poscar │ │ ├── 0069751a4ba575438917a8b4e5a565b686388f6e.poscar │ │ ├── 0ad3ae37a9fb7eab9189ff1cb384b9d2dd26bbf9.poscar │ │ ├── 0a10bbf9e7cd3d8da2469f15e06ceca7d8d2af5d.poscar │ │ └── 0ad7c1ba7ec6a37cae0523bbd1c2a75f2d56a3f2.poscar │ ├── poscar_to_df.py │ └── examples.csv └── matbench_example │ ├── prepare_matbench_datasets.py │ ├── readme.md │ ├── train_wrenformer.py │ └── make_plots.py ├── .gitignore ├── LICENSE ├── CONTRIBUTING.md ├── .pre-commit-config.yaml ├── pyproject.toml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aviary/wren/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aviary/cgcnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aviary/roost/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aviary/wrenformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/data/matbench_phonons.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompRhys/aviary/HEAD/tests/data/matbench_phonons.json.gz -------------------------------------------------------------------------------- /.github/workflows/link-check-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "aliveStatusCodes": [ 3 | 200, 4 | 403, 5 | 503 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /aviary/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | from os.path import abspath, dirname 3 | 4 | __version__ = importlib.metadata.version("aviary-models") 5 | PKG_DIR = dirname(abspath(__file__)) 6 | ROOT = dirname(PKG_DIR) 7 | -------------------------------------------------------------------------------- /aviary/wren/utils.py: -------------------------------------------------------------------------------- 1 | def __getattr__(name): 2 | raise ImportError( 3 | "The functionality from aviary.wren.utils has been moved to pymatgen. " 4 | "Please install pymatgen using 'pip install pymatgen>2025.3.10' to " 5 | "access these features." 6 | ) 7 | -------------------------------------------------------------------------------- /tests/wren/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_utils_import_error(): 5 | with pytest.raises(ImportError) as exc_info: 6 | from aviary.wren.utils import relab_dict # noqa: F401 7 | 8 | assert "functionality from aviary.wren.utils has been moved to pymatgen" in str( 9 | exc_info.value 10 | ) 11 | -------------------------------------------------------------------------------- /aviary/embeddings/wyckoff/README.md: -------------------------------------------------------------------------------- 1 | # Wyckoff Position Embeddings 2 | 3 | ## bra-alg-off 4 | 5 | A 6 + 5 + 185 + 248 = 444 dimensional embedding encoding 6 | 7 | - the crystal system 8 | - unit cell centering 9 | - the sum of whether sites in a wyckoff position sit on lines, planes or in volumes 10 | - the specific offset of that wyckoff position within the unit cell 11 | -------------------------------------------------------------------------------- /citation.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | authors: 5 | - given-names: Rhys 6 | family-names: Goodall 7 | orcid: https://orcid.org/0000-0002-6589-1700 8 | - given-names: Janosh 9 | family-names: Riebesell 10 | orcid: https://orcid.org/0000-0001-5233-3462 11 | title: "aviary" 12 | doi: 10.5281/zenodo.1234 13 | date-released: 2024-11-09 14 | url: "https://github.com/CompRhys/aviary" 15 | -------------------------------------------------------------------------------- /.github/workflows/link-check.yml: -------------------------------------------------------------------------------- 1 | name: Link check 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | markdown-link-check: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Check out repo 14 | uses: actions/checkout@v4 15 | 16 | - name: Run markdown link check 17 | uses: gaurav-nelson/github-action-markdown-link-check@v1 18 | # docs at https://git.io/JBaKu 19 | with: 20 | config-file: .github/workflows/link-check-config.json 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0fd347fb29c23cad30d42852d7aa4ed906102b8a.poscar: -------------------------------------------------------------------------------- 1 | -.17043474E+02 '002/ht.task.gamma_devel--default.0fd347fb29c23cad30d42852d7aa4ed906102b8a.cleanup.0.unclaimed.1.finished/ht.run.2015-11-30_09.47.37' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426988. Zr : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).'] 2 | 1 3 | 2.800434 -1.616831 0.000000 4 | 0.000000 3.233663 0.000000 5 | 0.000000 0.000000 5.166609 6 | Zr 7 | 2 8 | Direct 9 | 0.33333333 0.66666667 0.75000000 10 | 0.66666667 0.33333333 0.25000000 11 | -------------------------------------------------------------------------------- /examples/inputs/raw/b3603cef0816d0e115adeafa4eee65c056dc7a26.poscar: -------------------------------------------------------------------------------- 1 | -.21842114E+01 '002/ht.task.gamma_devel--default.b3603cef0816d0e115adeafa4eee65c056dc7a26.cleanup.0.unclaimed.1.finished/ht.run.2015-11-24_17.17.33' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426987. Zn : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).', '(Tag) type: pure'] 2 | 1 3 | 2.443685 -1.410862 0.000000 4 | 0.000000 2.821725 -0.000000 5 | 0.000000 -0.000000 4.305149 6 | Zn 7 | 2 8 | Direct 9 | 0.33333333 0.66666667 0.75000000 10 | 0.66666667 0.33333333 0.25000000 11 | -------------------------------------------------------------------------------- /examples/inputs/raw/b428ccf0f8cd5488a3af5d72b74e117dc71aa575.poscar: -------------------------------------------------------------------------------- 1 | -.15606022E+02 '002/ht.task.gamma_devel--default.b428ccf0f8cd5488a3af5d72b74e117dc71aa575.cleanup.0.unclaimed.1.finished/ht.run.2015-11-24_17.24.06' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426981. Ti : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).', '(Tag) type: pure'] 2 | 1 3 | 2.540600 -1.466816 0.000000 4 | 0.000000 2.933632 0.000000 5 | 0.000000 0.000000 4.657195 6 | Ti 7 | 2 8 | Direct 9 | 0.33333333 0.66666667 0.75000000 10 | 0.66666667 0.33333333 0.25000000 11 | -------------------------------------------------------------------------------- /examples/inputs/raw/c07b988125b3a0adc329b2449eeaac24c9e7def8.poscar: -------------------------------------------------------------------------------- 1 | -.19847840E+02 '002/ht.task.triolith--default.c07b988125b3a0adc329b2449eeaac24c9e7def8.cleanup.0.unclaimed.1.finished/ht.run.2015-12-10_10.57.55' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426944. Hf : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).', '(Tag) type: pure'] 2 | 1 3 | 2.771359 -1.600045 0.000000 4 | 0.000000 3.200090 0.000000 5 | 0.000000 0.000000 5.055614 6 | Hf 7 | 2 8 | Direct 9 | 0.33333333 0.66666667 0.75000000 10 | 0.66666667 0.33333333 0.25000000 11 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | tests: 9 | uses: ./.github/workflows/test.yml 10 | 11 | release: 12 | runs-on: ubuntu-latest 13 | needs: tests 14 | if: needs.tests.result == 'success' 15 | permissions: 16 | id-token: write 17 | 18 | steps: 19 | - name: Check out repo 20 | uses: actions/checkout@v4 21 | 22 | - name: Setup uv 23 | uses: astral-sh/setup-uv@v6 24 | 25 | - name: Build package with uv 26 | run: uv build 27 | 28 | - name: Publish package distributions to PyPI with uv 29 | run: uv publish -t ${{ secrets.PYPI_TOKEN }} 30 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c9cf664171d7e10527033cda1551694141207ed.poscar: -------------------------------------------------------------------------------- 1 | -.27483177E+02 '004/ht.task.triolith--default.0c9cf664171d7e10527033cda1551694141207ed.cleanup.0.unclaimed.2.finished/ht.run.2015-11-21_17.02.04' ['(Tag) original structure: AuKO2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/47/3_47_1aA1fB1nC.cif'] 2 | 1 3 | 2.999577 0.000000 0.000000 4 | 0.000000 5.448017 0.000000 5 | 0.000000 0.000000 2.860144 6 | Zr Zn N 7 | 1 1 2 8 | Direct 9 | 0.50000000 0.50000000 0.00000000 10 | 0.00000000 0.00000000 0.00000000 11 | 0.00000000 0.28297953 0.50000000 12 | 0.00000000 0.71702047 0.50000000 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # cache & build 2 | __pycache__/ 3 | *.egg-info 4 | .mypy_cache/ 5 | build/ 6 | 7 | # vscode 8 | .vscode/ 9 | 10 | # notebooks 11 | .ipynb_checkpoints/ 12 | 13 | # TensorBoard logs, model predictions and checkpoints 14 | runs/ 15 | results/ 16 | models/ 17 | 18 | # exclude data files and cached graphs 19 | *.pkl 20 | *.pdf 21 | *.json.gz 22 | *.json.bz2 23 | 24 | # HPC 25 | slurm-*.out 26 | plots/ 27 | datasets/ 28 | pds/ 29 | manuscript/ 30 | voro-thesis/ 31 | 32 | # run artifacts like model preds, checkpoints, metrics and slurm job logs 33 | examples/**/model_preds/ 34 | examples/**/model_scores/ 35 | examples/**/job-logs/ 36 | examples/**/artifacts/ 37 | examples/**/*.csv 38 | wandb/ 39 | 40 | # profiling 41 | *.prof 42 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c5ef0e6b20d6e34ae388d640797f032a15df918.poscar: -------------------------------------------------------------------------------- 1 | -.30885650E+02 '005/ht.task.triolith--default.0c5ef0e6b20d6e34ae388d640797f032a15df918.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_15.09.03' ['(Tag) original structure: AuCl3Cs', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/215/3_215_1aA1bB1cC.cif'] 2 | 1 3 | 3.779652 0.000000 0.000000 4 | 0.000000 3.779652 0.000000 5 | 0.000000 0.000000 3.779652 6 | Zr Zn N 7 | 1 1 3 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.50000000 0.50000000 0.50000000 11 | 0.00000000 0.50000000 0.50000000 12 | 0.50000000 0.00000000 0.50000000 13 | 0.50000000 0.50000000 0.00000000 14 | -------------------------------------------------------------------------------- /examples/inputs/raw/006047d5cfe39157f2904d2533bad74b2ddd3170.poscar: -------------------------------------------------------------------------------- 1 | -.40933125E+02 '006/ht.task.triolith--default.006047d5cfe39157f2904d2533bad74b2ddd3170.cleanup.0.unclaimed.2.finished/ht.run.2015-12-10_22.22.59' ['(Tag) original structure: CsCuO', '(Tag) original structure path: structures/all/matches/ternaries/63/3_63_1aA1cB1cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 1.749531 -3.048103 0.000000 4 | 1.749531 3.048103 -0.000000 5 | 0.000000 -0.000000 8.634492 6 | Zn Hf N 7 | 2 2 2 8 | Direct 9 | 0.66982580 0.33017420 0.25000000 10 | 0.33017420 0.66982580 0.75000000 11 | -0.00000000 0.00000000 -0.00000000 12 | -0.00000000 0.00000000 0.50000000 13 | 0.00158956 0.99841044 0.25000000 14 | 0.99841044 0.00158956 0.75000000 15 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a2c3a235656a97d76f731b269982f92de9b37fd.poscar: -------------------------------------------------------------------------------- 1 | -.38204233E+02 '006/ht.task.gamma--default.0a2c3a235656a97d76f731b269982f92de9b37fd.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_14.53.55' ['(Tag) original structure: GaIY', '(Tag) original structure path: structures/all/matches/ternaries/164/3_164_1cA1dB1dC.cif', '(Tag) type: ternary'] 2 | 1 3 | 2.730114 -1.576232 -0.000000 4 | 0.000000 3.152464 0.000000 5 | -0.000000 0.000000 11.873833 6 | Ti N Zn 7 | 2 2 2 8 | Direct 9 | 0.00000000 -0.00000000 0.28482879 10 | 0.00000000 -0.00000000 0.71517121 11 | 0.33333333 0.66666667 0.33039369 12 | 0.66666667 0.33333333 0.66960631 13 | 0.66666667 0.33333333 0.08508663 14 | 0.33333333 0.66666667 0.91491337 15 | -------------------------------------------------------------------------------- /examples/inputs/raw/0abb6669aa25624d8371765c5afa6b80ae20f30b.poscar: -------------------------------------------------------------------------------- 1 | -.46777870E+02 '006/ht.task.gamma--default.0abb6669aa25624d8371765c5afa6b80ae20f30b.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_01.48.20' ['(Tag) original structure: CoO3Sr2', '(Tag) original structure path: structures/all/matches/ternaries/12/4_12_1cA1dB1iB1iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 2.082063 -2.081320 0.050838 4 | 2.082063 2.081320 0.050838 5 | -1.049178 0.000000 9.996562 6 | Ti Zn N 7 | 2 1 3 8 | Direct 9 | 0.46125530 0.46125530 0.29720372 10 | 0.53874470 0.53874470 0.70279628 11 | 0.00000000 0.00000000 0.50000000 12 | 0.50000000 0.50000000 0.50000000 13 | 0.95590632 0.95590632 0.27067391 14 | 0.04409368 0.04409368 0.72932609 15 | -------------------------------------------------------------------------------- /examples/inputs/raw/006a8e78e4bb5c1570bf70b47b24c0b855a5b5b5.poscar: -------------------------------------------------------------------------------- 1 | -.23913820E+02 '006/ht.task.triolith--default.006a8e78e4bb5c1570bf70b47b24c0b855a5b5b5.cleanup.0.unclaimed.2.finished/ht.run.2015-12-10_19.42.57' ['(Tag) original structure: Br2Ca3Si', '(Tag) original structure path: structures/all/matches/ternaries/139/4_139_1aA1bB1eC1eA.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.768322 0.000000 -0.000000 4 | 0.000000 3.768322 -0.000000 5 | 1.884161 1.884161 6.517125 6 | Zn Hf N 7 | 3 1 2 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.70751047 0.70751047 0.58497906 11 | 0.29248953 0.29248953 0.41502094 12 | 0.50000000 0.50000000 0.00000000 13 | 0.14233929 0.14233929 0.71532143 14 | 0.85766071 0.85766071 0.28467857 15 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a2e1f340255f62a617974d778714e5fdd28a007.poscar: -------------------------------------------------------------------------------- 1 | -.35119632E+02 '008/ht.task.gamma--default.0a2e1f340255f62a617974d778714e5fdd28a007.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_00.54.14' ['(Tag) original structure: CsN2Nb', '(Tag) original structure path: structures/all/matches/ternaries/227/3_227_1aA1bB1cC.cif', '(Tag) type: ternary'] 2 | 1 3 | -4.226700 -4.226700 0.000000 4 | -4.226700 0.000000 -4.226700 5 | 0.000000 -4.226700 -4.226700 6 | Zn Ti N 7 | 2 2 4 8 | Direct 9 | 0.87500000 0.87500000 0.87500000 10 | 0.12500000 0.12500000 0.12500000 11 | 0.62500000 0.62500000 0.62500000 12 | 0.37500000 0.37500000 0.37500000 13 | 0.00000000 0.00000000 0.00000000 14 | 0.00000000 0.50000000 0.00000000 15 | 0.50000000 0.00000000 0.00000000 16 | 0.00000000 0.00000000 0.50000000 17 | -------------------------------------------------------------------------------- /examples/inputs/raw/000fe0f2a4730347df688bd0d3c75f4494a28dca.poscar: -------------------------------------------------------------------------------- 1 | -.60754153E+02 '008/ht.task.triolith--default.000fe0f2a4730347df688bd0d3c75f4494a28dca.cleanup.0.unclaimed.2.finished/ht.run.2016-04-28_16.45.35' ['(Tag) original structure: CsSe2Yb', '(Tag) original structure path: structures/all/matches/ternaries/194/3_194_1aA1cB1fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 2.724122 -1.572773 0.000000 4 | 0.000000 3.145546 0.000000 5 | 0.000000 0.000000 10.570134 6 | Hf Zn N 7 | 2 2 4 8 | Direct 9 | 0.33333333 0.66666667 0.25000000 10 | 0.66666667 0.33333333 0.75000000 11 | 0.00000000 0.00000000 0.00000000 12 | 0.00000000 0.00000000 0.50000000 13 | 0.33333333 0.66666667 0.62866178 14 | 0.66666667 0.33333333 0.12866178 15 | 0.33333333 0.66666667 0.87133822 16 | 0.66666667 0.33333333 0.37133822 17 | -------------------------------------------------------------------------------- /examples/inputs/raw/00af9b9350147b8c20498eb9cd65b63a3132c741.poscar: -------------------------------------------------------------------------------- 1 | -.57885965E+02 '008/ht.task.triolith--default.00af9b9350147b8c20498eb9cd65b63a3132c741.cleanup.0.unclaimed.2.finished/ht.run.2016-04-28_17.15.45' ['(Tag) original structure: FeSe2Tl', '(Tag) original structure path: structures/all/matches/ternaries/12/4_12_1gA1iB2iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.845331 -2.256375 1.002936 4 | 5.845331 2.256375 1.002936 5 | -2.359010 0.000000 3.498230 6 | Hf Zn N 7 | 2 2 4 8 | Direct 9 | 0.19931503 0.19931503 0.63601203 10 | 0.80068497 0.80068497 0.36398797 11 | 0.74375205 0.25624795 0.00000000 12 | 0.25624795 0.74375205 0.00000000 13 | 0.63229337 0.63229337 0.41424153 14 | 0.36770663 0.36770663 0.58575847 15 | 0.15059160 0.15059160 0.05184831 16 | 0.84940840 0.84940840 0.94815169 17 | -------------------------------------------------------------------------------- /tests/test_package.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | import pytest 5 | 6 | from aviary import ROOT 7 | 8 | package_sources_path = f"{ROOT}/aviary.egg-info/SOURCES.txt" 9 | 10 | 11 | @pytest.mark.skipif( 12 | not os.path.isfile(package_sources_path), 13 | reason="No aviary.egg-info/SOURCES.txt file, run pip install . to create it", 14 | ) 15 | def test_egg_sources(): 16 | """Check we're correctly packaging all JSON files under aviary/ to prevent issues 17 | like https://github.com/CompRhys/aviary/pull/45. 18 | 19 | This test can fail due to outdated SOURCES.txt. Try `pip install -e .` to update. 20 | """ 21 | with open(package_sources_path) as file: 22 | sources = file.read() 23 | 24 | for filepath in glob(f"{ROOT}/aviary/**/*.json", recursive=True): 25 | rel_path = filepath.split(f"{ROOT}/aviary/")[1] 26 | assert rel_path in sources 27 | -------------------------------------------------------------------------------- /examples/inputs/raw/6f7d6a970fde2a4243091aba720fc14ceec2854b.poscar: -------------------------------------------------------------------------------- 1 | -.66733607E+02 '008/ht.task.triolith--default.6f7d6a970fde2a4243091aba720fc14ceec2854b.cleanup.0.unclaimed.1.finished/ht.run.2015-12-10_11.08.37' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426956. N : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).', '(Tag) type: pure'] 2 | 1 3 | 5.946714 -0.000000 -0.000000 4 | 0.000000 5.946714 -0.000000 5 | 0.000000 -0.000000 5.946714 6 | N 7 | 8 8 | Direct 9 | 0.55402836 0.55402836 0.55402836 10 | 0.05402836 0.94597164 0.44597164 11 | 0.55402836 0.94597164 0.05402836 12 | 0.94597164 0.44597164 0.05402836 13 | 0.94597164 0.05402836 0.55402836 14 | 0.44597164 0.05402836 0.94597164 15 | 0.05402836 0.55402836 0.94597164 16 | 0.44597164 0.44597164 0.44597164 17 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c7147c05a270b9b2744576b0bfb69c8391ee1e4.poscar: -------------------------------------------------------------------------------- 1 | -.56738203E+02 '009/ht.task.triolith--default.0c7147c05a270b9b2744576b0bfb69c8391ee1e4.cleanup.0.unclaimed.2.finished/ht.run.2015-11-25_10.02.52' ['(Tag) original structure: CdO6Ti2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/162/3_162_1aA1dB1kC.cif'] 2 | 1 3 | 4.813812 -2.779256 0.000000 4 | -0.000000 5.558512 0.000000 5 | 0.000000 -0.000000 3.745624 6 | Zr Zn N 7 | 1 2 6 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.33333333 0.66666667 0.50000000 11 | 0.66666667 0.33333333 0.50000000 12 | 0.41946521 0.00000000 0.12157999 13 | 0.58053479 1.00000000 0.87842001 14 | 0.00000000 0.41946521 0.12157999 15 | 0.58053479 0.58053479 0.12157999 16 | 0.41946521 0.41946521 0.87842001 17 | 0.00000000 0.58053479 0.87842001 18 | -------------------------------------------------------------------------------- /examples/inputs/raw/0060794138df3ba93c68007949b0a72d19cb4c62.poscar: -------------------------------------------------------------------------------- 1 | -.56928156E+02 '010/ht.task.triolith--default.0060794138df3ba93c68007949b0a72d19cb4c62.cleanup.0.unclaimed.2.finished/ht.run.2016-04-28_01.39.04' ['(Tag) original structure: C3O6Y', '(Tag) original structure path: structures/all/matches/ternaries/160/4_160_1aA1bB2bC.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.062902 -0.000000 1.294977 4 | -3.031451 5.250627 1.294978 5 | -3.031451 -5.250627 1.294978 6 | Hf N Zn 7 | 1 6 3 8 | Direct 9 | 0.07941259 0.07941259 0.07941259 10 | 0.97170032 0.35461423 0.97170032 11 | 0.97170032 0.97170032 0.35461423 12 | 0.35461423 0.97170033 0.97170033 13 | 0.62831402 0.48486499 0.62831402 14 | 0.62831402 0.62831402 0.48486499 15 | 0.48486498 0.62831401 0.62831401 16 | 0.43962517 0.92332917 0.43962517 17 | 0.43962517 0.43962517 0.92332917 18 | 0.92332917 0.43962517 0.43962517 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c2fc1da358bcddfb4baf41b5dfd7285a13b766a.poscar: -------------------------------------------------------------------------------- 1 | -.66952114E+02 '010/ht.task.triolith--default.0c2fc1da358bcddfb4baf41b5dfd7285a13b766a.cleanup.0.unclaimed.2.finished/ht.run.2015-11-27_01.25.28' ['(Tag) original structure: CoLaO3', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/15/4_15_1cA1eB1eC1fC.cif'] 2 | 1 3 | 4.971455 0.000000 -0.378972 4 | 0.000000 5.742631 -0.000000 5 | 2.166218 2.871315 3.580151 6 | Zr Zn N 7 | 2 2 6 8 | Direct 9 | 0.25000000 0.25874194 0.00000000 10 | 0.75000000 0.74125806 1.00000000 11 | 0.50000000 0.00000000 0.50000000 12 | 0.00000000 0.50000000 0.50000000 13 | 0.25000000 0.63406163 1.00000000 14 | 0.75000000 0.36593837 0.00000000 15 | 0.87451404 0.86599373 0.39235106 16 | 0.12548596 0.13400627 0.60764894 17 | 0.62548596 0.25834479 0.60764894 18 | 0.37451404 0.74165521 0.39235106 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/0cd142efcf69ba63dc62760ecd72fb4f1e930115.poscar: -------------------------------------------------------------------------------- 1 | -.57088230E+02 '010/ht.task.triolith--default.0cd142efcf69ba63dc62760ecd72fb4f1e930115.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_16.16.25' ['(Tag) original structure: Ca2FeN2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/12/5_12_1iA2iB2iC.cif'] 2 | 1 3 | 5.503549 -2.105416 0.306120 4 | 5.503549 2.105416 0.306120 5 | -3.212949 0.000000 5.665865 6 | Zr Zn N 7 | 2 4 4 8 | Direct 9 | 0.33909577 0.33909577 0.88821450 10 | 0.66090423 0.66090423 0.11178550 11 | 0.01117193 0.01117193 0.73206145 12 | 0.98882807 0.98882807 0.26793855 13 | 0.33875302 0.33875302 0.40873122 14 | 0.66124698 0.66124698 0.59126878 15 | 0.51840336 0.51840336 0.23739812 16 | 0.48159664 0.48159664 0.76260188 17 | 0.15707630 0.15707630 0.59729982 18 | 0.84292370 0.84292370 0.40270018 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/004c70f544dca7f8e1c74c6c4ff87a1e355e7cd5.poscar: -------------------------------------------------------------------------------- 1 | -.54542103E+02 '010/ht.task.triolith--default.004c70f544dca7f8e1c74c6c4ff87a1e355e7cd5.cleanup.0.unclaimed.2.finished/ht.run.2016-04-27_21.28.21' ['(Tag) original structure: NiS6Ti3', '(Tag) original structure path: structures/all/matches/ternaries/148/4_148_1aA1bB1cA1fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.062449 0.000000 4.987390 4 | -1.531224 2.652158 4.987390 5 | -1.531224 -2.652158 4.987390 6 | Hf Zn N 7 | 1 3 6 8 | Direct 9 | 0.50000000 0.50000000 0.50000000 10 | 0.00000000 -0.00000000 -0.00000000 11 | 0.30592519 0.30592519 0.30592519 12 | 0.69407481 0.69407481 0.69407481 13 | 0.60495542 0.89911462 0.27504484 14 | 0.10088538 0.72495516 0.39504458 15 | 0.72495516 0.39504458 0.10088538 16 | 0.27504484 0.60495542 0.89911462 17 | 0.89911462 0.27504484 0.60495542 18 | 0.39504458 0.10088538 0.72495516 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c5e41fdc8fba8e90b48cbe4752dc2c4a2805f21.poscar: -------------------------------------------------------------------------------- 1 | -.61534447E+02 '010/ht.task.triolith--default.0c5e41fdc8fba8e90b48cbe4752dc2c4a2805f21.cleanup.0.unclaimed.2.finished/ht.run.2015-11-23_17.48.43' ['(Tag) original structure: CuF3K', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/140/4_140_1bA1cB1dC1hA.cif'] 2 | 1 3 | 6.038698 0.000000 -0.000000 4 | 0.000000 6.038698 -0.000000 5 | 3.019349 3.019349 4.167943 6 | Zr Zn N 7 | 2 2 6 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.50000000 0.50000000 0.00000000 11 | 0.00000000 0.50000000 0.00000000 12 | 0.50000000 0.00000000 0.00000000 13 | 0.23981570 0.73981570 1.00000000 14 | 0.76018430 0.26018430 0.00000000 15 | 0.73981570 0.76018430 0.00000000 16 | 0.26018430 0.23981570 0.00000000 17 | 0.75000000 0.25000000 0.50000000 18 | 0.25000000 0.75000000 0.50000000 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a74297fb11027ba2a958da5ee31980b8876e1d4.poscar: -------------------------------------------------------------------------------- 1 | -.70525527E+02 '012/ht.task.gamma--default.0a74297fb11027ba2a958da5ee31980b8876e1d4.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_06.45.53' ['(Tag) original structure: O4SiTi', '(Tag) original structure path: structures/all/matches/ternaries/63/4_63_1aA1cB1fC1gC.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.058315 -4.469731 0.000000 4 | 3.058315 4.469731 0.000000 5 | 0.000000 0.000000 6.091423 6 | Ti Zn N 7 | 2 2 8 8 | Direct 9 | 0.00000000 -0.00000000 0.00000000 10 | 0.00000000 -0.00000000 0.50000000 11 | 0.61502096 0.38497904 0.25000000 12 | 0.38497904 0.61502096 0.75000000 13 | 0.77707723 0.22292277 0.02604294 14 | 0.22292277 0.77707723 0.52604294 15 | 0.22292277 0.77707723 0.97395706 16 | 0.77707723 0.22292277 0.47395706 17 | 0.79819076 0.76342840 0.25000000 18 | 0.20180924 0.23657160 0.75000000 19 | 0.76342840 0.79819076 0.75000000 20 | 0.23657160 0.20180924 0.25000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ac7f92ee7225e47dbdc25f6d60a10b116b0813b.poscar: -------------------------------------------------------------------------------- 1 | -.63921795E+02 '012/ht.task.gamma--default.0ac7f92ee7225e47dbdc25f6d60a10b116b0813b.cleanup.0.unclaimed.3.finished/ht.run.2015-11-09_06.07.49' ['(Tag) original structure: Fe2S3Tl', '(Tag) original structure path: structures/all/matches/ternaries/63/4_63_1cA1cB1eC1gA.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.275648 -3.887910 0.000000 4 | 4.275648 3.887910 -0.000000 5 | 0.000000 0.000000 4.969047 6 | Ti Zn N 7 | 2 4 6 8 | Direct 9 | 0.40001196 0.59998804 0.25000000 10 | 0.59998804 0.40001196 0.75000000 11 | 0.83644128 0.83644128 -0.00000000 12 | 0.16355872 0.16355872 0.50000000 13 | 0.83644128 0.83644128 0.50000000 14 | 0.16355872 0.16355872 -0.00000000 15 | 0.89385404 0.10614596 0.25000000 16 | 0.10614596 0.89385404 0.75000000 17 | 0.71612659 0.59086281 0.25000000 18 | 0.28387341 0.40913719 0.75000000 19 | 0.59086281 0.71612659 0.75000000 20 | 0.40913719 0.28387341 0.25000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c13902f8200732555ad1603e954a726d4d24d85.poscar: -------------------------------------------------------------------------------- 1 | -.58351590E+02 '012/ht.task.triolith--default.0c13902f8200732555ad1603e954a726d4d24d85.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_21.10.12' ['(Tag) original structure: Ag3NaO2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/72/4_72_1bA1cB1eB1jC.cif'] 2 | 1 3 | 5.720953 0.000000 0.000000 4 | 0.000000 11.967072 0.000000 5 | 2.860476 5.983536 3.126332 6 | Zr Zn N 7 | 2 6 4 8 | Direct 9 | 0.25000000 0.75000000 0.50000000 10 | 0.75000000 0.25000000 0.50000000 11 | 0.00000000 0.00000000 0.00000000 12 | 0.50000000 0.50000000 0.00000000 13 | 0.00000000 0.00000000 0.50000000 14 | 0.50000000 0.00000000 0.50000000 15 | 0.00000000 0.50000000 0.50000000 16 | 0.50000000 0.50000000 0.50000000 17 | 0.29083893 0.06226772 0.00000000 18 | 0.20916107 0.56226772 0.00000000 19 | 0.79083893 0.43773228 -0.00000000 20 | 0.70916107 0.93773228 0.00000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c9c74b6f48facd507c4e41d9df797a2017634e7.poscar: -------------------------------------------------------------------------------- 1 | -.69099362E+02 '012/ht.task.triolith--default.0c9c74b6f48facd507c4e41d9df797a2017634e7.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_09.39.55' ['(Tag) original structure: Cl3Cs2Li', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/63/5_63_1cA1cB1fA2cC.cif'] 2 | 1 3 | 1.537037 -8.306579 0.000000 4 | 1.537037 8.306579 -0.000000 5 | 0.000000 0.000000 5.818854 6 | Zn Zr N 7 | 4 2 6 8 | Direct 9 | 0.01974571 0.98025429 0.25000000 10 | 0.98025429 0.01974571 0.75000000 11 | 0.16944846 0.83055154 0.25000000 12 | 0.83055154 0.16944846 0.75000000 13 | 0.33953783 0.66046217 0.25000000 14 | 0.66046217 0.33953783 0.75000000 15 | 0.41969297 0.58030703 0.97742156 16 | 0.58030703 0.41969297 0.47742156 17 | 0.58030703 0.41969297 0.02257844 18 | 0.41969297 0.58030703 0.52257844 19 | 0.75587878 0.24412122 0.25000000 20 | 0.24412122 0.75587878 0.75000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ad7d4c22ba4414fbe161f27f42c112c6cb83b29.poscar: -------------------------------------------------------------------------------- 1 | -.69790953E+02 '012/ht.task.gamma--default.0ad7d4c22ba4414fbe161f27f42c112c6cb83b29.cleanup.0.unclaimed.3.finished/ht.run.2015-11-07_17.50.22' ['(Tag) original structure: Ba3S7Zr2', '(Tag) original structure path: structures/all/matches/ternaries/139/6_139_1aA1bB1eA1eB1eC1gB.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.021411 0.000000 0.000000 4 | 0.000000 4.021411 0.000000 5 | 2.010706 2.010706 8.874517 6 | Zn Ti N 7 | 3 2 7 8 | Direct 9 | -0.00000000 -0.00000000 0.00000000 10 | 0.83602820 0.83602820 0.32794360 11 | 0.16397180 0.16397180 0.67205640 12 | 0.60922880 0.60922880 0.78154240 13 | 0.39077120 0.39077120 0.21845760 14 | 0.61249780 0.11249780 0.77500439 15 | 0.11249780 0.61249780 0.77500439 16 | 0.38750220 0.88750220 0.22499561 17 | 0.88750220 0.38750220 0.22499561 18 | 0.72695695 0.72695695 0.54608610 19 | 0.27304305 0.27304305 0.45391390 20 | 0.50000000 0.50000000 0.00000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c710ee221d7567cc0f1a06fba8231ae8261a181.poscar: -------------------------------------------------------------------------------- 1 | -.65122406E+02 '013/ht.task.triolith--default.0c710ee221d7567cc0f1a06fba8231ae8261a181.cleanup.0.unclaimed.2.finished/ht.run.2015-11-26_08.31.24' ['(Tag) original structure: AgMo6Te6', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/12/7_12_1aA3iB3iC.cif'] 2 | 1 3 | 7.932667 -1.614529 0.519404 4 | 7.932667 1.614529 0.519404 5 | -3.354285 0.000000 6.743572 6 | Zn N Zr 7 | 6 6 1 8 | Direct 9 | 0.72571673 0.72571673 0.20605258 10 | 0.27428327 0.27428327 0.79394742 11 | 0.88941755 0.88941755 0.62365230 12 | 0.11058245 0.11058245 0.37634770 13 | 0.59020144 0.59020144 0.70963199 14 | 0.40979856 0.40979856 0.29036801 15 | 0.85546464 0.85546464 0.34923770 16 | 0.14453536 0.14453536 0.65076230 17 | 0.51240013 0.51240013 0.21277716 18 | 0.48759987 0.48759987 0.78722284 19 | 0.84710770 0.84710770 0.84881876 20 | 0.15289230 0.15289230 0.15118124 21 | -0.00000000 -0.00000000 0.00000000 22 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ac3201405a380e382377a71216b33cde892b161.poscar: -------------------------------------------------------------------------------- 1 | -.80497110E+02 '014/ht.task.gamma--default.0ac3201405a380e382377a71216b33cde892b161.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_19.27.00' ['(Tag) original structure: MoO4Tl2', '(Tag) original structure path: structures/all/matches/ternaries/5/8_5_1aA1bA1cB1cA4cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.952401 -2.607102 -0.345365 4 | 4.952401 2.607102 -0.345365 5 | -0.793138 0.000000 6.257486 6 | Zn Ti N 7 | 4 2 8 8 | Direct 9 | 0.95747786 0.04252214 -0.00000000 10 | 0.96835037 0.03164963 0.50000000 11 | 0.36677608 0.33913125 0.20324262 12 | 0.66086875 0.63322392 0.79675738 13 | 0.57577517 0.75508330 0.28765311 14 | 0.24491670 0.42422483 0.71234689 15 | 0.85703681 0.61391281 0.53609174 16 | 0.38608719 0.14296319 0.46390826 17 | 0.20259831 0.78613842 0.24921180 18 | 0.21386158 0.79740169 0.75078820 19 | 0.64474828 0.89284159 0.05563091 20 | 0.10715841 0.35525172 0.94436909 21 | 0.75781176 0.31376727 0.27014504 22 | 0.68623273 0.24218824 0.72985496 23 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c27147781fc9c64cbcf45acc1575a0bf8070841.poscar: -------------------------------------------------------------------------------- 1 | -.95884543E+02 '014/ht.task.triolith--default.0c27147781fc9c64cbcf45acc1575a0bf8070841.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_18.58.40' ['(Tag) original structure: Na2O3Zn2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/14/4_14_1aA1eB1eA1eC.cif'] 2 | 1 3 | 6.561230 0.000000 0.098238 4 | 0.000000 6.060726 0.000000 5 | -2.983057 0.000000 5.510293 6 | Zn Zr N 7 | 4 4 6 8 | Direct 9 | 0.16989291 0.87830687 0.43646054 10 | 0.83010709 0.12169313 0.56353946 11 | 0.83010709 0.37830687 0.06353946 12 | 0.16989291 0.62169313 0.93646054 13 | 0.32449651 0.11021715 0.07651784 14 | 0.67550349 0.88978285 0.92348216 15 | 0.67550349 0.61021715 0.42348216 16 | 0.32449651 0.38978285 0.57651784 17 | 0.37508238 0.40328558 0.26426987 18 | 0.62491762 0.59671442 0.73573013 19 | 0.62491762 0.90328558 0.23573013 20 | 0.37508238 0.09671442 0.76426987 21 | 0.00000000 0.00000000 0.00000000 22 | 0.00000000 0.50000000 0.50000000 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Rhys Goodall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c5264c980964b88ae8a7c2c18d5a61daff481b1.poscar: -------------------------------------------------------------------------------- 1 | -.77616025E+02 '014/ht.task.triolith--default.0c5264c980964b88ae8a7c2c18d5a61daff481b1.cleanup.0.unclaimed.2.finished/ht.run.2015-11-22_15.02.35' ['(Tag) original structure: MnNa2O4', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/186/5_186_1aA1bB1bA1bC1cC.cif'] 2 | 1 3 | 4.890752 -2.823677 0.000000 4 | 0.000000 5.647354 0.000000 5 | 0.000000 0.000000 6.428382 6 | Zr Zn N 7 | 2 4 8 8 | Direct 9 | 0.33333333 0.66666667 0.62486757 10 | 0.66666667 0.33333333 0.12486757 11 | 0.00000000 0.00000000 0.04179757 12 | 0.00000000 0.00000000 0.54179757 13 | 0.33333333 0.66666667 0.02173231 14 | 0.66666667 0.33333333 0.52173231 15 | 0.33333333 0.66666667 0.30226121 16 | 0.66666667 0.33333333 0.80226121 17 | 0.16316165 0.32632331 0.82311378 18 | 0.83683835 0.67367669 0.32311378 19 | 0.32632331 0.16316165 0.32311378 20 | 0.83683835 0.16316165 0.32311378 21 | 0.16316165 0.83683835 0.82311378 22 | 0.67367669 0.83683835 0.82311378 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | `aviary` is an open-source machine learning library for materials discovery. It aims to be modular and easy to understand. 2 | 3 | Contributing Guidelines: 4 | 5 | 1. Code should be written in Python and adhere to PEP8 style guidelines. We provide `pre-commit` hooks and run pre-commit.ci to on PRs. 6 | 1. All code contributions should be accompanied by unit tests. 7 | 1. Use descriptive and meaningful variable and function names. 8 | 1. Include doc strings for all functions and classes. 9 | 1. Keep the codebase simple and easy to understand. 10 | 1. Use Git and GitHub for version control and pull requests. 11 | 1. All contributions must be released under the MIT license. 12 | 1. Before submitting a pull request, make sure all tests pass and that your code is well-documented. 13 | 1. Avoid using external libraries unless they are necessary for the functionality of the library. 14 | 1. Follow the issues and pull requests to keep track of the development of the library and contribute where you can. Don't be afraid to ask for help! 15 | 16 | Thank you for your interest in contributing to `aviary`! We look forward to your contributions. 17 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a58ef8a37185d8abcbd6e1bad7eaf7e3d3d6016.poscar: -------------------------------------------------------------------------------- 1 | -.13166155E+03 '015/ht.task.gamma--default.0a58ef8a37185d8abcbd6e1bad7eaf7e3d3d6016.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_20.24.54' ['(Tag) original structure: Mo6Se8Yb', '(Tag) original structure path: structures/all/matches/ternaries/148/4_148_1aA1cB1fC1fB.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.112473 -0.000003 2.815429 4 | -2.556234 4.427533 2.815429 5 | -2.556238 -4.427530 2.815429 6 | Ti Zn N 7 | 6 1 8 8 | Direct 9 | 0.55169467 0.35282435 0.16914522 10 | 0.64717565 0.83085478 0.44830533 11 | 0.83085478 0.44830533 0.64717565 12 | 0.16914522 0.55169467 0.35282435 13 | 0.35282435 0.16914522 0.55169467 14 | 0.44830533 0.64717565 0.83085478 15 | 0.00000000 0.00000000 0.00000000 16 | 0.22107771 0.22107771 0.22107771 17 | 0.77892229 0.77892229 0.77892229 18 | 0.15262017 0.41165775 0.66079431 19 | 0.58834225 0.33920569 0.84737983 20 | 0.33920569 0.84737983 0.58834225 21 | 0.66079431 0.15262017 0.41165775 22 | 0.41165775 0.66079431 0.15262017 23 | 0.84737983 0.58834225 0.33920569 24 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c9cbed1e5d755a9666db36f742004c14bf780bd.poscar: -------------------------------------------------------------------------------- 1 | -.12835112E+03 '016/ht.task.triolith--default.0c9cbed1e5d755a9666db36f742004c14bf780bd.cleanup.0.unclaimed.2.finished/ht.run.2015-11-25_19.44.22' ['(Tag) original structure: Bi2GeO5', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/9/8_9_1aA2aB5aC.cif'] 2 | 1 3 | 6.131598 -3.234940 0.226503 4 | 6.131598 3.234940 0.226503 5 | 0.246773 0.000000 5.643709 6 | Zr Zn N 7 | 4 2 10 8 | Direct 9 | 0.00267917 0.27212232 0.25091105 10 | 0.27212232 0.00267917 0.75091105 11 | 0.53639722 0.15487041 0.20074018 12 | 0.15487041 0.53639722 0.70074018 13 | 0.41162369 0.66084098 0.29113314 14 | 0.66084098 0.41162369 0.79113314 15 | 0.30194670 0.26792173 0.94447680 16 | 0.26792173 0.30194670 0.44447680 17 | 0.84479035 0.67122354 0.09235619 18 | 0.67122354 0.84479035 0.59235619 19 | 0.89077500 0.81130119 0.05351553 20 | 0.81130119 0.89077500 0.55351553 21 | 0.22233666 0.95259651 0.11785524 22 | 0.95259651 0.22233666 0.61785524 23 | 0.36334386 0.66043067 0.64751187 24 | 0.66043067 0.36334386 0.14751187 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/0079b9e3ab0e210ea4afe037a9faee8cd6577922.poscar: -------------------------------------------------------------------------------- 1 | -.12677341E+03 '016/ht.task.triolith--default.0079b9e3ab0e210ea4afe037a9faee8cd6577922.cleanup.0.unclaimed.2.finished/ht.run.2015-12-13_17.30.39' ['(Tag) original structure: LiO5V2', '(Tag) original structure path: structures/all/matches/ternaries/31/8_31_1aA2aB5aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 12.156346 0.000000 0.000000 4 | 0.000000 4.319613 0.000000 5 | 0.000000 0.000000 4.222262 6 | Hf N Zn 7 | 4 10 2 8 | Direct 9 | 0.58531349 0.50000000 0.16750956 10 | 0.08531349 0.00000000 0.83249044 11 | 0.40614012 0.00000000 0.04123820 12 | 0.90614012 0.50000000 0.95876180 13 | 0.64872327 0.50000000 0.63718450 14 | 0.14872327 0.00000000 0.36281550 15 | 0.39480052 0.00000000 0.54485042 16 | 0.89480052 0.50000000 0.45514958 17 | 0.42143717 0.50000000 0.07644136 18 | 0.92143717 0.00000000 0.92355864 19 | 0.58993952 0.00000000 0.09266172 20 | 0.08993952 0.50000000 0.90733828 21 | 0.23293583 0.00000000 0.16108981 22 | 0.73293583 0.50000000 0.83891019 23 | 0.19201007 0.50000000 0.28261666 24 | 0.69201007 0.00000000 0.71738334 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/00bbba1d86168727008d4c1542b39d45d93d693c.poscar: -------------------------------------------------------------------------------- 1 | -.91305185E+02 '016/ht.task.triolith--default.00bbba1d86168727008d4c1542b39d45d93d693c.cleanup.0.unclaimed.2.finished/ht.run.2015-12-13_02.31.04' ['(Tag) original structure: O4Pd3Sr', '(Tag) original structure path: structures/all/matches/ternaries/223/3_223_1aA1cB1eC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.629202 0.000000 0.000000 4 | 0.000000 5.629202 0.000000 5 | 0.000000 0.000000 5.629202 6 | Hf Zn N 7 | 2 6 8 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.50000000 0.50000000 0.50000000 11 | 0.25000000 0.00000000 0.50000000 12 | 0.50000000 0.75000000 0.00000000 13 | 0.00000000 0.50000000 0.25000000 14 | 0.00000000 0.50000000 0.75000000 15 | 0.50000000 0.25000000 0.00000000 16 | 0.75000000 0.00000000 0.50000000 17 | 0.25000000 0.25000000 0.25000000 18 | 0.25000000 0.75000000 0.75000000 19 | 0.75000000 0.75000000 0.75000000 20 | 0.25000000 0.75000000 0.25000000 21 | 0.25000000 0.25000000 0.75000000 22 | 0.75000000 0.25000000 0.75000000 23 | 0.75000000 0.75000000 0.25000000 24 | 0.75000000 0.25000000 0.25000000 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a3eb2074380f4abdec94d4a6171a099c9ecb30a.poscar: -------------------------------------------------------------------------------- 1 | -.10306186E+03 '016/ht.task.gamma--default.0a3eb2074380f4abdec94d4a6171a099c9ecb30a.cleanup.0.unclaimed.3.finished/ht.run.2015-11-09_09.55.13' ['(Tag) original structure: Fe2O5Ti', '(Tag) original structure path: structures/all/matches/ternaries/63/5_63_1cA1cB1fC2fA.cif', '(Tag) type: ternary'] 2 | 1 3 | 1.887732 -8.293794 -0.000000 4 | 1.887732 8.293794 0.000000 5 | 0.000000 -0.000000 10.903555 6 | Ti Zn N 7 | 2 4 10 8 | Direct 9 | 0.77771064 0.22228936 0.25000000 10 | 0.22228936 0.77771064 0.75000000 11 | 0.76640818 0.23359182 0.58357468 12 | 0.23359182 0.76640818 0.08357468 13 | 0.23359182 0.76640818 0.41642532 14 | 0.76640818 0.23359182 0.91642532 15 | 0.31288974 0.68711026 0.25000000 16 | 0.68711026 0.31288974 0.75000000 17 | 0.00945486 0.99054514 0.19890372 18 | 0.99054514 0.00945486 0.69890372 19 | 0.99054514 0.00945486 0.80109628 20 | 0.00945486 0.99054514 0.30109628 21 | 0.71261595 0.28738405 0.10220736 22 | 0.28738405 0.71261595 0.60220736 23 | 0.28738405 0.71261595 0.89779264 24 | 0.71261595 0.28738405 0.39779264 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a646e174268b493bd67f58fdf52fead02f74a96.poscar: -------------------------------------------------------------------------------- 1 | -.11595836E+03 '016/ht.task.gamma--default.0a646e174268b493bd67f58fdf52fead02f74a96.cleanup.0.unclaimed.3.finished/ht.run.2016-06-23_09.53.03' ['(Tag) original structure: Ba3O9Yb4', '(Tag) original structure path: structures/all/matches/ternaries/146/10_146_3aA3bB4aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.166867 -0.000000 6.757404 4 | -1.583433 2.742587 6.757404 5 | -1.583433 -2.742587 6.757404 6 | Zn Ti N 7 | 3 4 9 8 | Direct 9 | 0.01183712 0.01183712 0.01183712 10 | 0.17509278 0.17509278 0.17509278 11 | 0.58822661 0.58822661 0.58822661 12 | 0.44824029 0.44824029 0.44824029 13 | 0.88012936 0.88012936 0.88012936 14 | 0.73794629 0.73794629 0.73794629 15 | 0.30888986 0.30888986 0.30888986 16 | 0.22038392 0.96628227 0.59433793 17 | 0.96628227 0.59433793 0.22038392 18 | 0.59433793 0.22038392 0.96628227 19 | 0.10520064 0.47388568 0.80909796 20 | 0.47388568 0.80909796 0.10520064 21 | 0.80909796 0.10520064 0.47388568 22 | 0.07908029 0.37863310 0.71223592 23 | 0.37863310 0.71223592 0.07908029 24 | 0.71223592 0.07908029 0.37863310 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/00ceac1a8da8e554218a84def07e5a6da9f106f3.poscar: -------------------------------------------------------------------------------- 1 | -.14491348E+03 '016/ht.task.triolith--default.00ceac1a8da8e554218a84def07e5a6da9f106f3.cleanup.0.unclaimed.2.finished/ht.run.2015-12-13_14.53.57' ['(Tag) original structure: CsI4Li3', '(Tag) original structure path: structures/all/matches/ternaries/6/16_6_2bA3aB3bC3bB5aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 8.083094 0.000000 -0.208196 4 | 0.000000 3.746822 0.000000 5 | 0.280916 0.000000 8.346883 6 | Zn Hf N 7 | 2 6 8 8 | Direct 9 | 0.11519925 0.50000000 0.14812181 10 | 0.52623739 0.50000000 0.32004718 11 | 0.07390019 0.50000000 0.78474769 12 | 0.44698331 0.50000000 0.97366906 13 | 0.37200353 0.00000000 0.59577805 14 | 0.97106300 0.00000000 0.44002485 15 | 0.70070504 0.50000000 0.65161885 16 | 0.77589491 0.00000000 0.00494085 17 | 0.50521321 0.00000000 0.05573442 18 | 0.33683218 0.50000000 0.73258869 19 | 0.93708398 0.50000000 0.56442548 20 | 0.70583489 0.50000000 0.89905921 21 | 0.99245769 0.00000000 0.84790603 22 | 0.21380499 0.00000000 0.42136796 23 | 0.89843228 0.00000000 0.21193260 24 | 0.62841745 0.00000000 0.58627726 25 | -------------------------------------------------------------------------------- /aviary/embeddings/element/README.md: -------------------------------------------------------------------------------- 1 | # Element - Embeddings 2 | 3 | ## CGCNN92 4 | 5 | The following paper describes the details of the CGCNN framework, information about how the one-hot atom embedding was constructed is available in the supplementary materials: 6 | 7 | > [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://link.aps.org/doi/10.1103/PhysRevLett.120.145301) 8 | 9 | ## MEGNet16 10 | 11 | The following paper describes the details of the MEGnet framework, the embedding is generated from the atomic weights through one MEGnet layer on a training task to predict the computed formation energies of ∼69,000 materials from the Materials Project database: 12 | 13 | > [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/abs/1812.05055) 14 | 15 | ## MatScholar200 16 | 17 | This is an experimental NLP embedding based on the data mining of definite compositions and structure prototypes. 18 | 19 | > [Unsupervised word embeddings capture latent knowledge from materials science literature](https://www.nature.com/articles/s41586-019-1335-8) 20 | 21 | ## Onehot112 22 | 23 | This is a simple one-hot encoding for the elements 24 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a5af69c21e6c2d9b004c2f51939790b10b19814.poscar: -------------------------------------------------------------------------------- 1 | -.12124245E+03 '017/ht.task.gamma--default.0a5af69c21e6c2d9b004c2f51939790b10b19814.cleanup.0.unclaimed.3.finished/ht.run.2016-06-11_04.24.19' ['(Tag) original structure: As5Cs3O9', '(Tag) original structure path: structures/all/matches/ternaries/157/5_157_1bA1cA1cB1cC1dC.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.776201 -3.912241 0.000000 4 | 0.000000 7.824483 0.000000 5 | 0.000000 0.000000 4.868194 6 | Zn Ti N 7 | 3 5 9 8 | Direct 9 | 0.43304785 1.00000000 0.90142116 10 | 0.56695215 0.56695215 0.90142116 11 | 0.00000000 0.43304785 0.90142116 12 | 0.73676343 1.00000000 0.46239314 13 | 0.26323657 0.26323657 0.46239314 14 | 0.00000000 0.73676343 0.46239314 15 | 0.33333333 0.66666667 0.48072680 16 | 0.66666667 0.33333333 0.48072680 17 | 0.54692646 0.00000000 0.24748564 18 | 0.45307354 0.45307354 0.24748564 19 | 0.00000000 0.54692646 0.24748564 20 | 0.20794832 0.42235553 0.68279110 21 | 0.21440721 0.79205168 0.68279110 22 | 0.57764447 0.78559279 0.68279110 23 | 0.79205168 0.21440721 0.68279110 24 | 0.42235553 0.20794832 0.68279110 25 | 0.78559279 0.57764447 0.68279110 26 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a0ae9c18de23e72bb764e10ef542de693321542.poscar: -------------------------------------------------------------------------------- 1 | -.12758363E+03 '018/ht.task.gamma--default.0a0ae9c18de23e72bb764e10ef542de693321542.cleanup.1.unclaimed.3.finished/ht.run.2016-05-19_22.39.46' ['(Tag) original structure: In2O5Sr2', '(Tag) original structure path: structures/all/matches/ternaries/46/6_46_1aA1bA1bB1cC2cB.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.380653 0.000000 0.000000 4 | -0.000000 14.311422 0.000000 5 | 2.690327 7.155711 2.712897 6 | Zn Ti N 7 | 4 4 10 8 | Direct 9 | 0.59605935 0.66971651 0.94546880 10 | 0.45847185 0.16971651 0.94546880 11 | 0.45847185 0.38481469 0.94546880 12 | 0.59605935 0.88481469 0.94546880 13 | 0.99066415 0.99066415 0.01867171 14 | 0.99066415 0.49066415 0.01867171 15 | 0.93415690 0.25680447 0.98639105 16 | 0.07945204 0.75680447 0.98639105 17 | 0.23781504 0.81687590 0.31447074 18 | 0.44771422 0.31687590 0.31447074 19 | 0.44771422 0.86865337 0.31447074 20 | 0.23781504 0.36865337 0.31447074 21 | 0.97010020 0.02841900 0.20962174 22 | 0.82027806 0.52841900 0.20962174 23 | 0.82027806 0.76195926 0.20962174 24 | 0.97010020 0.26195926 0.20962174 25 | 0.26244227 0.61609265 0.26781470 26 | 0.46974303 0.11609265 0.26781470 27 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c450e6f9c256e33ef452c98dd20917fb6484317.poscar: -------------------------------------------------------------------------------- 1 | -.11423502E+03 '018/ht.task.triolith--default.0c450e6f9c256e33ef452c98dd20917fb6484317.cleanup.0.unclaimed.2.finished/ht.run.2015-11-23_07.40.55' ['(Tag) original structure: Ca3Ga2N4', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/15/5_15_1eA1fA1fB2fC.cif'] 2 | 1 3 | 4.850647 -4.711321 -0.214326 4 | 4.850647 4.711321 -0.214326 5 | -0.377966 0.000000 5.614459 6 | Zn Zr N 7 | 6 4 8 8 | Direct 9 | 0.93947997 0.24500985 0.59229075 10 | 0.75499015 0.06052003 0.90770925 11 | 0.06052003 0.75499015 0.40770925 12 | 0.24500985 0.93947997 0.09229075 13 | 0.65515963 0.34484037 0.25000000 14 | 0.34484037 0.65515963 0.75000000 15 | 0.18630163 0.38379918 0.12077952 16 | 0.61620082 0.81369837 0.37922048 17 | 0.81369837 0.61620082 0.87922048 18 | 0.38379918 0.18630163 0.62077952 19 | 0.88219627 0.32914320 0.00176296 20 | 0.67085680 0.11780373 0.49823704 21 | 0.11780373 0.67085680 0.99823704 22 | 0.32914320 0.88219627 0.50176296 23 | 0.24930985 0.44271839 0.48301702 24 | 0.55728161 0.75069015 0.01698298 25 | 0.75069015 0.55728161 0.51698298 26 | 0.44271839 0.24930985 0.98301702 27 | -------------------------------------------------------------------------------- /examples/inputs/raw/0059fd6dd1d17138c42bb5430fa0bc074053f0de.poscar: -------------------------------------------------------------------------------- 1 | -.12205554E+03 '018/ht.task.triolith--default.0059fd6dd1d17138c42bb5430fa0bc074053f0de.cleanup.0.unclaimed.2.finished/ht.run.2016-04-29_04.58.57' ['(Tag) original structure: Cl6K2Pu', '(Tag) original structure path: structures/all/matches/ternaries/10/8_10_1aA1hA1mB1mC1nB1nC2oB.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.517271 0.000000 -1.602318 4 | 0.000000 5.261389 0.000000 5 | -2.700464 0.000000 6.804575 6 | Zn Hf N 7 | 4 2 12 8 | Direct 9 | 0.53948074 -0.00000000 0.19307466 10 | 0.46051926 -0.00000000 0.80692534 11 | 0.03938539 0.50000000 0.69311581 12 | 0.96061461 0.50000000 0.30688419 13 | 0.00000000 -0.00000000 0.00000000 14 | 0.50000000 0.50000000 0.50000000 15 | 0.19037601 -0.00000000 0.37098077 16 | 0.80962399 -0.00000000 0.62901923 17 | 0.69031429 0.50000000 0.87094031 18 | 0.30968571 0.50000000 0.12905969 19 | 0.24131655 0.28439444 0.00896358 20 | 0.75868345 0.28439444 0.99103642 21 | 0.75868345 0.71560556 0.99103642 22 | 0.24131655 0.71560556 0.00896358 23 | 0.74128896 0.78441677 0.50890356 24 | 0.25871104 0.78441677 0.49109644 25 | 0.25871104 0.21558323 0.49109644 26 | 0.74128896 0.21558323 0.50890356 27 | -------------------------------------------------------------------------------- /examples/inputs/raw/00fa7d75b7b2e1442df0a68071210c8d64b93ebe.poscar: -------------------------------------------------------------------------------- 1 | -.15718967E+03 '018/ht.task.triolith--default.00fa7d75b7b2e1442df0a68071210c8d64b93ebe.cleanup.0.unclaimed.2.finished/ht.run.2016-04-29_06.09.22' ['(Tag) original structure: Ag5RbSe3', '(Tag) original structure path: structures/all/matches/ternaries/125/5_125_1aA1cB1dC1gB1mA.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.828845 0.000000 0.000000 4 | -0.000000 5.828845 0.000000 5 | -0.000000 0.000000 9.239161 6 | Zn N Hf 7 | 2 6 10 8 | Direct 9 | 0.75000000 0.25000000 0.50000000 10 | 0.25000000 0.75000000 0.50000000 11 | 0.75000000 0.25000000 0.00000000 12 | 0.25000000 0.75000000 0.00000000 13 | 0.25000000 0.25000000 0.23728861 14 | 0.25000000 0.25000000 0.76271139 15 | 0.75000000 0.75000000 0.23728861 16 | 0.75000000 0.75000000 0.76271139 17 | 0.25000000 0.25000000 0.00000000 18 | 0.75000000 0.75000000 -0.00000000 19 | 0.57763663 0.07763663 0.25057902 20 | 0.92236337 0.42236337 0.25057902 21 | 0.57763663 0.42236337 0.74942098 22 | 0.07763663 0.92236337 0.25057902 23 | 0.92236337 0.07763663 0.74942098 24 | 0.42236337 0.92236337 0.74942098 25 | 0.07763663 0.57763663 0.74942098 26 | 0.42236337 0.57763663 0.25057902 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install these hooks with `pre-commit install`. 2 | 3 | ci: 4 | autoupdate_schedule: quarterly 5 | 6 | repos: 7 | - repo: https://github.com/astral-sh/ruff-pre-commit 8 | rev: v0.13.3 9 | hooks: 10 | - id: ruff 11 | args: [--fix] 12 | - id: ruff-format 13 | 14 | - repo: https://github.com/pre-commit/pre-commit-hooks 15 | rev: v6.0.0 16 | hooks: 17 | - id: check-case-conflict 18 | - id: check-symlinks 19 | - id: check-yaml 20 | - id: destroyed-symlinks 21 | - id: end-of-file-fixer 22 | - id: mixed-line-ending 23 | - id: trailing-whitespace 24 | 25 | - repo: https://github.com/codespell-project/codespell 26 | rev: v2.4.1 27 | hooks: 28 | - id: codespell 29 | exclude_types: [json] 30 | args: [--check-filenames] 31 | 32 | - repo: https://github.com/pre-commit/mirrors-mypy 33 | rev: v1.18.2 34 | hooks: 35 | - id: mypy 36 | exclude: (tests|examples)/ 37 | additional_dependencies: [wandb] 38 | 39 | - repo: https://github.com/janosh/format-ipy-cells 40 | rev: v0.1.11 41 | hooks: 42 | - id: format-ipy-cells 43 | 44 | - repo: https://github.com/kynan/nbstripout 45 | rev: 0.8.1 46 | hooks: 47 | - id: nbstripout 48 | args: [--drop-empty-cells] 49 | -------------------------------------------------------------------------------- /examples/inputs/raw/0cd1f1670176b03cc3d61935d9a6882c9508e33a.poscar: -------------------------------------------------------------------------------- 1 | -.16820589E+03 '020/ht.task.triolith--default.0cd1f1670176b03cc3d61935d9a6882c9508e33a.cleanup.0.unclaimed.4.finished/ht.run.2015-11-27_20.14.49' ['(Tag) original structure: Er3Se6Sm', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/11/10_11_1eA3eB6eC.cif'] 2 | 1 3 | 8.161179 0.000000 0.449837 4 | 0.000000 3.324704 0.000000 5 | -2.371431 0.000000 9.270920 6 | Zr Zn N 7 | 6 2 12 8 | Direct 9 | 0.94849747 0.25000000 0.32956364 10 | 0.05150253 0.75000000 0.67043636 11 | 0.18106850 0.25000000 0.03690071 12 | 0.81893150 0.75000000 0.96309929 13 | 0.67614681 0.25000000 0.61333077 14 | 0.32385319 0.75000000 0.38666923 15 | 0.53082596 0.25000000 0.12658890 16 | 0.46917404 0.75000000 0.87341110 17 | 0.92053932 0.25000000 0.58165729 18 | 0.07946068 0.75000000 0.41834271 19 | 0.30285053 0.25000000 0.68482012 20 | 0.69714947 0.75000000 0.31517988 21 | 0.94853790 0.25000000 0.11106725 22 | 0.05146210 0.75000000 0.88893275 23 | 0.63551390 0.25000000 0.95343270 24 | 0.36448610 0.75000000 0.04656730 25 | 0.37209085 0.25000000 0.57851740 26 | 0.62790915 0.75000000 0.42148260 27 | 0.29515767 0.25000000 0.26471150 28 | 0.70484233 0.75000000 0.73528850 29 | -------------------------------------------------------------------------------- /examples/inputs/raw/00237cf70206a5de77e5154ea72633514cdd4445.poscar: -------------------------------------------------------------------------------- 1 | -.13765617E+03 '020/ht.task.triolith--default.00237cf70206a5de77e5154ea72633514cdd4445.cleanup.0.unclaimed.2.finished/ht.run.2015-12-27_15.11.38' ['(Tag) original structure: Cl7CsTi2', '(Tag) original structure path: structures/all/matches/ternaries/11/7_11_1eA1fB2fC3eC.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.086823 0.000000 0.232607 4 | 0.000000 9.465448 0.000000 5 | -0.059997 0.000000 6.856863 6 | Hf N Zn 7 | 2 14 4 8 | Direct 9 | 0.18968935 0.25000000 0.17660677 10 | 0.81031065 0.75000000 0.82339323 11 | 0.66980532 0.25000000 0.30672720 12 | 0.33019468 0.75000000 0.69327280 13 | 0.68756283 0.25000000 0.47916256 14 | 0.31243717 0.75000000 0.52083744 15 | 0.27122645 0.25000000 0.89843441 16 | 0.72877355 0.75000000 0.10156559 17 | 0.21180647 0.03552534 0.27936006 18 | 0.78819353 0.96447466 0.72063994 19 | 0.78819353 0.53552534 0.72063994 20 | 0.21180647 0.46447466 0.27936006 21 | 0.14015442 0.02839410 0.47485170 22 | 0.85984558 0.97160590 0.52514830 23 | 0.85984558 0.52839410 0.52514830 24 | 0.14015442 0.47160590 0.47485170 25 | 0.36190174 0.08645613 0.74088626 26 | 0.63809826 0.91354387 0.25911374 27 | 0.63809826 0.58645613 0.25911374 28 | 0.36190174 0.41354387 0.74088626 29 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c882a0621189d18ff1350e3bed4114868663acf.poscar: -------------------------------------------------------------------------------- 1 | -.17215809E+03 '020/ht.task.triolith--default.0c882a0621189d18ff1350e3bed4114868663acf.cleanup.0.unclaimed.4.finished/ht.run.2016-06-04_05.34.00' ['(Tag) original structure: I5NiPr4', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/59/7_59_1aA1aB1eC2bC2eA.cif'] 2 | 1 3 | 3.240560 0.000000 0.000000 4 | 0.000000 12.886769 0.000000 5 | 0.000000 0.000000 7.790818 6 | Zr N Zn 7 | 8 10 2 8 | Direct 9 | 0.25000000 0.75000000 0.55699214 10 | 0.75000000 0.25000000 0.44300786 11 | 0.25000000 0.75000000 0.16926273 12 | 0.75000000 0.25000000 0.83073727 13 | 0.25000000 0.44772047 0.15767431 14 | 0.75000000 0.94772047 0.84232569 15 | 0.75000000 0.55227953 0.84232569 16 | 0.25000000 0.05227953 0.15767431 17 | 0.25000000 0.60748537 0.68873064 18 | 0.75000000 0.10748537 0.31126936 19 | 0.75000000 0.39251463 0.31126936 20 | 0.25000000 0.89251463 0.68873064 21 | 0.25000000 0.59581971 0.05569449 22 | 0.75000000 0.09581971 0.94430551 23 | 0.75000000 0.40418029 0.94430551 24 | 0.25000000 0.90418029 0.05569449 25 | 0.25000000 0.25000000 0.64078793 26 | 0.75000000 0.75000000 0.35921207 27 | 0.25000000 0.25000000 0.13097308 28 | 0.75000000 0.75000000 0.86902692 29 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ad8980cb8eb0693647f411732b2165b6a0fce0e.poscar: -------------------------------------------------------------------------------- 1 | -.16275965E+03 '020/ht.task.gamma--default.0ad8980cb8eb0693647f411732b2165b6a0fce0e.cleanup.0.unclaimed.3.finished/ht.run.2016-04-12_08.07.42' ['(Tag) original structure: Mn7O12Pr', '(Tag) original structure path: structures/all/matches/ternaries/12/10_12_1aA1bB1cB1dB1eB1fB2iC2jC.cif', '(Tag) type: ternary'] 2 | 1 3 | 7.625926 -0.000000 0.076043 4 | -0.000000 7.580728 -0.000000 5 | 3.767819 3.790364 3.834359 6 | Zn Ti N 7 | 1 7 12 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.50000000 0.50000000 0.00000000 11 | 0.50000000 0.00000000 0.00000000 12 | 0.00000000 0.50000000 0.00000000 13 | 0.50000000 0.50000000 0.50000000 14 | 0.50000000 0.00000000 0.50000000 15 | 0.00000000 0.00000000 0.50000000 16 | 0.00000000 0.50000000 0.50000000 17 | 0.17462695 0.86728018 0.65081407 18 | 0.82537305 0.51809424 0.34918593 19 | 0.82537305 0.13271982 0.34918593 20 | 0.17462695 0.48190576 0.65081407 21 | 0.50337766 0.68510006 0.62979988 22 | 0.49662234 0.31489994 0.37020012 23 | 0.86607696 0.69069681 0.61860638 24 | 0.13392304 0.30930319 0.38139362 25 | 0.68829538 0.82298879 0.99962863 26 | 0.31170462 0.82261742 0.00037137 27 | 0.31170462 0.17701121 0.00037137 28 | 0.68829538 0.17738258 0.99962863 29 | -------------------------------------------------------------------------------- /examples/inputs/raw/009ee0451531276f598976f46848fe7d590262ca.poscar: -------------------------------------------------------------------------------- 1 | -.17482530E+03 '021/ht.task.triolith--default.009ee0451531276f598976f46848fe7d590262ca.cleanup.0.unclaimed.2.finished/ht.run.2016-01-05_18.33.01' ['(Tag) original structure: Li2O13V6', '(Tag) original structure path: structures/all/matches/ternaries/12/11_12_1bA1iB3iC6iA.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.835117 -2.032658 -0.515824 4 | 6.835117 2.032658 -0.515824 5 | -3.147213 0.000000 11.027458 6 | Hf N Zn 7 | 6 13 2 8 | Direct 9 | 0.34297942 0.34297942 0.99965321 10 | 0.65702058 0.65702058 0.00034679 11 | 0.39485957 0.39485957 0.38058576 12 | 0.60514043 0.60514043 0.61941424 13 | 0.69607843 0.69607843 0.38139415 14 | 0.30392157 0.30392157 0.61860585 15 | 0.19897726 0.19897726 0.04469716 16 | 0.80102274 0.80102274 0.95530284 17 | 0.84879528 0.84879528 0.34182287 18 | 0.15120472 0.15120472 0.65817713 19 | 0.23924638 0.23924638 0.40916785 20 | 0.76075362 0.76075362 0.59083215 21 | 0.50000000 0.50000000 0.00000000 22 | 0.39982928 0.39982928 0.19561494 23 | 0.60017072 0.60017072 0.80438506 24 | 0.62078394 0.62078394 0.18012629 25 | 0.37921606 0.37921606 0.81987371 26 | 0.55627083 0.55627083 0.41837015 27 | 0.44372917 0.44372917 0.58162985 28 | 0.11671079 0.11671079 0.16606411 29 | 0.88328921 0.88328921 0.83393589 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | paths: ["**/*.py", ".github/workflows/test.yml"] 6 | branches: [main] 7 | pull_request: 8 | paths: ["**/*.py", ".github/workflows/test.yml"] 9 | branches: [main] 10 | release: 11 | types: [published] 12 | workflow_dispatch: 13 | workflow_call: 14 | 15 | concurrency: 16 | # Cancel only on same PR number 17 | group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | tests: 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | os: [ubuntu-latest, macos-latest, windows-latest] 26 | version: 27 | - { python: "3.10", resolution: highest } 28 | - { python: "3.12", resolution: lowest-direct } 29 | runs-on: ${{ matrix.os }} 30 | 31 | steps: 32 | - name: Check out repo 33 | uses: actions/checkout@v4 34 | 35 | - name: Set up Python 36 | uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ matrix.version.python }} 39 | 40 | - name: Set up uv 41 | uses: astral-sh/setup-uv@v2 42 | 43 | - name: Install dependencies 44 | run: | 45 | uv pip install torch --index-url https://download.pytorch.org/whl/cpu --system 46 | uv pip install .[test] --system 47 | 48 | - name: Run Tests 49 | run: pytest --capture=no --cov . 50 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c5ffa8e2b3126eba02901c37b31e8faf3b5377c.poscar: -------------------------------------------------------------------------------- 1 | -.16882837E+03 '023/ht.task.triolith--default.0c5ffa8e2b3126eba02901c37b31e8faf3b5377c.cleanup.0.unclaimed.4.finished/ht.run.2016-06-02_02.38.13' ['(Tag) original structure: B6O13Zn4', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/217/4_217_1aA1cB1dC1gA.cif'] 2 | 1 3 | -4.856584 4.856584 4.856584 4 | 4.856584 -4.856584 4.856584 5 | 4.856584 4.856584 -4.856584 6 | Zn Zr N 7 | 4 6 13 8 | Direct 9 | 0.42585440 0.42585440 0.42585440 10 | -0.00000000 -0.00000000 0.57414560 11 | -0.00000000 0.57414560 -0.00000000 12 | 0.57414560 -0.00000000 -0.00000000 13 | 0.50000000 0.25000000 0.75000000 14 | 0.50000000 0.75000000 0.25000000 15 | 0.75000000 0.50000000 0.25000000 16 | 0.75000000 0.25000000 0.50000000 17 | 0.25000000 0.75000000 0.50000000 18 | 0.25000000 0.50000000 0.75000000 19 | -0.00000000 -0.00000000 -0.00000000 20 | 0.51838322 0.51838322 0.27053201 21 | 0.24785121 0.24785121 0.72946799 22 | 0.24785121 0.72946799 0.24785121 23 | 0.00000000 0.48161678 0.75214879 24 | 0.27053201 0.51838322 0.51838322 25 | 0.51838322 0.27053201 0.51838322 26 | -0.00000000 0.75214879 0.48161678 27 | 0.75214879 0.48161678 -0.00000000 28 | 0.48161678 0.75214879 0.00000000 29 | 0.48161678 -0.00000000 0.75214879 30 | 0.75214879 0.00000000 0.48161678 31 | 0.72946799 0.24785121 0.24785121 32 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a458b5f2eeebeb0e4e55e1d1a17988ec87b7123.poscar: -------------------------------------------------------------------------------- 1 | -.14108066E+03 '024/ht.task.gamma--default.0a458b5f2eeebeb0e4e55e1d1a17988ec87b7123.cleanup.0.unclaimed.3.finished/ht.run.2016-05-01_14.31.14' ['(Tag) original structure: O3Rb2Ti', '(Tag) original structure path: structures/all/matches/ternaries/64/6_64_1eA1fB2fA2fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 2.889034 -4.949055 -0.000000 4 | 2.889034 4.949055 0.000000 5 | 0.000000 0.000000 10.531082 6 | Zn Ti N 7 | 8 4 12 8 | Direct 9 | 0.94903530 0.05096470 0.16283211 10 | 0.55096470 0.44903530 0.66283211 11 | 0.44903530 0.55096470 0.33716789 12 | 0.05096470 0.94903530 0.83716789 13 | 0.63746097 0.36253903 0.10209048 14 | 0.86253903 0.13746097 0.60209048 15 | 0.13746097 0.86253903 0.39790952 16 | 0.36253903 0.63746097 0.89790952 17 | 0.75230031 0.24769969 0.35600488 18 | 0.74769969 0.25230031 0.85600488 19 | 0.25230031 0.74769969 0.14399512 20 | 0.24769969 0.75230031 0.64399512 21 | 0.67082105 0.32917895 0.51925674 22 | 0.82917895 0.17082105 0.01925674 23 | 0.17082105 0.82917895 0.98074326 24 | 0.32917895 0.67082105 0.48074326 25 | 0.94589238 0.05410762 0.35185694 26 | 0.55410762 0.44589238 0.85185694 27 | 0.44589238 0.55410762 0.14814306 28 | 0.05410762 0.94589238 0.64814306 29 | 0.93929385 0.56070615 0.25000000 30 | 0.56070615 0.93929385 0.75000000 31 | 0.43929385 0.06070615 0.25000000 32 | 0.06070615 0.43929385 0.75000000 33 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ac8a65af75e8223b07dcd98c6958db7af6504cf.poscar: -------------------------------------------------------------------------------- 1 | -.19471572E+03 '024/ht.task.gamma--default.0ac8a65af75e8223b07dcd98c6958db7af6504cf.cleanup.0.unclaimed.3.finished/ht.run.2016-06-10_22.43.42' ['(Tag) original structure: CaGa4O7', '(Tag) original structure path: structures/all/matches/ternaries/15/7_15_1eA1eB2fC3fB.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.730469 -4.584523 -0.342761 4 | 6.730469 4.584523 -0.342761 5 | -1.806793 0.000000 5.356055 6 | Zn Ti N 7 | 2 8 14 8 | Direct 9 | 0.71413180 0.28586820 0.25000000 10 | 0.28586820 0.71413180 0.75000000 11 | 0.06304672 0.16684276 0.66757928 12 | 0.83315724 0.93695328 0.83242072 13 | 0.93695328 0.83315724 0.33242072 14 | 0.16684276 0.06304672 0.16757928 15 | 0.26619824 0.44058687 0.15513109 16 | 0.55941313 0.73380176 0.34486891 17 | 0.73380176 0.55941313 0.84486891 18 | 0.44058687 0.26619824 0.65513109 19 | 0.13372099 0.26871112 0.03056721 20 | 0.73128888 0.86627901 0.46943279 21 | 0.86627901 0.73128888 0.96943279 22 | 0.26871112 0.13372099 0.53056721 23 | 0.31540762 0.48366840 0.86537518 24 | 0.51633160 0.68459238 0.63462482 25 | 0.68459238 0.51633160 0.13462482 26 | 0.48366840 0.31540762 0.36537518 27 | 0.84326964 0.35643738 0.60951799 28 | 0.64356262 0.15673036 0.89048201 29 | 0.15673036 0.64356262 0.39048201 30 | 0.35643738 0.84326964 0.10951799 31 | 0.95866177 0.04133823 0.25000000 32 | 0.04133823 0.95866177 0.75000000 33 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a130b596d1ce7bfe317d9e546bbeacc703fef65.poscar: -------------------------------------------------------------------------------- 1 | -.18818657E+03 '024/ht.task.gamma--default.0a130b596d1ce7bfe317d9e546bbeacc703fef65.cleanup.0.unclaimed.3.finished/ht.run.2015-11-10_03.32.23' ['(Tag) original structure: NaO8Ta3', '(Tag) original structure path: structures/all/matches/ternaries/72/6_72_1aA1bB1jB1kC2jC.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.383894 0.000000 -0.000000 4 | -0.000000 7.842330 0.000000 5 | 3.191947 3.921165 4.697699 6 | Ti Zn N 7 | 6 2 16 8 | Direct 9 | 0.25000000 0.75000000 0.50000000 10 | 0.75000000 0.25000000 0.50000000 11 | 0.26264823 0.24098197 0.00000000 12 | 0.23735177 0.74098197 1.00000000 13 | 0.76264823 0.25901803 0.00000000 14 | 0.73735177 0.75901803 1.00000000 15 | 0.75000000 0.75000000 0.50000000 16 | 0.25000000 0.25000000 0.50000000 17 | 0.99047279 0.08221216 1.00000000 18 | 0.50952721 0.58221216 0.00000000 19 | 0.49047279 0.41778784 1.00000000 20 | 0.00952721 0.91778784 0.00000000 21 | 0.08382932 0.45000872 1.00000000 22 | 0.41617068 0.95000872 0.00000000 23 | 0.58382932 0.04999128 1.00000000 24 | 0.91617068 0.54999128 0.00000000 25 | 0.94713593 0.03445718 0.57572495 26 | 0.55286407 0.11018214 0.42427505 27 | 0.02286088 0.46554282 0.42427505 28 | 0.52286088 0.61018214 0.42427505 29 | 0.97713912 0.53445718 0.57572495 30 | 0.44713593 0.88981786 0.57572495 31 | 0.05286407 0.96554282 0.42427505 32 | 0.47713912 0.38981786 0.57572495 33 | -------------------------------------------------------------------------------- /aviary/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def robust_l1_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor: 6 | """Robust L1 loss using a Lorentzian prior. Trains the model to learn to predict 7 | aleatoric (i.e. per-sample) uncertainty. 8 | 9 | Args: 10 | pred_mean (Tensor): Tensor of predicted means. 11 | pred_log_std (Tensor): Tensor of predicted log standard deviations representing 12 | per-sample model uncertainties. 13 | target (Tensor): Tensor of target values. 14 | 15 | Returns: 16 | Tensor: Evaluated robust L1 loss 17 | """ 18 | loss = 2**0.5 * (pred_mean - target).abs() * torch.exp(-pred_log_std) + pred_log_std 19 | return torch.mean(loss) 20 | 21 | 22 | def robust_l2_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor: 23 | """Robust L2 loss using a Gaussian prior. Trains the model to learn to predict 24 | aleatoric (i.e. per-sample) uncertainty. 25 | 26 | Args: 27 | pred_mean (Tensor): Tensor of predicted means. 28 | pred_log_std (Tensor): Tensor of predicted log standard deviations representing 29 | per-sample model uncertainties. 30 | target (Tensor): Tensor of target values. 31 | 32 | Returns: 33 | Tensor: Evaluated robust L2 loss 34 | """ 35 | loss = 0.5 * (pred_mean - target) ** 2 * torch.exp(-2 * pred_log_std) + pred_log_std 36 | return torch.mean(loss) 37 | 38 | 39 | # aliases for backwards compatibility 40 | RobustL1Loss = robust_l1_loss 41 | RobustL2Loss = robust_l2_loss 42 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a4a721fb8f0996350edd4fb224d89ce213f7f93.poscar: -------------------------------------------------------------------------------- 1 | -.17811622E+03 '026/ht.task.gamma--default.0a4a721fb8f0996350edd4fb224d89ce213f7f93.cleanup.3.unclaimed.3.finished/ht.run.2016-05-11_20.19.51' ['(Tag) original structure: Al7C3N3', '(Tag) original structure path: structures/all/matches/ternaries/36/13_36_3aA3aB7aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 1.739255 -2.701998 0.000000 4 | 1.739255 2.701998 0.000000 5 | 0.000000 0.000000 44.970811 6 | Ti Zn N 7 | 14 6 6 8 | Direct 9 | 0.65303539 0.34696461 0.05235786 10 | 0.34696461 0.65303539 0.55235786 11 | 0.01589023 0.98410977 0.10015629 12 | 0.98410977 0.01589023 0.60015629 13 | 0.65177435 0.34822565 0.14624567 14 | 0.34822565 0.65177435 0.64624567 15 | 0.06378092 0.93621908 0.26386996 16 | 0.93621908 0.06378092 0.76386996 17 | 0.70119409 0.29880591 0.30755386 18 | 0.29880591 0.70119409 0.80755386 19 | 0.97953586 0.02046414 0.41091222 20 | 0.02046414 0.97953586 0.91091222 21 | 0.63953190 0.36046810 0.46581246 22 | 0.36046810 0.63953190 0.96581246 23 | 0.00687154 0.99312846 0.00904892 24 | 0.99312846 0.00687154 0.50904892 25 | 0.71355352 0.28644648 0.21967736 26 | 0.28644648 0.71355352 0.71967736 27 | 0.08167077 0.91832923 0.35281686 28 | 0.91832923 0.08167077 0.85281686 29 | 0.66270201 0.33729799 0.09953666 30 | 0.33729799 0.66270201 0.59953666 31 | 0.00266910 0.99733090 0.14931080 32 | 0.99733090 0.00266910 0.64931080 33 | 0.62990179 0.37009821 0.42280110 34 | 0.37009821 0.62990179 0.92280110 35 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c429f98b79c9522a11d51fbfae484004f899422.poscar: -------------------------------------------------------------------------------- 1 | -.17580398E+03 '026/ht.task.triolith--default.0c429f98b79c9522a11d51fbfae484004f899422.cleanup.0.unclaimed.4.finished/ht.run.2015-12-09_10.07.49' ['(Tag) original structure: CeCr2O10', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/11/8_11_1eA1fB2eC4fC.cif'] 2 | 1 3 | 4.103714 0.000000 -0.700673 4 | 0.000000 12.972733 0.000000 5 | -0.424296 0.000000 6.461036 6 | Zr Zn N 7 | 2 4 20 8 | Direct 9 | 0.03289830 0.25000000 0.13910651 10 | 0.96710170 0.75000000 0.86089349 11 | 0.40925565 0.04150102 0.79236623 12 | 0.59074435 0.95849898 0.20763377 13 | 0.59074435 0.54150102 0.20763377 14 | 0.40925565 0.45849898 0.79236623 15 | 0.87785521 0.97059168 0.98812405 16 | 0.12214479 0.02940832 0.01187595 17 | 0.12214479 0.47059168 0.01187595 18 | 0.87785521 0.52940832 0.98812405 19 | 0.79856089 0.41140412 0.37612110 20 | 0.20143911 0.58859588 0.62387890 21 | 0.20143911 0.91140412 0.62387890 22 | 0.79856089 0.08859588 0.37612110 23 | 0.28692864 0.84017359 0.21926568 24 | 0.71307136 0.15982641 0.78073432 25 | 0.71307136 0.34017359 0.78073432 26 | 0.28692864 0.65982641 0.21926568 27 | 0.92927501 0.88310609 0.58519048 28 | 0.07072499 0.11689391 0.41480952 29 | 0.07072499 0.38310609 0.41480952 30 | 0.92927501 0.61689391 0.58519048 31 | 0.49121727 0.25000000 0.15385230 32 | 0.50878273 0.75000000 0.84614770 33 | 0.35117108 0.75000000 0.24759333 34 | 0.64882892 0.25000000 0.75240667 35 | -------------------------------------------------------------------------------- /examples/inputs/raw/001a938f944fa33e899b81e3d0a909876587e7b0.poscar: -------------------------------------------------------------------------------- 1 | -.19821513E+03 '028/ht.task.triolith--default.001a938f944fa33e899b81e3d0a909876587e7b0.cleanup.0.unclaimed.2.finished/ht.run.2016-04-26_03.52.10' ['(Tag) original structure: AuN12Rb', '(Tag) original structure path: structures/all/matches/ternaries/15/8_15_1eA1eB6fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.380312 -6.691171 -0.394081 4 | 5.380312 6.691171 -0.394081 5 | -0.621649 0.000000 7.622393 6 | Zn Hf N 7 | 2 2 24 8 | Direct 9 | 0.43380935 0.56619065 0.25000000 10 | 0.56619065 0.43380935 0.75000000 11 | 0.04124018 0.95875982 0.25000000 12 | 0.95875982 0.04124018 0.75000000 13 | 0.15480179 0.73674461 0.24234883 14 | 0.26325539 0.84519821 0.25765117 15 | 0.84519821 0.26325539 0.75765117 16 | 0.73674461 0.15480179 0.74234883 17 | 0.30916777 0.45162915 0.40555707 18 | 0.54837085 0.69083223 0.09444293 19 | 0.69083223 0.54837085 0.59444293 20 | 0.45162915 0.30916777 0.90555707 21 | 0.00874248 0.11999663 0.47949900 22 | 0.88000337 0.99125752 0.02050100 23 | 0.99125752 0.88000337 0.52050100 24 | 0.11999663 0.00874248 0.97949900 25 | 0.82331917 0.56703037 0.21133745 26 | 0.43296963 0.17668083 0.28866255 27 | 0.17668083 0.43296963 0.78866255 28 | 0.56703037 0.82331917 0.71133745 29 | 0.82021728 0.43928951 0.16649660 30 | 0.56071049 0.17978272 0.33350340 31 | 0.17978272 0.56071049 0.83350340 32 | 0.43928951 0.82021728 0.66649660 33 | 0.02591212 0.12246309 0.04628719 34 | 0.87753691 0.97408788 0.45371281 35 | 0.97408788 0.87753691 0.95371281 36 | 0.12246309 0.02591212 0.54628719 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/0039f3c461ce4a2cbf4c48b09bfc3c6b13a5c769.poscar: -------------------------------------------------------------------------------- 1 | -.23404935E+03 '028/ht.task.triolith--default.0039f3c461ce4a2cbf4c48b09bfc3c6b13a5c769.cleanup.0.unclaimed.2.finished/ht.run.2016-04-25_15.56.15' ['(Tag) original structure: O4SiZn2', '(Tag) original structure path: structures/all/matches/ternaries/61/4_61_1aA1cB2cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.306134 0.000000 0.000000 4 | -0.000000 5.455820 0.000000 5 | 0.000000 0.000000 11.449448 6 | Hf Zn N 7 | 8 4 16 8 | Direct 9 | 0.00077530 0.08033905 0.34704739 10 | 0.49922470 0.91966095 0.84704739 11 | 0.99922470 0.91966095 0.65295261 12 | 0.49922470 0.58033905 0.34704739 13 | 0.00077530 0.41966095 0.84704739 14 | 0.50077530 0.08033905 0.15295261 15 | 0.50077530 0.41966095 0.65295261 16 | 0.99922470 0.58033905 0.15295261 17 | -0.00000000 0.00000000 -0.00000000 18 | 0.50000000 0.00000000 0.50000000 19 | 0.50000000 0.50000000 -0.00000000 20 | -0.00000000 0.50000000 0.50000000 21 | 0.19703661 0.30428541 0.05279103 22 | 0.30296339 0.69571459 0.55279103 23 | 0.80296339 0.69571459 0.94720897 24 | 0.30296339 0.80428541 0.05279103 25 | 0.19703661 0.19571459 0.55279103 26 | 0.69703661 0.30428541 0.44720897 27 | 0.69703661 0.19571459 0.94720897 28 | 0.80296339 0.80428541 0.44720897 29 | 0.86819053 0.95152275 0.18764705 30 | 0.63180947 0.04847725 0.68764705 31 | 0.13180947 0.04847725 0.81235295 32 | 0.63180947 0.45152275 0.18764705 33 | 0.86819053 0.54847725 0.68764705 34 | 0.36819053 0.95152275 0.31235295 35 | 0.36819053 0.54847725 0.81235295 36 | 0.13180947 0.45152275 0.31235295 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/0abac62147275aa8199465764751eaa67c94d16c.poscar: -------------------------------------------------------------------------------- 1 | -.15943383E+03 '028/ht.task.gamma--default.0abac62147275aa8199465764751eaa67c94d16c.cleanup.0.unclaimed.3.finished/ht.run.2015-11-16_23.35.20' ['(Tag) original structure: S4Tm2Zn', '(Tag) original structure path: structures/all/matches/ternaries/62/6_62_1aA1cA1cB1dC2cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.225880 0.000000 0.000000 4 | 0.000000 6.241736 -0.000000 5 | 0.000000 -0.000000 4.996117 6 | Zn Ti N 7 | 8 4 16 8 | Direct 9 | -0.00000000 0.00000000 -0.00000000 10 | 0.50000000 0.50000000 0.50000000 11 | 0.50000000 0.00000000 0.50000000 12 | -0.00000000 0.50000000 -0.00000000 13 | 0.27439835 0.25000000 0.02027797 14 | 0.22560165 0.75000000 0.52027797 15 | 0.72560165 0.75000000 0.97972203 16 | 0.77439835 0.25000000 0.47972203 17 | 0.10637322 0.25000000 0.56181050 18 | 0.39362678 0.75000000 0.06181050 19 | 0.89362678 0.75000000 0.43818950 20 | 0.60637322 0.25000000 0.93818950 21 | 0.09067969 0.25000000 0.17405839 22 | 0.40932031 0.75000000 0.67405839 23 | 0.90932031 0.75000000 0.82594161 24 | 0.59067969 0.25000000 0.32594161 25 | 0.43282869 0.25000000 0.79687706 26 | 0.06717131 0.75000000 0.29687706 27 | 0.56717131 0.75000000 0.20312294 28 | 0.93282869 0.25000000 0.70312294 29 | 0.33440626 0.99742074 0.26413801 30 | 0.16559374 0.49742074 0.76413801 31 | 0.16559374 0.00257926 0.76413801 32 | 0.66559374 0.49742074 0.73586199 33 | 0.66559374 0.00257926 0.73586199 34 | 0.83440626 0.50257926 0.23586199 35 | 0.83440626 0.99742074 0.23586199 36 | 0.33440626 0.50257926 0.26413801 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/000866a75fbc0a305808c11c58342bb973c3cad5.poscar: -------------------------------------------------------------------------------- 1 | -.12200503E+03 '028/ht.task.triolith--default.000866a75fbc0a305808c11c58342bb973c3cad5.cleanup.0.unclaimed.2.finished/ht.run.2016-04-25_22.05.20' ['(Tag) original structure: Mo4NNi2', '(Tag) original structure path: structures/all/matches/ternaries/227/4_227_1cA1dB1eC1fB.cif', '(Tag) type: ternary'] 2 | 1 3 | -6.117621 -6.117621 0.000000 4 | -6.117621 0.000000 -6.117621 5 | 0.000000 -6.117621 -6.117621 6 | N Zn Hf 7 | 4 16 8 8 | Direct 9 | 0.00000000 -0.00000000 0.00000000 10 | 0.00000000 0.50000000 0.00000000 11 | 0.50000000 -0.00000000 0.00000000 12 | 0.00000000 -0.00000000 0.50000000 13 | 0.50000000 0.50000000 0.50000000 14 | 0.50000000 -0.00000000 0.50000000 15 | 0.00000000 0.50000000 0.50000000 16 | 0.50000000 0.50000000 0.00000000 17 | 0.67137393 0.67137393 0.07862607 18 | 0.32862607 0.92137393 0.92137393 19 | 0.92137393 0.92137393 0.32862607 20 | 0.07862607 0.67137393 0.67137393 21 | 0.67137393 0.07862607 0.07862607 22 | 0.32862607 0.32862607 0.92137393 23 | 0.07862607 0.07862607 0.67137393 24 | 0.67137393 0.07862607 0.67137393 25 | 0.32862607 0.92137393 0.32862607 26 | 0.92137393 0.32862607 0.92137393 27 | 0.07862607 0.67137393 0.07862607 28 | 0.92137393 0.32862607 0.32862607 29 | 0.71250038 0.71250038 0.71250038 30 | 0.28749962 0.63750113 0.28749962 31 | 0.63750113 0.28749962 0.28749962 32 | 0.36249887 0.71250038 0.71250038 33 | 0.71250038 0.36249887 0.71250038 34 | 0.28749962 0.28749962 0.63750113 35 | 0.28749962 0.28749962 0.28749962 36 | 0.71250038 0.71250038 0.36249887 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/00a8d0f6f55191596770605314ff347995a3540d.poscar: -------------------------------------------------------------------------------- 1 | -.12200932E+03 '028/ht.task.triolith--default.00a8d0f6f55191596770605314ff347995a3540d.cleanup.0.unclaimed.2.finished/ht.run.2016-04-25_22.41.50' ['(Tag) original structure: Co2NZr4', '(Tag) original structure path: structures/all/matches/ternaries/227/4_227_1cA1dB1eC1fA.cif', '(Tag) type: ternary'] 2 | 1 3 | -6.116632 -6.116632 -0.000000 4 | -6.116632 -0.000000 -6.116632 5 | -0.000000 -6.116632 -6.116632 6 | Zn Hf N 7 | 16 8 4 8 | Direct 9 | 0.00000000 0.00000000 -0.00000000 10 | 0.00000000 0.50000000 -0.00000000 11 | 0.50000000 0.00000000 -0.00000000 12 | 0.00000000 0.00000000 0.50000000 13 | 0.57230170 0.57230170 0.17769830 14 | 0.42769830 0.82230170 0.82230170 15 | 0.82230170 0.82230170 0.42769830 16 | 0.17769830 0.57230170 0.57230170 17 | 0.57230170 0.17769830 0.17769830 18 | 0.42769830 0.42769830 0.82230170 19 | 0.17769830 0.17769830 0.57230170 20 | 0.57230170 0.17769830 0.57230170 21 | 0.42769830 0.82230170 0.42769830 22 | 0.82230170 0.42769830 0.82230170 23 | 0.17769830 0.57230170 0.17769830 24 | 0.82230170 0.42769830 0.42769830 25 | 0.78788573 0.78788573 0.78788573 26 | 0.21211427 0.86365720 0.21211427 27 | 0.86365720 0.21211427 0.21211427 28 | 0.13634280 0.78788573 0.78788573 29 | 0.78788573 0.13634280 0.78788573 30 | 0.21211427 0.21211427 0.86365720 31 | 0.21211427 0.21211427 0.21211427 32 | 0.78788573 0.78788573 0.13634280 33 | 0.50000000 0.50000000 0.50000000 34 | 0.50000000 0.00000000 0.50000000 35 | 0.00000000 0.50000000 0.50000000 36 | 0.50000000 0.50000000 -0.00000000 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ab8ebf80d94a1ba826857a9e81cea8c7124d86e.poscar: -------------------------------------------------------------------------------- 1 | -.21135696E+03 '030/ht.task.gamma--default.0ab8ebf80d94a1ba826857a9e81cea8c7124d86e.cleanup.3.unclaimed.3.finished/ht.run.2016-06-20_16.32.26' ['(Tag) original structure: H4O8Ti3', '(Tag) original structure path: structures/all/matches/ternaries/12/15_12_3iA4iB8iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 14.680227 -1.507597 -0.760453 4 | 14.680227 1.507597 -0.760453 5 | -7.316016 0.000000 9.877532 6 | Ti Zn N 7 | 8 6 16 8 | Direct 9 | 0.36914484 0.36914484 0.97637487 10 | 0.63085516 0.63085516 0.02362513 11 | 0.48270716 0.48270716 0.65034958 12 | 0.51729284 0.51729284 0.34965042 13 | 0.29723292 0.29723292 0.67782370 14 | 0.70276708 0.70276708 0.32217630 15 | 0.90549748 0.90549748 0.30384910 16 | 0.09450252 0.09450252 0.69615090 17 | 0.24540401 0.24540401 0.18401970 18 | 0.75459599 0.75459599 0.81598030 19 | 0.16410137 0.16410137 0.49746173 20 | 0.83589863 0.83589863 0.50253827 21 | 0.20389767 0.20389767 0.90895922 22 | 0.79610233 0.79610233 0.09104078 23 | 0.16237373 0.16237373 0.99331877 24 | 0.83762627 0.83762627 0.00668123 25 | 0.35419025 0.35419025 0.23475963 26 | 0.64580975 0.64580975 0.76524037 27 | 0.15175152 0.15175152 0.18211561 28 | 0.84824848 0.84824848 0.81788439 29 | 0.29589949 0.29589949 0.52234703 30 | 0.70410051 0.70410051 0.47765297 31 | 0.07845619 0.07845619 0.41454019 32 | 0.92154381 0.92154381 0.58545981 33 | 0.23587806 0.23587806 0.70148934 34 | 0.76412194 0.76412194 0.29851066 35 | 0.04432356 0.04432356 0.75256226 36 | 0.95567644 0.95567644 0.24743774 37 | 0.93768367 0.93768367 0.50017113 38 | 0.06231633 0.06231633 0.49982887 39 | -------------------------------------------------------------------------------- /examples/inputs/raw/0019407b910a29730999dab65491b4e50624c219.poscar: -------------------------------------------------------------------------------- 1 | -.18172455E+03 '030/ht.task.triolith--default.0019407b910a29730999dab65491b4e50624c219.cleanup.0.unclaimed.2.finished/ht.run.2016-04-15_10.15.11' ['(Tag) original structure: K2O9Si4', '(Tag) original structure path: structures/all/matches/ternaries/176/5_176_1bA1fB1hC1hA1iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.370680 -3.100763 0.000000 4 | -0.000000 6.201526 0.000000 5 | 0.000000 0.000000 10.101875 6 | Hf Zn N 7 | 4 8 18 8 | Direct 9 | 0.33333333 0.66666667 0.07050127 10 | 0.66666667 0.33333333 0.57050127 11 | 0.33333333 0.66666667 0.42949873 12 | 0.66666667 0.33333333 0.92949873 13 | 0.00000000 0.00000000 -0.00000000 14 | 0.00000000 0.00000000 0.50000000 15 | 0.34920449 0.19805089 0.25000000 16 | 0.19805089 0.84884640 0.75000000 17 | 0.84884640 0.65079551 0.25000000 18 | 0.15115360 0.34920449 0.75000000 19 | 0.80194911 0.15115360 0.25000000 20 | 0.65079551 0.80194911 0.75000000 21 | 0.43067355 0.93955068 0.25000000 22 | 0.93955068 0.50887713 0.75000000 23 | 0.50887713 0.56932645 0.25000000 24 | 0.49112287 0.43067355 0.75000000 25 | 0.06044932 0.49112287 0.25000000 26 | 0.56932645 0.06044932 0.75000000 27 | 0.34135426 0.31478008 0.07552230 28 | 0.31478008 0.97342582 0.57552230 29 | 0.34135426 0.31478008 0.42447770 30 | 0.97342582 0.65864574 0.42447770 31 | 0.02657418 0.34135426 0.92447770 32 | 0.68521992 0.02657418 0.42447770 33 | 0.31478008 0.97342582 0.92447770 34 | 0.65864574 0.68521992 0.57552230 35 | 0.65864574 0.68521992 0.92447770 36 | 0.68521992 0.02657418 0.07552230 37 | 0.97342582 0.65864574 0.07552230 38 | 0.02657418 0.34135426 0.57552230 39 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c0b4fd84c8f4e075436d7ec34bfb3fad5e610be.poscar: -------------------------------------------------------------------------------- 1 | -.21275901E+03 '030/ht.task.triolith--default.0c0b4fd84c8f4e075436d7ec34bfb3fad5e610be.cleanup.0.unclaimed.4.finished/ht.run.2015-12-07_19.52.24' ['(Tag) original structure: MnO3Yb', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/185/7_185_1aA1aB1bA1bB1cC2cA.cif'] 2 | 1 3 | 5.543294 -3.200422 0.000000 4 | -0.000000 6.400844 -0.000000 5 | 0.000000 0.000000 10.157681 6 | Zn Zr N 7 | 6 6 18 8 | Direct 9 | -0.00000000 -0.00000000 0.26067791 10 | -0.00000000 -0.00000000 0.76067791 11 | 0.33333333 0.66666667 0.22343397 12 | 0.33333333 0.66666667 0.72343397 13 | 0.66666667 0.33333333 0.72343397 14 | 0.66666667 0.33333333 0.22343397 15 | 0.34838803 1.00000000 0.98014843 16 | 0.65161197 0.00000000 0.48014843 17 | 1.00000000 0.34838803 0.98014843 18 | 0.34838803 0.34838803 0.48014843 19 | 0.65161197 0.65161197 0.98014843 20 | 0.00000000 0.65161197 0.48014843 21 | 0.29706565 1.00000000 0.20897943 22 | 0.70293435 0.00000000 0.70897943 23 | -0.00000000 0.29706565 0.20897943 24 | 0.29706565 0.29706565 0.70897943 25 | 0.70293435 0.70293435 0.20897943 26 | -0.00000000 0.70293435 0.70897943 27 | 0.44196668 0.00000000 0.31072121 28 | 0.55803332 1.00000000 0.81072121 29 | -0.00000000 0.44196668 0.31072121 30 | 0.44196668 0.44196668 0.81072121 31 | 0.55803332 0.55803332 0.31072121 32 | -0.00000000 0.55803332 0.81072121 33 | -0.00000000 -0.00000000 0.46416939 34 | -0.00000000 -0.00000000 0.96416939 35 | 0.33333333 0.66666667 0.02996876 36 | 0.33333333 0.66666667 0.52996876 37 | 0.66666667 0.33333333 0.52996876 38 | 0.66666667 0.33333333 0.02996876 39 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | from matminer.datasets import load_dataset 6 | from pymatgen.analysis.prototypes import get_protostructure_label_from_spglib 7 | 8 | torch.manual_seed(0) # ensure reproducible results (applies to all tests) 9 | 10 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def df_matbench_phonons(): 15 | """Returns the dataframe for the Matbench phonon DOS peak task.""" 16 | 17 | df = load_dataset("matbench_phonons") 18 | df["material_id"] = [f"mb_phdos_{idx + 1}" for idx in range(len(df))] 19 | df = df.set_index("material_id", drop=False) 20 | df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure] 21 | 22 | df["phdos_clf"] = [1 if x > 450 else 0 for x in df["last phdos peak"]] 23 | 24 | return df 25 | 26 | 27 | @pytest.fixture(scope="session") 28 | def df_matbench_jdft2d(): 29 | """Returns Matbench experimental band gap task dataframe. Currently unused.""" 30 | 31 | df = load_dataset("matbench_jdft2d") 32 | df["material_id"] = [f"mb_jdft2d_{idx + 1}" for idx in range(len(df))] 33 | df = df.set_index("material_id", drop=False) 34 | df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure] 35 | 36 | df["protostructure"] = df.structure.map(get_protostructure_label_from_spglib) 37 | 38 | return df 39 | 40 | 41 | @pytest.fixture(scope="session") 42 | def df_matbench_phonons_wyckoff(df_matbench_phonons): 43 | """Getting Aflow labels is expensive so we split into a separate fixture to avoid 44 | paying for it unless requested. 45 | """ 46 | df_matbench_phonons["protostructure"] = df_matbench_phonons.structure.map( 47 | get_protostructure_label_from_spglib 48 | ) 49 | 50 | return df_matbench_phonons 51 | -------------------------------------------------------------------------------- /examples/inputs/raw/00a3259f00f8a9aa503caae2045970d2f7b22e13.poscar: -------------------------------------------------------------------------------- 1 | -.22155319E+03 '031/ht.task.triolith--default.00a3259f00f8a9aa503caae2045970d2f7b22e13.cleanup.0.unclaimed.2.finished/ht.run.2016-01-28_12.35.41' ['(Tag) original structure: O19Sr5Ti7', '(Tag) original structure path: structures/all/matches/ternaries/47/14_47_1aA1eB1fC1qA1rB1rA1tC2qB2sC3vB.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.184151 -0.000000 0.000000 4 | -0.000000 7.172117 0.000000 5 | 0.000000 0.000000 25.251255 6 | Hf N Zn 7 | 7 19 5 8 | Direct 9 | 0.50000000 0.00000000 0.25755200 10 | 0.50000000 0.00000000 0.74244800 11 | 0.50000000 0.50000000 0.17987301 12 | 0.50000000 0.50000000 0.82012699 13 | 0.50000000 0.00000000 0.09174676 14 | 0.50000000 0.00000000 0.90825324 15 | 0.50000000 0.50000000 -0.00000000 16 | 0.00000000 0.00000000 0.27199352 17 | 0.00000000 0.00000000 0.72800658 18 | 0.50000000 0.78477034 0.20603554 19 | 0.50000000 0.78477034 0.79396446 20 | 0.50000000 0.21522966 0.79396446 21 | 0.50000000 0.21522966 0.20603554 22 | 0.00000000 0.50000000 0.15419995 23 | -0.00000000 0.50000000 0.84580005 24 | 0.50000000 0.26138331 0.13101457 25 | 0.50000000 0.26138331 0.86898543 26 | 0.50000000 0.73861669 0.86898543 27 | 0.50000000 0.73861669 0.13101457 28 | -0.00000000 0.00000000 0.11651130 29 | 0.00000000 0.00000000 0.88348870 30 | 0.50000000 0.25939093 0.05225807 31 | 0.50000000 0.25939093 0.94774193 32 | 0.50000000 0.74060907 0.94774193 33 | 0.50000000 0.74060907 0.05225807 34 | -0.00000000 0.50000000 -0.00000000 35 | 0.00000000 0.00000000 0.19141435 36 | 0.00000000 0.00000000 0.80858565 37 | 0.00000000 0.50000000 0.07865654 38 | -0.00000000 0.50000000 0.92134346 39 | -0.00000000 0.00000000 -0.00000000 40 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a2c4708d3c853ea3c4b8630027afe9b788ee2ce.poscar: -------------------------------------------------------------------------------- 1 | -.23520684E+03 '032/ht.task.gamma--default.0a2c4708d3c853ea3c4b8630027afe9b788ee2ce.cleanup.0.unclaimed.3.finished/ht.run.2016-05-02_10.51.25' ['(Tag) original structure: Cu2O5P', '(Tag) original structure path: structures/all/matches/ternaries/58/7_58_1fA1gA1gB1hC3gC.cif', '(Tag) type: ternary'] 2 | 1 3 | 8.736347 -0.000000 0.000000 4 | 0.000000 8.231356 0.000000 5 | 0.000000 0.000000 6.998135 6 | N Zn Ti 7 | 20 4 8 8 | Direct 9 | 0.04933257 0.15515104 0.00000000 10 | 0.45066743 0.65515104 0.50000000 11 | 0.95066743 0.84484896 -0.00000000 12 | 0.54933257 0.34484896 0.50000000 13 | 0.06055200 0.97872251 0.50000000 14 | 0.43944800 0.47872251 0.00000000 15 | 0.93944800 0.02127749 0.50000000 16 | 0.56055200 0.52127749 0.00000000 17 | 0.14268832 0.43858699 0.50000000 18 | 0.35731168 0.93858699 -0.00000000 19 | 0.85731168 0.56141301 0.50000000 20 | 0.64268832 0.06141301 0.00000000 21 | 0.36196035 0.18003421 0.24899865 22 | 0.13803965 0.68003421 0.74899865 23 | 0.63803965 0.81996579 0.75100135 24 | 0.86196035 0.31996579 0.25100135 25 | 0.13803965 0.68003421 0.25100135 26 | 0.63803965 0.81996579 0.24899865 27 | 0.86196035 0.31996579 0.74899865 28 | 0.36196035 0.18003421 0.75100135 29 | 0.32947201 0.28836426 0.50000000 30 | 0.17052799 0.78836426 0.00000000 31 | 0.67052799 0.71163574 0.50000000 32 | 0.82947201 0.21163574 0.00000000 33 | 0.00000000 0.50000000 0.30022293 34 | 0.50000000 -0.00000000 0.80022293 35 | 0.00000000 0.50000000 0.69977707 36 | 0.50000000 -0.00000000 0.19977707 37 | 0.26884529 0.16402866 0.00000000 38 | 0.23115471 0.66402866 0.50000000 39 | 0.73115471 0.83597134 -0.00000000 40 | 0.76884529 0.33597134 0.50000000 41 | -------------------------------------------------------------------------------- /examples/matbench_example/prepare_matbench_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from matbench.data_ops import load 5 | from pymatgen.analysis.prototypes import get_protostructure_label_from_spglib 6 | from tqdm import tqdm 7 | 8 | tqdm.pandas() 9 | 10 | current_dir = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | SMOKE_TEST = True 13 | 14 | matbench_datasets = [ 15 | "matbench_steels", 16 | "matbench_jdft2d", 17 | "matbench_phonons", 18 | "matbench_expt_gap", 19 | "matbench_dielectric", 20 | "matbench_expt_is_metal", 21 | "matbench_glass", 22 | "matbench_log_gvrh", 23 | "matbench_log_kvrh", 24 | "matbench_perovskites", 25 | "matbench_mp_gap", 26 | "matbench_mp_is_metal", 27 | "matbench_mp_e_form", 28 | ] 29 | 30 | if SMOKE_TEST: 31 | matbench_datasets = matbench_datasets[2:3] 32 | 33 | os.makedirs(f"{current_dir}/datasets", exist_ok=True) 34 | for dataset in matbench_datasets: 35 | dataset_path = f"{current_dir}/datasets/{dataset}.json.bz2" 36 | 37 | if os.path.exists(dataset_path): 38 | print(f"Dataset {dataset} already exists, skipping") 39 | continue 40 | 41 | df = load(dataset) 42 | 43 | if "structure" in df: 44 | df["composition"] = [struct.formula for struct in df.structure] 45 | df["protostructure"] = df["structure"].progress_apply( 46 | get_protostructure_label_from_spglib 47 | ) 48 | else: 49 | raise ValueError("No structure or composition column found") 50 | 51 | df.to_json( 52 | dataset_path, 53 | default_handler=lambda x: x.as_dict(), 54 | ) 55 | 56 | 57 | DATA_PATHS = { 58 | path.split("/")[-1].split(".")[0]: path 59 | for path in glob(f"{current_dir}/datasets/matbench_*.json.bz2") 60 | } 61 | 62 | assert len(DATA_PATHS) == len(matbench_datasets), ( 63 | f"glob found {len(DATA_PATHS)} data sets, expected {len(matbench_datasets)}" 64 | ) 65 | -------------------------------------------------------------------------------- /examples/inputs/raw/00cca7ad273875ba80f1d437af797b6a01cc6ebd.poscar: -------------------------------------------------------------------------------- 1 | -.28166871E+03 '034/ht.task.triolith--default.00cca7ad273875ba80f1d437af797b6a01cc6ebd.cleanup.0.unclaimed.2.finished/ht.run.2015-12-28_06.56.02' ['(Tag) original structure: AuK3Se13', '(Tag) original structure path: structures/all/matches/ternaries/13/10_13_1dA1eB1fC1gB6gC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.962471 0.000000 0.122473 4 | 0.000000 3.175217 0.000000 5 | -0.573407 0.000000 9.709596 6 | Zn N Hf 7 | 2 26 6 8 | Direct 9 | 0.50000000 0.00000000 -0.00000000 10 | 0.50000000 0.00000000 0.50000000 11 | 0.67360320 0.18759091 0.93174711 12 | 0.67360320 0.81240909 0.43174711 13 | 0.32639680 0.81240909 0.06825289 14 | 0.32639680 0.18759091 0.56825289 15 | 0.75558845 0.23022353 0.03949303 16 | 0.75558845 0.76977647 0.53949303 17 | 0.24441155 0.76977647 0.96050697 18 | 0.24441155 0.23022353 0.46050697 19 | 0.12699532 0.23002080 0.47579975 20 | 0.12699532 0.76997920 0.97579975 21 | 0.87300468 0.76997920 0.52420025 22 | 0.87300468 0.23002080 0.02420025 23 | 0.07990668 0.20674338 0.60169736 24 | 0.07990668 0.79325662 0.10169736 25 | 0.92009332 0.79325662 0.39830264 26 | 0.92009332 0.20674338 0.89830264 27 | 0.14854564 0.29132038 0.73090146 28 | 0.14854564 0.70867962 0.23090146 29 | 0.85145436 0.70867962 0.26909854 30 | 0.85145436 0.29132038 0.76909854 31 | 0.41976842 0.75041369 0.33129433 32 | 0.41976842 0.24958631 0.83129433 33 | 0.58023158 0.24958631 0.66870567 34 | 0.58023158 0.75041369 0.16870567 35 | 0.50000000 0.45635328 0.75000000 36 | 0.50000000 0.54364672 0.25000000 37 | -0.00000000 0.23886969 0.25000000 38 | -0.00000000 0.76113031 0.75000000 39 | 0.69243328 0.26495359 0.26261208 40 | 0.69243328 0.73504641 0.76261208 41 | 0.30756672 0.73504641 0.73738792 42 | 0.30756672 0.26495359 0.23738792 43 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a946dc02260d0d4b36ef2087e60b415a5fff800.poscar: -------------------------------------------------------------------------------- 1 | -.29008227E+03 '036/ht.task.gamma--default.0a946dc02260d0d4b36ef2087e60b415a5fff800.cleanup.0.unclaimed.3.finished/ht.run.2015-11-16_19.23.17' ['(Tag) original structure: Cl5PbTl3', '(Tag) original structure path: structures/all/matches/ternaries/76/9_76_1aA3aB5aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.839669 0.000000 0.000000 4 | -0.000000 5.839669 0.000000 5 | 0.000000 0.000000 10.063436 6 | Ti Zn N 7 | 12 4 20 8 | Direct 9 | 0.96905875 0.80718338 0.81186643 10 | 0.03094125 0.19281662 0.31186643 11 | 0.19281662 0.96905875 0.06186643 12 | 0.80718338 0.03094125 0.56186643 13 | 0.10899183 0.34228606 0.88023088 14 | 0.89100817 0.65771394 0.38023088 15 | 0.65771394 0.10899183 0.13023088 16 | 0.34228606 0.89100817 0.63023088 17 | 0.60631294 0.18018496 0.83510905 18 | 0.39368706 0.81981504 0.33510905 19 | 0.81981504 0.60631294 0.08510905 20 | 0.18018496 0.39368706 0.58510905 21 | 0.48525695 0.69477963 0.89229795 22 | 0.51474305 0.30522037 0.39229795 23 | 0.30522037 0.48525695 0.14229795 24 | 0.69477963 0.51474305 0.64229795 25 | 0.34294622 0.11501552 0.23713654 26 | 0.65705378 0.88498448 0.73713654 27 | 0.88498448 0.34294622 0.48713654 28 | 0.11501552 0.65705378 0.98713654 29 | 0.73779301 0.97899975 0.33796876 30 | 0.26220699 0.02100025 0.83796876 31 | 0.02100025 0.73779301 0.58796876 32 | 0.97899975 0.26220699 0.08796876 33 | 0.00065757 0.82328758 0.21040091 34 | 0.99934243 0.17671242 0.71040091 35 | 0.17671242 0.00065757 0.46040091 36 | 0.82328758 0.99934243 0.96040091 37 | 0.55444204 0.73450842 0.50475735 38 | 0.44555796 0.26549158 0.00475735 39 | 0.26549158 0.55444204 0.75475735 40 | 0.73450842 0.44555796 0.25475735 41 | 0.51196313 0.78128955 0.10023213 42 | 0.48803687 0.21871045 0.60023213 43 | 0.21871045 0.51196313 0.35023213 44 | 0.78128955 0.48803687 0.85023213 45 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ab8cc7e5cc3d48fcc786f9060da5a8882374f44.poscar: -------------------------------------------------------------------------------- 1 | -.22445649E+03 '036/ht.task.gamma--default.0ab8cc7e5cc3d48fcc786f9060da5a8882374f44.cleanup.0.unclaimed.3.finished/ht.run.2016-06-23_23.03.48' ['(Tag) original structure: Al2Ca3N4', '(Tag) original structure path: structures/all/matches/ternaries/14/9_14_2eA3eB4eC.cif', '(Tag) type: ternary'] 2 | 1 3 | 9.402404 -0.001130 0.397882 4 | -0.000650 5.646741 0.000464 5 | -3.073322 0.000794 8.683129 6 | Zn Ti N 7 | 12 8 16 8 | Direct 9 | 0.05277259 0.04468460 0.22554858 10 | 0.94722716 0.95531565 0.77445205 11 | 0.94695245 0.54466241 0.27464315 12 | 0.05304916 0.45533967 0.72535521 13 | 0.18677287 0.65379652 0.12185680 14 | 0.81322239 0.34620769 0.87814188 15 | 0.81300997 0.15355040 0.37838871 16 | 0.18699246 0.84645005 0.62161212 17 | 0.42929711 0.36288193 0.28830955 18 | 0.57069576 0.63711251 0.71169070 19 | 0.57061516 0.86310666 0.21156040 20 | 0.42938961 0.13689561 0.78843907 21 | 0.22918593 0.15003641 0.98725724 22 | 0.77081261 0.84996283 0.01274258 23 | 0.77079612 0.65007753 0.51282933 24 | 0.22920238 0.34992259 0.48717364 25 | 0.61109232 0.38180476 0.12711576 26 | 0.38891406 0.61819165 0.87288127 27 | 0.38888351 0.88195244 0.37279498 28 | 0.61111673 0.11804393 0.62720759 29 | 0.02632592 0.24383232 0.39072107 30 | 0.97367511 0.75617134 0.60927740 31 | 0.97352280 0.74373049 0.10950746 32 | 0.02648239 0.25627025 0.89049600 33 | 0.25275843 0.02636994 0.18684925 34 | 0.74723814 0.97364108 0.81314111 35 | 0.74694060 0.52616949 0.31323453 36 | 0.25306032 0.47383160 0.68676848 37 | 0.36602370 0.41637224 0.03775171 38 | 0.63397686 0.58361155 0.96224839 39 | 0.63414571 0.91658293 0.46221345 40 | 0.36585057 0.08341002 0.53778712 41 | 0.71185292 0.07797699 0.14532722 42 | 0.28814492 0.92203290 0.85467531 43 | 0.28829662 0.57788891 0.35464532 44 | 0.71170460 0.42211213 0.64535556 45 | -------------------------------------------------------------------------------- /examples/matbench_example/readme.md: -------------------------------------------------------------------------------- 1 | # Matbench 2 | 3 | This directory contains the files needed to create Matbench submissions for Roostformer and Wrenformer (structure tasks only for Wren) which are rewrites of Roost and Wren using PyTorch's builtin `TransformerEncoder` instead of the custom self-attention modules used by Roost and Wren. 4 | 5 | Added in [aviary#44](https://github.com/CompRhys/aviary/pull/44). 6 | 7 | ## Speed difference between Wren and Wrenformer 8 | 9 | According to Rhys, Wren could run 500 epochs in 5.5 h on a P100 training on 120k samples of MP data (similar to the `matbench_mp_e_form` dataset with 132k samples). Wrenformer only managed 207 epochs in 4h on the more powerful A100 training on `matbench_mp_e_form`. However, to avoid out-of-memory issues, Rhys constrained Wren to only run on systems with <= 16 Wyckoff positions. The code below shows that this lightens the workload by a factor of about 7.5, likely explaining the apparent slowdown in Wrenformer. 10 | 11 | ```py 12 | import os 13 | import sys 14 | from pathlib import Path 15 | 16 | # Add the parent directory to system path 17 | sys.path.append(str(Path(__file__).parent.parent)) 18 | 19 | import pandas as pd 20 | from pymatgen.analysis.prototypes import count_wyckoff_positions 21 | from matbench_example import DATA_PATHS 22 | 23 | df = pd.read_json(DATA_PATHS["matbench_mp_e_form"]) 24 | 25 | df["n_wyckoff"] = df.wyckoff.map(count_wyckoff_positions) 26 | 27 | 28 | sum_wyckoffs_sqr = (df.n_wyckoff**2).sum() 29 | sum_wyckoffs_lte_16_sqr = (df.query("n_wyckoff <= 16").n_wyckoff ** 2).sum() 30 | print(f"{sum_wyckoffs_sqr=}") 31 | print(f"{sum_wyckoffs_lte_16_sqr=}") 32 | print(f"{sum_wyckoffs_sqr/sum_wyckoffs_lte_16_sqr=:.3}") 33 | # prints 7.45, so Wrenformer has to do 7.45x more work, explaining the about 2x slow down 34 | # on a more powerful GPU (Nvidia A100 vs Wren on a P100) 35 | ``` 36 | 37 | ## Benchmarks 38 | 39 | JSON files in `model_scores/` contain only the calculated scores (MAE/ROCAUC) for a given model run. Files with the same name in `model_preds/` contain the full set of model predictions, targets and material ids. Code for loading these into memory is in `make_plots.py`. 40 | -------------------------------------------------------------------------------- /examples/inputs/raw/00bada6787af6463447a1460da0c031d51494b60.poscar: -------------------------------------------------------------------------------- 1 | -.30469940E+03 '039/ht.task.triolith--default.00bada6787af6463447a1460da0c031d51494b60.cleanup.0.unclaimed.2.finished/ht.run.2016-01-28_08.04.36' ['(Tag) original structure: Ba3S8Ta2', '(Tag) original structure path: structures/all/matches/ternaries/10/17_10_1aA1cB1eA1mA1nA2mB2mC2nB2nC4oC.cif', '(Tag) type: ternary'] 2 | 1 3 | 9.521579 0.000000 0.373799 4 | 0.000000 4.706142 0.000000 5 | -2.621791 0.000000 9.869892 6 | Hf Zn N 7 | 9 6 24 8 | Direct 9 | -0.00000000 0.00000000 0.50000000 10 | 0.71433010 0.00000000 0.98271042 11 | 0.28566990 0.00000000 0.01728958 12 | 0.33416582 0.50000000 0.41311094 13 | 0.66583418 0.50000000 0.58688906 14 | 0.55172795 -0.00000000 0.35393491 15 | 0.44827205 0.00000000 0.64606509 16 | 0.96996488 0.50000000 0.19091014 17 | 0.03003512 0.50000000 0.80908986 18 | -0.00000000 0.00000000 0.00000000 19 | 0.50000000 0.50000000 0.00000000 20 | 0.18814200 0.00000000 0.27782706 21 | 0.81185800 0.00000000 0.72217294 22 | 0.67649413 0.50000000 0.24016345 23 | 0.32350587 0.50000000 0.75983655 24 | 0.73909176 0.50000000 0.06433623 25 | 0.26090824 0.50000000 0.93566377 26 | 0.36649012 0.00000000 0.41827029 27 | 0.63350988 0.00000000 0.58172971 28 | 0.56803545 0.50000000 0.38970358 29 | 0.43196455 0.50000000 0.61029642 30 | 0.82102394 0.00000000 0.23359175 31 | 0.17897606 -0.00000000 0.76640825 32 | 0.13781094 0.25687296 0.11731282 33 | 0.86218906 0.25687296 0.88268718 34 | 0.86218906 0.74312704 0.88268718 35 | 0.13781094 0.74312704 0.11731282 36 | 0.48082977 0.15574699 0.15015500 37 | 0.51917023 0.15574699 0.84984500 38 | 0.51917023 0.84425301 0.84984500 39 | 0.48082977 0.84425301 0.15015500 40 | 0.11588547 0.34907379 0.40667526 41 | 0.88411453 0.34907379 0.59332474 42 | 0.88411453 0.65092621 0.59332474 43 | 0.11588547 0.65092621 0.40667526 44 | 0.85246030 0.22762806 0.30914063 45 | 0.14753970 0.22762806 0.69085937 46 | 0.14753970 0.77237194 0.69085937 47 | 0.85246030 0.77237194 0.30914063 48 | -------------------------------------------------------------------------------- /examples/inputs/raw/0adeaf7a9c16772b66815cce19917ef0460affe8.poscar: -------------------------------------------------------------------------------- 1 | -.27195782E+03 '040/ht.task.gamma--default.0adeaf7a9c16772b66815cce19917ef0460affe8.cleanup.0.unclaimed.3.finished/ht.run.2016-04-13_19.08.23' ['(Tag) original structure: O3SrZr', '(Tag) original structure path: structures/all/matches/ternaries/198/6_198_2aA2aB2bC.cif', '(Tag) type: ternary'] 2 | 1 3 | 7.685200 0.000000 0.000000 4 | 0.000000 7.685200 0.000000 5 | 0.000000 0.000000 7.685200 6 | Ti Zn N 7 | 8 8 24 8 | Direct 9 | 0.00179104 0.00179104 0.00179104 10 | 0.49820896 0.99820896 0.50179104 11 | 0.50179104 0.49820896 0.99820896 12 | 0.99820896 0.50179104 0.49820896 13 | 0.47336587 0.47336587 0.47336587 14 | 0.02663413 0.52663413 0.97336587 15 | 0.97336587 0.02663413 0.52663413 16 | 0.52663413 0.97336587 0.02663413 17 | 0.29008367 0.29008367 0.29008367 18 | 0.20991633 0.70991633 0.79008367 19 | 0.79008367 0.20991633 0.70991633 20 | 0.70991633 0.79008367 0.20991633 21 | 0.75881592 0.75881592 0.75881592 22 | 0.74118408 0.24118408 0.25881592 23 | 0.25881592 0.74118408 0.24118408 24 | 0.24118408 0.25881592 0.74118408 25 | 0.23092657 0.11049583 0.08125003 26 | 0.41874997 0.76907343 0.61049583 27 | 0.58125003 0.26907343 0.88950417 28 | 0.91874997 0.73092657 0.38950417 29 | 0.11049583 0.08125003 0.23092657 30 | 0.38950417 0.91874997 0.73092657 31 | 0.76907343 0.61049583 0.41874997 32 | 0.08125003 0.23092657 0.11049583 33 | 0.26907343 0.88950417 0.58125003 34 | 0.73092657 0.38950417 0.91874997 35 | 0.61049583 0.41874997 0.76907343 36 | 0.88950417 0.58125003 0.26907343 37 | 0.75939082 0.03937817 0.09224158 38 | 0.40775842 0.24060918 0.53937817 39 | 0.59224158 0.74060918 0.96062183 40 | 0.90775842 0.25939082 0.46062183 41 | 0.03937817 0.09224158 0.75939082 42 | 0.46062183 0.90775842 0.25939082 43 | 0.24060918 0.53937817 0.40775842 44 | 0.09224158 0.75939082 0.03937817 45 | 0.74060918 0.96062183 0.59224158 46 | 0.25939082 0.46062183 0.90775842 47 | 0.53937817 0.40775842 0.24060918 48 | 0.96062183 0.59224158 0.74060918 49 | -------------------------------------------------------------------------------- /examples/inputs/raw/00ecadb66a8baab030675a817d2ee15f0002a4dc.poscar: -------------------------------------------------------------------------------- 1 | -.37086156E+03 '046/ht.task.triolith--default.00ecadb66a8baab030675a817d2ee15f0002a4dc.cleanup.0.unclaimed.2.finished/ht.run.2016-01-11_18.05.48' ['(Tag) original structure: Al2F16Sr5', '(Tag) original structure path: structures/all/matches/ternaries/68/8_68_1bA1gA1hB1hA4iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.444884 -6.785908 0.000000 4 | 3.444884 6.785908 0.000000 5 | 0.000000 0.000000 12.043561 6 | Hf Zn N 7 | 10 4 32 8 | Direct 9 | 0.75000000 0.75000000 0.14314454 10 | 0.25000000 0.25000000 0.35685546 11 | 0.75000000 0.75000000 0.64314454 12 | 0.25000000 0.25000000 0.85685546 13 | 0.25000000 0.75000000 0.25000000 14 | 0.75000000 0.25000000 0.75000000 15 | 0.75000000 0.25000000 0.02657995 16 | 0.75000000 0.25000000 0.47342005 17 | 0.25000000 0.75000000 0.52657995 18 | 0.25000000 0.75000000 0.97342005 19 | 0.25000000 0.25000000 0.10718944 20 | 0.75000000 0.75000000 0.39281056 21 | 0.25000000 0.25000000 0.60718944 22 | 0.75000000 0.75000000 0.89281056 23 | 0.32171573 0.49040581 0.24334585 24 | 0.50959419 0.67828427 0.25665415 25 | 0.00959419 0.17828427 0.74334585 26 | 0.17828427 0.00959419 0.24334585 27 | 0.67828427 0.50959419 0.75665415 28 | 0.99040581 0.82171573 0.25665415 29 | 0.49040581 0.32171573 0.74334585 30 | 0.82171573 0.99040581 0.75665415 31 | 0.02449093 0.58763465 0.10661258 32 | 0.41236535 0.97550907 0.39338742 33 | 0.91236535 0.47550907 0.60661258 34 | 0.47550907 0.91236535 0.10661258 35 | 0.97550907 0.41236535 0.89338742 36 | 0.08763465 0.52449093 0.39338742 37 | 0.58763465 0.02449093 0.60661258 38 | 0.52449093 0.08763465 0.89338742 39 | 0.92789297 0.09878444 0.10280100 40 | 0.90121556 0.07210703 0.39719900 41 | 0.40121556 0.57210703 0.60280100 42 | 0.57210703 0.40121556 0.10280100 43 | 0.07210703 0.90121556 0.89719900 44 | 0.59878444 0.42789297 0.39719900 45 | 0.09878444 0.92789297 0.60280100 46 | 0.42789297 0.59878444 0.89719900 47 | 0.28120306 0.47712413 0.97303318 48 | 0.52287587 0.71879694 0.52696682 49 | 0.02287587 0.21879694 0.47303318 50 | 0.21879694 0.02287587 0.97303318 51 | 0.71879694 0.52287587 0.02696682 52 | 0.97712413 0.78120306 0.52696682 53 | 0.47712413 0.28120306 0.47303318 54 | 0.78120306 0.97712413 0.02696682 55 | -------------------------------------------------------------------------------- /examples/inputs/raw/00021364446a637881257fd9ee912a422a6b1753.poscar: -------------------------------------------------------------------------------- 1 | -.38267572E+03 '048/ht.task.triolith--default.00021364446a637881257fd9ee912a422a6b1753.cleanup.0.unclaimed.2.finished/ht.run.2016-03-10_13.38.07' ['(Tag) original structure: F9RbTh2', '(Tag) original structure path: structures/all/matches/ternaries/62/7_62_1cA1cB1dC4dA.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.729996 0.000000 0.000000 4 | 0.000000 9.412429 0.000000 5 | 0.000000 0.000000 8.919448 6 | Hf Zn N 7 | 8 4 36 8 | Direct 9 | 0.25298222 0.45025421 0.31557996 10 | 0.24701778 0.95025421 0.81557996 11 | 0.24701778 0.54974579 0.81557996 12 | 0.74701778 0.95025421 0.68442004 13 | 0.74701778 0.54974579 0.68442004 14 | 0.75298222 0.04974579 0.18442004 15 | 0.75298222 0.45025421 0.18442004 16 | 0.25298222 0.04974579 0.31557996 17 | 0.59697916 0.25000000 0.82686578 18 | 0.90302084 0.75000000 0.32686578 19 | 0.40302084 0.75000000 0.17313422 20 | 0.09697916 0.25000000 0.67313422 21 | 0.01809884 0.57498929 0.15848322 22 | 0.48190116 0.07498929 0.65848322 23 | 0.48190116 0.42501071 0.65848322 24 | 0.98190116 0.07498929 0.84151678 25 | 0.98190116 0.42501071 0.84151678 26 | 0.51809884 0.92501071 0.34151678 27 | 0.51809884 0.57498929 0.34151678 28 | 0.01809884 0.92501071 0.15848322 29 | 0.38099059 0.11552545 0.96761734 30 | 0.11900941 0.61552545 0.46761734 31 | 0.11900941 0.88447455 0.46761734 32 | 0.61900941 0.61552545 0.03238266 33 | 0.61900941 0.88447455 0.03238266 34 | 0.88099059 0.38447455 0.53238266 35 | 0.88099059 0.11552545 0.53238266 36 | 0.38099059 0.38447455 0.96761734 37 | 0.18405523 0.59629476 0.09903596 38 | 0.31594477 0.09629476 0.59903596 39 | 0.31594477 0.40370524 0.59903596 40 | 0.81594477 0.09629476 0.90096404 41 | 0.81594477 0.40370524 0.90096404 42 | 0.68405523 0.90370524 0.40096404 43 | 0.68405523 0.59629476 0.40096404 44 | 0.18405523 0.90370524 0.09903596 45 | 0.94408897 0.38934530 0.40753579 46 | 0.55591103 0.88934530 0.90753579 47 | 0.55591103 0.61065470 0.90753579 48 | 0.05591103 0.88934530 0.59246421 49 | 0.05591103 0.61065470 0.59246421 50 | 0.44408897 0.11065470 0.09246421 51 | 0.44408897 0.38934530 0.09246421 52 | 0.94408897 0.11065470 0.40753579 53 | 0.32785766 0.25000000 0.33884582 54 | 0.17214234 0.75000000 0.83884582 55 | 0.67214234 0.75000000 0.66115418 56 | 0.82785766 0.25000000 0.16115418 57 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from aviary.core import masked_mean, masked_std, np_one_hot, np_softmax 6 | 7 | 8 | def test_np_one_hot(): 9 | assert np.allclose(np_one_hot(np.arange(3)), np.eye(3)) 10 | 11 | # test n_classes kwarg 12 | out = np_one_hot(np.arange(3), n_classes=5) 13 | expected = np.eye(5)[np.arange(3)] 14 | assert np.allclose(out, expected) 15 | 16 | 17 | def test_np_softmax(): 18 | x1 = np.array([[0, 1, 2], [3, 4, 5]]) 19 | expected = np.array([[0.0900, 0.2447, 0.6652], [0.0900, 0.2447, 0.6652]]) 20 | assert np.allclose(np_softmax(x1), expected, atol=1e-4) 21 | 22 | for xi in np.random.rand(3, 2, 3): 23 | for axis in (0, 1): 24 | # test numbers in softmaxed dimension all sum to 1 25 | out = np_softmax(xi, axis=axis) 26 | assert np.allclose(out.sum(axis=axis), 1) 27 | 28 | 29 | def test_masked_mean(): 30 | # test 1d tensor 31 | x1 = torch.arange(5).float() 32 | mask1 = torch.tensor([0, 1, 0, 1, 0]).bool() 33 | assert masked_mean(x1, mask1) == 2 34 | 35 | assert masked_mean(x1, mask1) == sum(x1 * mask1) / sum(mask1) 36 | 37 | # test 2d tensor 38 | x2 = torch.tensor([[1, 2, 3], [4, 5, 6]]).float() 39 | mask2 = torch.tensor([[1, 1, 0], [0, 0, 1]]).bool() 40 | 41 | assert masked_mean(x2, mask2) == pytest.approx([1, 2, 6]) 42 | assert masked_mean(x2, mask2, dim=1) == pytest.approx([1.5, 6]) 43 | 44 | 45 | def test_masked_std(): 46 | # test 1d tensor 47 | x1 = torch.arange(5).float() 48 | mask1 = torch.tensor([0, 1, 0, 1, 0]).bool() 49 | assert masked_std(x1, mask1) == 1 50 | 51 | # test 2d tensor 52 | x2 = torch.tensor([[1, 1, 1], [2, 2, 4]]).float() 53 | mask2 = torch.tensor([[0, 1, 0], [1, 0, 1]]).bool() 54 | assert masked_std(x2, mask2, dim=0) == pytest.approx([0, 0, 0], abs=1e-4) 55 | assert masked_std(x2, mask2, dim=1) == pytest.approx([0, 1], abs=1e-4) 56 | 57 | # test against explicit calculation 58 | rand_floats = torch.rand(3, 4, 5) 59 | rand_masks = torch.randint(0, 2, (3, 4, 5)).bool() 60 | for xi, mask in zip(rand_floats, rand_masks, strict=False): 61 | for dim in (0, 1): 62 | out = masked_std(xi, mask, dim=dim) 63 | xi_nan = torch.where(mask, xi, torch.tensor(float("nan"))) 64 | mean = xi_nan.nanmean(dim=dim) 65 | std = (xi_nan - mean.unsqueeze(dim=dim)).pow(2).nanmean(dim=dim).sqrt() 66 | 67 | assert out == pytest.approx(std, abs=1e-4, nan_ok=True) 68 | -------------------------------------------------------------------------------- /examples/inputs/raw/006197ce31efdef23c82f0f6cec08ac0b0febdce.poscar: -------------------------------------------------------------------------------- 1 | -.39938747E+03 '060/ht.task.triolith--default.006197ce31efdef23c82f0f6cec08ac0b0febdce.cleanup.0.unclaimed.2.finished/ht.run.2016-01-14_17.42.34' ['(Tag) original structure: As2Pr4S9', '(Tag) original structure path: structures/all/matches/ternaries/60/8_60_1cA1dB2dC4dA.cif', '(Tag) type: ternary'] 2 | 1 3 | 19.735812 0.000000 0.000000 4 | 0.000000 6.101939 0.000000 5 | 0.000000 0.000000 5.570787 6 | Zn N Hf 7 | 16 36 8 8 | Direct 9 | 0.43941653 0.35413699 0.14372104 10 | 0.56058347 0.35413699 0.35627896 11 | 0.06058347 0.14586301 0.64372104 12 | 0.56058347 0.64586301 0.85627896 13 | 0.06058347 0.85413699 0.14372104 14 | 0.43941653 0.64586301 0.64372104 15 | 0.93941653 0.85413699 0.35627896 16 | 0.93941653 0.14586301 0.85627896 17 | 0.28716220 0.40737164 0.21147903 18 | 0.71283780 0.40737164 0.28852097 19 | 0.21283780 0.09262836 0.71147903 20 | 0.71283780 0.59262836 0.78852097 21 | 0.21283780 0.90737164 0.21147903 22 | 0.28716220 0.59262836 0.71147903 23 | 0.78716220 0.90737164 0.28852097 24 | 0.78716220 0.09262836 0.78852097 25 | -0.00000000 0.13213508 0.25000000 26 | 0.50000000 0.36786492 0.75000000 27 | -0.00000000 0.86786492 0.75000000 28 | 0.50000000 0.63213508 0.25000000 29 | 0.28161480 0.90975240 0.50653017 30 | 0.71838520 0.90975240 0.99346983 31 | 0.21838520 0.59024760 0.00653017 32 | 0.71838520 0.09024760 0.49346983 33 | 0.21838520 0.40975240 0.50653017 34 | 0.28161480 0.09024760 0.00653017 35 | 0.78161480 0.40975240 0.99346983 36 | 0.78161480 0.59024760 0.49346983 37 | 0.46760771 0.05450945 0.24347796 38 | 0.53239229 0.05450945 0.25652204 39 | 0.03239229 0.44549055 0.74347796 40 | 0.53239229 0.94549055 0.75652204 41 | 0.03239229 0.55450945 0.24347796 42 | 0.46760771 0.94549055 0.74347796 43 | 0.96760771 0.55450945 0.25652204 44 | 0.96760771 0.44549055 0.75652204 45 | 0.34446810 0.04626765 0.98351114 46 | 0.65553190 0.04626765 0.51648886 47 | 0.15553190 0.45373235 0.48351114 48 | 0.65553190 0.95373235 0.01648886 49 | 0.15553190 0.54626765 0.98351114 50 | 0.34446810 0.95373235 0.48351114 51 | 0.84446810 0.54626765 0.51648886 52 | 0.84446810 0.45373235 0.01648886 53 | 0.37061771 0.46850177 0.41085414 54 | 0.62938229 0.46850177 0.08914586 55 | 0.12938229 0.03149823 0.91085414 56 | 0.62938229 0.53149823 0.58914586 57 | 0.12938229 0.96850177 0.41085414 58 | 0.37061771 0.53149823 0.91085414 59 | 0.87061771 0.96850177 0.08914586 60 | 0.87061771 0.03149823 0.58914586 61 | 0.09968678 0.26946500 0.18570777 62 | 0.90031322 0.26946500 0.31429223 63 | 0.40031322 0.23053500 0.68570777 64 | 0.90031322 0.73053500 0.81429223 65 | 0.40031322 0.76946500 0.18570777 66 | 0.09968678 0.73053500 0.68570777 67 | 0.59968678 0.76946500 0.31429223 68 | 0.59968678 0.23053500 0.81429223 69 | -------------------------------------------------------------------------------- /tests/test_print_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from aviary.utils import print_metrics_classification, print_metrics_regression 4 | 5 | # generate random data to test print functions 6 | rng = np.random.RandomState(42) 7 | 8 | # Generate reg data 9 | xs = rng.rand(100) 10 | y_pred = xs + 0.1 * rng.normal(size=100) 11 | y_true = xs + 0.1 * rng.normal(size=100) 12 | 13 | # Generate clf data 14 | y_binary = rng.choice([0, 1], (100)) 15 | y_proba = np.clip(y_binary - 0.1 * rng.normal(scale=5, size=(100)), 0.1, 0.9) 16 | 17 | # NOTE binary clf is handled as a multi-class clf problem therefore we need 18 | # to add another prediction dimension to accommodate the negative class 19 | y_probs = np.expand_dims(y_proba, axis=(0, 2)) 20 | y_probs = np.tile(y_probs, (1, 1, 2)) 21 | y_probs[0, :, 1] = 1 - y_proba 22 | 23 | 24 | def test_regression_metrics(capsys): 25 | print_metrics_regression(y_true, y_pred[None, :]) 26 | out, err = capsys.readouterr() 27 | assert err == "" 28 | lines = out.split("\n") 29 | assert len(lines) == 5 30 | assert out.startswith("Model Performance Metrics:\nR2 Score: ") 31 | assert lines[2].startswith("MAE: ") 32 | assert lines[3].startswith("RMSE: ") 33 | 34 | 35 | def test_regression_metrics_ensemble(capsys): 36 | # simulate 2-model ensemble by duplicating predictions along 0-axis 37 | y_preds = np.tile(y_pred, (2, 1)) 38 | print_metrics_regression(y_true, y_preds) 39 | out, err = capsys.readouterr() 40 | assert err == "" 41 | lines = out.split("\n") 42 | assert len(lines) == 10 43 | assert out.startswith("Model Performance Metrics:\nR2 Score: ") 44 | assert lines[2].startswith("MAE: ") 45 | assert "+/-" in lines[2] 46 | assert lines[3].startswith("RMSE: ") 47 | assert "+/-" in lines[3] 48 | assert lines[5].startswith("Ensemble Performance Metrics:") 49 | 50 | 51 | def test_classification_metrics(capsys): 52 | print_metrics_classification(y_binary, y_probs) 53 | out, err = capsys.readouterr() 54 | assert err == "" 55 | lines = out.split("\n") 56 | assert len(lines) == 8 57 | assert out.startswith("\nModel Performance Metrics:\nAccuracy") 58 | assert lines[3].startswith("ROC-AUC") 59 | assert lines[4].startswith("Weighted Precision") 60 | assert lines[5].startswith("Weighted Recall") 61 | assert lines[6].startswith("Weighted F-score") 62 | 63 | 64 | def test_classification_metrics_ensemble(capsys): 65 | y_probs_ens = np.tile(y_probs, (2, 1, 1)) 66 | print_metrics_classification(y_binary, y_probs_ens) 67 | out, err = capsys.readouterr() 68 | assert err == "" 69 | lines = out.split("\n") 70 | assert len(lines) == 15 71 | assert out.startswith("\nModel Performance Metrics:\nAccuracy") 72 | assert lines[3].startswith("ROC-AUC") 73 | assert "+/-" in lines[3] 74 | assert lines[4].startswith("Weighted Precision") 75 | assert "+/-" in lines[4] 76 | assert lines[5].startswith("Weighted Recall") 77 | assert "+/-" in lines[5] 78 | assert lines[6].startswith("Weighted F-score") 79 | assert "+/-" in lines[6] 80 | assert lines[8].startswith("Ensemble Performance Metrics:") 81 | -------------------------------------------------------------------------------- /aviary/scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def scatter_reduce(src, index, dim=-1, dim_size=None, reduce="sum"): 5 | """Performs a scatter-reduce operation on the input tensor. 6 | 7 | This function scatters the elements from the source tensor (src) into a new tensor 8 | of shape determined by dim_size along the specified dimension (dim), using the 9 | given reduction method. It's compatible with autograd for gradient computation. 10 | 11 | NOTE this function was written by Claude 3.5 Sonnet. 12 | 13 | Args: 14 | src (torch.Tensor): The source tensor. 15 | index (torch.Tensor): The indices of elements to scatter. Must be 1D or have 16 | the same number of dimensions as src. 17 | dim (int, optional): The axis along which to index. Defaults to -1. 18 | dim_size (int, optional): The size of the output tensor's dimension `dim`. 19 | If None, it's inferred as index.max().item() + 1. Defaults to None. 20 | reduce (str, optional): The reduction operation to perform. 21 | Options: "sum", "mean", "amax", "max", "amin", "min", "prod". 22 | Defaults to "sum". 23 | 24 | Returns: 25 | torch.Tensor: The output tensor after the scatter-reduce operation. 26 | 27 | Raises: 28 | ValueError: If an unsupported reduction method is specified. 29 | RuntimeError: If index and src tensors are incompatible. 30 | 31 | Example: 32 | >>> src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) 33 | >>> index = torch.tensor([0, 1, 0, 1, 2]) 34 | >>> scatter_reduce(src, index, dim=0, reduce="sum") 35 | tensor([4., 6., 5.]) 36 | """ 37 | if dim_size is None: 38 | dim_size = index.max().item() + 1 39 | 40 | # Prepare the output tensor shape 41 | shape = list(src.shape) 42 | shape[dim] = dim_size 43 | 44 | # Ensure index has the same number of dimensions as src 45 | if index.dim() != src.dim(): 46 | if index.dim() != 1: 47 | raise RuntimeError( 48 | "Index tensor must be 1D or have the same number of dimensions " 49 | f"as src tensor. {index.shape=} != {src.shape=}" 50 | ) 51 | # Expand index to match src dimensions 52 | repeat_shape = [1] * src.dim() 53 | repeat_shape[dim] = src.size(dim) 54 | index = index.view(-1, *[1] * (src.dim() - 1)).expand_as(src) 55 | 56 | # Perform scatter_reduce operation 57 | if reduce in ["sum", "mean"]: 58 | out = torch.zeros(shape, dtype=src.dtype, device=src.device) 59 | out = out.scatter_add(dim, index, src) 60 | if reduce == "mean": 61 | count = torch.zeros(shape, dtype=src.dtype, device=src.device) 62 | count = count.scatter_add(dim, index, torch.ones_like(src)) 63 | out = out / (count + (count == 0).float()) # avoid division by zero 64 | elif reduce in ["amax", "max"]: 65 | out = torch.full(shape, float("-inf"), dtype=src.dtype, device=src.device) 66 | out = torch.max(out, out.scatter(dim, index, src)) 67 | elif reduce in ["amin", "min"]: 68 | out = torch.full(shape, float("inf"), dtype=src.dtype, device=src.device) 69 | out = torch.min(out, out.scatter(dim, index, src)) 70 | elif reduce == "prod": 71 | out = torch.ones(shape, dtype=src.dtype, device=src.device) 72 | out = out.scatter(dim, index, src, reduce="multiply") 73 | else: 74 | raise ValueError(f"Unsupported reduction method: {reduce}") 75 | 76 | return out 77 | -------------------------------------------------------------------------------- /examples/inputs/raw/0069751a4ba575438917a8b4e5a565b686388f6e.poscar: -------------------------------------------------------------------------------- 1 | -.38833339E+03 '072/ht.task.triolith--default.0069751a4ba575438917a8b4e5a565b686388f6e.cleanup.0.unclaimed.2.finished/ht.run.2016-02-05_08.20.36' ['(Tag) original structure: GeNa4Se4', '(Tag) original structure path: structures/all/matches/ternaries/62/13_62_2cA2cB2dC3dB4cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 21.954219 0.000000 0.000000 4 | 0.000000 7.464043 0.000000 5 | 0.000000 -0.000000 5.699575 6 | Zn Hf N 7 | 32 8 32 8 | Direct 9 | 0.15188087 0.25000000 0.37123552 10 | 0.34811913 0.75000000 0.87123552 11 | 0.84811913 0.75000000 0.62876448 12 | 0.65188087 0.25000000 0.12876448 13 | 0.07011749 0.75000000 0.53590908 14 | 0.42988251 0.25000000 0.03590908 15 | 0.92988251 0.25000000 0.46409092 16 | 0.57011749 0.75000000 0.96409092 17 | 0.22545090 0.58089851 0.61111204 18 | 0.27454910 0.08089851 0.11111204 19 | 0.27454910 0.41910149 0.11111204 20 | 0.77454910 0.08089851 0.38888796 21 | 0.77454910 0.41910149 0.38888796 22 | 0.72545090 0.91910149 0.88888796 23 | 0.72545090 0.58089851 0.88888796 24 | 0.22545090 0.91910149 0.61111204 25 | 0.46623328 0.53080619 0.77253392 26 | 0.03376672 0.03080619 0.27253392 27 | 0.03376672 0.46919381 0.27253392 28 | 0.53376672 0.03080619 0.22746608 29 | 0.53376672 0.46919381 0.22746608 30 | 0.96623328 0.96919381 0.72746608 31 | 0.96623328 0.53080619 0.72746608 32 | 0.46623328 0.96919381 0.77253392 33 | 0.36582353 0.48136016 0.38535976 34 | 0.13417647 0.98136016 0.88535976 35 | 0.13417647 0.51863984 0.88535976 36 | 0.63417647 0.98136016 0.61464024 37 | 0.63417647 0.51863984 0.61464024 38 | 0.86582353 0.01863984 0.11464024 39 | 0.86582353 0.48136016 0.11464024 40 | 0.36582353 0.01863984 0.38535976 41 | 0.22946066 0.75000000 0.15437777 42 | 0.27053934 0.25000000 0.65437777 43 | 0.77053934 0.25000000 0.84562223 44 | 0.72946066 0.75000000 0.34562223 45 | 0.44935397 0.75000000 0.30458592 46 | 0.05064603 0.25000000 0.80458592 47 | 0.55064603 0.25000000 0.69541408 48 | 0.94935397 0.75000000 0.19541408 49 | 0.32504633 0.75000000 0.23191919 50 | 0.17495367 0.25000000 0.73191919 51 | 0.67495367 0.25000000 0.76808081 52 | 0.82504633 0.75000000 0.26808081 53 | 0.31700735 0.25000000 0.33328863 54 | 0.18299265 0.75000000 0.83328863 55 | 0.68299265 0.75000000 0.66671137 56 | 0.81700735 0.25000000 0.16671137 57 | 0.29931688 0.49244928 0.78952729 58 | 0.20068312 0.99244928 0.28952729 59 | 0.20068312 0.50755072 0.28952729 60 | 0.70068312 0.99244928 0.21047271 61 | 0.70068312 0.50755072 0.21047271 62 | 0.79931688 0.00755072 0.71047271 63 | 0.79931688 0.49244928 0.71047271 64 | 0.29931688 0.00755072 0.78952729 65 | 0.45714683 0.25000000 0.69923780 66 | 0.04285317 0.75000000 0.19923780 67 | 0.54285317 0.75000000 0.30076220 68 | 0.95714683 0.25000000 0.80076220 69 | 0.42259913 0.75000000 0.66077824 70 | 0.07740087 0.25000000 0.16077824 71 | 0.57740087 0.25000000 0.33922176 72 | 0.92259913 0.75000000 0.83922176 73 | 0.06081497 0.50723160 0.66631613 74 | 0.43918503 0.00723160 0.16631613 75 | 0.43918503 0.49276840 0.16631613 76 | 0.93918503 0.00723160 0.33368387 77 | 0.93918503 0.49276840 0.33368387 78 | 0.56081497 0.99276840 0.83368387 79 | 0.56081497 0.50723160 0.83368387 80 | 0.06081497 0.99276840 0.66631613 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "aviary-models" 3 | version = "1.2.1" 4 | description = "A collection of machine learning models for materials discovery" 5 | authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }] 6 | readme = "README.md" 7 | license = { file = "LICENSE" } 8 | keywords = [ 9 | "Graph Neural Network", 10 | "Machine Learning", 11 | "Materials Discovery", 12 | "Materials Informatics", 13 | "Materials Science", 14 | "Roost", 15 | "Self-Attention", 16 | "Transformer", 17 | "Wren", 18 | "Wyckoff positions", 19 | ] 20 | classifiers = [ 21 | "Intended Audience :: Science/Research", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | "Topic :: Scientific/Engineering :: Chemistry", 28 | "Topic :: Scientific/Engineering :: Physics", 29 | ] 30 | 31 | requires-python = ">=3.10" 32 | dependencies = [ 33 | "numpy>=2,<3", 34 | "pandas", 35 | "pymatgen>=2025.4.10", 36 | "scikit_learn", 37 | "tensorboard", 38 | "torch>=2.3.0", 39 | "tqdm", 40 | "typing-extensions", 41 | "wandb", 42 | ] 43 | 44 | [project.urls] 45 | Repo = "https://github.com/CompRhys/aviary" 46 | 47 | [project.optional-dependencies] 48 | test = ["matminer", "moyopy>=0.3.3", "pytest", "pytest-cov"] 49 | moyopy = ["moyopy>=0.3.3"] 50 | 51 | [build-system] 52 | requires = ["uv_build>=0.7.5"] 53 | build-backend = "uv_build" 54 | 55 | [tool.uv.build-backend] 56 | module-name = "aviary" 57 | module-root = "" 58 | 59 | [tool.pytest.ini_options] 60 | testpaths = ["tests"] 61 | addopts = "-p no:warnings" 62 | 63 | [tool.mypy] 64 | no_implicit_optional = false 65 | 66 | [tool.ruff] 67 | line-length = 90 68 | target-version = "py310" 69 | output-format = "concise" 70 | 71 | [tool.ruff.lint] 72 | select = [ 73 | "B", # flake8-bugbear 74 | "C4", # flake8-comprehensions 75 | "D", # pydocstyle 76 | "E", # pycodestyle error 77 | "EXE", # flake8-executable 78 | "F", # pyflakes 79 | "FA", # flake8-future-annotations 80 | "FLY", # flynt 81 | "I", # isort 82 | "ICN", # flake8-import-conventions 83 | "ISC", # flake8-implicit-str-concat 84 | "PD", # pandas-vet 85 | "PERF", # perflint 86 | "PIE", # flake8-pie 87 | "PL", # pylint 88 | "PT", # flake8-pytest-style 89 | "PYI", # flakes8-pyi 90 | "Q", # flake8-quotes 91 | "RET", # flake8-return 92 | "RSE", # flake8-raise 93 | "RUF", # Ruff-specific rules 94 | "SIM", # flake8-simplify 95 | "SLOT", # flake8-slots 96 | "TCH", # flake8-type-checking 97 | "TID", # tidy imports 98 | "TID", # flake8-tidy-imports 99 | "UP", # pyupgrade 100 | "W", # pycodestyle warning 101 | "YTT", # flake8-2020 102 | ] 103 | ignore = [ 104 | "C408", # Unnecessary dict call - rewrite as a literal 105 | "D100", # Missing docstring in public module 106 | "D104", # Missing docstring in public package 107 | "D105", # Missing docstring in magic method 108 | "D205", # 1 blank line required between summary line and description 109 | "E731", # Do not assign a lambda expression, use a def 110 | "ISC001", 111 | "PD901", # pandas-df-variable-name 112 | "PLC0415", 113 | "PLR", # pylint refactor 114 | "PT006", # pytest-parametrize-names-wrong-type 115 | ] 116 | pydocstyle.convention = "google" 117 | isort.known-third-party = ["wandb"] 118 | 119 | [tool.ruff.lint.per-file-ignores] 120 | "tests/*" = ["D"] 121 | "examples/notebooks/*.py" = ["E402"] 122 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ad3ae37a9fb7eab9189ff1cb384b9d2dd26bbf9.poscar: -------------------------------------------------------------------------------- 1 | -.61080027E+03 '078/ht.task.gamma--default.0ad3ae37a9fb7eab9189ff1cb384b9d2dd26bbf9.cleanup.0.unclaimed.3.finished/ht.run.2016-02-12_22.11.44' ['(Tag) original structure: Ba11O24Os4', '(Tag) original structure path: structures/all/matches/ternaries/88/12_88_1aA1cB1dB1eA2fA6fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.139712 0.000001 0.000000 4 | -0.000001 10.139712 -0.000000 5 | 5.069855 5.069856 8.570492 6 | Ti Zn N 7 | 22 8 48 8 | Direct 9 | 0.87500000 0.12500000 0.25000000 10 | 0.12500000 0.87500000 0.75000000 11 | 0.85706887 0.60706887 0.28586226 12 | 0.60706887 0.85706887 0.78586226 13 | 0.39293113 0.14293113 0.21413774 14 | 0.14293113 0.39293113 0.71413774 15 | 0.36659200 0.17340598 0.69506440 16 | 0.17340598 0.93834360 0.19506440 17 | 0.36847038 0.63340800 0.80493560 18 | 0.06165640 0.36847038 0.30493560 19 | 0.93834360 0.63152962 0.69506440 20 | 0.63340800 0.82659402 0.30493560 21 | 0.82659402 0.06165640 0.80493560 22 | 0.63152962 0.36659200 0.19506440 23 | 0.64923651 0.69033378 0.13171933 24 | 0.69033378 0.21904416 0.63171933 25 | 0.32205310 0.35076349 0.36828067 26 | 0.78095584 0.32205310 0.86828067 27 | 0.21904416 0.67794690 0.13171933 28 | 0.35076349 0.30966622 0.86828067 29 | 0.30966622 0.78095584 0.36828067 30 | 0.67794690 0.64923651 0.63171933 31 | -0.00000000 0.00000000 0.00000000 32 | -0.00000000 0.00000000 0.50000000 33 | 0.50000000 0.00000000 0.50000000 34 | -0.00000000 0.50000000 0.00000000 35 | 0.50000000 0.50000000 0.00000000 36 | 0.50000000 0.50000000 0.50000000 37 | -0.00000000 0.50000000 0.50000000 38 | 0.50000000 0.00000000 0.00000000 39 | 0.53440356 0.79247721 0.60695497 40 | 0.79247721 0.85864147 0.10695497 41 | 0.89943218 0.46559644 0.89304503 42 | 0.14135853 0.89943218 0.39304503 43 | 0.85864147 0.10056782 0.60695497 44 | 0.46559644 0.20752279 0.39304503 45 | 0.20752279 0.14135853 0.89304503 46 | 0.10056782 0.53440356 0.10695497 47 | 0.27999134 0.82335358 0.57883182 48 | 0.82335358 0.14117683 0.07883182 49 | 0.90218540 0.72000866 0.92116818 50 | 0.85882317 0.90218540 0.42116818 51 | 0.14117683 0.09781460 0.57883182 52 | 0.72000866 0.17664642 0.42116818 53 | 0.17664642 0.85882317 0.92116818 54 | 0.09781460 0.27999134 0.07883182 55 | 0.90435227 0.82714494 0.81183548 56 | 0.82714494 0.28381226 0.31183548 57 | 0.13898042 0.09564773 0.68816452 58 | 0.71618774 0.13898042 0.18816452 59 | 0.28381226 0.86101958 0.81183548 60 | 0.09564773 0.17285506 0.18816452 61 | 0.17285506 0.71618774 0.68816452 62 | 0.86101958 0.90435227 0.31183548 63 | 0.18007599 0.49398734 0.32307195 64 | 0.49398734 0.49685207 0.82307195 65 | 0.31705928 0.81992401 0.17692805 66 | 0.50314793 0.31705928 0.67692805 67 | 0.49685207 0.68294072 0.32307195 68 | 0.81992401 0.50601266 0.67692805 69 | 0.50601266 0.50314793 0.17692805 70 | 0.68294072 0.18007599 0.82307195 71 | 0.41501395 0.69820587 0.93773867 72 | 0.69820587 0.64724739 0.43773867 73 | 0.13594453 0.58498605 0.56226133 74 | 0.35275261 0.13594453 0.06226133 75 | 0.64724739 0.86405547 0.93773867 76 | 0.58498605 0.30179413 0.06226133 77 | 0.30179413 0.35275261 0.56226133 78 | 0.86405547 0.41501395 0.43773867 79 | 0.74435821 0.52432675 0.21225406 80 | 0.52432675 0.04338773 0.71225406 81 | 0.23658080 0.25564179 0.28774594 82 | 0.95661227 0.23658080 0.78774594 83 | 0.04338773 0.76341920 0.21225406 84 | 0.25564179 0.47567325 0.78774594 85 | 0.47567325 0.95661227 0.28774594 86 | 0.76341920 0.74435821 0.71225406 87 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a10bbf9e7cd3d8da2469f15e06ceca7d8d2af5d.poscar: -------------------------------------------------------------------------------- 1 | -.59711780E+03 '080/ht.task.gamma--default.0a10bbf9e7cd3d8da2469f15e06ceca7d8d2af5d.cleanup.0.unclaimed.3.finished/ht.run.2016-01-04_11.35.30' ['(Tag) original structure: N11P6Rb3', '(Tag) original structure path: structures/all/matches/ternaries/213/6_213_1aA1cB1cA1dB1eB1eC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.640122 -0.000335 -0.000327 4 | -0.000336 10.647665 0.001083 5 | -0.000326 0.001083 10.648228 6 | Zn Ti N 7 | 12 24 44 8 | Direct 9 | 0.37508445 0.37508887 0.37504493 10 | 0.62494728 0.87508013 0.12491953 11 | 0.12492617 0.62492222 0.87505292 12 | 0.87507172 0.12492622 0.62485794 13 | 0.72704563 0.72702715 0.72698344 14 | 0.27292335 0.22692544 0.77304857 15 | 0.97704783 0.52297581 0.47696343 16 | 0.77294307 0.27297746 0.22699959 17 | 0.02286844 0.02290848 0.02295900 18 | 0.22706647 0.77295062 0.27297118 19 | 0.52287739 0.47709822 0.97699114 20 | 0.47709683 0.97708896 0.52300783 21 | 0.06997315 0.27086890 0.27986617 22 | 0.72011622 0.56998201 0.22902245 23 | 0.31997055 0.97004664 0.02086265 24 | 0.27990750 0.06991229 0.27104539 25 | 0.22872668 0.72045335 0.56996215 26 | 0.46999786 0.47910899 0.68011386 27 | 0.27120699 0.27942688 0.07017936 28 | 0.56998693 0.22908425 0.72016603 29 | 0.97876892 0.82007522 0.52957564 30 | 0.72869531 0.77947918 0.42980911 31 | 0.52129812 0.17987130 0.02962728 32 | 0.67988556 0.47028440 0.47908583 33 | 0.17998432 0.02971812 0.52094208 34 | 0.02994635 0.52092573 0.18011369 35 | 0.02140657 0.32020134 0.97055809 36 | 0.22024769 0.92998691 0.77095846 37 | 0.93001137 0.77095766 0.22030524 38 | 0.43003656 0.72905476 0.77967244 39 | 0.77125825 0.22056725 0.92993063 40 | 0.52993122 0.97910780 0.81993856 41 | 0.97008829 0.02089702 0.31991814 42 | 0.82009873 0.52987022 0.97918264 43 | 0.47866263 0.67994669 0.47039571 44 | 0.77996067 0.43009416 0.72906280 45 | 0.94438375 0.42715485 0.32289350 46 | 0.67710879 0.44437890 0.07278568 47 | 0.19437105 0.92712974 0.17722607 48 | 0.32291109 0.94432523 0.42718949 49 | 0.07271000 0.67721224 0.44424549 50 | 0.42710803 0.32279449 0.80565916 51 | 0.42733405 0.32275592 0.94426201 52 | 0.44439866 0.07285410 0.67712669 53 | 0.82264647 0.69422908 0.57276947 54 | 0.57269570 0.82281734 0.55570719 55 | 0.67729931 0.30576928 0.07274910 56 | 0.80557950 0.42712590 0.32276359 57 | 0.30560411 0.07290852 0.67715989 58 | 0.07289180 0.67716750 0.30568349 59 | 0.17730793 0.19426276 0.92727206 60 | 0.17711398 0.05559526 0.92719094 61 | 0.05562500 0.92718605 0.17712709 62 | 0.55560647 0.57281449 0.82283937 63 | 0.92738006 0.17724268 0.05575621 64 | 0.57291827 0.82288582 0.69428940 65 | 0.92710292 0.17714834 0.19431068 66 | 0.69442518 0.57285312 0.82276475 67 | 0.32278145 0.80580202 0.42709828 68 | 0.82285963 0.55565492 0.57280392 69 | 0.12500800 0.16511079 0.41532897 70 | 0.58457373 0.62518218 0.33495937 71 | 0.37501382 0.83481996 0.91517228 72 | 0.41546897 0.12511502 0.16510698 73 | 0.33450992 0.58487564 0.62474184 74 | 0.16547506 0.41494480 0.12485987 75 | 0.62494007 0.33488535 0.58467423 76 | 0.08454666 0.87501842 0.66510908 77 | 0.83453232 0.91497982 0.37511309 78 | 0.91541597 0.37494171 0.83503988 79 | 0.87503930 0.66517762 0.08477465 80 | 0.66541513 0.08507066 0.87518667 81 | 0.17719115 0.17712798 0.17720488 82 | 0.82277684 0.67702595 0.32291383 83 | 0.42716227 0.07267635 0.92722819 84 | 0.32279849 0.82299536 0.67703616 85 | 0.57272894 0.57281672 0.57292929 86 | 0.67724781 0.32293006 0.82278498 87 | 0.07280887 0.92705742 0.42716620 88 | 0.92711841 0.42728898 0.07290209 89 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ad7c1ba7ec6a37cae0523bbd1c2a75f2d56a3f2.poscar: -------------------------------------------------------------------------------- 1 | -.50650847E+03 '080/ht.task.gamma--default.0ad7c1ba7ec6a37cae0523bbd1c2a75f2d56a3f2.cleanup.0.unclaimed.3.finished/ht.run.2016-01-03_13.19.29' ['(Tag) original structure: Fe5O12Y3', '(Tag) original structure path: structures/all/matches/ternaries/148/16_148_1aA1bA1dA1eA2fA2fB8fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.216080 -0.000000 3.537716 4 | -5.108039 8.847384 3.537716 5 | -5.108039 -8.847384 3.537716 6 | Ti Zn N 7 | 12 20 48 8 | Direct 9 | 0.36754946 0.12846507 0.24712369 10 | 0.87153493 0.75287631 0.63245054 11 | 0.75287631 0.63245054 0.87153493 12 | 0.24712369 0.36754946 0.12846507 13 | 0.12846507 0.24712369 0.36754946 14 | 0.63245054 0.87153493 0.75287631 15 | 0.87840662 0.62291357 0.25024206 16 | 0.37708643 0.74975794 0.12159338 17 | 0.74975794 0.12159338 0.37708643 18 | 0.25024206 0.87840662 0.62291357 19 | 0.62291357 0.25024206 0.87840662 20 | 0.12159338 0.37708643 0.74975794 21 | 0.00000000 0.00000000 0.00000000 22 | 0.50000000 0.50000000 0.50000000 23 | 0.50000000 0.00000000 0.00000000 24 | 0.00000000 0.00000000 0.50000000 25 | 0.00000000 0.50000000 0.00000000 26 | 0.50000000 0.50000000 0.00000000 27 | 0.50000000 0.00000000 0.50000000 28 | 0.00000000 0.50000000 0.50000000 29 | 0.62493308 0.37417690 0.24949405 30 | 0.62582310 0.75050595 0.37506692 31 | 0.75050595 0.37506692 0.62582310 32 | 0.24949405 0.62493308 0.37417690 33 | 0.37417690 0.24949405 0.62493308 34 | 0.37506692 0.62582310 0.75050595 35 | 0.87025560 0.12728315 0.74740531 36 | 0.87271685 0.25259469 0.12974440 37 | 0.25259469 0.12974440 0.87271685 38 | 0.74740531 0.87025560 0.12728315 39 | 0.12728315 0.74740531 0.87025560 40 | 0.12974440 0.87271685 0.25259469 41 | 0.27851882 0.17126274 0.07800032 42 | 0.82873726 0.92199968 0.72148118 43 | 0.92199968 0.72148118 0.82873726 44 | 0.07800032 0.27851882 0.17126274 45 | 0.17126274 0.07800032 0.27851882 46 | 0.72148118 0.82873726 0.92199968 47 | 0.61707782 0.20288302 0.29370045 48 | 0.79711698 0.70629955 0.38292218 49 | 0.70629955 0.38292218 0.79711698 50 | 0.29370045 0.61707782 0.20288302 51 | 0.20288302 0.29370045 0.61707782 52 | 0.38292218 0.79711698 0.70629955 53 | 0.17814519 0.59047415 0.89135218 54 | 0.40952585 0.10864782 0.82185481 55 | 0.10864782 0.82185481 0.40952585 56 | 0.89135218 0.17814519 0.59047415 57 | 0.59047415 0.89135218 0.17814519 58 | 0.82185481 0.40952585 0.10864782 59 | 0.91394805 0.09170363 0.30370414 60 | 0.90829637 0.69629586 0.08605195 61 | 0.69629585 0.08605195 0.90829637 62 | 0.30370415 0.91394805 0.09170363 63 | 0.09170363 0.30370414 0.91394805 64 | 0.08605195 0.90829637 0.69629586 65 | 0.31496298 0.20893908 0.40710560 66 | 0.79106092 0.59289440 0.68503702 67 | 0.59289440 0.68503702 0.79106092 68 | 0.40710560 0.31496298 0.20893908 69 | 0.20893908 0.40710560 0.31496298 70 | 0.68503702 0.79106092 0.59289440 71 | 0.30086016 0.88715915 0.20512678 72 | 0.11284085 0.79487322 0.69913984 73 | 0.79487322 0.69913984 0.11284085 74 | 0.20512678 0.30086016 0.88715915 75 | 0.88715915 0.20512678 0.30086016 76 | 0.69913984 0.11284085 0.79487322 77 | 0.90727520 0.31850294 0.60566245 78 | 0.68149706 0.39433755 0.09272480 79 | 0.39433755 0.09272480 0.68149706 80 | 0.60566245 0.90727520 0.31850294 81 | 0.31850294 0.60566245 0.90727520 82 | 0.09272480 0.68149706 0.39433755 83 | 0.40249571 0.58352797 0.19640396 84 | 0.41647203 0.80359604 0.59750429 85 | 0.80359604 0.59750429 0.41647203 86 | 0.19640396 0.40249571 0.58352797 87 | 0.58352797 0.19640396 0.40249571 88 | 0.59750429 0.41647203 0.80359604 89 | -------------------------------------------------------------------------------- /examples/matbench_example/train_wrenformer.py: -------------------------------------------------------------------------------- 1 | """Train a Wrenformer ensemble of size n_folds on collection of Matbench datasets.""" 2 | 3 | # %% 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | # Add the parent directory to system path 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | 11 | from datetime import datetime 12 | from itertools import product 13 | 14 | import wandb 15 | from matbench.metadata import mbv01_metadata 16 | from matbench.task import MatbenchTask 17 | from matbench_example.prepare_matbench_datasets import DATA_PATHS 18 | from matbench_example.trainer import train_wrenformer 19 | from matbench_example.utils import merge_json_on_disk, slurm_submit 20 | from matminer.utils.io import load_dataframe_from_json 21 | 22 | from aviary.core import TaskType 23 | 24 | MODULE_DIR = os.path.dirname(__file__) 25 | 26 | 27 | # %% 28 | epochs = 10 29 | folds = list(range(5)) 30 | timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}" 31 | today = timestamp.split("@")[0] 32 | # job_name unlike run_name doesn't include dataset and fold since not yet known without 33 | # SLURM_ARRAY_TASK_ID 34 | job_name = f"matbench-wrenformer-robust-{epochs=}" 35 | out_dir = f"{os.path.dirname(__file__)}/{today}-{job_name}" 36 | 37 | if "roost" in job_name.lower(): 38 | # deploy Roost on all tasks 39 | datasets = list(DATA_PATHS) 40 | else: 41 | # deploy Wren on structure tasks only 42 | datasets = [ 43 | k 44 | for k, v in mbv01_metadata.items() 45 | if v.input_type == "structure" and k in DATA_PATHS 46 | ] 47 | 48 | # NOTE: this script will run as is if you want to run it locally without slurm. 49 | slurm_submit( 50 | job_name=job_name, 51 | partition="ampere", 52 | account="LEE-SL3-GPU", 53 | time="8:0:0", 54 | array=f"0-{len(datasets) * len(folds) - 1}", 55 | out_dir=out_dir, 56 | slurm_flags=("--nodes", "1", "--gpus-per-node", "1"), 57 | # prepend into sbatch script to source module command and load default env 58 | # for Ampere GPU partition before actual job command 59 | pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;", 60 | ) 61 | 62 | 63 | # %% 64 | slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", "0")) 65 | print(f"Job started running {timestamp}") 66 | 67 | dataset_name, fold = list(product(datasets, folds))[slurm_array_task_id] 68 | print(f"{dataset_name=}") 69 | print(f"{fold=}") 70 | 71 | 72 | data_path = DATA_PATHS[dataset_name] 73 | id_col = "mbid" 74 | df = load_dataframe_from_json(data_path) 75 | df.index.name = id_col 76 | 77 | matbench_task = MatbenchTask(dataset_name, autoload=False) 78 | matbench_task.df = df 79 | 80 | target_col = matbench_task.metadata.target 81 | run_name = f"{job_name}-{dataset_name}-{fold=}-{target_col}" 82 | print(f"{run_name=}") 83 | task_type: TaskType = matbench_task.metadata.task_type 84 | 85 | train_df = matbench_task.get_train_and_val_data(fold, as_type="df") 86 | test_df = matbench_task.get_test_data(fold, as_type="df", include_target=True) 87 | 88 | wandb_path = None 89 | 90 | test_metrics, run_params, test_df = train_wrenformer( 91 | checkpoint=None, # None | 'local' | 'wandb' 92 | run_name=run_name, 93 | train_df=train_df, 94 | test_df=test_df, 95 | target_col=target_col, 96 | task_type=task_type, 97 | id_col=id_col, 98 | # set to None to disable logging 99 | wandb_path=wandb_path, 100 | run_params=dict(dataset=dataset_name, fold=fold), 101 | timestamp=timestamp, 102 | epochs=epochs, 103 | ) 104 | 105 | 106 | # %% 107 | # save model predictions to JSON 108 | preds_path = f"{MODULE_DIR}/model_preds/{timestamp}-{run_name}.json" 109 | os.makedirs(os.path.dirname(preds_path), exist_ok=True) 110 | 111 | # record model predictions 112 | test_df[id_col] = test_df.index 113 | preds_dict = test_df[[id_col, target_col, f"{target_col}_pred_0"]].to_dict(orient="list") 114 | merge_json_on_disk({dataset_name: {f"fold_{fold}": preds_dict}}, preds_path) 115 | 116 | # save model scores to JSON 117 | scores_path = f"{MODULE_DIR}/model_scores/{timestamp}-{run_name}.json" 118 | os.makedirs(os.path.dirname(scores_path), exist_ok=True) 119 | 120 | scores_dict = {dataset_name: {f"fold_{fold}": test_metrics}} 121 | scores_dict["params"] = run_params 122 | if wandb_path is not None: 123 | scores_dict["wandb_run"] = wandb.run.get_url() 124 | merge_json_on_disk(scores_dict, scores_path) 125 | 126 | print(f"scores for {fold = } of task {dataset_name} written to {scores_path}") 127 | -------------------------------------------------------------------------------- /aviary/networks.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | from torch import Tensor, nn 4 | 5 | 6 | class SimpleNetwork(nn.Module): 7 | """Simple Feed Forward Neural Network.""" 8 | 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | output_dim: int, 13 | hidden_layer_dims: Sequence[int], 14 | activation: type[nn.Module] = nn.LeakyReLU, 15 | batch_norm: bool = False, 16 | ) -> None: 17 | """Create a simple feed forward neural network. 18 | 19 | Args: 20 | input_dim (int): Number of input features 21 | output_dim (int): Number of output features 22 | hidden_layer_dims (list[int]): List of hidden layer sizes 23 | activation (type[nn.Module], optional): Which activation function to use. 24 | Defaults to nn.LeakyReLU. 25 | batch_norm (bool, optional): Whether to use batch_norm. Defaults to False. 26 | """ 27 | super().__init__() 28 | 29 | dims = [input_dim, *list(hidden_layer_dims)] 30 | 31 | self.fcs = nn.ModuleList( 32 | nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1) 33 | ) 34 | 35 | if batch_norm: 36 | self.bns = nn.ModuleList( 37 | nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1) 38 | ) 39 | else: 40 | self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1)) 41 | 42 | self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1)) 43 | 44 | self.fc_out = nn.Linear(dims[-1], output_dim) 45 | 46 | def forward(self, x: Tensor) -> Tensor: 47 | """Forward pass through network.""" 48 | for fc, bn, act in zip(self.fcs, self.bns, self.acts, strict=False): 49 | x = act(bn(fc(x))) 50 | 51 | return self.fc_out(x) 52 | 53 | def reset_parameters(self) -> None: 54 | """Reinitialize network weights using PyTorch defaults.""" 55 | for fc in self.fcs: 56 | fc.reset_parameters() 57 | 58 | self.fc_out.reset_parameters() 59 | 60 | def __repr__(self) -> str: 61 | input_dim = self.fcs[0].in_features 62 | output_dim = self.fc_out.out_features 63 | activation = type(self.acts[0]).__name__ 64 | return f"{type(self).__name__}({input_dim=}, {output_dim=}, {activation=})" 65 | 66 | 67 | class ResidualNetwork(nn.Module): 68 | """Feed forward Residual Neural Network.""" 69 | 70 | def __init__( 71 | self, 72 | input_dim: int, 73 | output_dim: int, 74 | hidden_layer_dims: Sequence[int], 75 | activation: type[nn.Module] = nn.ReLU, 76 | batch_norm: bool = False, 77 | ) -> None: 78 | """Create a feed forward neural network with skip connections. 79 | 80 | Args: 81 | input_dim (int): Number of input features 82 | output_dim (int): Number of output features 83 | hidden_layer_dims (list[int]): List of hidden layer sizes 84 | activation (type[nn.Module], optional): Which activation function to use. 85 | Defaults to nn.LeakyReLU. 86 | batch_norm (bool, optional): Whether to use batch_norm. Defaults to False. 87 | """ 88 | super().__init__() 89 | 90 | dims = [input_dim, *list(hidden_layer_dims)] 91 | 92 | self.fcs = nn.ModuleList( 93 | nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1) 94 | ) 95 | 96 | if batch_norm: 97 | self.bns = nn.ModuleList( 98 | nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1) 99 | ) 100 | else: 101 | self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1)) 102 | 103 | self.res_fcs = nn.ModuleList( 104 | nn.Linear(dims[idx], dims[idx + 1], bias=False) 105 | if (dims[idx] != dims[idx + 1]) 106 | else nn.Identity() 107 | for idx in range(len(dims) - 1) 108 | ) 109 | self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1)) 110 | 111 | self.fc_out = nn.Linear(dims[-1], output_dim) 112 | 113 | def forward(self, x: Tensor) -> Tensor: 114 | """Forward pass through network.""" 115 | for fc, bn, res_fc, act in zip( 116 | self.fcs, self.bns, self.res_fcs, self.acts, strict=False 117 | ): 118 | x = act(bn(fc(x))) + res_fc(x) 119 | 120 | return self.fc_out(x) 121 | 122 | def __repr__(self) -> str: 123 | input_dim = self.fcs[0].in_features 124 | output_dim = self.fc_out.out_features 125 | activation = type(self.acts[0]).__name__ 126 | return f"{type(self).__name__}({input_dim=}, {output_dim=}, {activation=})" 127 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | import torch 7 | 8 | from aviary.utils import get_element_embedding, get_metrics, get_sym_embedding 9 | 10 | 11 | @pytest.fixture 12 | def temp_element_embedding(tmp_path): 13 | embedding_data = { 14 | "H": [1.0, 2.0], 15 | "He": [3.0, 4.0], 16 | "Li": [5.0, 6.0], 17 | } 18 | path = tmp_path / "test_elem_embedding.json" 19 | with open(path, "w") as f: 20 | json.dump(embedding_data, f) 21 | return str(path) 22 | 23 | 24 | @pytest.fixture 25 | def temp_sym_embedding(tmp_path): 26 | embedding_data = { 27 | "1": {"a": [1.0, 2.0], "b": [3.0, 4.0]}, 28 | "2": {"c": [5.0, 6.0]}, 29 | } 30 | path = tmp_path / "test_sym_embedding.json" 31 | with open(path, "w") as f: 32 | json.dump(embedding_data, f) 33 | return str(path) 34 | 35 | 36 | def test_get_element_embedding_custom(temp_element_embedding): 37 | embedding = get_element_embedding(temp_element_embedding) 38 | assert isinstance(embedding, torch.nn.Embedding) 39 | assert embedding.weight.shape == (3 + 1, 2) # max_Z + 1, embedding_dim 40 | assert torch.allclose(embedding.weight[1], torch.tensor([1.0, 2.0])) # H 41 | assert torch.allclose(embedding.weight[2], torch.tensor([3.0, 4.0])) # He 42 | 43 | 44 | def test_get_element_embedding_builtin(): 45 | embedding = get_element_embedding("matscholar200") 46 | assert isinstance(embedding, torch.nn.Embedding) 47 | assert embedding.weight.shape[1] == 200 48 | 49 | 50 | def test_get_element_embedding_invalid(): 51 | with pytest.raises(ValueError, match="Invalid element embedding: invalid_embedding"): 52 | get_element_embedding("invalid_embedding") 53 | 54 | 55 | def test_get_sym_embedding_custom(temp_sym_embedding): 56 | embedding = get_sym_embedding(temp_sym_embedding) 57 | assert isinstance(embedding, torch.nn.Embedding) 58 | assert embedding.weight.shape == (3, 2) # total features, embedding_dim 59 | assert torch.allclose(embedding.weight[0], torch.tensor([1.0, 2.0])) 60 | assert torch.allclose(embedding.weight[1], torch.tensor([3.0, 4.0])) 61 | 62 | 63 | def test_get_sym_embedding_builtin(): 64 | embedding = get_sym_embedding("bra-alg-off") 65 | assert isinstance(embedding, torch.nn.Embedding) 66 | assert isinstance(embedding.weight, torch.Tensor) 67 | 68 | 69 | def test_get_sym_embedding_invalid(): 70 | with pytest.raises(ValueError, match="Invalid symmetry embedding: invalid_embedding"): 71 | get_sym_embedding("invalid_embedding") 72 | 73 | 74 | def test_regression_metrics(): 75 | targets = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) 76 | predictions = np.array([1.1, 2.1, 3.1, 4.1, 5.1]) 77 | 78 | metrics = get_metrics(targets, predictions, "regression") 79 | 80 | assert set(metrics.keys()) == {"MAE", "RMSE", "R2"} 81 | assert metrics["MAE"] == pytest.approx(0.1, abs=1e-4) 82 | assert metrics["RMSE"] == pytest.approx(0.1, abs=1e-4) 83 | assert metrics["R2"] == pytest.approx(0.995, abs=1e-4) 84 | 85 | 86 | def test_classification_metrics(): 87 | targets = np.array([0, 1, 0, 1, 0]) 88 | # Probabilities for class 0 and 1 89 | predictions = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]]) 90 | 91 | metrics = get_metrics(targets, predictions, "classification") 92 | 93 | assert set(metrics.keys()) == {"accuracy", "balanced_accuracy", "F1", "ROCAUC"} 94 | assert metrics["accuracy"] == 1.0 95 | assert metrics["balanced_accuracy"] == 1.0 96 | assert metrics["F1"] == 1.0 97 | assert metrics["ROCAUC"] == 1.0 98 | 99 | 100 | def test_nan_handling(): 101 | targets = np.array([1.0, np.nan, 3.0, 4.0]) 102 | predictions = np.array([1.1, 2.1, np.nan, 4.1]) 103 | 104 | metrics = get_metrics(targets, predictions, "regression") 105 | assert not np.isnan(metrics["MAE"]) 106 | assert not np.isnan(metrics["RMSE"]) 107 | assert not np.isnan(metrics["R2"]) 108 | 109 | 110 | def test_pandas_input(): 111 | targets = pd.Series([1.0, 2.0, 3.0]) 112 | predictions = pd.Series([1.1, 2.1, 3.1]) 113 | 114 | metrics = get_metrics(targets, predictions, "regression") 115 | assert set(metrics.keys()) == {"MAE", "RMSE", "R2"} 116 | 117 | 118 | def test_precision(): 119 | targets = np.array([1.0, 2.0, 3.0]) 120 | predictions = np.array([1.12345, 2.12345, 3.12345]) 121 | 122 | metrics = get_metrics(targets, predictions, "regression", prec=2) 123 | assert all(len(str(v).split(".")[-1]) <= 2 for v in metrics.values()) 124 | 125 | 126 | def test_invalid_type(): 127 | targets = np.array([1.0, 2.0]) 128 | predictions = np.array([1.1, 2.1]) 129 | 130 | with pytest.raises(ValueError, match="Invalid task type: invalid_type"): 131 | get_metrics(targets, predictions, "invalid_type") 132 | 133 | 134 | def test_mismatched_shapes(): 135 | targets = np.array([0, 1, 0]) 136 | predictions = np.array([[0.9, 0.1], [0.1, 0.9]]) # Wrong shape 137 | 138 | with pytest.raises(ValueError): # noqa: PT011 139 | get_metrics(targets, predictions, "classification") 140 | -------------------------------------------------------------------------------- /aviary/data.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable, Iterator 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | from typing_extensions import Self 8 | 9 | 10 | class Normalizer: 11 | """Normalize a Tensor and restore it later.""" 12 | 13 | def __init__(self) -> None: 14 | """Initialize Normalizer with mean 0 and std 1.""" 15 | self.mean = torch.tensor(0) 16 | self.std = torch.tensor(1) 17 | 18 | def fit(self, tensor: Tensor, dim: int = 0, keepdim: bool = False) -> None: 19 | """Compute the mean and standard deviation of the given tensor. 20 | 21 | Args: 22 | tensor (Tensor): Tensor to determine the mean and standard deviation over. 23 | dim (int, optional): Which dimension to take mean and standard deviation 24 | over. Defaults to 0. 25 | keepdim (bool, optional): Whether to keep the reduced dimension in Tensor. 26 | Defaults to False. 27 | """ 28 | self.mean = torch.mean(tensor, dim, keepdim) 29 | self.std = torch.std(tensor, dim, keepdim) 30 | 31 | def norm(self, tensor: Tensor) -> Tensor: 32 | """Normalize a Tensor. 33 | 34 | Args: 35 | tensor (Tensor): Tensor to be normalized 36 | 37 | Returns: 38 | Tensor: Normalized Tensor 39 | """ 40 | return (tensor - self.mean) / self.std 41 | 42 | def denorm(self, normed_tensor: Tensor) -> Tensor: 43 | """Restore normalized Tensor to original. 44 | 45 | Args: 46 | normed_tensor (Tensor): Tensor to be restored 47 | 48 | Returns: 49 | Tensor: Restored Tensor 50 | """ 51 | return normed_tensor * self.std + self.mean 52 | 53 | def state_dict(self) -> dict[str, Tensor]: 54 | """Get Normalizer parameters mean and std. 55 | 56 | Returns: 57 | dict[str, Tensor]: Dictionary storing Normalizer parameters. 58 | """ 59 | return {"mean": self.mean, "std": self.std} 60 | 61 | def load_state_dict(self, state_dict: dict[str, Tensor]) -> None: 62 | """Overwrite Normalizer parameters given a new state_dict. 63 | 64 | Args: 65 | state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters. 66 | """ 67 | self.mean = state_dict["mean"].cpu() 68 | self.std = state_dict["std"].cpu() 69 | 70 | @classmethod 71 | def from_state_dict(cls, state_dict: dict[str, Tensor]) -> Self: 72 | """Create a new Normalizer given a state_dict. 73 | 74 | Args: 75 | state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters. 76 | 77 | Returns: 78 | Normalizer 79 | """ 80 | instance = cls() 81 | instance.mean = state_dict["mean"].cpu() 82 | instance.std = state_dict["std"].cpu() 83 | 84 | return instance 85 | 86 | 87 | @dataclass 88 | class InMemoryDataLoader: 89 | """In-memory DataLoader using array/tensor slicing to generate whole batches at 90 | once instead of sample-by-sample. 91 | Source: https://discuss.pytorch.org/t/27014/6. 92 | 93 | Args: 94 | *tensors: List of arrays or tensors. Must all have the same length in 95 | dimension 0. 96 | collate_fn (Callable): Should accept variadic list of tensors and 97 | output a minibatch of data ready for model consumption. 98 | batch_size (int, optional): Usually 64, 128 or 256. Can be larger for test set 99 | loaders to speedup inference. Defaults to 64. 100 | shuffle (bool, optional): If True, shuffle the data *in-place* whenever an 101 | iterator is created from this object. Defaults to False. 102 | """ 103 | 104 | # each item must be indexable (usually torch.tensor, np.array or pd.Series) 105 | tensors: list[Tensor | np.ndarray] 106 | collate_fn: Callable 107 | batch_size: int = 64 108 | shuffle: bool = False 109 | 110 | def __post_init__(self): 111 | self.dataset_len = len(self.tensors[0]) 112 | if not all(len(t) == self.dataset_len for t in self.tensors): 113 | raise ValueError("All tensors must have the same length in dim 0") 114 | 115 | def __iter__(self) -> Iterator[tuple[Tensor, ...]]: 116 | self.indices = np.random.permutation(self.dataset_len) if self.shuffle else None 117 | self.current_idx = 0 118 | return self 119 | 120 | def __next__(self) -> tuple[Tensor, ...]: 121 | start_idx = self.current_idx 122 | if start_idx >= self.dataset_len: 123 | raise StopIteration 124 | 125 | end_idx = start_idx + self.batch_size 126 | 127 | if self.indices is None: # shuffle=False 128 | slices = (t[start_idx:end_idx] for t in self.tensors) 129 | else: 130 | idx = self.indices[start_idx:end_idx] 131 | slices = (t[idx] for t in self.tensors) 132 | 133 | batch = self.collate_fn(*slices) 134 | 135 | self.current_idx += self.batch_size 136 | return batch 137 | 138 | def __len__(self) -> int: 139 | """Get the number of batches in this data loader.""" 140 | n_batches, remainder = divmod(self.dataset_len, self.batch_size) 141 | return n_batches + bool(remainder) 142 | -------------------------------------------------------------------------------- /tests/test_wrenformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from sklearn.model_selection import train_test_split as split 5 | 6 | from aviary.utils import get_metrics, results_multitask, train_ensemble 7 | from aviary.wrenformer.data import df_to_in_mem_dataloader 8 | from aviary.wrenformer.model import Wrenformer 9 | 10 | 11 | @pytest.fixture 12 | def base_config(): 13 | return { 14 | "robust": True, 15 | "ensemble": 2, 16 | "run_id": 1, 17 | "data_seed": 42, 18 | "log": False, 19 | "sample": 1, 20 | "test_size": 0.2, 21 | } 22 | 23 | 24 | @pytest.fixture 25 | def model_architecture(): 26 | return { 27 | "d_model": 128, 28 | "n_attn_layers": 2, 29 | "n_attn_heads": 4, 30 | "trunk_hidden": (1024, 512), 31 | "out_hidden": (256, 128, 64), 32 | "embedding_aggregations": ("mean",), 33 | } 34 | 35 | 36 | @pytest.fixture 37 | def training_config(): 38 | return { 39 | "resume": False, 40 | "fine_tune": None, 41 | "transfer": None, 42 | "optim": "AdamW", 43 | "learning_rate": 3e-4, 44 | "momentum": 0.9, 45 | "weight_decay": 1e-6, 46 | "batch_size": 128, 47 | "workers": 0, 48 | "device": "cuda" if torch.cuda.is_available() else "cpu", 49 | } 50 | 51 | 52 | def test_wrenformer_regression( 53 | df_matbench_phonons_wyckoff, base_config, model_architecture, training_config 54 | ): 55 | target_name = "last phdos peak" 56 | task = "regression" 57 | losses = ["L1"] 58 | epochs = 25 59 | model_name = "wrenformer-reg-test" 60 | input_col = "protostructure" 61 | embedding_type = "protostructure" 62 | 63 | task_dict = dict(zip([target_name], [task], strict=False)) 64 | loss_dict = dict(zip([target_name], losses, strict=False)) 65 | 66 | train_idx = list(range(len(df_matbench_phonons_wyckoff))) 67 | train_idx, test_idx = split( 68 | train_idx, 69 | random_state=base_config["data_seed"], 70 | test_size=base_config["test_size"], 71 | ) 72 | 73 | train_df = df_matbench_phonons_wyckoff.iloc[train_idx[0 :: base_config["sample"]]] 74 | test_df = df_matbench_phonons_wyckoff.iloc[test_idx] 75 | val_df = test_df # Using test set for validation 76 | 77 | data_loader_kwargs = dict( 78 | id_col="material_id", 79 | input_col=input_col, 80 | target_col=target_name, 81 | embedding_type=embedding_type, 82 | device=training_config["device"], 83 | ) 84 | 85 | train_loader = df_to_in_mem_dataloader( 86 | train_df, 87 | batch_size=training_config["batch_size"], 88 | shuffle=True, 89 | **data_loader_kwargs, 90 | ) 91 | 92 | val_loader = df_to_in_mem_dataloader( 93 | val_df, 94 | batch_size=training_config["batch_size"] * 16, 95 | shuffle=False, 96 | **data_loader_kwargs, 97 | ) 98 | 99 | setup_params = { 100 | "optim": training_config["optim"], 101 | "learning_rate": training_config["learning_rate"], 102 | "weight_decay": training_config["weight_decay"], 103 | "momentum": training_config["momentum"], 104 | "device": training_config["device"], 105 | } 106 | 107 | restart_params = { 108 | "resume": training_config["resume"], 109 | "fine_tune": training_config["fine_tune"], 110 | "transfer": training_config["transfer"], 111 | } 112 | 113 | n_targets = [1] # Regression task has 1 target 114 | 115 | model_params = { 116 | "task_dict": task_dict, 117 | "robust": base_config["robust"], 118 | "n_targets": n_targets, 119 | "n_features": train_loader.tensors[0][0].shape[-1], 120 | **model_architecture, 121 | } 122 | 123 | train_ensemble( 124 | model_class=Wrenformer, 125 | model_name=model_name, 126 | run_id=base_config["run_id"], 127 | ensemble_folds=base_config["ensemble"], 128 | epochs=epochs, 129 | train_loader=train_loader, 130 | val_loader=val_loader, 131 | log=base_config["log"], 132 | setup_params=setup_params, 133 | restart_params=restart_params, 134 | model_params=model_params, 135 | loss_dict=loss_dict, 136 | ) 137 | 138 | test_loader = df_to_in_mem_dataloader( 139 | test_df, 140 | batch_size=training_config["batch_size"] * 64, 141 | shuffle=False, 142 | **data_loader_kwargs, 143 | ) 144 | 145 | results_dict = results_multitask( 146 | model_class=Wrenformer, 147 | model_name=model_name, 148 | run_id=base_config["run_id"], 149 | ensemble_folds=base_config["ensemble"], 150 | test_loader=test_loader, 151 | robust=base_config["robust"], 152 | task_dict=task_dict, 153 | device=training_config["device"], 154 | eval_type="checkpoint", 155 | save_results=False, 156 | ) 157 | 158 | preds = results_dict[target_name]["preds"] 159 | targets = results_dict[target_name]["targets"] 160 | 161 | y_ens = np.mean(preds, axis=0) 162 | mae, rmse, r2 = get_metrics(targets, y_ens, task).values() 163 | 164 | assert len(targets) == len(test_df) 165 | assert r2 > 0.7 166 | assert mae < 150 167 | assert rmse < 300 168 | 169 | 170 | if __name__ == "__main__": 171 | pytest.main(["-v", __file__]) 172 | -------------------------------------------------------------------------------- /examples/inputs/poscar_to_df.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import glob 3 | import os 4 | 5 | import pandas as pd 6 | from pymatgen.analysis.prototypes import ( 7 | count_wyckoff_positions, 8 | get_protostructure_label_from_spglib, 9 | ) 10 | from pymatgen.core import Composition, Structure 11 | from tqdm import tqdm 12 | 13 | tqdm.pandas() # prime progress_map functionality 14 | 15 | final_dir = os.path.dirname(os.path.abspath(__file__)) 16 | 17 | idx_list = [] 18 | structs = [] 19 | E_vasp_list = [] 20 | meta_list = [] 21 | ht_paths = [] 22 | 23 | for filepath in glob.glob(final_dir + "/raw/*.poscar", recursive=True): 24 | task_id = filepath.split("/")[-1].split(".")[0] 25 | 26 | with open(filepath) as file_contents: 27 | file_str = file_contents.read() 28 | struct = Structure.from_str(file_str, fmt="poscar") 29 | lines = file_str.splitlines() 30 | 31 | num = lines[6].split() 32 | E_vasp_per_atom = float(lines[0].split()[0]) / sum(int(a) for a in num) 33 | 34 | ht_path = lines[0].split()[1] 35 | meta_data = "[" + lines[0].split("[")[-1] 36 | 37 | idx_list.append(task_id) 38 | structs.append(struct) 39 | E_vasp_list.append(E_vasp_per_atom) 40 | ht_paths.append(ht_path) 41 | meta_list.append(meta_data) 42 | 43 | 44 | df = pd.DataFrame() 45 | df["material_id"] = idx_list 46 | df["final_structure"] = structs 47 | df["E_vasp_per_atom"] = E_vasp_list 48 | df["meta_data"] = meta_list 49 | df["ht_data"] = ht_paths 50 | 51 | df["E_vasp_per_atom"] = df["E_vasp_per_atom"].astype(float) 52 | 53 | print("\n~~~~ LOAD DATA ~~~~") 54 | # Remove duplicated ID's keeping lowest energy 55 | # NOTE this is a bug in TAATA we really shouldn't have to do this 56 | df = df.sort_values(by=["material_id", "E_vasp_per_atom"], ascending=True) 57 | df = df.drop_duplicates(subset="material_id", keep="first") 58 | 59 | 60 | # %% 61 | # Count number of datapoints 62 | print(f"Number of points in dataset: {len(df)}") 63 | 64 | # takes ~ 15mins 65 | df["protostructure"] = df.final_structure.progress_map( 66 | get_protostructure_label_from_spglib 67 | ) 68 | 69 | # lattice, sites = zip(*df.final_structure.progress_map(get_cgcnn_input)) 70 | 71 | df["composition"] = df.final_structure.map(lambda x: x.composition.reduced_formula) 72 | df["nelements"] = df.final_structure.map(lambda x: len(x.composition.elements)) 73 | df["volume"] = df.final_structure.map(lambda x: x.volume) 74 | df["n_sites"] = df.final_structure.map(lambda x: x.num_sites) 75 | 76 | # df["lattice"] = lattice 77 | # df["sites"] = sites 78 | 79 | 80 | # %% 81 | # Calculate Formation Enthalpy 82 | df_el = df[df["nelements"] == 1] 83 | df_el = df_el.sort_values(by=["composition", "E_vasp_per_atom"], ascending=True) 84 | el_refs = { 85 | c.composition.elements[0]: e 86 | for c, e in zip(df_el.final_structure, df_el.E_vasp_per_atom, strict=False) 87 | } 88 | 89 | 90 | def get_formation_energy(comp: str, energy: float, el_refs: dict[str, float]) -> float: 91 | """Compute formation energy per atom for formula/energy pair and elemental 92 | references. 93 | 94 | Args: 95 | comp (str): Formula string 96 | energy (float): energy in eV/atom 97 | el_refs (dict[str, float]): elemental reference energies in eV/atom 98 | 99 | Returns: 100 | float: formation energy per atom in eV 101 | """ 102 | c = Composition(comp) 103 | # NOTE our references use energies_per_atom for energy 104 | ref_e = sum(c[el] * el_refs[el] for el in c.elements) 105 | return energy - ref_e / c.num_atoms 106 | 107 | 108 | df["E_f"] = [ 109 | get_formation_energy(row.composition, row.E_vasp_per_atom, el_refs=el_refs) 110 | for row in df.itertuples() 111 | ] 112 | 113 | 114 | # %% 115 | # Remove invalid Wyckoff Sequences 116 | df["n_wyckoff"] = df["protostructure"].map(count_wyckoff_positions) 117 | 118 | df = df.query("'Invalid' not in wyckoff") 119 | print(f"Valid Wyckoff representation {len(df)}") 120 | 121 | 122 | # %% 123 | # Drop duplicated wyckoff representations 124 | df = df.sort_values(by=["protostructure", "E_vasp_per_atom"], ascending=True) 125 | df_wyk = df.drop_duplicates(["protostructure"], keep="first") 126 | print(f"Lowest energy unique wyckoff sequences: {len(df_wyk)}") 127 | 128 | 129 | # %% 130 | # NOTE searching after having dropped wyckoff duplicates will remove 131 | # some scaled duplicates. This value may still contain duplicates. 132 | df_wyk = df_wyk.sort_values(by=["composition", "E_vasp_per_atom"], ascending=True) 133 | df_comp = df_wyk.drop_duplicates("composition", keep="first") 134 | print(f"Lowest energy polymorphs only: {len(df_comp)}") 135 | 136 | 137 | # %% 138 | # Clean the data 139 | 140 | print("\n~~~~ DATA CLEANING ~~~~") 141 | print(f"Total systems: {len(df_wyk)}") 142 | 143 | wyk_lim = 16 144 | df_wyk = df_wyk[df_wyk["n_wyckoff"] <= wyk_lim] 145 | print(f"Less than {wyk_lim} Wyckoff species in cell: {len(df_wyk)}") 146 | 147 | cell_lim = 64 148 | df_wyk = df_wyk[df_wyk["n_sites"] <= cell_lim] 149 | print(f"Less than {cell_lim} atoms in cell: {len(df_wyk)}") 150 | 151 | vol_lim = 500 152 | df_wyk = df_wyk[df_wyk["volume"] / df_wyk["n_sites"] < vol_lim] 153 | print(f"Less than {vol_lim} A^3 per site: {len(df_wyk)}") 154 | 155 | fields = ["material_id", "composition", "E_f", "protostructure"] # , "lattice", "sites"] 156 | 157 | df_wyk[["material_id", "composition", "E_f", "protostructure"]].to_csv( 158 | final_dir + "/examples.csv", index=False 159 | ) 160 | 161 | df_wyk["structure"] = df_wyk["final_structure"].map(lambda x: x.as_dict()) 162 | 163 | df_wyk[["material_id", "composition", "E_f", "protostructure", "structure"]].to_json( 164 | final_dir + "/examples.json" 165 | ) 166 | -------------------------------------------------------------------------------- /aviary/segments.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | from torch import LongTensor, Tensor, nn 5 | 6 | from aviary.networks import SimpleNetwork 7 | from aviary.scatter import scatter_reduce 8 | 9 | 10 | class AttentionPooling(nn.Module): 11 | """Softmax attention layer. Currently unused.""" 12 | 13 | def __init__(self, gate_nn: nn.Module, message_nn: nn.Module) -> None: 14 | """Initialize softmax attention layer. 15 | 16 | Args: 17 | gate_nn (nn.Module): Neural network to calculate attention scalars 18 | message_nn (nn.Module): Neural network to evaluate message updates 19 | """ 20 | super().__init__() 21 | self.gate_nn = gate_nn 22 | self.message_nn = message_nn 23 | 24 | def forward(self, x: Tensor, index: Tensor) -> Tensor: 25 | """Forward pass. 26 | 27 | Args: 28 | x (Tensor): Input features for nodes 29 | index (Tensor): The indices for scatter operation over nodes 30 | 31 | Returns: 32 | Tensor: Output features for nodes 33 | """ 34 | gate = self.gate_nn(x) 35 | 36 | gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index] 37 | gate = gate.exp() 38 | gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10 39 | 40 | x = self.message_nn(x) 41 | return scatter_reduce(gate * x, index, dim=0, reduce="sum") 42 | 43 | def __repr__(self) -> str: 44 | gate_nn, message_nn = self.gate_nn, self.message_nn 45 | return f"{type(self).__name__}({gate_nn=}, {message_nn=})" 46 | 47 | 48 | class WeightedAttentionPooling(nn.Module): 49 | """Weighted softmax attention layer.""" 50 | 51 | def __init__(self, gate_nn: nn.Module, message_nn: nn.Module) -> None: 52 | """Initialize softmax attention layer. 53 | 54 | Args: 55 | gate_nn (nn.Module): Neural network to calculate attention scalars 56 | message_nn (nn.Module): Neural network to evaluate message updates 57 | """ 58 | super().__init__() 59 | self.gate_nn = gate_nn 60 | self.message_nn = message_nn 61 | self.pow = torch.nn.Parameter(torch.randn(1)) 62 | 63 | def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor: 64 | """Forward pass. 65 | 66 | Args: 67 | x (Tensor): Input features for nodes 68 | index (Tensor): The indices for scatter operation over nodes 69 | weights (Tensor): The weights to assign to nodes 70 | 71 | Returns: 72 | Tensor: Output features for nodes 73 | """ 74 | gate = self.gate_nn(x) 75 | 76 | gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index] 77 | gate = (weights**self.pow) * gate.exp() 78 | gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10 79 | 80 | x = self.message_nn(x) 81 | return scatter_reduce(gate * x, index, dim=0, reduce="sum") 82 | 83 | def __repr__(self) -> str: 84 | pow, gate_nn, message_nn = float(self.pow), self.gate_nn, self.message_nn 85 | return f"{type(self).__name__}({pow=:.3}, {gate_nn=}, {message_nn=})" 86 | 87 | 88 | class MessageLayer(nn.Module): 89 | """MessageLayer to propagate information between nodes in graph.""" 90 | 91 | def __init__( 92 | self, 93 | msg_fea_len: int, 94 | num_msg_heads: int, 95 | msg_gate_layers: Sequence[int], 96 | msg_net_layers: Sequence[int], 97 | ) -> None: 98 | """Initialise MessageLayer. 99 | 100 | Args: 101 | msg_fea_len (int): Number of input features 102 | num_msg_heads (int): Number of attention heads 103 | msg_gate_layers (list[int]): List of hidden layer sizes for gate network 104 | msg_net_layers (list[int]): List of hidden layer sizes for message network 105 | """ 106 | super().__init__() 107 | 108 | self._repr = ( 109 | f"{self._get_name()}({msg_fea_len=}, {num_msg_heads=}, {msg_gate_layers=}, " 110 | f"{msg_net_layers=})" 111 | ) 112 | 113 | # Pooling and Output 114 | self.pooling = nn.ModuleList( 115 | WeightedAttentionPooling( 116 | gate_nn=SimpleNetwork(2 * msg_fea_len, 1, msg_gate_layers), 117 | message_nn=SimpleNetwork(2 * msg_fea_len, msg_fea_len, msg_net_layers), 118 | ) 119 | for _ in range(num_msg_heads) 120 | ) 121 | 122 | def forward( 123 | self, 124 | node_weights: Tensor, 125 | node_prev_features: Tensor, 126 | self_idx: LongTensor, 127 | neighbor_idx: LongTensor, 128 | ) -> Tensor: 129 | """Forward pass. 130 | 131 | Args: 132 | node_weights (Tensor): The fractional weights of elements in their materials 133 | node_prev_features (Tensor): Node hidden features before message passing 134 | self_idx (LongTensor): Indices of the 1st element in each of the node pairs 135 | neighbor_idx (LongTensor): Indices of the 2nd element in each of the node 136 | pairs 137 | 138 | Returns: 139 | Tensor: node hidden features after message passing 140 | """ 141 | # construct the total features for passing 142 | node_nbr_weights = node_weights[neighbor_idx, :] 143 | msg_nbr_fea = node_prev_features[neighbor_idx, :] 144 | msg_self_fea = node_prev_features[self_idx, :] 145 | message = torch.cat([msg_self_fea, msg_nbr_fea], dim=1) 146 | 147 | # sum selectivity over the neighbors to get node updates 148 | head_features = [] 149 | for attn_head in self.pooling: 150 | out_msg = attn_head(message, index=self_idx, weights=node_nbr_weights) 151 | head_features.append(out_msg) 152 | 153 | # average the attention heads 154 | node_update = torch.stack(head_features).mean(dim=0) 155 | 156 | return node_update + node_prev_features 157 | 158 | def __repr__(self) -> str: 159 | return self._repr 160 | -------------------------------------------------------------------------------- /examples/inputs/examples.csv: -------------------------------------------------------------------------------- 1 | material_id,composition,E_f,protostructure 2 | c07b988125b3a0adc329b2449eeaac24c9e7def8,Hf,0.0,A_hP2_194_c:Hf 3 | 0060794138df3ba93c68007949b0a72d19cb4c62,Hf(ZnN2)3,0.6322286350000006,AB6C3_hR30_160_a_2b_b:Hf-N-Zn 4 | 004c70f544dca7f8e1c74c6c4ff87a1e355e7cd5,Hf(ZnN2)3,0.8708339350000012,AB6C3_hR30_148_a_f_bc:Hf-N-Zn 5 | 0059fd6dd1d17138c42bb5430fa0bc074053f0de,Hf(ZnN3)2,0.1256185166666688,AB6C2_mC18_12_a_ij_i:Hf-N-Zn 6 | 00a8d0f6f55191596770605314ff347995a3540d,Hf2Zn4N,0.2936619535714291,A2BC4_cF112_227_e_c_df:Hf-N-Zn 7 | 006197ce31efdef23c82f0f6cec08ac0b0febdce,Hf2Zn4N9,-0.037019788333333636,A2B9C4_oP60_60_d_c4d_2d:Hf-N-Zn 8 | 0019407b910a29730999dab65491b4e50624c219,Hf2Zn4N9,0.5619530450000001,A2B9C4_hP30_176_f_hi_bh:Hf-N-Zn 9 | 0039f3c461ce4a2cbf4c48b09bfc3c6b13a5c769,Hf2ZnN4,-0.600798328571428,A2B4C_oP28_61_c_2c_a:Hf-N-Zn 10 | 0079b9e3ab0e210ea4afe037a9faee8cd6577922,Hf2ZnN5,-0.09228186562499907,A2B5C_oP16_31_2a_5a_a:Hf-N-Zn 11 | 00021364446a637881257fd9ee912a422a6b1753,Hf2ZnN9,0.028860297916666333,A2B9C_oP48_62_d_c4d_c:Hf-N-Zn 12 | 00cca7ad273875ba80f1d437af797b6a01cc6ebd,Hf3ZnN13,-0.08990458382352706,A3B13C_mP34_13_fg_e6g_a:Hf-N-Zn 13 | 00ceac1a8da8e554218a84def07e5a6da9f106f3,Hf3ZnN4,-1.0282588499999985,A3B4C_mP16_6_3a3b_3a5b_2a:Hf-N-Zn 14 | 00ecadb66a8baab030675a817d2ee15f0002a4dc,Hf5(ZnN8)2,-0.0069458521739118595,A5B16C2_oC92_68_agh_4i_h:Hf-N-Zn 15 | 00fa7d75b7b2e1442df0a68071210c8d64b93ebe,Hf5ZnN3,-0.3175585194444448,A5B3C_tP18_125_am_cg_d:Hf-N-Zn 16 | 009ee0451531276f598976f46848fe7d590262ca,Hf6Zn2N13,-0.2216884392857139,A6B13C2_mC42_12_3i_a6i_i:Hf-N-Zn 17 | 00a3259f00f8a9aa503caae2045970d2f7b22e13,Hf7Zn5N19,0.3828095201612909,A7B19C5_oP31_47_ek2l_ai2j3x_cij:Hf-N-Zn 18 | 00237cf70206a5de77e5154ea72633514cdd4445,HfZn2N7,0.1671952525000009,AB7C2_mP20_11_e_3e2f_f:Hf-N-Zn 19 | 006a8e78e4bb5c1570bf70b47b24c0b855a5b5b5,HfZn3N2,0.9949698083333343,AB2C3_tI12_139_a_e_be:Hf-N-Zn 20 | 00bbba1d86168727008d4c1542b39d45d93d693c,HfZn3N4,0.11430601250000105,AB4C3_cP16_223_a_e_c:Hf-N-Zn 21 | 006047d5cfe39157f2904d2533bad74b2ddd3170,HfZnN,-0.36961197499999887,ABC_hP6_194_a_b_c:Hf-N-Zn 22 | 001a938f944fa33e899b81e3d0a909876587e7b0,HfZnN12,0.8577765142857157,AB12C_mC56_15_e_6f_e:Hf-N-Zn 23 | 000fe0f2a4730347df688bd0d3c75f4494a28dca,HfZnN2,-0.6694122624999999,AB2C_hP8_194_c_f_a:Hf-N-Zn 24 | 00af9b9350147b8c20498eb9cd65b63a3132c741,HfZnN2,-0.3108887624999994,AB2C_mC16_12_i_2i_g:Hf-N-Zn 25 | 6f7d6a970fde2a4243091aba720fc14ceec2854b,N2,0.0,A_cP8_205_c:N 26 | b428ccf0f8cd5488a3af5d72b74e117dc71aa575,Ti,0.0,A_hP2_194_c:Ti 27 | 0ac3201405a380e382377a71216b33cde892b161,Ti(ZnN2)2,0.44363870000000016,A4BC2_mC28_5_4c_c_abc:N-Ti-Zn 28 | 0abac62147275aa8199465764751eaa67c94d16c,Ti(ZnN2)2,0.4993669142857149,A4BC2_oP28_62_2cd_c_ac:N-Ti-Zn 29 | 0a0ae9c18de23e72bb764e10ef542de693321542,Ti2Zn2N5,-0.4770085805555553,A5B2C2_oI36_46_b2c_ab_c:N-Ti-Zn 30 | 0ab8cc7e5cc3d48fcc786f9060da5a8882374f44,Ti2Zn3N4,-0.4294422111111116,A4B2C3_mP36_14_4e_2e_3e:N-Ti-Zn 31 | 0ad7d4c22ba4414fbe161f27f42c112c6cb83b29,Ti2Zn3N7,0.6236076854166672,A7B2C3_tI24_139_aeg_e_be:N-Ti-Zn 32 | 0abb6669aa25624d8371765c5afa6b80ae20f30b,Ti2ZnN3,-0.8424399458333331,A3B2C_mC12_12_ai_i_b:N-Ti-Zn 33 | 0a2c4708d3c853ea3c4b8630027afe9b788ee2ce,Ti2ZnN5,-0.049384740624999424,A5B2C_oP32_58_3gh_eg_g:N-Ti-Zn 34 | 0a946dc02260d0d4b36ef2087e60b415a5fff800,Ti3ZnN5,-0.701213825,A5B3C_tP36_76_5a_3a_a:N-Ti-Zn 35 | 0a130b596d1ce7bfe317d9e546bbeacc703fef65,Ti3ZnN8,-0.23821160833333277,A8B3C_oI48_72_2jk_aj_b:N-Ti-Zn 36 | 0a646e174268b493bd67f58fdf52fead02f74a96,Ti4(ZnN3)3,-0.3996681890624991,A9B4C3_hR48_146_3b_4a_3a:N-Ti-Zn 37 | 0ab8ebf80d94a1ba826857a9e81cea8c7124d86e,Ti4Zn3N8,-0.2971007933333327,A8B4C3_mC60_12_8i_4i_3i:N-Ti-Zn 38 | 0ac8a65af75e8223b07dcd98c6958db7af6504cf,Ti4ZnN7,-0.5551503479166673,A7B4C_mC48_15_e3f_2f_e:N-Ti-Zn 39 | 0a5af69c21e6c2d9b004c2f51939790b10b19814,Ti5(ZnN3)3,-0.22798647205882272,A9B5C3_hP17_157_cd_bc_c:N-Ti-Zn 40 | 0a58ef8a37185d8abcbd6e1bad7eaf7e3d3d6016,Ti6ZnN8,-1.1345180866666658,A8B6C_hR45_148_cf_f_a:N-Ti-Zn 41 | 0a4a721fb8f0996350edd4fb224d89ce213f7f93,Ti7(ZnN)3,-0.47197025192307684,A3B7C3_oC52_36_3a_7a_3a:N-Ti-Zn 42 | 0ad8980cb8eb0693647f411732b2165b6a0fce0e,Ti7ZnN12,-0.34730283999999934,A12B7C_cI40_204_g_bc_a:N-Ti-Zn 43 | 0a458b5f2eeebeb0e4e55e1d1a17988ec87b7123,TiZn2N3,-0.04297332916666541,A3BC2_oC48_64_e2f_f_2f:N-Ti-Zn 44 | 0ac7f92ee7225e47dbdc25f6d60a10b116b0813b,TiZn2N3,0.5085712541666672,A3BC2_oC24_63_cg_c_e:N-Ti-Zn 45 | 0a3eb2074380f4abdec94d4a6171a099c9ecb30a,TiZn2N5,0.020599596875000792,A5BC2_oC32_63_c2f_c_f:N-Ti-Zn 46 | 0a2c3a235656a97d76f731b269982f92de9b37fd,TiZnN,-0.6217663083333331,ABC_hP6_164_d_c_d:N-Ti-Zn 47 | 0a2e1f340255f62a617974d778714e5fdd28a007,TiZnN2,2.0046756125,A2BC_cF32_227_c_b_a:N-Ti-Zn 48 | 0adeaf7a9c16772b66815cce19917ef0460affe8,TiZnN3,-0.014901634999999303,A3BC_cP40_198_2b_2a_2a:N-Ti-Zn 49 | 0a74297fb11027ba2a958da5ee31980b8876e1d4,TiZnN4,1.1665261166666676,A4BC_oC24_63_fg_a_c:N-Ti-Zn 50 | b3603cef0816d0e115adeafa4eee65c056dc7a26,Zn,0.0,A_hP2_194_c:Zn 51 | 0fd347fb29c23cad30d42852d7aa4ed906102b8a,Zr,0.0,A_hP2_194_c:Zr 52 | 0cd142efcf69ba63dc62760ecd72fb4f1e930115,Zr(ZnN)2,-0.2309529700000006,A2B2C_mC20_12_2i_2i_i:N-Zn-Zr 53 | 0c710ee221d7567cc0f1a06fba8231ae8261a181,Zr(ZnN)6,0.00016695769230778978,A6B6C_mC26_12_3i_3i_a:N-Zn-Zr 54 | 0c5264c980964b88ae8a7c2c18d5a61daff481b1,Zr(ZnN2)2,0.7521056285714289,A4B2C_hP14_186_bc_ab_b:N-Zn-Zr 55 | 0c7147c05a270b9b2744576b0bfb69c8391ee1e4,Zr(ZnN3)2,0.44643896111111125,A6B2C_hP9_162_k_c_b:N-Zn-Zr 56 | 0c429f98b79c9522a11d51fbfae484004f899422,Zr(ZnN5)2,0.478535934615385,A10B2C_mP26_11_2e4f_f_e:N-Zn-Zr 57 | 0c27147781fc9c64cbcf45acc1575a0bf8070841,Zr2Zn2N3,-0.5270690678571421,A3B2C2_mP14_14_ae_e_e:N-Zn-Zr 58 | 0c450e6f9c256e33ef452c98dd20917fb6484317,Zr2Zn3N4,-0.3812128222222215,A4B3C2_mC36_15_2f_ef_f:N-Zn-Zr 59 | 0c9cbed1e5d755a9666db36f742004c14bf780bd,Zr2ZnN5,-0.5414344906249999,A5BC2_mC32_9_5a_a_2a:N-Zn-Zr 60 | 0cd1f1670176b03cc3d61935d9a6882c9508e33a,Zr3ZnN6,-0.7395423050000005,A6BC3_mP20_11_6e_e_3e:N-Zn-Zr 61 | 0c882a0621189d18ff1350e3bed4114868663acf,Zr4ZnN5,-0.9191486925000003,A5BC4_oP20_59_a2e_a_2be:N-Zn-Zr 62 | 0c5ffa8e2b3126eba02901c37b31e8faf3b5377c,Zr6Zn4N13,-0.21249625326086985,A13B4C6_cI46_217_ag_c_d:N-Zn-Zr 63 | 0c9c74b6f48facd507c4e41d9df797a2017634e7,ZrZn2N3,0.1968950041666666,A3B2C_oC24_63_cf_2c_c:N-Zn-Zr 64 | 0c13902f8200732555ad1603e954a726d4d24d85,ZrZn3N2,-0.11572319166666656,A2B3C_oI24_72_j_ce_b:N-Zn-Zr 65 | 0c9cf664171d7e10527033cda1551694141207ed,ZrZnN2,-0.2964831375000001,A2BC_oP4_47_k_a_d:N-Zn-Zr 66 | 0c0b4fd84c8f4e075436d7ec34bfb3fad5e610be,ZrZnN3,-0.16417793499999878,A3BC_hP30_185_ab2c_ab_c:N-Zn-Zr 67 | 0c2fc1da358bcddfb4baf41b5dfd7285a13b766a,ZrZnN3,0.23257766500000088,A3BC_mC20_15_ef_c_e:N-Zn-Zr 68 | 0c5ef0e6b20d6e34ae388d640797f032a15df918,ZrZnN3,0.7506590650000007,A3BC_cP5_221_c_b_a:N-Zn-Zr 69 | 0c5e41fdc8fba8e90b48cbe4752dc2c4a2805f21,ZrZnN3,0.774344365000001,A3BC_tI20_140_bh_d_c:N-Zn-Zr 70 | -------------------------------------------------------------------------------- /aviary/roost/data.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from functools import cache 3 | from typing import Any 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from pymatgen.core import Composition 9 | from torch import LongTensor, Tensor 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class CompositionData(Dataset): 14 | """Dataset class for the Roost composition model.""" 15 | 16 | def __init__( 17 | self, 18 | df: pd.DataFrame, 19 | task_dict: dict[str, str], 20 | inputs: str = "composition", 21 | identifiers: Sequence[str] = ("material_id", "composition"), 22 | ): 23 | """Data class for Roost models. 24 | 25 | Args: 26 | df (pd.DataFrame): Pandas dataframe holding input and target values. 27 | task_dict (dict[str, "regression" | "classification"]): Map from target 28 | names to task type. 29 | elem_embedding (str, optional): One of "matscholar200", "cgcnn92", 30 | "megnet16", "onehot112" or path to a file with custom element 31 | embeddings. Defaults to "matscholar200". 32 | inputs (str, optional): df column name holding material compositions. 33 | Defaults to "composition". 34 | identifiers (list, optional): df columns for distinguishing data points. 35 | Will be copied over into the model's output CSV. Defaults to 36 | ["material_id", "composition"]. 37 | """ 38 | if len(identifiers) != 2: 39 | raise AssertionError("Two identifiers are required") 40 | 41 | self.inputs = inputs 42 | self.task_dict = task_dict 43 | self.identifiers = list(identifiers) 44 | self.df = df 45 | 46 | self.n_targets = [] 47 | for target, task in self.task_dict.items(): 48 | if task == "regression": 49 | self.n_targets.append(1) 50 | elif task == "classification": 51 | n_classes = np.max(self.df[target].values) + 1 52 | self.n_targets.append(n_classes) 53 | 54 | def __len__(self) -> int: 55 | return len(self.df) 56 | 57 | def __repr__(self) -> str: 58 | df_repr = f"cols=[{', '.join(self.df.columns)}], len={len(self.df)}" 59 | return f"{type(self).__name__}({df_repr}, task_dict={self.task_dict})" 60 | 61 | # Cache data for faster training 62 | @cache # noqa: B019 63 | def __getitem__(self, idx: int): 64 | """Get an entry out of the Dataset. 65 | 66 | Args: 67 | idx (int): index of entry in Dataset 68 | 69 | Returns: 70 | tuple: containing 71 | - tuple[Tensor, Tensor, LongTensor, LongTensor]: Roost model inputs 72 | - list[Tensor | LongTensor]: regression or classification targets 73 | - list[str | int]: identifiers like material_id, composition 74 | """ 75 | row = self.df.iloc[idx] 76 | composition = row[self.inputs] 77 | material_ids = row[self.identifiers].to_list() 78 | 79 | comp_dict = Composition(composition).fractional_composition 80 | weights = list(comp_dict.values()) 81 | weights = np.atleast_2d(weights).T / np.sum(weights) 82 | elem_fea = [elem.Z for elem in comp_dict] 83 | 84 | n_elems = len(comp_dict) 85 | self_idx = [] 86 | nbr_idx = [] 87 | for elem_idx in range(n_elems): 88 | self_idx += [elem_idx] * n_elems 89 | nbr_idx += list(range(n_elems)) 90 | 91 | # convert all data to tensors 92 | elem_weights = Tensor(weights) 93 | elem_fea = LongTensor(elem_fea) 94 | self_idx = LongTensor(self_idx) 95 | nbr_idx = LongTensor(nbr_idx) 96 | 97 | targets = [] 98 | for target in self.task_dict: 99 | if self.task_dict[target] == "regression": 100 | targets.append(Tensor([row[target]])) 101 | elif self.task_dict[target] == "classification": 102 | targets.append(LongTensor([row[target]])) 103 | 104 | return ( 105 | (elem_weights, elem_fea, self_idx, nbr_idx), 106 | targets, 107 | *material_ids, 108 | ) 109 | 110 | 111 | def collate_batch( 112 | samples: tuple[ 113 | tuple[Tensor, Tensor, LongTensor, LongTensor], 114 | list[Tensor | LongTensor], 115 | list[str | int], 116 | ], 117 | ) -> tuple[Any, ...]: 118 | """Collate a list of data and return a batch for predicting crystal properties. 119 | 120 | Args: 121 | samples (list): list of tuples for each data point where each tuple contains: 122 | (elem_fea, nbr_fea, nbr_idx, target) 123 | - elem_fea (Tensor): Atom hidden features before convolution 124 | - self_idx (LongTensor): Indices of the atom's self 125 | - nbr_idx (LongTensor): Indices of M neighbors of each atom 126 | - target (Tensor | LongTensor): target values containing floats for 127 | regression or integers as class labels for classification 128 | - cif_id: str or int 129 | 130 | Returns: 131 | tuple[ 132 | tuple[Tensor, Tensor, LongTensor, LongTensor, LongTensor]: batched Roost 133 | model inputs, 134 | tuple[Tensor | LongTensor]: Target values for different tasks, 135 | # TODO this last tuple is unpacked how to do type hint? 136 | *tuple[str | int]: Identifiers like material_id, composition 137 | ] 138 | """ 139 | # define the lists 140 | batch_elem_weights = [] 141 | batch_elem_fea = [] 142 | batch_self_idx = [] 143 | batch_nbr_idx = [] 144 | crystal_elem_idx = [] 145 | batch_targets = [] 146 | batch_cry_ids = [] 147 | 148 | cry_base_idx = 0 149 | for idx, (inputs, target, *cry_ids) in enumerate(samples): 150 | elem_weights, elem_fea, self_idx, nbr_idx = inputs 151 | 152 | n_sites = elem_fea.shape[0] # number of atoms for this crystal 153 | 154 | # batch the features together 155 | batch_elem_weights.append(elem_weights) 156 | batch_elem_fea.append(elem_fea) 157 | 158 | # mappings from bonds to atoms 159 | batch_self_idx.append(self_idx + cry_base_idx) 160 | batch_nbr_idx.append(nbr_idx + cry_base_idx) 161 | 162 | # mapping from atoms to crystals 163 | crystal_elem_idx.append(torch.tensor([idx] * n_sites)) 164 | 165 | # batch the targets and ids 166 | batch_targets.append(target) 167 | batch_cry_ids.append(cry_ids) 168 | 169 | # increment the id counter 170 | cry_base_idx += n_sites 171 | 172 | return ( 173 | ( 174 | torch.cat(batch_elem_weights, dim=0), 175 | torch.cat(batch_elem_fea, dim=0), 176 | torch.cat(batch_self_idx, dim=0), 177 | torch.cat(batch_nbr_idx, dim=0), 178 | torch.cat(crystal_elem_idx), 179 | ), 180 | tuple( 181 | torch.stack(b_target, dim=0) for b_target in zip(*batch_targets, strict=False) 182 | ), 183 | *zip(*batch_cry_ids, strict=False), 184 | ) 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Aviary

2 | 3 |

4 | 5 | ![License: MIT](https://img.shields.io/badge/License-MIT-green.svg) 6 | [![GitHub Repo Size](https://img.shields.io/github/repo-size/comprhys/aviary?label=Repo+Size)](https://github.com/comprhys/aviary/graphs/contributors) 7 | [![PyPI](https://img.shields.io/pypi/v/aviary-models?logo=pypi&logoColor=white)](https://pypi.org/project/aviary-models) 8 | [![GitHub last commit](https://img.shields.io/github/last-commit/comprhys/aviary?label=Last+Commit)](https://github.com/comprhys/aviary/commits) 9 | [![Tests](https://github.com/CompRhys/aviary/actions/workflows/test.yml/badge.svg)](https://github.com/CompRhys/aviary/actions/workflows/test.yml) 10 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/CompRhys/aviary/main.svg)](https://results.pre-commit.ci/latest/github/CompRhys/aviary/main) 11 | [![This project supports Python 3.10+](https://img.shields.io/badge/Python-3.10+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads) 12 | 13 |

14 | 15 | The aim of `aviary` is to contain multiple models for materials discovery under a common interface, over time we hope to add more models with a particular focus on coordinate-free deep learning models. 16 | 17 | ## Installation 18 | 19 | Users can install `aviary` from source with 20 | 21 | ```sh 22 | pip install aviary-models 23 | ``` 24 | 25 | or for an editable source install from a local clone: 26 | 27 | ```sh 28 | git clone https://github.com/CompRhys/aviary 29 | pip install -e ./aviary 30 | ``` 31 | 32 | ## Example Use from CLI 33 | 34 | To test the input files generation and cleaning/canonicalization please run: 35 | 36 | ```sh 37 | python examples/inputs/poscar_to_df.py 38 | ``` 39 | 40 | This script will load and parse a subset of raw POSCAR files from the TAATA dataset and produce the `datasets/examples/examples.csv` and `datasets/examples/examples.json` files used for the next example. 41 | For the coordinate-free `roost` and `wren` models where the inputs are easily expressed as strings we use CSV inputs. 42 | For the structure-based `cgcnn` model we first construct `pymatgen` structures from the raw POSCAR files then determine their dictionary serializations before saving in a JSON format. 43 | The raw POSCAR files have been selected to ensure that the subset contains all the correct endpoints for the 5 elemental species in the `Hf-N-Ti-Zr-Zn` chemical system. 44 | To test each of the three models provided please run: 45 | 46 | ```sh 47 | python examples/roost-example.py --train --evaluate --data-path examples/inputs/examples.csv --targets E_f --tasks regression --losses L1 --robust --epoch 10 48 | ``` 49 | 50 | ```sh 51 | python examples/wren-example.py --train --evaluate --data-path examples/inputs/examples.csv --targets E_f --tasks regression --losses L1 --robust --epoch 10 52 | ``` 53 | 54 | ```sh 55 | python examples/wrenformer-example.py --train --evaluate --data-path examples/inputs/examples.csv --targets E_f --tasks regression --losses L1 --robust --epoch 10 56 | ``` 57 | 58 | ```sh 59 | python examples/cgcnn-example.py --train --evaluate --data-path examples/inputs/examples.json --targets E_f --tasks regression --losses L1 --robust --epoch 10 60 | ``` 61 | 62 | Please note that for speed/demonstration purposes this example runs on only ~68 materials for 10 epochs - running all these examples should take < 30 sec. These examples do not have sufficient data or training to make accurate predictions, however, the same scripts were used for all experiments conducted as part of the development and publication of these models. 63 | Consequently understanding these examples will ensure you can deploy the models as intended for your research. 64 | 65 | ## Notebooks 66 | 67 | We also provide some notebooks that show more a more pythonic way to interact with the codebase, these examples make use of the TAATA dataset examined in the `wren` manuscript: 68 | 69 | | | | | 70 | | ---------------------------------------------------------------------------------------- | ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------- | 71 | | **[Roost](https://github.com/CompRhys/aviary/blob/main/examples/notebooks/Roost.ipynb)** | [![Launch Codespace]][codespace url] | [![Open in Google Colab]](https://colab.research.google.com/github/CompRhys/aviary/blob/main/examples/notebooks/Roost.ipynb) | 72 | | **[Wren](https://github.com/CompRhys/aviary/blob/main/examples/notebooks/Wren.ipynb)** | [![Launch Codespace]][codespace url] | [![Open in Google Colab]](https://colab.research.google.com/github/CompRhys/aviary/blob/main/examples/notebooks/Wren.ipynb) | 73 | 74 | [Open in Google Colab]: https://colab.research.google.com/assets/colab-badge.svg 75 | [Launch Codespace]: https://img.shields.io/badge/Launch-Codespace-darkblue?logo=github 76 | [codespace url]: https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=411272553 77 | 78 | ## Cite This Work 79 | 80 | If you use this code please cite the relevant work: 81 | 82 | `roost` - Predicting materials properties without crystal structure: Deep representation learning from stoichiometry. [[Paper]](https://doi.org/10.1038/s41467-020-19964-7) [[arXiv]](https://arxiv.org/abs/1910.00617) 83 | 84 | ```bibtex 85 | @article{goodall_2020_predicting, 86 | title={Predicting materials properties without crystal structure: Deep representation learning from stoichiometry}, 87 | author={Goodall, Rhys EA and Lee, Alpha A}, 88 | journal={Nature Communications}, 89 | volume={11}, 90 | number={1}, 91 | pages={1--9}, 92 | year={2020}, 93 | publisher={Nature Publishing Group} 94 | } 95 | ``` 96 | 97 | `wren` - Rapid Discovery of Stable Materials by Coordinate-free Coarse Graining. [[Paper]](https://www.science.org/doi/10.1126/sciadv.abn4117) [[arXiv]](https://arxiv.org/abs/2106.11132) 98 | 99 | ```bibtex 100 | @article{goodall_2022_rapid, 101 | title={Rapid discovery of stable materials by coordinate-free coarse graining}, 102 | author={Goodall, Rhys EA and Parackal, Abhijith S and Faber, Felix A and Armiento, Rickard and Lee, Alpha A}, 103 | journal={Science Advances}, 104 | volume={8}, 105 | number={30}, 106 | pages={eabn4117}, 107 | year={2022}, 108 | publisher={American Association for the Advancement of Science} 109 | } 110 | ``` 111 | 112 | `cgcnn` - Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties. [[Paper]](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301) [[arXiv]](https://arxiv.org/abs/1710.10324) 113 | 114 | ```bibtex 115 | @article{xie_2018_crystal, 116 | title={Crystal graph convolutional neural networks for an accurate and interpretable prediction of material properties}, 117 | author={Xie, Tian and Grossman, Jeffrey C}, 118 | journal={Physical review letters}, 119 | volume={120}, 120 | number={14}, 121 | pages={145301}, 122 | year={2018}, 123 | publisher={APS} 124 | } 125 | ``` 126 | 127 | ## Disclaimer 128 | 129 | This research code is provided as-is. We have checked for potential bugs and believe that the code is being shared in a bug-free state. 130 | -------------------------------------------------------------------------------- /aviary/roost/model.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from pymatgen.util.due import Doi, due 6 | from torch import LongTensor, Tensor, nn 7 | 8 | from aviary.core import BaseModelClass 9 | from aviary.networks import ResidualNetwork, SimpleNetwork 10 | from aviary.segments import MessageLayer, WeightedAttentionPooling 11 | from aviary.utils import get_element_embedding 12 | 13 | 14 | @due.dcite(Doi("10.1038/s41467-020-19964-7"), description="Roost model") 15 | class Roost(BaseModelClass): 16 | """The Roost model is comprised of a fully connected network 17 | and message passing graph layers. 18 | 19 | The message passing layers are used to determine a descriptor set 20 | for the fully connected network. The graphs are used to represent 21 | the stoichiometry of inorganic materials in a trainable manner. 22 | This makes them systematically improvable with more data. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | robust: bool, 28 | n_targets: Sequence[int], 29 | elem_embedding: str = "matscholar200", 30 | elem_fea_len: int = 64, 31 | n_graph: int = 3, 32 | elem_heads: int = 3, 33 | elem_gate: Sequence[int] = (256,), 34 | elem_msg: Sequence[int] = (256,), 35 | cry_heads: int = 3, 36 | cry_gate: Sequence[int] = (256,), 37 | cry_msg: Sequence[int] = (256,), 38 | trunk_hidden: Sequence[int] = (1024, 512), 39 | out_hidden: Sequence[int] = (256, 128, 64), 40 | **kwargs, 41 | ) -> None: 42 | """Composition-only model.""" 43 | super().__init__(robust=robust, **kwargs) 44 | 45 | self.elem_embedding = get_element_embedding(elem_embedding) 46 | elem_emb_len = self.elem_embedding.weight.shape[1] 47 | desc_dict = { 48 | "elem_emb_len": elem_emb_len, 49 | "elem_fea_len": elem_fea_len, 50 | "n_graph": n_graph, 51 | "elem_heads": elem_heads, 52 | "elem_gate": elem_gate, 53 | "elem_msg": elem_msg, 54 | "cry_heads": cry_heads, 55 | "cry_gate": cry_gate, 56 | "cry_msg": cry_msg, 57 | } 58 | 59 | self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore[arg-type] 60 | 61 | model_params = { 62 | "robust": robust, 63 | "n_targets": n_targets, 64 | "out_hidden": out_hidden, 65 | "trunk_hidden": trunk_hidden, 66 | "elem_embedding": elem_embedding, 67 | **desc_dict, 68 | } 69 | self.model_params.update(model_params) 70 | 71 | # define an output neural network 72 | if self.robust: 73 | n_targets = [2 * n for n in n_targets] 74 | 75 | self.trunk_nn = ResidualNetwork(elem_fea_len, out_hidden[0], trunk_hidden) 76 | 77 | self.output_nns = nn.ModuleList( 78 | ResidualNetwork(out_hidden[0], n, out_hidden[1:]) for n in n_targets 79 | ) 80 | 81 | def forward( 82 | self, 83 | elem_weights: Tensor, 84 | elem_fea: Tensor, 85 | self_idx: LongTensor, 86 | nbr_idx: LongTensor, 87 | cry_elem_idx: LongTensor, 88 | ) -> tuple[Tensor, ...]: 89 | """Forward pass through the material_nn and output_nn.""" 90 | elem_fea = self.elem_embedding(elem_fea) 91 | 92 | crys_fea = self.material_nn( 93 | elem_weights, elem_fea, self_idx, nbr_idx, cry_elem_idx 94 | ) 95 | 96 | crys_fea = F.relu(self.trunk_nn(crys_fea)) 97 | 98 | # apply neural network to map from learned features to target 99 | return tuple(output_nn(crys_fea) for output_nn in self.output_nns) 100 | 101 | 102 | class DescriptorNetwork(nn.Module): 103 | """The Descriptor Network is the message passing section of the Roost Model.""" 104 | 105 | def __init__( 106 | self, 107 | elem_emb_len: int, 108 | elem_fea_len: int = 64, 109 | n_graph: int = 3, 110 | elem_heads: int = 3, 111 | elem_gate: Sequence[int] = (256,), 112 | elem_msg: Sequence[int] = (256,), 113 | cry_heads: int = 3, 114 | cry_gate: Sequence[int] = (256,), 115 | cry_msg: Sequence[int] = (256,), 116 | ) -> None: 117 | """Bundles n_graph message passing layers followed by cry_heads weighted 118 | attention pooling layers. 119 | """ 120 | super().__init__() 121 | 122 | # apply linear transform to the input to get a trainable embedding 123 | # NOTE -1 here so we can add the weights as a node feature 124 | self.embedding = nn.Linear(elem_emb_len, elem_fea_len - 1) 125 | 126 | # create a list of Message passing layers 127 | self.graphs = nn.ModuleList( 128 | MessageLayer( 129 | msg_fea_len=elem_fea_len, 130 | num_msg_heads=elem_heads, 131 | msg_gate_layers=elem_gate, 132 | msg_net_layers=elem_msg, 133 | ) 134 | for _ in range(n_graph) 135 | ) 136 | 137 | # define a global pooling function for materials 138 | self.cry_pool = nn.ModuleList( 139 | WeightedAttentionPooling( 140 | gate_nn=SimpleNetwork(elem_fea_len, 1, cry_gate), 141 | message_nn=SimpleNetwork(elem_fea_len, elem_fea_len, cry_msg), 142 | ) 143 | for _ in range(cry_heads) 144 | ) 145 | 146 | def forward( 147 | self, 148 | elem_weights: Tensor, 149 | elem_fea: Tensor, 150 | self_idx: LongTensor, 151 | nbr_idx: LongTensor, 152 | cry_elem_idx: LongTensor, 153 | ) -> Tensor: 154 | """Forward pass through the DescriptorNetwork. 155 | 156 | Args: 157 | elem_weights (Tensor): Fractional weight of each Element in its 158 | stoichiometry 159 | elem_fea (Tensor): Element features of each of the elements in the batch 160 | self_idx (LongTensor): Indices of the 1st element in each of the pairs 161 | nbr_idx (LongTensor): Indices of the 2nd element in each of the pairs 162 | cry_elem_idx (list[LongTensor]): Mapping from the elem idx to crystal idx 163 | 164 | Returns: 165 | Tensor: Composition representation/features after message passing 166 | """ 167 | # embed the original features into a trainable embedding space 168 | elem_fea = self.embedding(elem_fea) 169 | 170 | # add weights as a node feature 171 | elem_fea = torch.cat([elem_fea, elem_weights], dim=1) 172 | 173 | # apply the message passing functions 174 | for graph_func in self.graphs: 175 | elem_fea = graph_func(elem_weights, elem_fea, self_idx, nbr_idx) 176 | 177 | # generate crystal features by pooling the elemental features 178 | head_fea = [ 179 | attn_head(elem_fea, index=cry_elem_idx, weights=elem_weights) 180 | for attn_head in self.cry_pool 181 | ] 182 | 183 | return torch.mean(torch.stack(head_fea), dim=0) 184 | 185 | def __repr__(self) -> str: 186 | return ( 187 | f"{type(self).__name__}(n_graph={len(self.graphs)}, cry_heads=" 188 | f"{len(self.cry_pool)}, elem_emb_len={self.embedding.in_features}, " 189 | f"elem_fea_len={self.embedding.out_features})" 190 | ) 191 | -------------------------------------------------------------------------------- /examples/matbench_example/make_plots.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import sys 3 | from pathlib import Path 4 | 5 | # Add the parent directory to system path 6 | sys.path.append(str(Path(__file__).parent.parent)) 7 | 8 | import json 9 | import logging 10 | import re 11 | from collections import defaultdict 12 | from datetime import datetime 13 | from glob import glob 14 | 15 | import pandas as pd 16 | import plotly.express as px 17 | import pymatviz as pmv 18 | from matbench import MatbenchBenchmark 19 | from matbench.constants import CLF_KEY, REG_KEY 20 | from matbench.metadata import mbv01_metadata as matbench_metadata 21 | from matbench_example.plotting_functions import ( 22 | dataset_labels_html, 23 | error_heatmap, 24 | plot_leaderboard, 25 | scale_errors, 26 | ) 27 | from matbench_example.prepare_matbench_datasets import DATA_PATHS 28 | from matbench_example.utils import recursive_dict_merge 29 | from sklearn.metrics import r2_score, roc_auc_score 30 | 31 | logging.getLogger("matbench").setLevel("ERROR") 32 | 33 | today = f"{datetime.now():%Y-%m-%d}" 34 | matbench_repo_path = "/Users/janosh/dev/matbench" # path to clone of matbench repo 35 | bench_dir = f"{matbench_repo_path}/benchmarks" 36 | 37 | our_scores: dict[str, dict[str, float]] = defaultdict(dict) 38 | others_scores: dict[str, dict[str, float]] = defaultdict(dict) 39 | 40 | 41 | # %% --- Load other's scores --- 42 | # load benchmark data for models with existing Matbench submission 43 | for idx, dirname in enumerate(glob(f"{bench_dir}/*"), start=1): 44 | model_name = dirname.split("/matbench_v0.1_")[-1] 45 | print(f"{idx}. {model_name}") 46 | mbbm = MatbenchBenchmark.from_file(f"{dirname}/results.json.gz") 47 | 48 | for task in mbbm.tasks: 49 | task_name = task.dataset_name 50 | task_type = task.metadata.task_type 51 | 52 | if task_type == REG_KEY: 53 | score = task.scores.mae.mean 54 | elif task_type == CLF_KEY: 55 | score = task.scores.rocauc.mean 56 | else: 57 | raise ValueError(f"Unknown {task_type = }") 58 | 59 | others_scores[model_name][task_name] = round(score, 3) 60 | 61 | 62 | # %% --- Load our scores --- 63 | our_score_files = sorted(glob("model_scores/*.json"), key=lambda s: s.split("@")[0]) 64 | 65 | for idx, filename in enumerate(our_score_files, start=1): 66 | date, model_name = re.split(r"@\d\d-\d\d-", filename.split("/")[-1]) 67 | 68 | print(f"{idx}. {date} {model_name}") 69 | 70 | with open(filename) as file: 71 | data = json.load(file) 72 | 73 | # filter params and other unwanted keys 74 | data = {k: data[k] for k in data if k.startswith("matbench_")} 75 | 76 | mean_scores = {} 77 | for task in data: 78 | df_fold_means = pd.DataFrame(data[task]).mean(1).round(3) 79 | key = "mae" if "mae" in df_fold_means else "rocauc" 80 | mean_scores[task] = df_fold_means[key] 81 | 82 | # mean_scores["date"] = f"2022-{date}" 83 | our_scores[model_name] = mean_scores 84 | 85 | 86 | # %% 87 | matbench_dfs: dict[str, pd.DataFrame] = {} 88 | tasks_to_load = ["matbench_mp_e_form"] 89 | # tasks_to_load = list(DATA_PATHS) 90 | 91 | for task_name in tasks_to_load: 92 | if task_name in matbench_dfs: 93 | continue 94 | task_df = pd.read_json(DATA_PATHS[task_name]).set_index("mbid") 95 | task_df = task_df.drop(columns=["structure"], errors="ignore") 96 | matbench_dfs[task_name] = task_df 97 | 98 | 99 | # %% --- Load other's predictions --- 100 | for task_name in tasks_to_load: 101 | for model_name in ["alignn", "Crabnet"]: 102 | results_path = f"{bench_dir}/matbench_v0.1_{model_name}/results.json.gz" 103 | 104 | with open(results_path) as file: 105 | data = json.load(file)["tasks"][task_name]["results"] 106 | 107 | join_folds: dict[str, float] = {} 108 | for idx in range(5): 109 | join_folds |= data[f"fold_{idx}"]["data"] 110 | 111 | task_df = matbench_dfs[task_name] 112 | 113 | task_df[model_name] = pd.Series(join_folds) 114 | 115 | 116 | # %% --- Load our predictions --- 117 | our_pred_files = sorted(glob("model_preds/*swa*.json")) 118 | pred_col = "predictions" 119 | 120 | for file_path in our_pred_files: 121 | model_name = file_path.split("/")[-1].split("-2022")[0] 122 | print(f"\nReading {model_name}...") 123 | with open(file_path) as file: 124 | json_data = json.load(file) 125 | 126 | # for task_name, folds in json_data.items(): # loads all tasks 127 | for task_name in tasks_to_load: # loads only selected tasks 128 | folds = json_data[task_name] 129 | if len(folds) != 5: 130 | print(f" {task_name} only partially recorded: {sorted(folds)}") 131 | 132 | target = matbench_metadata[task_name].target 133 | task_type = matbench_metadata[task_name].task_type 134 | 135 | dfs = {idx: pd.DataFrame(fold_dict) for idx, fold_dict in folds.items()} 136 | 137 | task_df = pd.concat(dfs.values()) 138 | task_df = task_df.set_index("mbid") 139 | 140 | if task_type == CLF_KEY: 141 | proba_cls_1 = task_df[pred_col].str[1] 142 | folds_mean_score = roc_auc_score(task_df[target], proba_cls_1) 143 | else: 144 | folds_mean_score = (task_df[target] - task_df[pred_col]).abs().mean() 145 | our_scores[model_name][task_name] = round(folds_mean_score, 4) 146 | 147 | # record our model preds into new df columns 148 | matbench_dfs[task_name][model_name] = task_df[pred_col] 149 | 150 | 151 | # %% 152 | df_err = pd.DataFrame(recursive_dict_merge(our_scores, others_scores)) 153 | df_err.index.name = "dataset" 154 | print(f"{df_err=}") 155 | 156 | 157 | # %% 158 | # html_path = f"plots/{today}-matbench-leaderboard.html" 159 | html_path = None 160 | plot_leaderboard(df_err.dropna(axis=1, thresh=5), html_path, width=1200, height=600) 161 | 162 | 163 | # %% error heatmap using Pandas dataframe styler 164 | # thresh=x means require at least x non-NA values 165 | df_err_scaled = ( 166 | scale_errors(df_err).dropna(thresh=4).dropna(thresh=8, axis=1) 167 | ) # .drop("matbench_jdft2d") 168 | 169 | df_err_scaled.loc["mean scaled error"] = df_err_scaled.mean(0) 170 | 171 | df_display = df_err_scaled.T.sort_values(by="mean scaled error") 172 | df_display = df_display.rename(columns=dataset_labels_html) 173 | 174 | df_display.style.format(precision=3).background_gradient(cmap="viridis") 175 | 176 | 177 | # %% error heatmap using plotly 178 | # thresh=x means require at least x non-NA values 179 | fig = error_heatmap(df_err.dropna(thresh=9, axis=1), width=1200, height=600) 180 | fig.show() 181 | fig.write_image(f"plots/{today}-matbench-scaled-errors-heatmap.png", scale=2) 182 | 183 | 184 | # %% scatter plot of predictions vs. targets for multiple models on matbench_mp_e_form 185 | df = matbench_dfs["matbench_mp_e_form"] 186 | df = df.dropna(axis=1) # drop models with missing predictions 187 | target = df.columns[0] 188 | 189 | 190 | y_cols = [c for c in df if c not in [target, "composition", "protostructure"]] 191 | labels = {} 192 | 193 | for y_col in y_cols: 194 | MAE = (df[y_col] - df[target]).abs().mean() 195 | pretty_title = y_col.replace("-", " ") 196 | R2 = r2_score(df[target], df[y_col]) 197 | labels[y_col] = f"{pretty_title}
{MAE=:.2f}, {R2=:.2f}" 198 | 199 | fig = px.scatter( 200 | df.rename(columns=labels).reset_index(), 201 | x=target, 202 | y=list(labels.values()), 203 | hover_data=["mbid"], 204 | opacity=0.7, 205 | width=1200, 206 | height=800, 207 | labels={ 208 | "e_form": "DFT formation energy (eV/atom)", 209 | "value ": "Predicted formation energy (eV/atom)", 210 | }, 211 | ) 212 | pmv.powerups.add_identity_line(fig) 213 | 214 | fig.update_layout(legend=dict(x=0.02, y=0.95, xanchor="left", title="Models")) 215 | 216 | # fig.write_image(f"plots/{today}-matbench-mp-e-form-scatter.png", scale=2) 217 | --------------------------------------------------------------------------------