├── tests ├── __init__.py ├── data │ └── matbench_phonons.json.gz ├── wren │ └── test_utils.py ├── test_package.py ├── conftest.py ├── test_core.py ├── test_print_metrics.py ├── test_utils.py └── test_wrenformer.py ├── aviary ├── wren │ ├── __init__.py │ └── utils.py ├── cgcnn │ └── __init__.py ├── roost │ ├── __init__.py │ ├── data.py │ └── model.py ├── wrenformer │ └── __init__.py ├── __init__.py ├── embeddings │ ├── wyckoff │ │ └── README.md │ └── element │ │ └── README.md ├── losses.py ├── scatter.py ├── networks.py ├── data.py └── segments.py ├── .github └── workflows │ ├── link-check-config.json │ ├── link-check.yml │ ├── publish.yml │ └── test.yml ├── citation.cff ├── examples ├── inputs │ ├── raw │ │ ├── 0fd347fb29c23cad30d42852d7aa4ed906102b8a.poscar │ │ ├── b3603cef0816d0e115adeafa4eee65c056dc7a26.poscar │ │ ├── b428ccf0f8cd5488a3af5d72b74e117dc71aa575.poscar │ │ ├── c07b988125b3a0adc329b2449eeaac24c9e7def8.poscar │ │ ├── 0c9cf664171d7e10527033cda1551694141207ed.poscar │ │ ├── 0c5ef0e6b20d6e34ae388d640797f032a15df918.poscar │ │ ├── 006047d5cfe39157f2904d2533bad74b2ddd3170.poscar │ │ ├── 0a2c3a235656a97d76f731b269982f92de9b37fd.poscar │ │ ├── 0abb6669aa25624d8371765c5afa6b80ae20f30b.poscar │ │ ├── 006a8e78e4bb5c1570bf70b47b24c0b855a5b5b5.poscar │ │ ├── 0a2e1f340255f62a617974d778714e5fdd28a007.poscar │ │ ├── 000fe0f2a4730347df688bd0d3c75f4494a28dca.poscar │ │ ├── 00af9b9350147b8c20498eb9cd65b63a3132c741.poscar │ │ ├── 6f7d6a970fde2a4243091aba720fc14ceec2854b.poscar │ │ ├── 0c7147c05a270b9b2744576b0bfb69c8391ee1e4.poscar │ │ ├── 0060794138df3ba93c68007949b0a72d19cb4c62.poscar │ │ ├── 0c2fc1da358bcddfb4baf41b5dfd7285a13b766a.poscar │ │ ├── 0cd142efcf69ba63dc62760ecd72fb4f1e930115.poscar │ │ ├── 004c70f544dca7f8e1c74c6c4ff87a1e355e7cd5.poscar │ │ ├── 0c5e41fdc8fba8e90b48cbe4752dc2c4a2805f21.poscar │ │ ├── 0a74297fb11027ba2a958da5ee31980b8876e1d4.poscar │ │ ├── 0ac7f92ee7225e47dbdc25f6d60a10b116b0813b.poscar │ │ ├── 0c13902f8200732555ad1603e954a726d4d24d85.poscar │ │ ├── 0c9c74b6f48facd507c4e41d9df797a2017634e7.poscar │ │ ├── 0ad7d4c22ba4414fbe161f27f42c112c6cb83b29.poscar │ │ ├── 0c710ee221d7567cc0f1a06fba8231ae8261a181.poscar │ │ ├── 0ac3201405a380e382377a71216b33cde892b161.poscar │ │ ├── 0c27147781fc9c64cbcf45acc1575a0bf8070841.poscar │ │ ├── 0c5264c980964b88ae8a7c2c18d5a61daff481b1.poscar │ │ ├── 0a58ef8a37185d8abcbd6e1bad7eaf7e3d3d6016.poscar │ │ ├── 0c9cbed1e5d755a9666db36f742004c14bf780bd.poscar │ │ ├── 0079b9e3ab0e210ea4afe037a9faee8cd6577922.poscar │ │ ├── 00bbba1d86168727008d4c1542b39d45d93d693c.poscar │ │ ├── 0a3eb2074380f4abdec94d4a6171a099c9ecb30a.poscar │ │ ├── 0a646e174268b493bd67f58fdf52fead02f74a96.poscar │ │ ├── 00ceac1a8da8e554218a84def07e5a6da9f106f3.poscar │ │ ├── 0a5af69c21e6c2d9b004c2f51939790b10b19814.poscar │ │ ├── 0a0ae9c18de23e72bb764e10ef542de693321542.poscar │ │ ├── 0c450e6f9c256e33ef452c98dd20917fb6484317.poscar │ │ ├── 0059fd6dd1d17138c42bb5430fa0bc074053f0de.poscar │ │ ├── 00fa7d75b7b2e1442df0a68071210c8d64b93ebe.poscar │ │ ├── 0cd1f1670176b03cc3d61935d9a6882c9508e33a.poscar │ │ ├── 00237cf70206a5de77e5154ea72633514cdd4445.poscar │ │ ├── 0c882a0621189d18ff1350e3bed4114868663acf.poscar │ │ ├── 0ad8980cb8eb0693647f411732b2165b6a0fce0e.poscar │ │ ├── 009ee0451531276f598976f46848fe7d590262ca.poscar │ │ ├── 0c5ffa8e2b3126eba02901c37b31e8faf3b5377c.poscar │ │ ├── 0a458b5f2eeebeb0e4e55e1d1a17988ec87b7123.poscar │ │ ├── 0ac8a65af75e8223b07dcd98c6958db7af6504cf.poscar │ │ ├── 0a130b596d1ce7bfe317d9e546bbeacc703fef65.poscar │ │ ├── 0a4a721fb8f0996350edd4fb224d89ce213f7f93.poscar │ │ ├── 0c429f98b79c9522a11d51fbfae484004f899422.poscar │ │ ├── 001a938f944fa33e899b81e3d0a909876587e7b0.poscar │ │ ├── 0039f3c461ce4a2cbf4c48b09bfc3c6b13a5c769.poscar │ │ ├── 0abac62147275aa8199465764751eaa67c94d16c.poscar │ │ ├── 000866a75fbc0a305808c11c58342bb973c3cad5.poscar │ │ ├── 00a8d0f6f55191596770605314ff347995a3540d.poscar │ │ ├── 0ab8ebf80d94a1ba826857a9e81cea8c7124d86e.poscar │ │ ├── 0019407b910a29730999dab65491b4e50624c219.poscar │ │ ├── 0c0b4fd84c8f4e075436d7ec34bfb3fad5e610be.poscar │ │ ├── 00a3259f00f8a9aa503caae2045970d2f7b22e13.poscar │ │ ├── 0a2c4708d3c853ea3c4b8630027afe9b788ee2ce.poscar │ │ ├── 00cca7ad273875ba80f1d437af797b6a01cc6ebd.poscar │ │ ├── 0a946dc02260d0d4b36ef2087e60b415a5fff800.poscar │ │ ├── 0ab8cc7e5cc3d48fcc786f9060da5a8882374f44.poscar │ │ ├── 00bada6787af6463447a1460da0c031d51494b60.poscar │ │ ├── 0adeaf7a9c16772b66815cce19917ef0460affe8.poscar │ │ ├── 00ecadb66a8baab030675a817d2ee15f0002a4dc.poscar │ │ ├── 00021364446a637881257fd9ee912a422a6b1753.poscar │ │ ├── 006197ce31efdef23c82f0f6cec08ac0b0febdce.poscar │ │ ├── 0069751a4ba575438917a8b4e5a565b686388f6e.poscar │ │ ├── 0ad3ae37a9fb7eab9189ff1cb384b9d2dd26bbf9.poscar │ │ ├── 0a10bbf9e7cd3d8da2469f15e06ceca7d8d2af5d.poscar │ │ └── 0ad7c1ba7ec6a37cae0523bbd1c2a75f2d56a3f2.poscar │ ├── poscar_to_df.py │ └── examples.csv └── matbench_example │ ├── prepare_matbench_datasets.py │ ├── readme.md │ ├── train_wrenformer.py │ └── make_plots.py ├── .gitignore ├── LICENSE ├── CONTRIBUTING.md ├── .pre-commit-config.yaml ├── pyproject.toml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aviary/wren/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aviary/cgcnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aviary/roost/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aviary/wrenformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/data/matbench_phonons.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompRhys/aviary/HEAD/tests/data/matbench_phonons.json.gz -------------------------------------------------------------------------------- /.github/workflows/link-check-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "aliveStatusCodes": [ 3 | 200, 4 | 403, 5 | 503 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /aviary/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | from os.path import abspath, dirname 3 | 4 | __version__ = importlib.metadata.version("aviary-models") 5 | PKG_DIR = dirname(abspath(__file__)) 6 | ROOT = dirname(PKG_DIR) 7 | -------------------------------------------------------------------------------- /aviary/wren/utils.py: -------------------------------------------------------------------------------- 1 | def __getattr__(name): 2 | raise ImportError( 3 | "The functionality from aviary.wren.utils has been moved to pymatgen. " 4 | "Please install pymatgen using 'pip install pymatgen>2025.3.10' to " 5 | "access these features." 6 | ) 7 | -------------------------------------------------------------------------------- /tests/wren/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_utils_import_error(): 5 | with pytest.raises(ImportError) as exc_info: 6 | from aviary.wren.utils import relab_dict # noqa: F401 7 | 8 | assert "functionality from aviary.wren.utils has been moved to pymatgen" in str( 9 | exc_info.value 10 | ) 11 | -------------------------------------------------------------------------------- /aviary/embeddings/wyckoff/README.md: -------------------------------------------------------------------------------- 1 | # Wyckoff Position Embeddings 2 | 3 | ## bra-alg-off 4 | 5 | A 6 + 5 + 185 + 248 = 444 dimensional embedding encoding 6 | 7 | - the crystal system 8 | - unit cell centering 9 | - the sum of whether sites in a wyckoff position sit on lines, planes or in volumes 10 | - the specific offset of that wyckoff position within the unit cell 11 | -------------------------------------------------------------------------------- /citation.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | authors: 5 | - given-names: Rhys 6 | family-names: Goodall 7 | orcid: https://orcid.org/0000-0002-6589-1700 8 | - given-names: Janosh 9 | family-names: Riebesell 10 | orcid: https://orcid.org/0000-0001-5233-3462 11 | title: "aviary" 12 | doi: 10.5281/zenodo.1234 13 | date-released: 2024-11-09 14 | url: "https://github.com/CompRhys/aviary" 15 | -------------------------------------------------------------------------------- /.github/workflows/link-check.yml: -------------------------------------------------------------------------------- 1 | name: Link check 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | markdown-link-check: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Check out repo 14 | uses: actions/checkout@v4 15 | 16 | - name: Run markdown link check 17 | uses: gaurav-nelson/github-action-markdown-link-check@v1 18 | # docs at https://git.io/JBaKu 19 | with: 20 | config-file: .github/workflows/link-check-config.json 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0fd347fb29c23cad30d42852d7aa4ed906102b8a.poscar: -------------------------------------------------------------------------------- 1 | -.17043474E+02 '002/ht.task.gamma_devel--default.0fd347fb29c23cad30d42852d7aa4ed906102b8a.cleanup.0.unclaimed.1.finished/ht.run.2015-11-30_09.47.37' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426988. Zr : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).'] 2 | 1 3 | 2.800434 -1.616831 0.000000 4 | 0.000000 3.233663 0.000000 5 | 0.000000 0.000000 5.166609 6 | Zr 7 | 2 8 | Direct 9 | 0.33333333 0.66666667 0.75000000 10 | 0.66666667 0.33333333 0.25000000 11 | -------------------------------------------------------------------------------- /examples/inputs/raw/b3603cef0816d0e115adeafa4eee65c056dc7a26.poscar: -------------------------------------------------------------------------------- 1 | -.21842114E+01 '002/ht.task.gamma_devel--default.b3603cef0816d0e115adeafa4eee65c056dc7a26.cleanup.0.unclaimed.1.finished/ht.run.2015-11-24_17.17.33' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426987. Zn : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).', '(Tag) type: pure'] 2 | 1 3 | 2.443685 -1.410862 0.000000 4 | 0.000000 2.821725 -0.000000 5 | 0.000000 -0.000000 4.305149 6 | Zn 7 | 2 8 | Direct 9 | 0.33333333 0.66666667 0.75000000 10 | 0.66666667 0.33333333 0.25000000 11 | -------------------------------------------------------------------------------- /examples/inputs/raw/b428ccf0f8cd5488a3af5d72b74e117dc71aa575.poscar: -------------------------------------------------------------------------------- 1 | -.15606022E+02 '002/ht.task.gamma_devel--default.b428ccf0f8cd5488a3af5d72b74e117dc71aa575.cleanup.0.unclaimed.1.finished/ht.run.2015-11-24_17.24.06' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426981. Ti : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).', '(Tag) type: pure'] 2 | 1 3 | 2.540600 -1.466816 0.000000 4 | 0.000000 2.933632 0.000000 5 | 0.000000 0.000000 4.657195 6 | Ti 7 | 2 8 | Direct 9 | 0.33333333 0.66666667 0.75000000 10 | 0.66666667 0.33333333 0.25000000 11 | -------------------------------------------------------------------------------- /examples/inputs/raw/c07b988125b3a0adc329b2449eeaac24c9e7def8.poscar: -------------------------------------------------------------------------------- 1 | -.19847840E+02 '002/ht.task.triolith--default.c07b988125b3a0adc329b2449eeaac24c9e7def8.cleanup.0.unclaimed.1.finished/ht.run.2015-12-10_10.57.55' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426944. Hf : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).', '(Tag) type: pure'] 2 | 1 3 | 2.771359 -1.600045 0.000000 4 | 0.000000 3.200090 0.000000 5 | 0.000000 0.000000 5.055614 6 | Hf 7 | 2 8 | Direct 9 | 0.33333333 0.66666667 0.75000000 10 | 0.66666667 0.33333333 0.25000000 11 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | tests: 9 | uses: ./.github/workflows/test.yml 10 | 11 | release: 12 | runs-on: ubuntu-latest 13 | needs: tests 14 | if: needs.tests.result == 'success' 15 | permissions: 16 | id-token: write 17 | 18 | steps: 19 | - name: Check out repo 20 | uses: actions/checkout@v4 21 | 22 | - name: Setup uv 23 | uses: astral-sh/setup-uv@v6 24 | 25 | - name: Build package with uv 26 | run: uv build 27 | 28 | - name: Publish package distributions to PyPI with uv 29 | run: uv publish -t ${{ secrets.PYPI_TOKEN }} 30 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c9cf664171d7e10527033cda1551694141207ed.poscar: -------------------------------------------------------------------------------- 1 | -.27483177E+02 '004/ht.task.triolith--default.0c9cf664171d7e10527033cda1551694141207ed.cleanup.0.unclaimed.2.finished/ht.run.2015-11-21_17.02.04' ['(Tag) original structure: AuKO2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/47/3_47_1aA1fB1nC.cif'] 2 | 1 3 | 2.999577 0.000000 0.000000 4 | 0.000000 5.448017 0.000000 5 | 0.000000 0.000000 2.860144 6 | Zr Zn N 7 | 1 1 2 8 | Direct 9 | 0.50000000 0.50000000 0.00000000 10 | 0.00000000 0.00000000 0.00000000 11 | 0.00000000 0.28297953 0.50000000 12 | 0.00000000 0.71702047 0.50000000 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # cache & build 2 | __pycache__/ 3 | *.egg-info 4 | .mypy_cache/ 5 | build/ 6 | 7 | # vscode 8 | .vscode/ 9 | 10 | # notebooks 11 | .ipynb_checkpoints/ 12 | 13 | # TensorBoard logs, model predictions and checkpoints 14 | runs/ 15 | results/ 16 | models/ 17 | 18 | # exclude data files and cached graphs 19 | *.pkl 20 | *.pdf 21 | *.json.gz 22 | *.json.bz2 23 | 24 | # HPC 25 | slurm-*.out 26 | plots/ 27 | datasets/ 28 | pds/ 29 | manuscript/ 30 | voro-thesis/ 31 | 32 | # run artifacts like model preds, checkpoints, metrics and slurm job logs 33 | examples/**/model_preds/ 34 | examples/**/model_scores/ 35 | examples/**/job-logs/ 36 | examples/**/artifacts/ 37 | examples/**/*.csv 38 | wandb/ 39 | 40 | # profiling 41 | *.prof 42 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c5ef0e6b20d6e34ae388d640797f032a15df918.poscar: -------------------------------------------------------------------------------- 1 | -.30885650E+02 '005/ht.task.triolith--default.0c5ef0e6b20d6e34ae388d640797f032a15df918.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_15.09.03' ['(Tag) original structure: AuCl3Cs', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/215/3_215_1aA1bB1cC.cif'] 2 | 1 3 | 3.779652 0.000000 0.000000 4 | 0.000000 3.779652 0.000000 5 | 0.000000 0.000000 3.779652 6 | Zr Zn N 7 | 1 1 3 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.50000000 0.50000000 0.50000000 11 | 0.00000000 0.50000000 0.50000000 12 | 0.50000000 0.00000000 0.50000000 13 | 0.50000000 0.50000000 0.00000000 14 | -------------------------------------------------------------------------------- /examples/inputs/raw/006047d5cfe39157f2904d2533bad74b2ddd3170.poscar: -------------------------------------------------------------------------------- 1 | -.40933125E+02 '006/ht.task.triolith--default.006047d5cfe39157f2904d2533bad74b2ddd3170.cleanup.0.unclaimed.2.finished/ht.run.2015-12-10_22.22.59' ['(Tag) original structure: CsCuO', '(Tag) original structure path: structures/all/matches/ternaries/63/3_63_1aA1cB1cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 1.749531 -3.048103 0.000000 4 | 1.749531 3.048103 -0.000000 5 | 0.000000 -0.000000 8.634492 6 | Zn Hf N 7 | 2 2 2 8 | Direct 9 | 0.66982580 0.33017420 0.25000000 10 | 0.33017420 0.66982580 0.75000000 11 | -0.00000000 0.00000000 -0.00000000 12 | -0.00000000 0.00000000 0.50000000 13 | 0.00158956 0.99841044 0.25000000 14 | 0.99841044 0.00158956 0.75000000 15 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a2c3a235656a97d76f731b269982f92de9b37fd.poscar: -------------------------------------------------------------------------------- 1 | -.38204233E+02 '006/ht.task.gamma--default.0a2c3a235656a97d76f731b269982f92de9b37fd.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_14.53.55' ['(Tag) original structure: GaIY', '(Tag) original structure path: structures/all/matches/ternaries/164/3_164_1cA1dB1dC.cif', '(Tag) type: ternary'] 2 | 1 3 | 2.730114 -1.576232 -0.000000 4 | 0.000000 3.152464 0.000000 5 | -0.000000 0.000000 11.873833 6 | Ti N Zn 7 | 2 2 2 8 | Direct 9 | 0.00000000 -0.00000000 0.28482879 10 | 0.00000000 -0.00000000 0.71517121 11 | 0.33333333 0.66666667 0.33039369 12 | 0.66666667 0.33333333 0.66960631 13 | 0.66666667 0.33333333 0.08508663 14 | 0.33333333 0.66666667 0.91491337 15 | -------------------------------------------------------------------------------- /examples/inputs/raw/0abb6669aa25624d8371765c5afa6b80ae20f30b.poscar: -------------------------------------------------------------------------------- 1 | -.46777870E+02 '006/ht.task.gamma--default.0abb6669aa25624d8371765c5afa6b80ae20f30b.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_01.48.20' ['(Tag) original structure: CoO3Sr2', '(Tag) original structure path: structures/all/matches/ternaries/12/4_12_1cA1dB1iB1iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 2.082063 -2.081320 0.050838 4 | 2.082063 2.081320 0.050838 5 | -1.049178 0.000000 9.996562 6 | Ti Zn N 7 | 2 1 3 8 | Direct 9 | 0.46125530 0.46125530 0.29720372 10 | 0.53874470 0.53874470 0.70279628 11 | 0.00000000 0.00000000 0.50000000 12 | 0.50000000 0.50000000 0.50000000 13 | 0.95590632 0.95590632 0.27067391 14 | 0.04409368 0.04409368 0.72932609 15 | -------------------------------------------------------------------------------- /examples/inputs/raw/006a8e78e4bb5c1570bf70b47b24c0b855a5b5b5.poscar: -------------------------------------------------------------------------------- 1 | -.23913820E+02 '006/ht.task.triolith--default.006a8e78e4bb5c1570bf70b47b24c0b855a5b5b5.cleanup.0.unclaimed.2.finished/ht.run.2015-12-10_19.42.57' ['(Tag) original structure: Br2Ca3Si', '(Tag) original structure path: structures/all/matches/ternaries/139/4_139_1aA1bB1eC1eA.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.768322 0.000000 -0.000000 4 | 0.000000 3.768322 -0.000000 5 | 1.884161 1.884161 6.517125 6 | Zn Hf N 7 | 3 1 2 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.70751047 0.70751047 0.58497906 11 | 0.29248953 0.29248953 0.41502094 12 | 0.50000000 0.50000000 0.00000000 13 | 0.14233929 0.14233929 0.71532143 14 | 0.85766071 0.85766071 0.28467857 15 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a2e1f340255f62a617974d778714e5fdd28a007.poscar: -------------------------------------------------------------------------------- 1 | -.35119632E+02 '008/ht.task.gamma--default.0a2e1f340255f62a617974d778714e5fdd28a007.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_00.54.14' ['(Tag) original structure: CsN2Nb', '(Tag) original structure path: structures/all/matches/ternaries/227/3_227_1aA1bB1cC.cif', '(Tag) type: ternary'] 2 | 1 3 | -4.226700 -4.226700 0.000000 4 | -4.226700 0.000000 -4.226700 5 | 0.000000 -4.226700 -4.226700 6 | Zn Ti N 7 | 2 2 4 8 | Direct 9 | 0.87500000 0.87500000 0.87500000 10 | 0.12500000 0.12500000 0.12500000 11 | 0.62500000 0.62500000 0.62500000 12 | 0.37500000 0.37500000 0.37500000 13 | 0.00000000 0.00000000 0.00000000 14 | 0.00000000 0.50000000 0.00000000 15 | 0.50000000 0.00000000 0.00000000 16 | 0.00000000 0.00000000 0.50000000 17 | -------------------------------------------------------------------------------- /examples/inputs/raw/000fe0f2a4730347df688bd0d3c75f4494a28dca.poscar: -------------------------------------------------------------------------------- 1 | -.60754153E+02 '008/ht.task.triolith--default.000fe0f2a4730347df688bd0d3c75f4494a28dca.cleanup.0.unclaimed.2.finished/ht.run.2016-04-28_16.45.35' ['(Tag) original structure: CsSe2Yb', '(Tag) original structure path: structures/all/matches/ternaries/194/3_194_1aA1cB1fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 2.724122 -1.572773 0.000000 4 | 0.000000 3.145546 0.000000 5 | 0.000000 0.000000 10.570134 6 | Hf Zn N 7 | 2 2 4 8 | Direct 9 | 0.33333333 0.66666667 0.25000000 10 | 0.66666667 0.33333333 0.75000000 11 | 0.00000000 0.00000000 0.00000000 12 | 0.00000000 0.00000000 0.50000000 13 | 0.33333333 0.66666667 0.62866178 14 | 0.66666667 0.33333333 0.12866178 15 | 0.33333333 0.66666667 0.87133822 16 | 0.66666667 0.33333333 0.37133822 17 | -------------------------------------------------------------------------------- /examples/inputs/raw/00af9b9350147b8c20498eb9cd65b63a3132c741.poscar: -------------------------------------------------------------------------------- 1 | -.57885965E+02 '008/ht.task.triolith--default.00af9b9350147b8c20498eb9cd65b63a3132c741.cleanup.0.unclaimed.2.finished/ht.run.2016-04-28_17.15.45' ['(Tag) original structure: FeSe2Tl', '(Tag) original structure path: structures/all/matches/ternaries/12/4_12_1gA1iB2iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.845331 -2.256375 1.002936 4 | 5.845331 2.256375 1.002936 5 | -2.359010 0.000000 3.498230 6 | Hf Zn N 7 | 2 2 4 8 | Direct 9 | 0.19931503 0.19931503 0.63601203 10 | 0.80068497 0.80068497 0.36398797 11 | 0.74375205 0.25624795 0.00000000 12 | 0.25624795 0.74375205 0.00000000 13 | 0.63229337 0.63229337 0.41424153 14 | 0.36770663 0.36770663 0.58575847 15 | 0.15059160 0.15059160 0.05184831 16 | 0.84940840 0.84940840 0.94815169 17 | -------------------------------------------------------------------------------- /tests/test_package.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | import pytest 5 | 6 | from aviary import ROOT 7 | 8 | package_sources_path = f"{ROOT}/aviary.egg-info/SOURCES.txt" 9 | 10 | 11 | @pytest.mark.skipif( 12 | not os.path.isfile(package_sources_path), 13 | reason="No aviary.egg-info/SOURCES.txt file, run pip install . to create it", 14 | ) 15 | def test_egg_sources(): 16 | """Check we're correctly packaging all JSON files under aviary/ to prevent issues 17 | like https://github.com/CompRhys/aviary/pull/45. 18 | 19 | This test can fail due to outdated SOURCES.txt. Try `pip install -e .` to update. 20 | """ 21 | with open(package_sources_path) as file: 22 | sources = file.read() 23 | 24 | for filepath in glob(f"{ROOT}/aviary/**/*.json", recursive=True): 25 | rel_path = filepath.split(f"{ROOT}/aviary/")[1] 26 | assert rel_path in sources 27 | -------------------------------------------------------------------------------- /examples/inputs/raw/6f7d6a970fde2a4243091aba720fc14ceec2854b.poscar: -------------------------------------------------------------------------------- 1 | -.66733607E+02 '008/ht.task.triolith--default.6f7d6a970fde2a4243091aba720fc14ceec2854b.cleanup.0.unclaimed.1.finished/ht.run.2015-12-10_11.08.37' ['(Tag) comment: Generated by cif2cell 1.0.12 from ICSD reference: 426956. N : Kurt Lejaeghere et al., Critical Reviews in Solid State and Materials Sciences 39, 1-24 (2014).', '(Tag) type: pure'] 2 | 1 3 | 5.946714 -0.000000 -0.000000 4 | 0.000000 5.946714 -0.000000 5 | 0.000000 -0.000000 5.946714 6 | N 7 | 8 8 | Direct 9 | 0.55402836 0.55402836 0.55402836 10 | 0.05402836 0.94597164 0.44597164 11 | 0.55402836 0.94597164 0.05402836 12 | 0.94597164 0.44597164 0.05402836 13 | 0.94597164 0.05402836 0.55402836 14 | 0.44597164 0.05402836 0.94597164 15 | 0.05402836 0.55402836 0.94597164 16 | 0.44597164 0.44597164 0.44597164 17 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c7147c05a270b9b2744576b0bfb69c8391ee1e4.poscar: -------------------------------------------------------------------------------- 1 | -.56738203E+02 '009/ht.task.triolith--default.0c7147c05a270b9b2744576b0bfb69c8391ee1e4.cleanup.0.unclaimed.2.finished/ht.run.2015-11-25_10.02.52' ['(Tag) original structure: CdO6Ti2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/162/3_162_1aA1dB1kC.cif'] 2 | 1 3 | 4.813812 -2.779256 0.000000 4 | -0.000000 5.558512 0.000000 5 | 0.000000 -0.000000 3.745624 6 | Zr Zn N 7 | 1 2 6 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.33333333 0.66666667 0.50000000 11 | 0.66666667 0.33333333 0.50000000 12 | 0.41946521 0.00000000 0.12157999 13 | 0.58053479 1.00000000 0.87842001 14 | 0.00000000 0.41946521 0.12157999 15 | 0.58053479 0.58053479 0.12157999 16 | 0.41946521 0.41946521 0.87842001 17 | 0.00000000 0.58053479 0.87842001 18 | -------------------------------------------------------------------------------- /examples/inputs/raw/0060794138df3ba93c68007949b0a72d19cb4c62.poscar: -------------------------------------------------------------------------------- 1 | -.56928156E+02 '010/ht.task.triolith--default.0060794138df3ba93c68007949b0a72d19cb4c62.cleanup.0.unclaimed.2.finished/ht.run.2016-04-28_01.39.04' ['(Tag) original structure: C3O6Y', '(Tag) original structure path: structures/all/matches/ternaries/160/4_160_1aA1bB2bC.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.062902 -0.000000 1.294977 4 | -3.031451 5.250627 1.294978 5 | -3.031451 -5.250627 1.294978 6 | Hf N Zn 7 | 1 6 3 8 | Direct 9 | 0.07941259 0.07941259 0.07941259 10 | 0.97170032 0.35461423 0.97170032 11 | 0.97170032 0.97170032 0.35461423 12 | 0.35461423 0.97170033 0.97170033 13 | 0.62831402 0.48486499 0.62831402 14 | 0.62831402 0.62831402 0.48486499 15 | 0.48486498 0.62831401 0.62831401 16 | 0.43962517 0.92332917 0.43962517 17 | 0.43962517 0.43962517 0.92332917 18 | 0.92332917 0.43962517 0.43962517 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c2fc1da358bcddfb4baf41b5dfd7285a13b766a.poscar: -------------------------------------------------------------------------------- 1 | -.66952114E+02 '010/ht.task.triolith--default.0c2fc1da358bcddfb4baf41b5dfd7285a13b766a.cleanup.0.unclaimed.2.finished/ht.run.2015-11-27_01.25.28' ['(Tag) original structure: CoLaO3', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/15/4_15_1cA1eB1eC1fC.cif'] 2 | 1 3 | 4.971455 0.000000 -0.378972 4 | 0.000000 5.742631 -0.000000 5 | 2.166218 2.871315 3.580151 6 | Zr Zn N 7 | 2 2 6 8 | Direct 9 | 0.25000000 0.25874194 0.00000000 10 | 0.75000000 0.74125806 1.00000000 11 | 0.50000000 0.00000000 0.50000000 12 | 0.00000000 0.50000000 0.50000000 13 | 0.25000000 0.63406163 1.00000000 14 | 0.75000000 0.36593837 0.00000000 15 | 0.87451404 0.86599373 0.39235106 16 | 0.12548596 0.13400627 0.60764894 17 | 0.62548596 0.25834479 0.60764894 18 | 0.37451404 0.74165521 0.39235106 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/0cd142efcf69ba63dc62760ecd72fb4f1e930115.poscar: -------------------------------------------------------------------------------- 1 | -.57088230E+02 '010/ht.task.triolith--default.0cd142efcf69ba63dc62760ecd72fb4f1e930115.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_16.16.25' ['(Tag) original structure: Ca2FeN2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/12/5_12_1iA2iB2iC.cif'] 2 | 1 3 | 5.503549 -2.105416 0.306120 4 | 5.503549 2.105416 0.306120 5 | -3.212949 0.000000 5.665865 6 | Zr Zn N 7 | 2 4 4 8 | Direct 9 | 0.33909577 0.33909577 0.88821450 10 | 0.66090423 0.66090423 0.11178550 11 | 0.01117193 0.01117193 0.73206145 12 | 0.98882807 0.98882807 0.26793855 13 | 0.33875302 0.33875302 0.40873122 14 | 0.66124698 0.66124698 0.59126878 15 | 0.51840336 0.51840336 0.23739812 16 | 0.48159664 0.48159664 0.76260188 17 | 0.15707630 0.15707630 0.59729982 18 | 0.84292370 0.84292370 0.40270018 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/004c70f544dca7f8e1c74c6c4ff87a1e355e7cd5.poscar: -------------------------------------------------------------------------------- 1 | -.54542103E+02 '010/ht.task.triolith--default.004c70f544dca7f8e1c74c6c4ff87a1e355e7cd5.cleanup.0.unclaimed.2.finished/ht.run.2016-04-27_21.28.21' ['(Tag) original structure: NiS6Ti3', '(Tag) original structure path: structures/all/matches/ternaries/148/4_148_1aA1bB1cA1fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.062449 0.000000 4.987390 4 | -1.531224 2.652158 4.987390 5 | -1.531224 -2.652158 4.987390 6 | Hf Zn N 7 | 1 3 6 8 | Direct 9 | 0.50000000 0.50000000 0.50000000 10 | 0.00000000 -0.00000000 -0.00000000 11 | 0.30592519 0.30592519 0.30592519 12 | 0.69407481 0.69407481 0.69407481 13 | 0.60495542 0.89911462 0.27504484 14 | 0.10088538 0.72495516 0.39504458 15 | 0.72495516 0.39504458 0.10088538 16 | 0.27504484 0.60495542 0.89911462 17 | 0.89911462 0.27504484 0.60495542 18 | 0.39504458 0.10088538 0.72495516 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c5e41fdc8fba8e90b48cbe4752dc2c4a2805f21.poscar: -------------------------------------------------------------------------------- 1 | -.61534447E+02 '010/ht.task.triolith--default.0c5e41fdc8fba8e90b48cbe4752dc2c4a2805f21.cleanup.0.unclaimed.2.finished/ht.run.2015-11-23_17.48.43' ['(Tag) original structure: CuF3K', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/140/4_140_1bA1cB1dC1hA.cif'] 2 | 1 3 | 6.038698 0.000000 -0.000000 4 | 0.000000 6.038698 -0.000000 5 | 3.019349 3.019349 4.167943 6 | Zr Zn N 7 | 2 2 6 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.50000000 0.50000000 0.00000000 11 | 0.00000000 0.50000000 0.00000000 12 | 0.50000000 0.00000000 0.00000000 13 | 0.23981570 0.73981570 1.00000000 14 | 0.76018430 0.26018430 0.00000000 15 | 0.73981570 0.76018430 0.00000000 16 | 0.26018430 0.23981570 0.00000000 17 | 0.75000000 0.25000000 0.50000000 18 | 0.25000000 0.75000000 0.50000000 19 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a74297fb11027ba2a958da5ee31980b8876e1d4.poscar: -------------------------------------------------------------------------------- 1 | -.70525527E+02 '012/ht.task.gamma--default.0a74297fb11027ba2a958da5ee31980b8876e1d4.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_06.45.53' ['(Tag) original structure: O4SiTi', '(Tag) original structure path: structures/all/matches/ternaries/63/4_63_1aA1cB1fC1gC.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.058315 -4.469731 0.000000 4 | 3.058315 4.469731 0.000000 5 | 0.000000 0.000000 6.091423 6 | Ti Zn N 7 | 2 2 8 8 | Direct 9 | 0.00000000 -0.00000000 0.00000000 10 | 0.00000000 -0.00000000 0.50000000 11 | 0.61502096 0.38497904 0.25000000 12 | 0.38497904 0.61502096 0.75000000 13 | 0.77707723 0.22292277 0.02604294 14 | 0.22292277 0.77707723 0.52604294 15 | 0.22292277 0.77707723 0.97395706 16 | 0.77707723 0.22292277 0.47395706 17 | 0.79819076 0.76342840 0.25000000 18 | 0.20180924 0.23657160 0.75000000 19 | 0.76342840 0.79819076 0.75000000 20 | 0.23657160 0.20180924 0.25000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ac7f92ee7225e47dbdc25f6d60a10b116b0813b.poscar: -------------------------------------------------------------------------------- 1 | -.63921795E+02 '012/ht.task.gamma--default.0ac7f92ee7225e47dbdc25f6d60a10b116b0813b.cleanup.0.unclaimed.3.finished/ht.run.2015-11-09_06.07.49' ['(Tag) original structure: Fe2S3Tl', '(Tag) original structure path: structures/all/matches/ternaries/63/4_63_1cA1cB1eC1gA.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.275648 -3.887910 0.000000 4 | 4.275648 3.887910 -0.000000 5 | 0.000000 0.000000 4.969047 6 | Ti Zn N 7 | 2 4 6 8 | Direct 9 | 0.40001196 0.59998804 0.25000000 10 | 0.59998804 0.40001196 0.75000000 11 | 0.83644128 0.83644128 -0.00000000 12 | 0.16355872 0.16355872 0.50000000 13 | 0.83644128 0.83644128 0.50000000 14 | 0.16355872 0.16355872 -0.00000000 15 | 0.89385404 0.10614596 0.25000000 16 | 0.10614596 0.89385404 0.75000000 17 | 0.71612659 0.59086281 0.25000000 18 | 0.28387341 0.40913719 0.75000000 19 | 0.59086281 0.71612659 0.75000000 20 | 0.40913719 0.28387341 0.25000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c13902f8200732555ad1603e954a726d4d24d85.poscar: -------------------------------------------------------------------------------- 1 | -.58351590E+02 '012/ht.task.triolith--default.0c13902f8200732555ad1603e954a726d4d24d85.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_21.10.12' ['(Tag) original structure: Ag3NaO2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/72/4_72_1bA1cB1eB1jC.cif'] 2 | 1 3 | 5.720953 0.000000 0.000000 4 | 0.000000 11.967072 0.000000 5 | 2.860476 5.983536 3.126332 6 | Zr Zn N 7 | 2 6 4 8 | Direct 9 | 0.25000000 0.75000000 0.50000000 10 | 0.75000000 0.25000000 0.50000000 11 | 0.00000000 0.00000000 0.00000000 12 | 0.50000000 0.50000000 0.00000000 13 | 0.00000000 0.00000000 0.50000000 14 | 0.50000000 0.00000000 0.50000000 15 | 0.00000000 0.50000000 0.50000000 16 | 0.50000000 0.50000000 0.50000000 17 | 0.29083893 0.06226772 0.00000000 18 | 0.20916107 0.56226772 0.00000000 19 | 0.79083893 0.43773228 -0.00000000 20 | 0.70916107 0.93773228 0.00000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c9c74b6f48facd507c4e41d9df797a2017634e7.poscar: -------------------------------------------------------------------------------- 1 | -.69099362E+02 '012/ht.task.triolith--default.0c9c74b6f48facd507c4e41d9df797a2017634e7.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_09.39.55' ['(Tag) original structure: Cl3Cs2Li', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/63/5_63_1cA1cB1fA2cC.cif'] 2 | 1 3 | 1.537037 -8.306579 0.000000 4 | 1.537037 8.306579 -0.000000 5 | 0.000000 0.000000 5.818854 6 | Zn Zr N 7 | 4 2 6 8 | Direct 9 | 0.01974571 0.98025429 0.25000000 10 | 0.98025429 0.01974571 0.75000000 11 | 0.16944846 0.83055154 0.25000000 12 | 0.83055154 0.16944846 0.75000000 13 | 0.33953783 0.66046217 0.25000000 14 | 0.66046217 0.33953783 0.75000000 15 | 0.41969297 0.58030703 0.97742156 16 | 0.58030703 0.41969297 0.47742156 17 | 0.58030703 0.41969297 0.02257844 18 | 0.41969297 0.58030703 0.52257844 19 | 0.75587878 0.24412122 0.25000000 20 | 0.24412122 0.75587878 0.75000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ad7d4c22ba4414fbe161f27f42c112c6cb83b29.poscar: -------------------------------------------------------------------------------- 1 | -.69790953E+02 '012/ht.task.gamma--default.0ad7d4c22ba4414fbe161f27f42c112c6cb83b29.cleanup.0.unclaimed.3.finished/ht.run.2015-11-07_17.50.22' ['(Tag) original structure: Ba3S7Zr2', '(Tag) original structure path: structures/all/matches/ternaries/139/6_139_1aA1bB1eA1eB1eC1gB.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.021411 0.000000 0.000000 4 | 0.000000 4.021411 0.000000 5 | 2.010706 2.010706 8.874517 6 | Zn Ti N 7 | 3 2 7 8 | Direct 9 | -0.00000000 -0.00000000 0.00000000 10 | 0.83602820 0.83602820 0.32794360 11 | 0.16397180 0.16397180 0.67205640 12 | 0.60922880 0.60922880 0.78154240 13 | 0.39077120 0.39077120 0.21845760 14 | 0.61249780 0.11249780 0.77500439 15 | 0.11249780 0.61249780 0.77500439 16 | 0.38750220 0.88750220 0.22499561 17 | 0.88750220 0.38750220 0.22499561 18 | 0.72695695 0.72695695 0.54608610 19 | 0.27304305 0.27304305 0.45391390 20 | 0.50000000 0.50000000 0.00000000 21 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c710ee221d7567cc0f1a06fba8231ae8261a181.poscar: -------------------------------------------------------------------------------- 1 | -.65122406E+02 '013/ht.task.triolith--default.0c710ee221d7567cc0f1a06fba8231ae8261a181.cleanup.0.unclaimed.2.finished/ht.run.2015-11-26_08.31.24' ['(Tag) original structure: AgMo6Te6', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/12/7_12_1aA3iB3iC.cif'] 2 | 1 3 | 7.932667 -1.614529 0.519404 4 | 7.932667 1.614529 0.519404 5 | -3.354285 0.000000 6.743572 6 | Zn N Zr 7 | 6 6 1 8 | Direct 9 | 0.72571673 0.72571673 0.20605258 10 | 0.27428327 0.27428327 0.79394742 11 | 0.88941755 0.88941755 0.62365230 12 | 0.11058245 0.11058245 0.37634770 13 | 0.59020144 0.59020144 0.70963199 14 | 0.40979856 0.40979856 0.29036801 15 | 0.85546464 0.85546464 0.34923770 16 | 0.14453536 0.14453536 0.65076230 17 | 0.51240013 0.51240013 0.21277716 18 | 0.48759987 0.48759987 0.78722284 19 | 0.84710770 0.84710770 0.84881876 20 | 0.15289230 0.15289230 0.15118124 21 | -0.00000000 -0.00000000 0.00000000 22 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ac3201405a380e382377a71216b33cde892b161.poscar: -------------------------------------------------------------------------------- 1 | -.80497110E+02 '014/ht.task.gamma--default.0ac3201405a380e382377a71216b33cde892b161.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_19.27.00' ['(Tag) original structure: MoO4Tl2', '(Tag) original structure path: structures/all/matches/ternaries/5/8_5_1aA1bA1cB1cA4cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.952401 -2.607102 -0.345365 4 | 4.952401 2.607102 -0.345365 5 | -0.793138 0.000000 6.257486 6 | Zn Ti N 7 | 4 2 8 8 | Direct 9 | 0.95747786 0.04252214 -0.00000000 10 | 0.96835037 0.03164963 0.50000000 11 | 0.36677608 0.33913125 0.20324262 12 | 0.66086875 0.63322392 0.79675738 13 | 0.57577517 0.75508330 0.28765311 14 | 0.24491670 0.42422483 0.71234689 15 | 0.85703681 0.61391281 0.53609174 16 | 0.38608719 0.14296319 0.46390826 17 | 0.20259831 0.78613842 0.24921180 18 | 0.21386158 0.79740169 0.75078820 19 | 0.64474828 0.89284159 0.05563091 20 | 0.10715841 0.35525172 0.94436909 21 | 0.75781176 0.31376727 0.27014504 22 | 0.68623273 0.24218824 0.72985496 23 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c27147781fc9c64cbcf45acc1575a0bf8070841.poscar: -------------------------------------------------------------------------------- 1 | -.95884543E+02 '014/ht.task.triolith--default.0c27147781fc9c64cbcf45acc1575a0bf8070841.cleanup.0.unclaimed.2.finished/ht.run.2015-11-24_18.58.40' ['(Tag) original structure: Na2O3Zn2', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/14/4_14_1aA1eB1eA1eC.cif'] 2 | 1 3 | 6.561230 0.000000 0.098238 4 | 0.000000 6.060726 0.000000 5 | -2.983057 0.000000 5.510293 6 | Zn Zr N 7 | 4 4 6 8 | Direct 9 | 0.16989291 0.87830687 0.43646054 10 | 0.83010709 0.12169313 0.56353946 11 | 0.83010709 0.37830687 0.06353946 12 | 0.16989291 0.62169313 0.93646054 13 | 0.32449651 0.11021715 0.07651784 14 | 0.67550349 0.88978285 0.92348216 15 | 0.67550349 0.61021715 0.42348216 16 | 0.32449651 0.38978285 0.57651784 17 | 0.37508238 0.40328558 0.26426987 18 | 0.62491762 0.59671442 0.73573013 19 | 0.62491762 0.90328558 0.23573013 20 | 0.37508238 0.09671442 0.76426987 21 | 0.00000000 0.00000000 0.00000000 22 | 0.00000000 0.50000000 0.50000000 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Rhys Goodall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c5264c980964b88ae8a7c2c18d5a61daff481b1.poscar: -------------------------------------------------------------------------------- 1 | -.77616025E+02 '014/ht.task.triolith--default.0c5264c980964b88ae8a7c2c18d5a61daff481b1.cleanup.0.unclaimed.2.finished/ht.run.2015-11-22_15.02.35' ['(Tag) original structure: MnNa2O4', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/186/5_186_1aA1bB1bA1bC1cC.cif'] 2 | 1 3 | 4.890752 -2.823677 0.000000 4 | 0.000000 5.647354 0.000000 5 | 0.000000 0.000000 6.428382 6 | Zr Zn N 7 | 2 4 8 8 | Direct 9 | 0.33333333 0.66666667 0.62486757 10 | 0.66666667 0.33333333 0.12486757 11 | 0.00000000 0.00000000 0.04179757 12 | 0.00000000 0.00000000 0.54179757 13 | 0.33333333 0.66666667 0.02173231 14 | 0.66666667 0.33333333 0.52173231 15 | 0.33333333 0.66666667 0.30226121 16 | 0.66666667 0.33333333 0.80226121 17 | 0.16316165 0.32632331 0.82311378 18 | 0.83683835 0.67367669 0.32311378 19 | 0.32632331 0.16316165 0.32311378 20 | 0.83683835 0.16316165 0.32311378 21 | 0.16316165 0.83683835 0.82311378 22 | 0.67367669 0.83683835 0.82311378 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | `aviary` is an open-source machine learning library for materials discovery. It aims to be modular and easy to understand. 2 | 3 | Contributing Guidelines: 4 | 5 | 1. Code should be written in Python and adhere to PEP8 style guidelines. We provide `pre-commit` hooks and run pre-commit.ci to on PRs. 6 | 1. All code contributions should be accompanied by unit tests. 7 | 1. Use descriptive and meaningful variable and function names. 8 | 1. Include doc strings for all functions and classes. 9 | 1. Keep the codebase simple and easy to understand. 10 | 1. Use Git and GitHub for version control and pull requests. 11 | 1. All contributions must be released under the MIT license. 12 | 1. Before submitting a pull request, make sure all tests pass and that your code is well-documented. 13 | 1. Avoid using external libraries unless they are necessary for the functionality of the library. 14 | 1. Follow the issues and pull requests to keep track of the development of the library and contribute where you can. Don't be afraid to ask for help! 15 | 16 | Thank you for your interest in contributing to `aviary`! We look forward to your contributions. 17 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a58ef8a37185d8abcbd6e1bad7eaf7e3d3d6016.poscar: -------------------------------------------------------------------------------- 1 | -.13166155E+03 '015/ht.task.gamma--default.0a58ef8a37185d8abcbd6e1bad7eaf7e3d3d6016.cleanup.0.unclaimed.3.finished/ht.run.2015-11-08_20.24.54' ['(Tag) original structure: Mo6Se8Yb', '(Tag) original structure path: structures/all/matches/ternaries/148/4_148_1aA1cB1fC1fB.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.112473 -0.000003 2.815429 4 | -2.556234 4.427533 2.815429 5 | -2.556238 -4.427530 2.815429 6 | Ti Zn N 7 | 6 1 8 8 | Direct 9 | 0.55169467 0.35282435 0.16914522 10 | 0.64717565 0.83085478 0.44830533 11 | 0.83085478 0.44830533 0.64717565 12 | 0.16914522 0.55169467 0.35282435 13 | 0.35282435 0.16914522 0.55169467 14 | 0.44830533 0.64717565 0.83085478 15 | 0.00000000 0.00000000 0.00000000 16 | 0.22107771 0.22107771 0.22107771 17 | 0.77892229 0.77892229 0.77892229 18 | 0.15262017 0.41165775 0.66079431 19 | 0.58834225 0.33920569 0.84737983 20 | 0.33920569 0.84737983 0.58834225 21 | 0.66079431 0.15262017 0.41165775 22 | 0.41165775 0.66079431 0.15262017 23 | 0.84737983 0.58834225 0.33920569 24 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c9cbed1e5d755a9666db36f742004c14bf780bd.poscar: -------------------------------------------------------------------------------- 1 | -.12835112E+03 '016/ht.task.triolith--default.0c9cbed1e5d755a9666db36f742004c14bf780bd.cleanup.0.unclaimed.2.finished/ht.run.2015-11-25_19.44.22' ['(Tag) original structure: Bi2GeO5', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/9/8_9_1aA2aB5aC.cif'] 2 | 1 3 | 6.131598 -3.234940 0.226503 4 | 6.131598 3.234940 0.226503 5 | 0.246773 0.000000 5.643709 6 | Zr Zn N 7 | 4 2 10 8 | Direct 9 | 0.00267917 0.27212232 0.25091105 10 | 0.27212232 0.00267917 0.75091105 11 | 0.53639722 0.15487041 0.20074018 12 | 0.15487041 0.53639722 0.70074018 13 | 0.41162369 0.66084098 0.29113314 14 | 0.66084098 0.41162369 0.79113314 15 | 0.30194670 0.26792173 0.94447680 16 | 0.26792173 0.30194670 0.44447680 17 | 0.84479035 0.67122354 0.09235619 18 | 0.67122354 0.84479035 0.59235619 19 | 0.89077500 0.81130119 0.05351553 20 | 0.81130119 0.89077500 0.55351553 21 | 0.22233666 0.95259651 0.11785524 22 | 0.95259651 0.22233666 0.61785524 23 | 0.36334386 0.66043067 0.64751187 24 | 0.66043067 0.36334386 0.14751187 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/0079b9e3ab0e210ea4afe037a9faee8cd6577922.poscar: -------------------------------------------------------------------------------- 1 | -.12677341E+03 '016/ht.task.triolith--default.0079b9e3ab0e210ea4afe037a9faee8cd6577922.cleanup.0.unclaimed.2.finished/ht.run.2015-12-13_17.30.39' ['(Tag) original structure: LiO5V2', '(Tag) original structure path: structures/all/matches/ternaries/31/8_31_1aA2aB5aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 12.156346 0.000000 0.000000 4 | 0.000000 4.319613 0.000000 5 | 0.000000 0.000000 4.222262 6 | Hf N Zn 7 | 4 10 2 8 | Direct 9 | 0.58531349 0.50000000 0.16750956 10 | 0.08531349 0.00000000 0.83249044 11 | 0.40614012 0.00000000 0.04123820 12 | 0.90614012 0.50000000 0.95876180 13 | 0.64872327 0.50000000 0.63718450 14 | 0.14872327 0.00000000 0.36281550 15 | 0.39480052 0.00000000 0.54485042 16 | 0.89480052 0.50000000 0.45514958 17 | 0.42143717 0.50000000 0.07644136 18 | 0.92143717 0.00000000 0.92355864 19 | 0.58993952 0.00000000 0.09266172 20 | 0.08993952 0.50000000 0.90733828 21 | 0.23293583 0.00000000 0.16108981 22 | 0.73293583 0.50000000 0.83891019 23 | 0.19201007 0.50000000 0.28261666 24 | 0.69201007 0.00000000 0.71738334 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/00bbba1d86168727008d4c1542b39d45d93d693c.poscar: -------------------------------------------------------------------------------- 1 | -.91305185E+02 '016/ht.task.triolith--default.00bbba1d86168727008d4c1542b39d45d93d693c.cleanup.0.unclaimed.2.finished/ht.run.2015-12-13_02.31.04' ['(Tag) original structure: O4Pd3Sr', '(Tag) original structure path: structures/all/matches/ternaries/223/3_223_1aA1cB1eC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.629202 0.000000 0.000000 4 | 0.000000 5.629202 0.000000 5 | 0.000000 0.000000 5.629202 6 | Hf Zn N 7 | 2 6 8 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.50000000 0.50000000 0.50000000 11 | 0.25000000 0.00000000 0.50000000 12 | 0.50000000 0.75000000 0.00000000 13 | 0.00000000 0.50000000 0.25000000 14 | 0.00000000 0.50000000 0.75000000 15 | 0.50000000 0.25000000 0.00000000 16 | 0.75000000 0.00000000 0.50000000 17 | 0.25000000 0.25000000 0.25000000 18 | 0.25000000 0.75000000 0.75000000 19 | 0.75000000 0.75000000 0.75000000 20 | 0.25000000 0.75000000 0.25000000 21 | 0.25000000 0.25000000 0.75000000 22 | 0.75000000 0.25000000 0.75000000 23 | 0.75000000 0.75000000 0.25000000 24 | 0.75000000 0.25000000 0.25000000 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a3eb2074380f4abdec94d4a6171a099c9ecb30a.poscar: -------------------------------------------------------------------------------- 1 | -.10306186E+03 '016/ht.task.gamma--default.0a3eb2074380f4abdec94d4a6171a099c9ecb30a.cleanup.0.unclaimed.3.finished/ht.run.2015-11-09_09.55.13' ['(Tag) original structure: Fe2O5Ti', '(Tag) original structure path: structures/all/matches/ternaries/63/5_63_1cA1cB1fC2fA.cif', '(Tag) type: ternary'] 2 | 1 3 | 1.887732 -8.293794 -0.000000 4 | 1.887732 8.293794 0.000000 5 | 0.000000 -0.000000 10.903555 6 | Ti Zn N 7 | 2 4 10 8 | Direct 9 | 0.77771064 0.22228936 0.25000000 10 | 0.22228936 0.77771064 0.75000000 11 | 0.76640818 0.23359182 0.58357468 12 | 0.23359182 0.76640818 0.08357468 13 | 0.23359182 0.76640818 0.41642532 14 | 0.76640818 0.23359182 0.91642532 15 | 0.31288974 0.68711026 0.25000000 16 | 0.68711026 0.31288974 0.75000000 17 | 0.00945486 0.99054514 0.19890372 18 | 0.99054514 0.00945486 0.69890372 19 | 0.99054514 0.00945486 0.80109628 20 | 0.00945486 0.99054514 0.30109628 21 | 0.71261595 0.28738405 0.10220736 22 | 0.28738405 0.71261595 0.60220736 23 | 0.28738405 0.71261595 0.89779264 24 | 0.71261595 0.28738405 0.39779264 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a646e174268b493bd67f58fdf52fead02f74a96.poscar: -------------------------------------------------------------------------------- 1 | -.11595836E+03 '016/ht.task.gamma--default.0a646e174268b493bd67f58fdf52fead02f74a96.cleanup.0.unclaimed.3.finished/ht.run.2016-06-23_09.53.03' ['(Tag) original structure: Ba3O9Yb4', '(Tag) original structure path: structures/all/matches/ternaries/146/10_146_3aA3bB4aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.166867 -0.000000 6.757404 4 | -1.583433 2.742587 6.757404 5 | -1.583433 -2.742587 6.757404 6 | Zn Ti N 7 | 3 4 9 8 | Direct 9 | 0.01183712 0.01183712 0.01183712 10 | 0.17509278 0.17509278 0.17509278 11 | 0.58822661 0.58822661 0.58822661 12 | 0.44824029 0.44824029 0.44824029 13 | 0.88012936 0.88012936 0.88012936 14 | 0.73794629 0.73794629 0.73794629 15 | 0.30888986 0.30888986 0.30888986 16 | 0.22038392 0.96628227 0.59433793 17 | 0.96628227 0.59433793 0.22038392 18 | 0.59433793 0.22038392 0.96628227 19 | 0.10520064 0.47388568 0.80909796 20 | 0.47388568 0.80909796 0.10520064 21 | 0.80909796 0.10520064 0.47388568 22 | 0.07908029 0.37863310 0.71223592 23 | 0.37863310 0.71223592 0.07908029 24 | 0.71223592 0.07908029 0.37863310 25 | -------------------------------------------------------------------------------- /examples/inputs/raw/00ceac1a8da8e554218a84def07e5a6da9f106f3.poscar: -------------------------------------------------------------------------------- 1 | -.14491348E+03 '016/ht.task.triolith--default.00ceac1a8da8e554218a84def07e5a6da9f106f3.cleanup.0.unclaimed.2.finished/ht.run.2015-12-13_14.53.57' ['(Tag) original structure: CsI4Li3', '(Tag) original structure path: structures/all/matches/ternaries/6/16_6_2bA3aB3bC3bB5aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 8.083094 0.000000 -0.208196 4 | 0.000000 3.746822 0.000000 5 | 0.280916 0.000000 8.346883 6 | Zn Hf N 7 | 2 6 8 8 | Direct 9 | 0.11519925 0.50000000 0.14812181 10 | 0.52623739 0.50000000 0.32004718 11 | 0.07390019 0.50000000 0.78474769 12 | 0.44698331 0.50000000 0.97366906 13 | 0.37200353 0.00000000 0.59577805 14 | 0.97106300 0.00000000 0.44002485 15 | 0.70070504 0.50000000 0.65161885 16 | 0.77589491 0.00000000 0.00494085 17 | 0.50521321 0.00000000 0.05573442 18 | 0.33683218 0.50000000 0.73258869 19 | 0.93708398 0.50000000 0.56442548 20 | 0.70583489 0.50000000 0.89905921 21 | 0.99245769 0.00000000 0.84790603 22 | 0.21380499 0.00000000 0.42136796 23 | 0.89843228 0.00000000 0.21193260 24 | 0.62841745 0.00000000 0.58627726 25 | -------------------------------------------------------------------------------- /aviary/embeddings/element/README.md: -------------------------------------------------------------------------------- 1 | # Element - Embeddings 2 | 3 | ## CGCNN92 4 | 5 | The following paper describes the details of the CGCNN framework, information about how the one-hot atom embedding was constructed is available in the supplementary materials: 6 | 7 | > [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://link.aps.org/doi/10.1103/PhysRevLett.120.145301) 8 | 9 | ## MEGNet16 10 | 11 | The following paper describes the details of the MEGnet framework, the embedding is generated from the atomic weights through one MEGnet layer on a training task to predict the computed formation energies of ∼69,000 materials from the Materials Project database: 12 | 13 | > [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/abs/1812.05055) 14 | 15 | ## MatScholar200 16 | 17 | This is an experimental NLP embedding based on the data mining of definite compositions and structure prototypes. 18 | 19 | > [Unsupervised word embeddings capture latent knowledge from materials science literature](https://www.nature.com/articles/s41586-019-1335-8) 20 | 21 | ## Onehot112 22 | 23 | This is a simple one-hot encoding for the elements 24 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a5af69c21e6c2d9b004c2f51939790b10b19814.poscar: -------------------------------------------------------------------------------- 1 | -.12124245E+03 '017/ht.task.gamma--default.0a5af69c21e6c2d9b004c2f51939790b10b19814.cleanup.0.unclaimed.3.finished/ht.run.2016-06-11_04.24.19' ['(Tag) original structure: As5Cs3O9', '(Tag) original structure path: structures/all/matches/ternaries/157/5_157_1bA1cA1cB1cC1dC.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.776201 -3.912241 0.000000 4 | 0.000000 7.824483 0.000000 5 | 0.000000 0.000000 4.868194 6 | Zn Ti N 7 | 3 5 9 8 | Direct 9 | 0.43304785 1.00000000 0.90142116 10 | 0.56695215 0.56695215 0.90142116 11 | 0.00000000 0.43304785 0.90142116 12 | 0.73676343 1.00000000 0.46239314 13 | 0.26323657 0.26323657 0.46239314 14 | 0.00000000 0.73676343 0.46239314 15 | 0.33333333 0.66666667 0.48072680 16 | 0.66666667 0.33333333 0.48072680 17 | 0.54692646 0.00000000 0.24748564 18 | 0.45307354 0.45307354 0.24748564 19 | 0.00000000 0.54692646 0.24748564 20 | 0.20794832 0.42235553 0.68279110 21 | 0.21440721 0.79205168 0.68279110 22 | 0.57764447 0.78559279 0.68279110 23 | 0.79205168 0.21440721 0.68279110 24 | 0.42235553 0.20794832 0.68279110 25 | 0.78559279 0.57764447 0.68279110 26 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a0ae9c18de23e72bb764e10ef542de693321542.poscar: -------------------------------------------------------------------------------- 1 | -.12758363E+03 '018/ht.task.gamma--default.0a0ae9c18de23e72bb764e10ef542de693321542.cleanup.1.unclaimed.3.finished/ht.run.2016-05-19_22.39.46' ['(Tag) original structure: In2O5Sr2', '(Tag) original structure path: structures/all/matches/ternaries/46/6_46_1aA1bA1bB1cC2cB.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.380653 0.000000 0.000000 4 | -0.000000 14.311422 0.000000 5 | 2.690327 7.155711 2.712897 6 | Zn Ti N 7 | 4 4 10 8 | Direct 9 | 0.59605935 0.66971651 0.94546880 10 | 0.45847185 0.16971651 0.94546880 11 | 0.45847185 0.38481469 0.94546880 12 | 0.59605935 0.88481469 0.94546880 13 | 0.99066415 0.99066415 0.01867171 14 | 0.99066415 0.49066415 0.01867171 15 | 0.93415690 0.25680447 0.98639105 16 | 0.07945204 0.75680447 0.98639105 17 | 0.23781504 0.81687590 0.31447074 18 | 0.44771422 0.31687590 0.31447074 19 | 0.44771422 0.86865337 0.31447074 20 | 0.23781504 0.36865337 0.31447074 21 | 0.97010020 0.02841900 0.20962174 22 | 0.82027806 0.52841900 0.20962174 23 | 0.82027806 0.76195926 0.20962174 24 | 0.97010020 0.26195926 0.20962174 25 | 0.26244227 0.61609265 0.26781470 26 | 0.46974303 0.11609265 0.26781470 27 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c450e6f9c256e33ef452c98dd20917fb6484317.poscar: -------------------------------------------------------------------------------- 1 | -.11423502E+03 '018/ht.task.triolith--default.0c450e6f9c256e33ef452c98dd20917fb6484317.cleanup.0.unclaimed.2.finished/ht.run.2015-11-23_07.40.55' ['(Tag) original structure: Ca3Ga2N4', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/15/5_15_1eA1fA1fB2fC.cif'] 2 | 1 3 | 4.850647 -4.711321 -0.214326 4 | 4.850647 4.711321 -0.214326 5 | -0.377966 0.000000 5.614459 6 | Zn Zr N 7 | 6 4 8 8 | Direct 9 | 0.93947997 0.24500985 0.59229075 10 | 0.75499015 0.06052003 0.90770925 11 | 0.06052003 0.75499015 0.40770925 12 | 0.24500985 0.93947997 0.09229075 13 | 0.65515963 0.34484037 0.25000000 14 | 0.34484037 0.65515963 0.75000000 15 | 0.18630163 0.38379918 0.12077952 16 | 0.61620082 0.81369837 0.37922048 17 | 0.81369837 0.61620082 0.87922048 18 | 0.38379918 0.18630163 0.62077952 19 | 0.88219627 0.32914320 0.00176296 20 | 0.67085680 0.11780373 0.49823704 21 | 0.11780373 0.67085680 0.99823704 22 | 0.32914320 0.88219627 0.50176296 23 | 0.24930985 0.44271839 0.48301702 24 | 0.55728161 0.75069015 0.01698298 25 | 0.75069015 0.55728161 0.51698298 26 | 0.44271839 0.24930985 0.98301702 27 | -------------------------------------------------------------------------------- /examples/inputs/raw/0059fd6dd1d17138c42bb5430fa0bc074053f0de.poscar: -------------------------------------------------------------------------------- 1 | -.12205554E+03 '018/ht.task.triolith--default.0059fd6dd1d17138c42bb5430fa0bc074053f0de.cleanup.0.unclaimed.2.finished/ht.run.2016-04-29_04.58.57' ['(Tag) original structure: Cl6K2Pu', '(Tag) original structure path: structures/all/matches/ternaries/10/8_10_1aA1hA1mB1mC1nB1nC2oB.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.517271 0.000000 -1.602318 4 | 0.000000 5.261389 0.000000 5 | -2.700464 0.000000 6.804575 6 | Zn Hf N 7 | 4 2 12 8 | Direct 9 | 0.53948074 -0.00000000 0.19307466 10 | 0.46051926 -0.00000000 0.80692534 11 | 0.03938539 0.50000000 0.69311581 12 | 0.96061461 0.50000000 0.30688419 13 | 0.00000000 -0.00000000 0.00000000 14 | 0.50000000 0.50000000 0.50000000 15 | 0.19037601 -0.00000000 0.37098077 16 | 0.80962399 -0.00000000 0.62901923 17 | 0.69031429 0.50000000 0.87094031 18 | 0.30968571 0.50000000 0.12905969 19 | 0.24131655 0.28439444 0.00896358 20 | 0.75868345 0.28439444 0.99103642 21 | 0.75868345 0.71560556 0.99103642 22 | 0.24131655 0.71560556 0.00896358 23 | 0.74128896 0.78441677 0.50890356 24 | 0.25871104 0.78441677 0.49109644 25 | 0.25871104 0.21558323 0.49109644 26 | 0.74128896 0.21558323 0.50890356 27 | -------------------------------------------------------------------------------- /examples/inputs/raw/00fa7d75b7b2e1442df0a68071210c8d64b93ebe.poscar: -------------------------------------------------------------------------------- 1 | -.15718967E+03 '018/ht.task.triolith--default.00fa7d75b7b2e1442df0a68071210c8d64b93ebe.cleanup.0.unclaimed.2.finished/ht.run.2016-04-29_06.09.22' ['(Tag) original structure: Ag5RbSe3', '(Tag) original structure path: structures/all/matches/ternaries/125/5_125_1aA1cB1dC1gB1mA.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.828845 0.000000 0.000000 4 | -0.000000 5.828845 0.000000 5 | -0.000000 0.000000 9.239161 6 | Zn N Hf 7 | 2 6 10 8 | Direct 9 | 0.75000000 0.25000000 0.50000000 10 | 0.25000000 0.75000000 0.50000000 11 | 0.75000000 0.25000000 0.00000000 12 | 0.25000000 0.75000000 0.00000000 13 | 0.25000000 0.25000000 0.23728861 14 | 0.25000000 0.25000000 0.76271139 15 | 0.75000000 0.75000000 0.23728861 16 | 0.75000000 0.75000000 0.76271139 17 | 0.25000000 0.25000000 0.00000000 18 | 0.75000000 0.75000000 -0.00000000 19 | 0.57763663 0.07763663 0.25057902 20 | 0.92236337 0.42236337 0.25057902 21 | 0.57763663 0.42236337 0.74942098 22 | 0.07763663 0.92236337 0.25057902 23 | 0.92236337 0.07763663 0.74942098 24 | 0.42236337 0.92236337 0.74942098 25 | 0.07763663 0.57763663 0.74942098 26 | 0.42236337 0.57763663 0.25057902 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install these hooks with `pre-commit install`. 2 | 3 | ci: 4 | autoupdate_schedule: quarterly 5 | 6 | repos: 7 | - repo: https://github.com/astral-sh/ruff-pre-commit 8 | rev: v0.13.3 9 | hooks: 10 | - id: ruff 11 | args: [--fix] 12 | - id: ruff-format 13 | 14 | - repo: https://github.com/pre-commit/pre-commit-hooks 15 | rev: v6.0.0 16 | hooks: 17 | - id: check-case-conflict 18 | - id: check-symlinks 19 | - id: check-yaml 20 | - id: destroyed-symlinks 21 | - id: end-of-file-fixer 22 | - id: mixed-line-ending 23 | - id: trailing-whitespace 24 | 25 | - repo: https://github.com/codespell-project/codespell 26 | rev: v2.4.1 27 | hooks: 28 | - id: codespell 29 | exclude_types: [json] 30 | args: [--check-filenames] 31 | 32 | - repo: https://github.com/pre-commit/mirrors-mypy 33 | rev: v1.18.2 34 | hooks: 35 | - id: mypy 36 | exclude: (tests|examples)/ 37 | additional_dependencies: [wandb] 38 | 39 | - repo: https://github.com/janosh/format-ipy-cells 40 | rev: v0.1.11 41 | hooks: 42 | - id: format-ipy-cells 43 | 44 | - repo: https://github.com/kynan/nbstripout 45 | rev: 0.8.1 46 | hooks: 47 | - id: nbstripout 48 | args: [--drop-empty-cells] 49 | -------------------------------------------------------------------------------- /examples/inputs/raw/0cd1f1670176b03cc3d61935d9a6882c9508e33a.poscar: -------------------------------------------------------------------------------- 1 | -.16820589E+03 '020/ht.task.triolith--default.0cd1f1670176b03cc3d61935d9a6882c9508e33a.cleanup.0.unclaimed.4.finished/ht.run.2015-11-27_20.14.49' ['(Tag) original structure: Er3Se6Sm', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/11/10_11_1eA3eB6eC.cif'] 2 | 1 3 | 8.161179 0.000000 0.449837 4 | 0.000000 3.324704 0.000000 5 | -2.371431 0.000000 9.270920 6 | Zr Zn N 7 | 6 2 12 8 | Direct 9 | 0.94849747 0.25000000 0.32956364 10 | 0.05150253 0.75000000 0.67043636 11 | 0.18106850 0.25000000 0.03690071 12 | 0.81893150 0.75000000 0.96309929 13 | 0.67614681 0.25000000 0.61333077 14 | 0.32385319 0.75000000 0.38666923 15 | 0.53082596 0.25000000 0.12658890 16 | 0.46917404 0.75000000 0.87341110 17 | 0.92053932 0.25000000 0.58165729 18 | 0.07946068 0.75000000 0.41834271 19 | 0.30285053 0.25000000 0.68482012 20 | 0.69714947 0.75000000 0.31517988 21 | 0.94853790 0.25000000 0.11106725 22 | 0.05146210 0.75000000 0.88893275 23 | 0.63551390 0.25000000 0.95343270 24 | 0.36448610 0.75000000 0.04656730 25 | 0.37209085 0.25000000 0.57851740 26 | 0.62790915 0.75000000 0.42148260 27 | 0.29515767 0.25000000 0.26471150 28 | 0.70484233 0.75000000 0.73528850 29 | -------------------------------------------------------------------------------- /examples/inputs/raw/00237cf70206a5de77e5154ea72633514cdd4445.poscar: -------------------------------------------------------------------------------- 1 | -.13765617E+03 '020/ht.task.triolith--default.00237cf70206a5de77e5154ea72633514cdd4445.cleanup.0.unclaimed.2.finished/ht.run.2015-12-27_15.11.38' ['(Tag) original structure: Cl7CsTi2', '(Tag) original structure path: structures/all/matches/ternaries/11/7_11_1eA1fB2fC3eC.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.086823 0.000000 0.232607 4 | 0.000000 9.465448 0.000000 5 | -0.059997 0.000000 6.856863 6 | Hf N Zn 7 | 2 14 4 8 | Direct 9 | 0.18968935 0.25000000 0.17660677 10 | 0.81031065 0.75000000 0.82339323 11 | 0.66980532 0.25000000 0.30672720 12 | 0.33019468 0.75000000 0.69327280 13 | 0.68756283 0.25000000 0.47916256 14 | 0.31243717 0.75000000 0.52083744 15 | 0.27122645 0.25000000 0.89843441 16 | 0.72877355 0.75000000 0.10156559 17 | 0.21180647 0.03552534 0.27936006 18 | 0.78819353 0.96447466 0.72063994 19 | 0.78819353 0.53552534 0.72063994 20 | 0.21180647 0.46447466 0.27936006 21 | 0.14015442 0.02839410 0.47485170 22 | 0.85984558 0.97160590 0.52514830 23 | 0.85984558 0.52839410 0.52514830 24 | 0.14015442 0.47160590 0.47485170 25 | 0.36190174 0.08645613 0.74088626 26 | 0.63809826 0.91354387 0.25911374 27 | 0.63809826 0.58645613 0.25911374 28 | 0.36190174 0.41354387 0.74088626 29 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c882a0621189d18ff1350e3bed4114868663acf.poscar: -------------------------------------------------------------------------------- 1 | -.17215809E+03 '020/ht.task.triolith--default.0c882a0621189d18ff1350e3bed4114868663acf.cleanup.0.unclaimed.4.finished/ht.run.2016-06-04_05.34.00' ['(Tag) original structure: I5NiPr4', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/59/7_59_1aA1aB1eC2bC2eA.cif'] 2 | 1 3 | 3.240560 0.000000 0.000000 4 | 0.000000 12.886769 0.000000 5 | 0.000000 0.000000 7.790818 6 | Zr N Zn 7 | 8 10 2 8 | Direct 9 | 0.25000000 0.75000000 0.55699214 10 | 0.75000000 0.25000000 0.44300786 11 | 0.25000000 0.75000000 0.16926273 12 | 0.75000000 0.25000000 0.83073727 13 | 0.25000000 0.44772047 0.15767431 14 | 0.75000000 0.94772047 0.84232569 15 | 0.75000000 0.55227953 0.84232569 16 | 0.25000000 0.05227953 0.15767431 17 | 0.25000000 0.60748537 0.68873064 18 | 0.75000000 0.10748537 0.31126936 19 | 0.75000000 0.39251463 0.31126936 20 | 0.25000000 0.89251463 0.68873064 21 | 0.25000000 0.59581971 0.05569449 22 | 0.75000000 0.09581971 0.94430551 23 | 0.75000000 0.40418029 0.94430551 24 | 0.25000000 0.90418029 0.05569449 25 | 0.25000000 0.25000000 0.64078793 26 | 0.75000000 0.75000000 0.35921207 27 | 0.25000000 0.25000000 0.13097308 28 | 0.75000000 0.75000000 0.86902692 29 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ad8980cb8eb0693647f411732b2165b6a0fce0e.poscar: -------------------------------------------------------------------------------- 1 | -.16275965E+03 '020/ht.task.gamma--default.0ad8980cb8eb0693647f411732b2165b6a0fce0e.cleanup.0.unclaimed.3.finished/ht.run.2016-04-12_08.07.42' ['(Tag) original structure: Mn7O12Pr', '(Tag) original structure path: structures/all/matches/ternaries/12/10_12_1aA1bB1cB1dB1eB1fB2iC2jC.cif', '(Tag) type: ternary'] 2 | 1 3 | 7.625926 -0.000000 0.076043 4 | -0.000000 7.580728 -0.000000 5 | 3.767819 3.790364 3.834359 6 | Zn Ti N 7 | 1 7 12 8 | Direct 9 | 0.00000000 0.00000000 0.00000000 10 | 0.50000000 0.50000000 0.00000000 11 | 0.50000000 0.00000000 0.00000000 12 | 0.00000000 0.50000000 0.00000000 13 | 0.50000000 0.50000000 0.50000000 14 | 0.50000000 0.00000000 0.50000000 15 | 0.00000000 0.00000000 0.50000000 16 | 0.00000000 0.50000000 0.50000000 17 | 0.17462695 0.86728018 0.65081407 18 | 0.82537305 0.51809424 0.34918593 19 | 0.82537305 0.13271982 0.34918593 20 | 0.17462695 0.48190576 0.65081407 21 | 0.50337766 0.68510006 0.62979988 22 | 0.49662234 0.31489994 0.37020012 23 | 0.86607696 0.69069681 0.61860638 24 | 0.13392304 0.30930319 0.38139362 25 | 0.68829538 0.82298879 0.99962863 26 | 0.31170462 0.82261742 0.00037137 27 | 0.31170462 0.17701121 0.00037137 28 | 0.68829538 0.17738258 0.99962863 29 | -------------------------------------------------------------------------------- /examples/inputs/raw/009ee0451531276f598976f46848fe7d590262ca.poscar: -------------------------------------------------------------------------------- 1 | -.17482530E+03 '021/ht.task.triolith--default.009ee0451531276f598976f46848fe7d590262ca.cleanup.0.unclaimed.2.finished/ht.run.2016-01-05_18.33.01' ['(Tag) original structure: Li2O13V6', '(Tag) original structure path: structures/all/matches/ternaries/12/11_12_1bA1iB3iC6iA.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.835117 -2.032658 -0.515824 4 | 6.835117 2.032658 -0.515824 5 | -3.147213 0.000000 11.027458 6 | Hf N Zn 7 | 6 13 2 8 | Direct 9 | 0.34297942 0.34297942 0.99965321 10 | 0.65702058 0.65702058 0.00034679 11 | 0.39485957 0.39485957 0.38058576 12 | 0.60514043 0.60514043 0.61941424 13 | 0.69607843 0.69607843 0.38139415 14 | 0.30392157 0.30392157 0.61860585 15 | 0.19897726 0.19897726 0.04469716 16 | 0.80102274 0.80102274 0.95530284 17 | 0.84879528 0.84879528 0.34182287 18 | 0.15120472 0.15120472 0.65817713 19 | 0.23924638 0.23924638 0.40916785 20 | 0.76075362 0.76075362 0.59083215 21 | 0.50000000 0.50000000 0.00000000 22 | 0.39982928 0.39982928 0.19561494 23 | 0.60017072 0.60017072 0.80438506 24 | 0.62078394 0.62078394 0.18012629 25 | 0.37921606 0.37921606 0.81987371 26 | 0.55627083 0.55627083 0.41837015 27 | 0.44372917 0.44372917 0.58162985 28 | 0.11671079 0.11671079 0.16606411 29 | 0.88328921 0.88328921 0.83393589 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | paths: ["**/*.py", ".github/workflows/test.yml"] 6 | branches: [main] 7 | pull_request: 8 | paths: ["**/*.py", ".github/workflows/test.yml"] 9 | branches: [main] 10 | release: 11 | types: [published] 12 | workflow_dispatch: 13 | workflow_call: 14 | 15 | concurrency: 16 | # Cancel only on same PR number 17 | group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | tests: 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | os: [ubuntu-latest, macos-latest, windows-latest] 26 | version: 27 | - { python: "3.10", resolution: highest } 28 | - { python: "3.12", resolution: lowest-direct } 29 | runs-on: ${{ matrix.os }} 30 | 31 | steps: 32 | - name: Check out repo 33 | uses: actions/checkout@v4 34 | 35 | - name: Set up Python 36 | uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ matrix.version.python }} 39 | 40 | - name: Set up uv 41 | uses: astral-sh/setup-uv@v2 42 | 43 | - name: Install dependencies 44 | run: | 45 | uv pip install torch --index-url https://download.pytorch.org/whl/cpu --system 46 | uv pip install .[test] --system 47 | 48 | - name: Run Tests 49 | run: pytest --capture=no --cov . 50 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c5ffa8e2b3126eba02901c37b31e8faf3b5377c.poscar: -------------------------------------------------------------------------------- 1 | -.16882837E+03 '023/ht.task.triolith--default.0c5ffa8e2b3126eba02901c37b31e8faf3b5377c.cleanup.0.unclaimed.4.finished/ht.run.2016-06-02_02.38.13' ['(Tag) original structure: B6O13Zn4', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/217/4_217_1aA1cB1dC1gA.cif'] 2 | 1 3 | -4.856584 4.856584 4.856584 4 | 4.856584 -4.856584 4.856584 5 | 4.856584 4.856584 -4.856584 6 | Zn Zr N 7 | 4 6 13 8 | Direct 9 | 0.42585440 0.42585440 0.42585440 10 | -0.00000000 -0.00000000 0.57414560 11 | -0.00000000 0.57414560 -0.00000000 12 | 0.57414560 -0.00000000 -0.00000000 13 | 0.50000000 0.25000000 0.75000000 14 | 0.50000000 0.75000000 0.25000000 15 | 0.75000000 0.50000000 0.25000000 16 | 0.75000000 0.25000000 0.50000000 17 | 0.25000000 0.75000000 0.50000000 18 | 0.25000000 0.50000000 0.75000000 19 | -0.00000000 -0.00000000 -0.00000000 20 | 0.51838322 0.51838322 0.27053201 21 | 0.24785121 0.24785121 0.72946799 22 | 0.24785121 0.72946799 0.24785121 23 | 0.00000000 0.48161678 0.75214879 24 | 0.27053201 0.51838322 0.51838322 25 | 0.51838322 0.27053201 0.51838322 26 | -0.00000000 0.75214879 0.48161678 27 | 0.75214879 0.48161678 -0.00000000 28 | 0.48161678 0.75214879 0.00000000 29 | 0.48161678 -0.00000000 0.75214879 30 | 0.75214879 0.00000000 0.48161678 31 | 0.72946799 0.24785121 0.24785121 32 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a458b5f2eeebeb0e4e55e1d1a17988ec87b7123.poscar: -------------------------------------------------------------------------------- 1 | -.14108066E+03 '024/ht.task.gamma--default.0a458b5f2eeebeb0e4e55e1d1a17988ec87b7123.cleanup.0.unclaimed.3.finished/ht.run.2016-05-01_14.31.14' ['(Tag) original structure: O3Rb2Ti', '(Tag) original structure path: structures/all/matches/ternaries/64/6_64_1eA1fB2fA2fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 2.889034 -4.949055 -0.000000 4 | 2.889034 4.949055 0.000000 5 | 0.000000 0.000000 10.531082 6 | Zn Ti N 7 | 8 4 12 8 | Direct 9 | 0.94903530 0.05096470 0.16283211 10 | 0.55096470 0.44903530 0.66283211 11 | 0.44903530 0.55096470 0.33716789 12 | 0.05096470 0.94903530 0.83716789 13 | 0.63746097 0.36253903 0.10209048 14 | 0.86253903 0.13746097 0.60209048 15 | 0.13746097 0.86253903 0.39790952 16 | 0.36253903 0.63746097 0.89790952 17 | 0.75230031 0.24769969 0.35600488 18 | 0.74769969 0.25230031 0.85600488 19 | 0.25230031 0.74769969 0.14399512 20 | 0.24769969 0.75230031 0.64399512 21 | 0.67082105 0.32917895 0.51925674 22 | 0.82917895 0.17082105 0.01925674 23 | 0.17082105 0.82917895 0.98074326 24 | 0.32917895 0.67082105 0.48074326 25 | 0.94589238 0.05410762 0.35185694 26 | 0.55410762 0.44589238 0.85185694 27 | 0.44589238 0.55410762 0.14814306 28 | 0.05410762 0.94589238 0.64814306 29 | 0.93929385 0.56070615 0.25000000 30 | 0.56070615 0.93929385 0.75000000 31 | 0.43929385 0.06070615 0.25000000 32 | 0.06070615 0.43929385 0.75000000 33 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ac8a65af75e8223b07dcd98c6958db7af6504cf.poscar: -------------------------------------------------------------------------------- 1 | -.19471572E+03 '024/ht.task.gamma--default.0ac8a65af75e8223b07dcd98c6958db7af6504cf.cleanup.0.unclaimed.3.finished/ht.run.2016-06-10_22.43.42' ['(Tag) original structure: CaGa4O7', '(Tag) original structure path: structures/all/matches/ternaries/15/7_15_1eA1eB2fC3fB.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.730469 -4.584523 -0.342761 4 | 6.730469 4.584523 -0.342761 5 | -1.806793 0.000000 5.356055 6 | Zn Ti N 7 | 2 8 14 8 | Direct 9 | 0.71413180 0.28586820 0.25000000 10 | 0.28586820 0.71413180 0.75000000 11 | 0.06304672 0.16684276 0.66757928 12 | 0.83315724 0.93695328 0.83242072 13 | 0.93695328 0.83315724 0.33242072 14 | 0.16684276 0.06304672 0.16757928 15 | 0.26619824 0.44058687 0.15513109 16 | 0.55941313 0.73380176 0.34486891 17 | 0.73380176 0.55941313 0.84486891 18 | 0.44058687 0.26619824 0.65513109 19 | 0.13372099 0.26871112 0.03056721 20 | 0.73128888 0.86627901 0.46943279 21 | 0.86627901 0.73128888 0.96943279 22 | 0.26871112 0.13372099 0.53056721 23 | 0.31540762 0.48366840 0.86537518 24 | 0.51633160 0.68459238 0.63462482 25 | 0.68459238 0.51633160 0.13462482 26 | 0.48366840 0.31540762 0.36537518 27 | 0.84326964 0.35643738 0.60951799 28 | 0.64356262 0.15673036 0.89048201 29 | 0.15673036 0.64356262 0.39048201 30 | 0.35643738 0.84326964 0.10951799 31 | 0.95866177 0.04133823 0.25000000 32 | 0.04133823 0.95866177 0.75000000 33 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a130b596d1ce7bfe317d9e546bbeacc703fef65.poscar: -------------------------------------------------------------------------------- 1 | -.18818657E+03 '024/ht.task.gamma--default.0a130b596d1ce7bfe317d9e546bbeacc703fef65.cleanup.0.unclaimed.3.finished/ht.run.2015-11-10_03.32.23' ['(Tag) original structure: NaO8Ta3', '(Tag) original structure path: structures/all/matches/ternaries/72/6_72_1aA1bB1jB1kC2jC.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.383894 0.000000 -0.000000 4 | -0.000000 7.842330 0.000000 5 | 3.191947 3.921165 4.697699 6 | Ti Zn N 7 | 6 2 16 8 | Direct 9 | 0.25000000 0.75000000 0.50000000 10 | 0.75000000 0.25000000 0.50000000 11 | 0.26264823 0.24098197 0.00000000 12 | 0.23735177 0.74098197 1.00000000 13 | 0.76264823 0.25901803 0.00000000 14 | 0.73735177 0.75901803 1.00000000 15 | 0.75000000 0.75000000 0.50000000 16 | 0.25000000 0.25000000 0.50000000 17 | 0.99047279 0.08221216 1.00000000 18 | 0.50952721 0.58221216 0.00000000 19 | 0.49047279 0.41778784 1.00000000 20 | 0.00952721 0.91778784 0.00000000 21 | 0.08382932 0.45000872 1.00000000 22 | 0.41617068 0.95000872 0.00000000 23 | 0.58382932 0.04999128 1.00000000 24 | 0.91617068 0.54999128 0.00000000 25 | 0.94713593 0.03445718 0.57572495 26 | 0.55286407 0.11018214 0.42427505 27 | 0.02286088 0.46554282 0.42427505 28 | 0.52286088 0.61018214 0.42427505 29 | 0.97713912 0.53445718 0.57572495 30 | 0.44713593 0.88981786 0.57572495 31 | 0.05286407 0.96554282 0.42427505 32 | 0.47713912 0.38981786 0.57572495 33 | -------------------------------------------------------------------------------- /aviary/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def robust_l1_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor: 6 | """Robust L1 loss using a Lorentzian prior. Trains the model to learn to predict 7 | aleatoric (i.e. per-sample) uncertainty. 8 | 9 | Args: 10 | pred_mean (Tensor): Tensor of predicted means. 11 | pred_log_std (Tensor): Tensor of predicted log standard deviations representing 12 | per-sample model uncertainties. 13 | target (Tensor): Tensor of target values. 14 | 15 | Returns: 16 | Tensor: Evaluated robust L1 loss 17 | """ 18 | loss = 2**0.5 * (pred_mean - target).abs() * torch.exp(-pred_log_std) + pred_log_std 19 | return torch.mean(loss) 20 | 21 | 22 | def robust_l2_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor: 23 | """Robust L2 loss using a Gaussian prior. Trains the model to learn to predict 24 | aleatoric (i.e. per-sample) uncertainty. 25 | 26 | Args: 27 | pred_mean (Tensor): Tensor of predicted means. 28 | pred_log_std (Tensor): Tensor of predicted log standard deviations representing 29 | per-sample model uncertainties. 30 | target (Tensor): Tensor of target values. 31 | 32 | Returns: 33 | Tensor: Evaluated robust L2 loss 34 | """ 35 | loss = 0.5 * (pred_mean - target) ** 2 * torch.exp(-2 * pred_log_std) + pred_log_std 36 | return torch.mean(loss) 37 | 38 | 39 | # aliases for backwards compatibility 40 | RobustL1Loss = robust_l1_loss 41 | RobustL2Loss = robust_l2_loss 42 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a4a721fb8f0996350edd4fb224d89ce213f7f93.poscar: -------------------------------------------------------------------------------- 1 | -.17811622E+03 '026/ht.task.gamma--default.0a4a721fb8f0996350edd4fb224d89ce213f7f93.cleanup.3.unclaimed.3.finished/ht.run.2016-05-11_20.19.51' ['(Tag) original structure: Al7C3N3', '(Tag) original structure path: structures/all/matches/ternaries/36/13_36_3aA3aB7aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 1.739255 -2.701998 0.000000 4 | 1.739255 2.701998 0.000000 5 | 0.000000 0.000000 44.970811 6 | Ti Zn N 7 | 14 6 6 8 | Direct 9 | 0.65303539 0.34696461 0.05235786 10 | 0.34696461 0.65303539 0.55235786 11 | 0.01589023 0.98410977 0.10015629 12 | 0.98410977 0.01589023 0.60015629 13 | 0.65177435 0.34822565 0.14624567 14 | 0.34822565 0.65177435 0.64624567 15 | 0.06378092 0.93621908 0.26386996 16 | 0.93621908 0.06378092 0.76386996 17 | 0.70119409 0.29880591 0.30755386 18 | 0.29880591 0.70119409 0.80755386 19 | 0.97953586 0.02046414 0.41091222 20 | 0.02046414 0.97953586 0.91091222 21 | 0.63953190 0.36046810 0.46581246 22 | 0.36046810 0.63953190 0.96581246 23 | 0.00687154 0.99312846 0.00904892 24 | 0.99312846 0.00687154 0.50904892 25 | 0.71355352 0.28644648 0.21967736 26 | 0.28644648 0.71355352 0.71967736 27 | 0.08167077 0.91832923 0.35281686 28 | 0.91832923 0.08167077 0.85281686 29 | 0.66270201 0.33729799 0.09953666 30 | 0.33729799 0.66270201 0.59953666 31 | 0.00266910 0.99733090 0.14931080 32 | 0.99733090 0.00266910 0.64931080 33 | 0.62990179 0.37009821 0.42280110 34 | 0.37009821 0.62990179 0.92280110 35 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c429f98b79c9522a11d51fbfae484004f899422.poscar: -------------------------------------------------------------------------------- 1 | -.17580398E+03 '026/ht.task.triolith--default.0c429f98b79c9522a11d51fbfae484004f899422.cleanup.0.unclaimed.4.finished/ht.run.2015-12-09_10.07.49' ['(Tag) original structure: CeCr2O10', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/11/8_11_1eA1fB2eC4fC.cif'] 2 | 1 3 | 4.103714 0.000000 -0.700673 4 | 0.000000 12.972733 0.000000 5 | -0.424296 0.000000 6.461036 6 | Zr Zn N 7 | 2 4 20 8 | Direct 9 | 0.03289830 0.25000000 0.13910651 10 | 0.96710170 0.75000000 0.86089349 11 | 0.40925565 0.04150102 0.79236623 12 | 0.59074435 0.95849898 0.20763377 13 | 0.59074435 0.54150102 0.20763377 14 | 0.40925565 0.45849898 0.79236623 15 | 0.87785521 0.97059168 0.98812405 16 | 0.12214479 0.02940832 0.01187595 17 | 0.12214479 0.47059168 0.01187595 18 | 0.87785521 0.52940832 0.98812405 19 | 0.79856089 0.41140412 0.37612110 20 | 0.20143911 0.58859588 0.62387890 21 | 0.20143911 0.91140412 0.62387890 22 | 0.79856089 0.08859588 0.37612110 23 | 0.28692864 0.84017359 0.21926568 24 | 0.71307136 0.15982641 0.78073432 25 | 0.71307136 0.34017359 0.78073432 26 | 0.28692864 0.65982641 0.21926568 27 | 0.92927501 0.88310609 0.58519048 28 | 0.07072499 0.11689391 0.41480952 29 | 0.07072499 0.38310609 0.41480952 30 | 0.92927501 0.61689391 0.58519048 31 | 0.49121727 0.25000000 0.15385230 32 | 0.50878273 0.75000000 0.84614770 33 | 0.35117108 0.75000000 0.24759333 34 | 0.64882892 0.25000000 0.75240667 35 | -------------------------------------------------------------------------------- /examples/inputs/raw/001a938f944fa33e899b81e3d0a909876587e7b0.poscar: -------------------------------------------------------------------------------- 1 | -.19821513E+03 '028/ht.task.triolith--default.001a938f944fa33e899b81e3d0a909876587e7b0.cleanup.0.unclaimed.2.finished/ht.run.2016-04-26_03.52.10' ['(Tag) original structure: AuN12Rb', '(Tag) original structure path: structures/all/matches/ternaries/15/8_15_1eA1eB6fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.380312 -6.691171 -0.394081 4 | 5.380312 6.691171 -0.394081 5 | -0.621649 0.000000 7.622393 6 | Zn Hf N 7 | 2 2 24 8 | Direct 9 | 0.43380935 0.56619065 0.25000000 10 | 0.56619065 0.43380935 0.75000000 11 | 0.04124018 0.95875982 0.25000000 12 | 0.95875982 0.04124018 0.75000000 13 | 0.15480179 0.73674461 0.24234883 14 | 0.26325539 0.84519821 0.25765117 15 | 0.84519821 0.26325539 0.75765117 16 | 0.73674461 0.15480179 0.74234883 17 | 0.30916777 0.45162915 0.40555707 18 | 0.54837085 0.69083223 0.09444293 19 | 0.69083223 0.54837085 0.59444293 20 | 0.45162915 0.30916777 0.90555707 21 | 0.00874248 0.11999663 0.47949900 22 | 0.88000337 0.99125752 0.02050100 23 | 0.99125752 0.88000337 0.52050100 24 | 0.11999663 0.00874248 0.97949900 25 | 0.82331917 0.56703037 0.21133745 26 | 0.43296963 0.17668083 0.28866255 27 | 0.17668083 0.43296963 0.78866255 28 | 0.56703037 0.82331917 0.71133745 29 | 0.82021728 0.43928951 0.16649660 30 | 0.56071049 0.17978272 0.33350340 31 | 0.17978272 0.56071049 0.83350340 32 | 0.43928951 0.82021728 0.66649660 33 | 0.02591212 0.12246309 0.04628719 34 | 0.87753691 0.97408788 0.45371281 35 | 0.97408788 0.87753691 0.95371281 36 | 0.12246309 0.02591212 0.54628719 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/0039f3c461ce4a2cbf4c48b09bfc3c6b13a5c769.poscar: -------------------------------------------------------------------------------- 1 | -.23404935E+03 '028/ht.task.triolith--default.0039f3c461ce4a2cbf4c48b09bfc3c6b13a5c769.cleanup.0.unclaimed.2.finished/ht.run.2016-04-25_15.56.15' ['(Tag) original structure: O4SiZn2', '(Tag) original structure path: structures/all/matches/ternaries/61/4_61_1aA1cB2cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.306134 0.000000 0.000000 4 | -0.000000 5.455820 0.000000 5 | 0.000000 0.000000 11.449448 6 | Hf Zn N 7 | 8 4 16 8 | Direct 9 | 0.00077530 0.08033905 0.34704739 10 | 0.49922470 0.91966095 0.84704739 11 | 0.99922470 0.91966095 0.65295261 12 | 0.49922470 0.58033905 0.34704739 13 | 0.00077530 0.41966095 0.84704739 14 | 0.50077530 0.08033905 0.15295261 15 | 0.50077530 0.41966095 0.65295261 16 | 0.99922470 0.58033905 0.15295261 17 | -0.00000000 0.00000000 -0.00000000 18 | 0.50000000 0.00000000 0.50000000 19 | 0.50000000 0.50000000 -0.00000000 20 | -0.00000000 0.50000000 0.50000000 21 | 0.19703661 0.30428541 0.05279103 22 | 0.30296339 0.69571459 0.55279103 23 | 0.80296339 0.69571459 0.94720897 24 | 0.30296339 0.80428541 0.05279103 25 | 0.19703661 0.19571459 0.55279103 26 | 0.69703661 0.30428541 0.44720897 27 | 0.69703661 0.19571459 0.94720897 28 | 0.80296339 0.80428541 0.44720897 29 | 0.86819053 0.95152275 0.18764705 30 | 0.63180947 0.04847725 0.68764705 31 | 0.13180947 0.04847725 0.81235295 32 | 0.63180947 0.45152275 0.18764705 33 | 0.86819053 0.54847725 0.68764705 34 | 0.36819053 0.95152275 0.31235295 35 | 0.36819053 0.54847725 0.81235295 36 | 0.13180947 0.45152275 0.31235295 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/0abac62147275aa8199465764751eaa67c94d16c.poscar: -------------------------------------------------------------------------------- 1 | -.15943383E+03 '028/ht.task.gamma--default.0abac62147275aa8199465764751eaa67c94d16c.cleanup.0.unclaimed.3.finished/ht.run.2015-11-16_23.35.20' ['(Tag) original structure: S4Tm2Zn', '(Tag) original structure path: structures/all/matches/ternaries/62/6_62_1aA1cA1cB1dC2cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.225880 0.000000 0.000000 4 | 0.000000 6.241736 -0.000000 5 | 0.000000 -0.000000 4.996117 6 | Zn Ti N 7 | 8 4 16 8 | Direct 9 | -0.00000000 0.00000000 -0.00000000 10 | 0.50000000 0.50000000 0.50000000 11 | 0.50000000 0.00000000 0.50000000 12 | -0.00000000 0.50000000 -0.00000000 13 | 0.27439835 0.25000000 0.02027797 14 | 0.22560165 0.75000000 0.52027797 15 | 0.72560165 0.75000000 0.97972203 16 | 0.77439835 0.25000000 0.47972203 17 | 0.10637322 0.25000000 0.56181050 18 | 0.39362678 0.75000000 0.06181050 19 | 0.89362678 0.75000000 0.43818950 20 | 0.60637322 0.25000000 0.93818950 21 | 0.09067969 0.25000000 0.17405839 22 | 0.40932031 0.75000000 0.67405839 23 | 0.90932031 0.75000000 0.82594161 24 | 0.59067969 0.25000000 0.32594161 25 | 0.43282869 0.25000000 0.79687706 26 | 0.06717131 0.75000000 0.29687706 27 | 0.56717131 0.75000000 0.20312294 28 | 0.93282869 0.25000000 0.70312294 29 | 0.33440626 0.99742074 0.26413801 30 | 0.16559374 0.49742074 0.76413801 31 | 0.16559374 0.00257926 0.76413801 32 | 0.66559374 0.49742074 0.73586199 33 | 0.66559374 0.00257926 0.73586199 34 | 0.83440626 0.50257926 0.23586199 35 | 0.83440626 0.99742074 0.23586199 36 | 0.33440626 0.50257926 0.26413801 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/000866a75fbc0a305808c11c58342bb973c3cad5.poscar: -------------------------------------------------------------------------------- 1 | -.12200503E+03 '028/ht.task.triolith--default.000866a75fbc0a305808c11c58342bb973c3cad5.cleanup.0.unclaimed.2.finished/ht.run.2016-04-25_22.05.20' ['(Tag) original structure: Mo4NNi2', '(Tag) original structure path: structures/all/matches/ternaries/227/4_227_1cA1dB1eC1fB.cif', '(Tag) type: ternary'] 2 | 1 3 | -6.117621 -6.117621 0.000000 4 | -6.117621 0.000000 -6.117621 5 | 0.000000 -6.117621 -6.117621 6 | N Zn Hf 7 | 4 16 8 8 | Direct 9 | 0.00000000 -0.00000000 0.00000000 10 | 0.00000000 0.50000000 0.00000000 11 | 0.50000000 -0.00000000 0.00000000 12 | 0.00000000 -0.00000000 0.50000000 13 | 0.50000000 0.50000000 0.50000000 14 | 0.50000000 -0.00000000 0.50000000 15 | 0.00000000 0.50000000 0.50000000 16 | 0.50000000 0.50000000 0.00000000 17 | 0.67137393 0.67137393 0.07862607 18 | 0.32862607 0.92137393 0.92137393 19 | 0.92137393 0.92137393 0.32862607 20 | 0.07862607 0.67137393 0.67137393 21 | 0.67137393 0.07862607 0.07862607 22 | 0.32862607 0.32862607 0.92137393 23 | 0.07862607 0.07862607 0.67137393 24 | 0.67137393 0.07862607 0.67137393 25 | 0.32862607 0.92137393 0.32862607 26 | 0.92137393 0.32862607 0.92137393 27 | 0.07862607 0.67137393 0.07862607 28 | 0.92137393 0.32862607 0.32862607 29 | 0.71250038 0.71250038 0.71250038 30 | 0.28749962 0.63750113 0.28749962 31 | 0.63750113 0.28749962 0.28749962 32 | 0.36249887 0.71250038 0.71250038 33 | 0.71250038 0.36249887 0.71250038 34 | 0.28749962 0.28749962 0.63750113 35 | 0.28749962 0.28749962 0.28749962 36 | 0.71250038 0.71250038 0.36249887 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/00a8d0f6f55191596770605314ff347995a3540d.poscar: -------------------------------------------------------------------------------- 1 | -.12200932E+03 '028/ht.task.triolith--default.00a8d0f6f55191596770605314ff347995a3540d.cleanup.0.unclaimed.2.finished/ht.run.2016-04-25_22.41.50' ['(Tag) original structure: Co2NZr4', '(Tag) original structure path: structures/all/matches/ternaries/227/4_227_1cA1dB1eC1fA.cif', '(Tag) type: ternary'] 2 | 1 3 | -6.116632 -6.116632 -0.000000 4 | -6.116632 -0.000000 -6.116632 5 | -0.000000 -6.116632 -6.116632 6 | Zn Hf N 7 | 16 8 4 8 | Direct 9 | 0.00000000 0.00000000 -0.00000000 10 | 0.00000000 0.50000000 -0.00000000 11 | 0.50000000 0.00000000 -0.00000000 12 | 0.00000000 0.00000000 0.50000000 13 | 0.57230170 0.57230170 0.17769830 14 | 0.42769830 0.82230170 0.82230170 15 | 0.82230170 0.82230170 0.42769830 16 | 0.17769830 0.57230170 0.57230170 17 | 0.57230170 0.17769830 0.17769830 18 | 0.42769830 0.42769830 0.82230170 19 | 0.17769830 0.17769830 0.57230170 20 | 0.57230170 0.17769830 0.57230170 21 | 0.42769830 0.82230170 0.42769830 22 | 0.82230170 0.42769830 0.82230170 23 | 0.17769830 0.57230170 0.17769830 24 | 0.82230170 0.42769830 0.42769830 25 | 0.78788573 0.78788573 0.78788573 26 | 0.21211427 0.86365720 0.21211427 27 | 0.86365720 0.21211427 0.21211427 28 | 0.13634280 0.78788573 0.78788573 29 | 0.78788573 0.13634280 0.78788573 30 | 0.21211427 0.21211427 0.86365720 31 | 0.21211427 0.21211427 0.21211427 32 | 0.78788573 0.78788573 0.13634280 33 | 0.50000000 0.50000000 0.50000000 34 | 0.50000000 0.00000000 0.50000000 35 | 0.00000000 0.50000000 0.50000000 36 | 0.50000000 0.50000000 -0.00000000 37 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ab8ebf80d94a1ba826857a9e81cea8c7124d86e.poscar: -------------------------------------------------------------------------------- 1 | -.21135696E+03 '030/ht.task.gamma--default.0ab8ebf80d94a1ba826857a9e81cea8c7124d86e.cleanup.3.unclaimed.3.finished/ht.run.2016-06-20_16.32.26' ['(Tag) original structure: H4O8Ti3', '(Tag) original structure path: structures/all/matches/ternaries/12/15_12_3iA4iB8iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 14.680227 -1.507597 -0.760453 4 | 14.680227 1.507597 -0.760453 5 | -7.316016 0.000000 9.877532 6 | Ti Zn N 7 | 8 6 16 8 | Direct 9 | 0.36914484 0.36914484 0.97637487 10 | 0.63085516 0.63085516 0.02362513 11 | 0.48270716 0.48270716 0.65034958 12 | 0.51729284 0.51729284 0.34965042 13 | 0.29723292 0.29723292 0.67782370 14 | 0.70276708 0.70276708 0.32217630 15 | 0.90549748 0.90549748 0.30384910 16 | 0.09450252 0.09450252 0.69615090 17 | 0.24540401 0.24540401 0.18401970 18 | 0.75459599 0.75459599 0.81598030 19 | 0.16410137 0.16410137 0.49746173 20 | 0.83589863 0.83589863 0.50253827 21 | 0.20389767 0.20389767 0.90895922 22 | 0.79610233 0.79610233 0.09104078 23 | 0.16237373 0.16237373 0.99331877 24 | 0.83762627 0.83762627 0.00668123 25 | 0.35419025 0.35419025 0.23475963 26 | 0.64580975 0.64580975 0.76524037 27 | 0.15175152 0.15175152 0.18211561 28 | 0.84824848 0.84824848 0.81788439 29 | 0.29589949 0.29589949 0.52234703 30 | 0.70410051 0.70410051 0.47765297 31 | 0.07845619 0.07845619 0.41454019 32 | 0.92154381 0.92154381 0.58545981 33 | 0.23587806 0.23587806 0.70148934 34 | 0.76412194 0.76412194 0.29851066 35 | 0.04432356 0.04432356 0.75256226 36 | 0.95567644 0.95567644 0.24743774 37 | 0.93768367 0.93768367 0.50017113 38 | 0.06231633 0.06231633 0.49982887 39 | -------------------------------------------------------------------------------- /examples/inputs/raw/0019407b910a29730999dab65491b4e50624c219.poscar: -------------------------------------------------------------------------------- 1 | -.18172455E+03 '030/ht.task.triolith--default.0019407b910a29730999dab65491b4e50624c219.cleanup.0.unclaimed.2.finished/ht.run.2016-04-15_10.15.11' ['(Tag) original structure: K2O9Si4', '(Tag) original structure path: structures/all/matches/ternaries/176/5_176_1bA1fB1hC1hA1iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.370680 -3.100763 0.000000 4 | -0.000000 6.201526 0.000000 5 | 0.000000 0.000000 10.101875 6 | Hf Zn N 7 | 4 8 18 8 | Direct 9 | 0.33333333 0.66666667 0.07050127 10 | 0.66666667 0.33333333 0.57050127 11 | 0.33333333 0.66666667 0.42949873 12 | 0.66666667 0.33333333 0.92949873 13 | 0.00000000 0.00000000 -0.00000000 14 | 0.00000000 0.00000000 0.50000000 15 | 0.34920449 0.19805089 0.25000000 16 | 0.19805089 0.84884640 0.75000000 17 | 0.84884640 0.65079551 0.25000000 18 | 0.15115360 0.34920449 0.75000000 19 | 0.80194911 0.15115360 0.25000000 20 | 0.65079551 0.80194911 0.75000000 21 | 0.43067355 0.93955068 0.25000000 22 | 0.93955068 0.50887713 0.75000000 23 | 0.50887713 0.56932645 0.25000000 24 | 0.49112287 0.43067355 0.75000000 25 | 0.06044932 0.49112287 0.25000000 26 | 0.56932645 0.06044932 0.75000000 27 | 0.34135426 0.31478008 0.07552230 28 | 0.31478008 0.97342582 0.57552230 29 | 0.34135426 0.31478008 0.42447770 30 | 0.97342582 0.65864574 0.42447770 31 | 0.02657418 0.34135426 0.92447770 32 | 0.68521992 0.02657418 0.42447770 33 | 0.31478008 0.97342582 0.92447770 34 | 0.65864574 0.68521992 0.57552230 35 | 0.65864574 0.68521992 0.92447770 36 | 0.68521992 0.02657418 0.07552230 37 | 0.97342582 0.65864574 0.07552230 38 | 0.02657418 0.34135426 0.57552230 39 | -------------------------------------------------------------------------------- /examples/inputs/raw/0c0b4fd84c8f4e075436d7ec34bfb3fad5e610be.poscar: -------------------------------------------------------------------------------- 1 | -.21275901E+03 '030/ht.task.triolith--default.0c0b4fd84c8f4e075436d7ec34bfb3fad5e610be.cleanup.0.unclaimed.4.finished/ht.run.2015-12-07_19.52.24' ['(Tag) original structure: MnO3Yb', '(Tag) type: ternary', '(Tag) original structure path: structures/all/matches/ternaries/185/7_185_1aA1aB1bA1bB1cC2cA.cif'] 2 | 1 3 | 5.543294 -3.200422 0.000000 4 | -0.000000 6.400844 -0.000000 5 | 0.000000 0.000000 10.157681 6 | Zn Zr N 7 | 6 6 18 8 | Direct 9 | -0.00000000 -0.00000000 0.26067791 10 | -0.00000000 -0.00000000 0.76067791 11 | 0.33333333 0.66666667 0.22343397 12 | 0.33333333 0.66666667 0.72343397 13 | 0.66666667 0.33333333 0.72343397 14 | 0.66666667 0.33333333 0.22343397 15 | 0.34838803 1.00000000 0.98014843 16 | 0.65161197 0.00000000 0.48014843 17 | 1.00000000 0.34838803 0.98014843 18 | 0.34838803 0.34838803 0.48014843 19 | 0.65161197 0.65161197 0.98014843 20 | 0.00000000 0.65161197 0.48014843 21 | 0.29706565 1.00000000 0.20897943 22 | 0.70293435 0.00000000 0.70897943 23 | -0.00000000 0.29706565 0.20897943 24 | 0.29706565 0.29706565 0.70897943 25 | 0.70293435 0.70293435 0.20897943 26 | -0.00000000 0.70293435 0.70897943 27 | 0.44196668 0.00000000 0.31072121 28 | 0.55803332 1.00000000 0.81072121 29 | -0.00000000 0.44196668 0.31072121 30 | 0.44196668 0.44196668 0.81072121 31 | 0.55803332 0.55803332 0.31072121 32 | -0.00000000 0.55803332 0.81072121 33 | -0.00000000 -0.00000000 0.46416939 34 | -0.00000000 -0.00000000 0.96416939 35 | 0.33333333 0.66666667 0.02996876 36 | 0.33333333 0.66666667 0.52996876 37 | 0.66666667 0.33333333 0.52996876 38 | 0.66666667 0.33333333 0.02996876 39 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | from matminer.datasets import load_dataset 6 | from pymatgen.analysis.prototypes import get_protostructure_label_from_spglib 7 | 8 | torch.manual_seed(0) # ensure reproducible results (applies to all tests) 9 | 10 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def df_matbench_phonons(): 15 | """Returns the dataframe for the Matbench phonon DOS peak task.""" 16 | 17 | df = load_dataset("matbench_phonons") 18 | df["material_id"] = [f"mb_phdos_{idx + 1}" for idx in range(len(df))] 19 | df = df.set_index("material_id", drop=False) 20 | df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure] 21 | 22 | df["phdos_clf"] = [1 if x > 450 else 0 for x in df["last phdos peak"]] 23 | 24 | return df 25 | 26 | 27 | @pytest.fixture(scope="session") 28 | def df_matbench_jdft2d(): 29 | """Returns Matbench experimental band gap task dataframe. Currently unused.""" 30 | 31 | df = load_dataset("matbench_jdft2d") 32 | df["material_id"] = [f"mb_jdft2d_{idx + 1}" for idx in range(len(df))] 33 | df = df.set_index("material_id", drop=False) 34 | df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure] 35 | 36 | df["protostructure"] = df.structure.map(get_protostructure_label_from_spglib) 37 | 38 | return df 39 | 40 | 41 | @pytest.fixture(scope="session") 42 | def df_matbench_phonons_wyckoff(df_matbench_phonons): 43 | """Getting Aflow labels is expensive so we split into a separate fixture to avoid 44 | paying for it unless requested. 45 | """ 46 | df_matbench_phonons["protostructure"] = df_matbench_phonons.structure.map( 47 | get_protostructure_label_from_spglib 48 | ) 49 | 50 | return df_matbench_phonons 51 | -------------------------------------------------------------------------------- /examples/inputs/raw/00a3259f00f8a9aa503caae2045970d2f7b22e13.poscar: -------------------------------------------------------------------------------- 1 | -.22155319E+03 '031/ht.task.triolith--default.00a3259f00f8a9aa503caae2045970d2f7b22e13.cleanup.0.unclaimed.2.finished/ht.run.2016-01-28_12.35.41' ['(Tag) original structure: O19Sr5Ti7', '(Tag) original structure path: structures/all/matches/ternaries/47/14_47_1aA1eB1fC1qA1rB1rA1tC2qB2sC3vB.cif', '(Tag) type: ternary'] 2 | 1 3 | 4.184151 -0.000000 0.000000 4 | -0.000000 7.172117 0.000000 5 | 0.000000 0.000000 25.251255 6 | Hf N Zn 7 | 7 19 5 8 | Direct 9 | 0.50000000 0.00000000 0.25755200 10 | 0.50000000 0.00000000 0.74244800 11 | 0.50000000 0.50000000 0.17987301 12 | 0.50000000 0.50000000 0.82012699 13 | 0.50000000 0.00000000 0.09174676 14 | 0.50000000 0.00000000 0.90825324 15 | 0.50000000 0.50000000 -0.00000000 16 | 0.00000000 0.00000000 0.27199352 17 | 0.00000000 0.00000000 0.72800658 18 | 0.50000000 0.78477034 0.20603554 19 | 0.50000000 0.78477034 0.79396446 20 | 0.50000000 0.21522966 0.79396446 21 | 0.50000000 0.21522966 0.20603554 22 | 0.00000000 0.50000000 0.15419995 23 | -0.00000000 0.50000000 0.84580005 24 | 0.50000000 0.26138331 0.13101457 25 | 0.50000000 0.26138331 0.86898543 26 | 0.50000000 0.73861669 0.86898543 27 | 0.50000000 0.73861669 0.13101457 28 | -0.00000000 0.00000000 0.11651130 29 | 0.00000000 0.00000000 0.88348870 30 | 0.50000000 0.25939093 0.05225807 31 | 0.50000000 0.25939093 0.94774193 32 | 0.50000000 0.74060907 0.94774193 33 | 0.50000000 0.74060907 0.05225807 34 | -0.00000000 0.50000000 -0.00000000 35 | 0.00000000 0.00000000 0.19141435 36 | 0.00000000 0.00000000 0.80858565 37 | 0.00000000 0.50000000 0.07865654 38 | -0.00000000 0.50000000 0.92134346 39 | -0.00000000 0.00000000 -0.00000000 40 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a2c4708d3c853ea3c4b8630027afe9b788ee2ce.poscar: -------------------------------------------------------------------------------- 1 | -.23520684E+03 '032/ht.task.gamma--default.0a2c4708d3c853ea3c4b8630027afe9b788ee2ce.cleanup.0.unclaimed.3.finished/ht.run.2016-05-02_10.51.25' ['(Tag) original structure: Cu2O5P', '(Tag) original structure path: structures/all/matches/ternaries/58/7_58_1fA1gA1gB1hC3gC.cif', '(Tag) type: ternary'] 2 | 1 3 | 8.736347 -0.000000 0.000000 4 | 0.000000 8.231356 0.000000 5 | 0.000000 0.000000 6.998135 6 | N Zn Ti 7 | 20 4 8 8 | Direct 9 | 0.04933257 0.15515104 0.00000000 10 | 0.45066743 0.65515104 0.50000000 11 | 0.95066743 0.84484896 -0.00000000 12 | 0.54933257 0.34484896 0.50000000 13 | 0.06055200 0.97872251 0.50000000 14 | 0.43944800 0.47872251 0.00000000 15 | 0.93944800 0.02127749 0.50000000 16 | 0.56055200 0.52127749 0.00000000 17 | 0.14268832 0.43858699 0.50000000 18 | 0.35731168 0.93858699 -0.00000000 19 | 0.85731168 0.56141301 0.50000000 20 | 0.64268832 0.06141301 0.00000000 21 | 0.36196035 0.18003421 0.24899865 22 | 0.13803965 0.68003421 0.74899865 23 | 0.63803965 0.81996579 0.75100135 24 | 0.86196035 0.31996579 0.25100135 25 | 0.13803965 0.68003421 0.25100135 26 | 0.63803965 0.81996579 0.24899865 27 | 0.86196035 0.31996579 0.74899865 28 | 0.36196035 0.18003421 0.75100135 29 | 0.32947201 0.28836426 0.50000000 30 | 0.17052799 0.78836426 0.00000000 31 | 0.67052799 0.71163574 0.50000000 32 | 0.82947201 0.21163574 0.00000000 33 | 0.00000000 0.50000000 0.30022293 34 | 0.50000000 -0.00000000 0.80022293 35 | 0.00000000 0.50000000 0.69977707 36 | 0.50000000 -0.00000000 0.19977707 37 | 0.26884529 0.16402866 0.00000000 38 | 0.23115471 0.66402866 0.50000000 39 | 0.73115471 0.83597134 -0.00000000 40 | 0.76884529 0.33597134 0.50000000 41 | -------------------------------------------------------------------------------- /examples/matbench_example/prepare_matbench_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from matbench.data_ops import load 5 | from pymatgen.analysis.prototypes import get_protostructure_label_from_spglib 6 | from tqdm import tqdm 7 | 8 | tqdm.pandas() 9 | 10 | current_dir = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | SMOKE_TEST = True 13 | 14 | matbench_datasets = [ 15 | "matbench_steels", 16 | "matbench_jdft2d", 17 | "matbench_phonons", 18 | "matbench_expt_gap", 19 | "matbench_dielectric", 20 | "matbench_expt_is_metal", 21 | "matbench_glass", 22 | "matbench_log_gvrh", 23 | "matbench_log_kvrh", 24 | "matbench_perovskites", 25 | "matbench_mp_gap", 26 | "matbench_mp_is_metal", 27 | "matbench_mp_e_form", 28 | ] 29 | 30 | if SMOKE_TEST: 31 | matbench_datasets = matbench_datasets[2:3] 32 | 33 | os.makedirs(f"{current_dir}/datasets", exist_ok=True) 34 | for dataset in matbench_datasets: 35 | dataset_path = f"{current_dir}/datasets/{dataset}.json.bz2" 36 | 37 | if os.path.exists(dataset_path): 38 | print(f"Dataset {dataset} already exists, skipping") 39 | continue 40 | 41 | df = load(dataset) 42 | 43 | if "structure" in df: 44 | df["composition"] = [struct.formula for struct in df.structure] 45 | df["protostructure"] = df["structure"].progress_apply( 46 | get_protostructure_label_from_spglib 47 | ) 48 | else: 49 | raise ValueError("No structure or composition column found") 50 | 51 | df.to_json( 52 | dataset_path, 53 | default_handler=lambda x: x.as_dict(), 54 | ) 55 | 56 | 57 | DATA_PATHS = { 58 | path.split("/")[-1].split(".")[0]: path 59 | for path in glob(f"{current_dir}/datasets/matbench_*.json.bz2") 60 | } 61 | 62 | assert len(DATA_PATHS) == len(matbench_datasets), ( 63 | f"glob found {len(DATA_PATHS)} data sets, expected {len(matbench_datasets)}" 64 | ) 65 | -------------------------------------------------------------------------------- /examples/inputs/raw/00cca7ad273875ba80f1d437af797b6a01cc6ebd.poscar: -------------------------------------------------------------------------------- 1 | -.28166871E+03 '034/ht.task.triolith--default.00cca7ad273875ba80f1d437af797b6a01cc6ebd.cleanup.0.unclaimed.2.finished/ht.run.2015-12-28_06.56.02' ['(Tag) original structure: AuK3Se13', '(Tag) original structure path: structures/all/matches/ternaries/13/10_13_1dA1eB1fC1gB6gC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.962471 0.000000 0.122473 4 | 0.000000 3.175217 0.000000 5 | -0.573407 0.000000 9.709596 6 | Zn N Hf 7 | 2 26 6 8 | Direct 9 | 0.50000000 0.00000000 -0.00000000 10 | 0.50000000 0.00000000 0.50000000 11 | 0.67360320 0.18759091 0.93174711 12 | 0.67360320 0.81240909 0.43174711 13 | 0.32639680 0.81240909 0.06825289 14 | 0.32639680 0.18759091 0.56825289 15 | 0.75558845 0.23022353 0.03949303 16 | 0.75558845 0.76977647 0.53949303 17 | 0.24441155 0.76977647 0.96050697 18 | 0.24441155 0.23022353 0.46050697 19 | 0.12699532 0.23002080 0.47579975 20 | 0.12699532 0.76997920 0.97579975 21 | 0.87300468 0.76997920 0.52420025 22 | 0.87300468 0.23002080 0.02420025 23 | 0.07990668 0.20674338 0.60169736 24 | 0.07990668 0.79325662 0.10169736 25 | 0.92009332 0.79325662 0.39830264 26 | 0.92009332 0.20674338 0.89830264 27 | 0.14854564 0.29132038 0.73090146 28 | 0.14854564 0.70867962 0.23090146 29 | 0.85145436 0.70867962 0.26909854 30 | 0.85145436 0.29132038 0.76909854 31 | 0.41976842 0.75041369 0.33129433 32 | 0.41976842 0.24958631 0.83129433 33 | 0.58023158 0.24958631 0.66870567 34 | 0.58023158 0.75041369 0.16870567 35 | 0.50000000 0.45635328 0.75000000 36 | 0.50000000 0.54364672 0.25000000 37 | -0.00000000 0.23886969 0.25000000 38 | -0.00000000 0.76113031 0.75000000 39 | 0.69243328 0.26495359 0.26261208 40 | 0.69243328 0.73504641 0.76261208 41 | 0.30756672 0.73504641 0.73738792 42 | 0.30756672 0.26495359 0.23738792 43 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a946dc02260d0d4b36ef2087e60b415a5fff800.poscar: -------------------------------------------------------------------------------- 1 | -.29008227E+03 '036/ht.task.gamma--default.0a946dc02260d0d4b36ef2087e60b415a5fff800.cleanup.0.unclaimed.3.finished/ht.run.2015-11-16_19.23.17' ['(Tag) original structure: Cl5PbTl3', '(Tag) original structure path: structures/all/matches/ternaries/76/9_76_1aA3aB5aC.cif', '(Tag) type: ternary'] 2 | 1 3 | 5.839669 0.000000 0.000000 4 | -0.000000 5.839669 0.000000 5 | 0.000000 0.000000 10.063436 6 | Ti Zn N 7 | 12 4 20 8 | Direct 9 | 0.96905875 0.80718338 0.81186643 10 | 0.03094125 0.19281662 0.31186643 11 | 0.19281662 0.96905875 0.06186643 12 | 0.80718338 0.03094125 0.56186643 13 | 0.10899183 0.34228606 0.88023088 14 | 0.89100817 0.65771394 0.38023088 15 | 0.65771394 0.10899183 0.13023088 16 | 0.34228606 0.89100817 0.63023088 17 | 0.60631294 0.18018496 0.83510905 18 | 0.39368706 0.81981504 0.33510905 19 | 0.81981504 0.60631294 0.08510905 20 | 0.18018496 0.39368706 0.58510905 21 | 0.48525695 0.69477963 0.89229795 22 | 0.51474305 0.30522037 0.39229795 23 | 0.30522037 0.48525695 0.14229795 24 | 0.69477963 0.51474305 0.64229795 25 | 0.34294622 0.11501552 0.23713654 26 | 0.65705378 0.88498448 0.73713654 27 | 0.88498448 0.34294622 0.48713654 28 | 0.11501552 0.65705378 0.98713654 29 | 0.73779301 0.97899975 0.33796876 30 | 0.26220699 0.02100025 0.83796876 31 | 0.02100025 0.73779301 0.58796876 32 | 0.97899975 0.26220699 0.08796876 33 | 0.00065757 0.82328758 0.21040091 34 | 0.99934243 0.17671242 0.71040091 35 | 0.17671242 0.00065757 0.46040091 36 | 0.82328758 0.99934243 0.96040091 37 | 0.55444204 0.73450842 0.50475735 38 | 0.44555796 0.26549158 0.00475735 39 | 0.26549158 0.55444204 0.75475735 40 | 0.73450842 0.44555796 0.25475735 41 | 0.51196313 0.78128955 0.10023213 42 | 0.48803687 0.21871045 0.60023213 43 | 0.21871045 0.51196313 0.35023213 44 | 0.78128955 0.48803687 0.85023213 45 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ab8cc7e5cc3d48fcc786f9060da5a8882374f44.poscar: -------------------------------------------------------------------------------- 1 | -.22445649E+03 '036/ht.task.gamma--default.0ab8cc7e5cc3d48fcc786f9060da5a8882374f44.cleanup.0.unclaimed.3.finished/ht.run.2016-06-23_23.03.48' ['(Tag) original structure: Al2Ca3N4', '(Tag) original structure path: structures/all/matches/ternaries/14/9_14_2eA3eB4eC.cif', '(Tag) type: ternary'] 2 | 1 3 | 9.402404 -0.001130 0.397882 4 | -0.000650 5.646741 0.000464 5 | -3.073322 0.000794 8.683129 6 | Zn Ti N 7 | 12 8 16 8 | Direct 9 | 0.05277259 0.04468460 0.22554858 10 | 0.94722716 0.95531565 0.77445205 11 | 0.94695245 0.54466241 0.27464315 12 | 0.05304916 0.45533967 0.72535521 13 | 0.18677287 0.65379652 0.12185680 14 | 0.81322239 0.34620769 0.87814188 15 | 0.81300997 0.15355040 0.37838871 16 | 0.18699246 0.84645005 0.62161212 17 | 0.42929711 0.36288193 0.28830955 18 | 0.57069576 0.63711251 0.71169070 19 | 0.57061516 0.86310666 0.21156040 20 | 0.42938961 0.13689561 0.78843907 21 | 0.22918593 0.15003641 0.98725724 22 | 0.77081261 0.84996283 0.01274258 23 | 0.77079612 0.65007753 0.51282933 24 | 0.22920238 0.34992259 0.48717364 25 | 0.61109232 0.38180476 0.12711576 26 | 0.38891406 0.61819165 0.87288127 27 | 0.38888351 0.88195244 0.37279498 28 | 0.61111673 0.11804393 0.62720759 29 | 0.02632592 0.24383232 0.39072107 30 | 0.97367511 0.75617134 0.60927740 31 | 0.97352280 0.74373049 0.10950746 32 | 0.02648239 0.25627025 0.89049600 33 | 0.25275843 0.02636994 0.18684925 34 | 0.74723814 0.97364108 0.81314111 35 | 0.74694060 0.52616949 0.31323453 36 | 0.25306032 0.47383160 0.68676848 37 | 0.36602370 0.41637224 0.03775171 38 | 0.63397686 0.58361155 0.96224839 39 | 0.63414571 0.91658293 0.46221345 40 | 0.36585057 0.08341002 0.53778712 41 | 0.71185292 0.07797699 0.14532722 42 | 0.28814492 0.92203290 0.85467531 43 | 0.28829662 0.57788891 0.35464532 44 | 0.71170460 0.42211213 0.64535556 45 | -------------------------------------------------------------------------------- /examples/matbench_example/readme.md: -------------------------------------------------------------------------------- 1 | # Matbench 2 | 3 | This directory contains the files needed to create Matbench submissions for Roostformer and Wrenformer (structure tasks only for Wren) which are rewrites of Roost and Wren using PyTorch's builtin `TransformerEncoder` instead of the custom self-attention modules used by Roost and Wren. 4 | 5 | Added in [aviary#44](https://github.com/CompRhys/aviary/pull/44). 6 | 7 | ## Speed difference between Wren and Wrenformer 8 | 9 | According to Rhys, Wren could run 500 epochs in 5.5 h on a P100 training on 120k samples of MP data (similar to the `matbench_mp_e_form` dataset with 132k samples). Wrenformer only managed 207 epochs in 4h on the more powerful A100 training on `matbench_mp_e_form`. However, to avoid out-of-memory issues, Rhys constrained Wren to only run on systems with <= 16 Wyckoff positions. The code below shows that this lightens the workload by a factor of about 7.5, likely explaining the apparent slowdown in Wrenformer. 10 | 11 | ```py 12 | import os 13 | import sys 14 | from pathlib import Path 15 | 16 | # Add the parent directory to system path 17 | sys.path.append(str(Path(__file__).parent.parent)) 18 | 19 | import pandas as pd 20 | from pymatgen.analysis.prototypes import count_wyckoff_positions 21 | from matbench_example import DATA_PATHS 22 | 23 | df = pd.read_json(DATA_PATHS["matbench_mp_e_form"]) 24 | 25 | df["n_wyckoff"] = df.wyckoff.map(count_wyckoff_positions) 26 | 27 | 28 | sum_wyckoffs_sqr = (df.n_wyckoff**2).sum() 29 | sum_wyckoffs_lte_16_sqr = (df.query("n_wyckoff <= 16").n_wyckoff ** 2).sum() 30 | print(f"{sum_wyckoffs_sqr=}") 31 | print(f"{sum_wyckoffs_lte_16_sqr=}") 32 | print(f"{sum_wyckoffs_sqr/sum_wyckoffs_lte_16_sqr=:.3}") 33 | # prints 7.45, so Wrenformer has to do 7.45x more work, explaining the about 2x slow down 34 | # on a more powerful GPU (Nvidia A100 vs Wren on a P100) 35 | ``` 36 | 37 | ## Benchmarks 38 | 39 | JSON files in `model_scores/` contain only the calculated scores (MAE/ROCAUC) for a given model run. Files with the same name in `model_preds/` contain the full set of model predictions, targets and material ids. Code for loading these into memory is in `make_plots.py`. 40 | -------------------------------------------------------------------------------- /examples/inputs/raw/00bada6787af6463447a1460da0c031d51494b60.poscar: -------------------------------------------------------------------------------- 1 | -.30469940E+03 '039/ht.task.triolith--default.00bada6787af6463447a1460da0c031d51494b60.cleanup.0.unclaimed.2.finished/ht.run.2016-01-28_08.04.36' ['(Tag) original structure: Ba3S8Ta2', '(Tag) original structure path: structures/all/matches/ternaries/10/17_10_1aA1cB1eA1mA1nA2mB2mC2nB2nC4oC.cif', '(Tag) type: ternary'] 2 | 1 3 | 9.521579 0.000000 0.373799 4 | 0.000000 4.706142 0.000000 5 | -2.621791 0.000000 9.869892 6 | Hf Zn N 7 | 9 6 24 8 | Direct 9 | -0.00000000 0.00000000 0.50000000 10 | 0.71433010 0.00000000 0.98271042 11 | 0.28566990 0.00000000 0.01728958 12 | 0.33416582 0.50000000 0.41311094 13 | 0.66583418 0.50000000 0.58688906 14 | 0.55172795 -0.00000000 0.35393491 15 | 0.44827205 0.00000000 0.64606509 16 | 0.96996488 0.50000000 0.19091014 17 | 0.03003512 0.50000000 0.80908986 18 | -0.00000000 0.00000000 0.00000000 19 | 0.50000000 0.50000000 0.00000000 20 | 0.18814200 0.00000000 0.27782706 21 | 0.81185800 0.00000000 0.72217294 22 | 0.67649413 0.50000000 0.24016345 23 | 0.32350587 0.50000000 0.75983655 24 | 0.73909176 0.50000000 0.06433623 25 | 0.26090824 0.50000000 0.93566377 26 | 0.36649012 0.00000000 0.41827029 27 | 0.63350988 0.00000000 0.58172971 28 | 0.56803545 0.50000000 0.38970358 29 | 0.43196455 0.50000000 0.61029642 30 | 0.82102394 0.00000000 0.23359175 31 | 0.17897606 -0.00000000 0.76640825 32 | 0.13781094 0.25687296 0.11731282 33 | 0.86218906 0.25687296 0.88268718 34 | 0.86218906 0.74312704 0.88268718 35 | 0.13781094 0.74312704 0.11731282 36 | 0.48082977 0.15574699 0.15015500 37 | 0.51917023 0.15574699 0.84984500 38 | 0.51917023 0.84425301 0.84984500 39 | 0.48082977 0.84425301 0.15015500 40 | 0.11588547 0.34907379 0.40667526 41 | 0.88411453 0.34907379 0.59332474 42 | 0.88411453 0.65092621 0.59332474 43 | 0.11588547 0.65092621 0.40667526 44 | 0.85246030 0.22762806 0.30914063 45 | 0.14753970 0.22762806 0.69085937 46 | 0.14753970 0.77237194 0.69085937 47 | 0.85246030 0.77237194 0.30914063 48 | -------------------------------------------------------------------------------- /examples/inputs/raw/0adeaf7a9c16772b66815cce19917ef0460affe8.poscar: -------------------------------------------------------------------------------- 1 | -.27195782E+03 '040/ht.task.gamma--default.0adeaf7a9c16772b66815cce19917ef0460affe8.cleanup.0.unclaimed.3.finished/ht.run.2016-04-13_19.08.23' ['(Tag) original structure: O3SrZr', '(Tag) original structure path: structures/all/matches/ternaries/198/6_198_2aA2aB2bC.cif', '(Tag) type: ternary'] 2 | 1 3 | 7.685200 0.000000 0.000000 4 | 0.000000 7.685200 0.000000 5 | 0.000000 0.000000 7.685200 6 | Ti Zn N 7 | 8 8 24 8 | Direct 9 | 0.00179104 0.00179104 0.00179104 10 | 0.49820896 0.99820896 0.50179104 11 | 0.50179104 0.49820896 0.99820896 12 | 0.99820896 0.50179104 0.49820896 13 | 0.47336587 0.47336587 0.47336587 14 | 0.02663413 0.52663413 0.97336587 15 | 0.97336587 0.02663413 0.52663413 16 | 0.52663413 0.97336587 0.02663413 17 | 0.29008367 0.29008367 0.29008367 18 | 0.20991633 0.70991633 0.79008367 19 | 0.79008367 0.20991633 0.70991633 20 | 0.70991633 0.79008367 0.20991633 21 | 0.75881592 0.75881592 0.75881592 22 | 0.74118408 0.24118408 0.25881592 23 | 0.25881592 0.74118408 0.24118408 24 | 0.24118408 0.25881592 0.74118408 25 | 0.23092657 0.11049583 0.08125003 26 | 0.41874997 0.76907343 0.61049583 27 | 0.58125003 0.26907343 0.88950417 28 | 0.91874997 0.73092657 0.38950417 29 | 0.11049583 0.08125003 0.23092657 30 | 0.38950417 0.91874997 0.73092657 31 | 0.76907343 0.61049583 0.41874997 32 | 0.08125003 0.23092657 0.11049583 33 | 0.26907343 0.88950417 0.58125003 34 | 0.73092657 0.38950417 0.91874997 35 | 0.61049583 0.41874997 0.76907343 36 | 0.88950417 0.58125003 0.26907343 37 | 0.75939082 0.03937817 0.09224158 38 | 0.40775842 0.24060918 0.53937817 39 | 0.59224158 0.74060918 0.96062183 40 | 0.90775842 0.25939082 0.46062183 41 | 0.03937817 0.09224158 0.75939082 42 | 0.46062183 0.90775842 0.25939082 43 | 0.24060918 0.53937817 0.40775842 44 | 0.09224158 0.75939082 0.03937817 45 | 0.74060918 0.96062183 0.59224158 46 | 0.25939082 0.46062183 0.90775842 47 | 0.53937817 0.40775842 0.24060918 48 | 0.96062183 0.59224158 0.74060918 49 | -------------------------------------------------------------------------------- /examples/inputs/raw/00ecadb66a8baab030675a817d2ee15f0002a4dc.poscar: -------------------------------------------------------------------------------- 1 | -.37086156E+03 '046/ht.task.triolith--default.00ecadb66a8baab030675a817d2ee15f0002a4dc.cleanup.0.unclaimed.2.finished/ht.run.2016-01-11_18.05.48' ['(Tag) original structure: Al2F16Sr5', '(Tag) original structure path: structures/all/matches/ternaries/68/8_68_1bA1gA1hB1hA4iC.cif', '(Tag) type: ternary'] 2 | 1 3 | 3.444884 -6.785908 0.000000 4 | 3.444884 6.785908 0.000000 5 | 0.000000 0.000000 12.043561 6 | Hf Zn N 7 | 10 4 32 8 | Direct 9 | 0.75000000 0.75000000 0.14314454 10 | 0.25000000 0.25000000 0.35685546 11 | 0.75000000 0.75000000 0.64314454 12 | 0.25000000 0.25000000 0.85685546 13 | 0.25000000 0.75000000 0.25000000 14 | 0.75000000 0.25000000 0.75000000 15 | 0.75000000 0.25000000 0.02657995 16 | 0.75000000 0.25000000 0.47342005 17 | 0.25000000 0.75000000 0.52657995 18 | 0.25000000 0.75000000 0.97342005 19 | 0.25000000 0.25000000 0.10718944 20 | 0.75000000 0.75000000 0.39281056 21 | 0.25000000 0.25000000 0.60718944 22 | 0.75000000 0.75000000 0.89281056 23 | 0.32171573 0.49040581 0.24334585 24 | 0.50959419 0.67828427 0.25665415 25 | 0.00959419 0.17828427 0.74334585 26 | 0.17828427 0.00959419 0.24334585 27 | 0.67828427 0.50959419 0.75665415 28 | 0.99040581 0.82171573 0.25665415 29 | 0.49040581 0.32171573 0.74334585 30 | 0.82171573 0.99040581 0.75665415 31 | 0.02449093 0.58763465 0.10661258 32 | 0.41236535 0.97550907 0.39338742 33 | 0.91236535 0.47550907 0.60661258 34 | 0.47550907 0.91236535 0.10661258 35 | 0.97550907 0.41236535 0.89338742 36 | 0.08763465 0.52449093 0.39338742 37 | 0.58763465 0.02449093 0.60661258 38 | 0.52449093 0.08763465 0.89338742 39 | 0.92789297 0.09878444 0.10280100 40 | 0.90121556 0.07210703 0.39719900 41 | 0.40121556 0.57210703 0.60280100 42 | 0.57210703 0.40121556 0.10280100 43 | 0.07210703 0.90121556 0.89719900 44 | 0.59878444 0.42789297 0.39719900 45 | 0.09878444 0.92789297 0.60280100 46 | 0.42789297 0.59878444 0.89719900 47 | 0.28120306 0.47712413 0.97303318 48 | 0.52287587 0.71879694 0.52696682 49 | 0.02287587 0.21879694 0.47303318 50 | 0.21879694 0.02287587 0.97303318 51 | 0.71879694 0.52287587 0.02696682 52 | 0.97712413 0.78120306 0.52696682 53 | 0.47712413 0.28120306 0.47303318 54 | 0.78120306 0.97712413 0.02696682 55 | -------------------------------------------------------------------------------- /examples/inputs/raw/00021364446a637881257fd9ee912a422a6b1753.poscar: -------------------------------------------------------------------------------- 1 | -.38267572E+03 '048/ht.task.triolith--default.00021364446a637881257fd9ee912a422a6b1753.cleanup.0.unclaimed.2.finished/ht.run.2016-03-10_13.38.07' ['(Tag) original structure: F9RbTh2', '(Tag) original structure path: structures/all/matches/ternaries/62/7_62_1cA1cB1dC4dA.cif', '(Tag) type: ternary'] 2 | 1 3 | 6.729996 0.000000 0.000000 4 | 0.000000 9.412429 0.000000 5 | 0.000000 0.000000 8.919448 6 | Hf Zn N 7 | 8 4 36 8 | Direct 9 | 0.25298222 0.45025421 0.31557996 10 | 0.24701778 0.95025421 0.81557996 11 | 0.24701778 0.54974579 0.81557996 12 | 0.74701778 0.95025421 0.68442004 13 | 0.74701778 0.54974579 0.68442004 14 | 0.75298222 0.04974579 0.18442004 15 | 0.75298222 0.45025421 0.18442004 16 | 0.25298222 0.04974579 0.31557996 17 | 0.59697916 0.25000000 0.82686578 18 | 0.90302084 0.75000000 0.32686578 19 | 0.40302084 0.75000000 0.17313422 20 | 0.09697916 0.25000000 0.67313422 21 | 0.01809884 0.57498929 0.15848322 22 | 0.48190116 0.07498929 0.65848322 23 | 0.48190116 0.42501071 0.65848322 24 | 0.98190116 0.07498929 0.84151678 25 | 0.98190116 0.42501071 0.84151678 26 | 0.51809884 0.92501071 0.34151678 27 | 0.51809884 0.57498929 0.34151678 28 | 0.01809884 0.92501071 0.15848322 29 | 0.38099059 0.11552545 0.96761734 30 | 0.11900941 0.61552545 0.46761734 31 | 0.11900941 0.88447455 0.46761734 32 | 0.61900941 0.61552545 0.03238266 33 | 0.61900941 0.88447455 0.03238266 34 | 0.88099059 0.38447455 0.53238266 35 | 0.88099059 0.11552545 0.53238266 36 | 0.38099059 0.38447455 0.96761734 37 | 0.18405523 0.59629476 0.09903596 38 | 0.31594477 0.09629476 0.59903596 39 | 0.31594477 0.40370524 0.59903596 40 | 0.81594477 0.09629476 0.90096404 41 | 0.81594477 0.40370524 0.90096404 42 | 0.68405523 0.90370524 0.40096404 43 | 0.68405523 0.59629476 0.40096404 44 | 0.18405523 0.90370524 0.09903596 45 | 0.94408897 0.38934530 0.40753579 46 | 0.55591103 0.88934530 0.90753579 47 | 0.55591103 0.61065470 0.90753579 48 | 0.05591103 0.88934530 0.59246421 49 | 0.05591103 0.61065470 0.59246421 50 | 0.44408897 0.11065470 0.09246421 51 | 0.44408897 0.38934530 0.09246421 52 | 0.94408897 0.11065470 0.40753579 53 | 0.32785766 0.25000000 0.33884582 54 | 0.17214234 0.75000000 0.83884582 55 | 0.67214234 0.75000000 0.66115418 56 | 0.82785766 0.25000000 0.16115418 57 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from aviary.core import masked_mean, masked_std, np_one_hot, np_softmax 6 | 7 | 8 | def test_np_one_hot(): 9 | assert np.allclose(np_one_hot(np.arange(3)), np.eye(3)) 10 | 11 | # test n_classes kwarg 12 | out = np_one_hot(np.arange(3), n_classes=5) 13 | expected = np.eye(5)[np.arange(3)] 14 | assert np.allclose(out, expected) 15 | 16 | 17 | def test_np_softmax(): 18 | x1 = np.array([[0, 1, 2], [3, 4, 5]]) 19 | expected = np.array([[0.0900, 0.2447, 0.6652], [0.0900, 0.2447, 0.6652]]) 20 | assert np.allclose(np_softmax(x1), expected, atol=1e-4) 21 | 22 | for xi in np.random.rand(3, 2, 3): 23 | for axis in (0, 1): 24 | # test numbers in softmaxed dimension all sum to 1 25 | out = np_softmax(xi, axis=axis) 26 | assert np.allclose(out.sum(axis=axis), 1) 27 | 28 | 29 | def test_masked_mean(): 30 | # test 1d tensor 31 | x1 = torch.arange(5).float() 32 | mask1 = torch.tensor([0, 1, 0, 1, 0]).bool() 33 | assert masked_mean(x1, mask1) == 2 34 | 35 | assert masked_mean(x1, mask1) == sum(x1 * mask1) / sum(mask1) 36 | 37 | # test 2d tensor 38 | x2 = torch.tensor([[1, 2, 3], [4, 5, 6]]).float() 39 | mask2 = torch.tensor([[1, 1, 0], [0, 0, 1]]).bool() 40 | 41 | assert masked_mean(x2, mask2) == pytest.approx([1, 2, 6]) 42 | assert masked_mean(x2, mask2, dim=1) == pytest.approx([1.5, 6]) 43 | 44 | 45 | def test_masked_std(): 46 | # test 1d tensor 47 | x1 = torch.arange(5).float() 48 | mask1 = torch.tensor([0, 1, 0, 1, 0]).bool() 49 | assert masked_std(x1, mask1) == 1 50 | 51 | # test 2d tensor 52 | x2 = torch.tensor([[1, 1, 1], [2, 2, 4]]).float() 53 | mask2 = torch.tensor([[0, 1, 0], [1, 0, 1]]).bool() 54 | assert masked_std(x2, mask2, dim=0) == pytest.approx([0, 0, 0], abs=1e-4) 55 | assert masked_std(x2, mask2, dim=1) == pytest.approx([0, 1], abs=1e-4) 56 | 57 | # test against explicit calculation 58 | rand_floats = torch.rand(3, 4, 5) 59 | rand_masks = torch.randint(0, 2, (3, 4, 5)).bool() 60 | for xi, mask in zip(rand_floats, rand_masks, strict=False): 61 | for dim in (0, 1): 62 | out = masked_std(xi, mask, dim=dim) 63 | xi_nan = torch.where(mask, xi, torch.tensor(float("nan"))) 64 | mean = xi_nan.nanmean(dim=dim) 65 | std = (xi_nan - mean.unsqueeze(dim=dim)).pow(2).nanmean(dim=dim).sqrt() 66 | 67 | assert out == pytest.approx(std, abs=1e-4, nan_ok=True) 68 | -------------------------------------------------------------------------------- /examples/inputs/raw/006197ce31efdef23c82f0f6cec08ac0b0febdce.poscar: -------------------------------------------------------------------------------- 1 | -.39938747E+03 '060/ht.task.triolith--default.006197ce31efdef23c82f0f6cec08ac0b0febdce.cleanup.0.unclaimed.2.finished/ht.run.2016-01-14_17.42.34' ['(Tag) original structure: As2Pr4S9', '(Tag) original structure path: structures/all/matches/ternaries/60/8_60_1cA1dB2dC4dA.cif', '(Tag) type: ternary'] 2 | 1 3 | 19.735812 0.000000 0.000000 4 | 0.000000 6.101939 0.000000 5 | 0.000000 0.000000 5.570787 6 | Zn N Hf 7 | 16 36 8 8 | Direct 9 | 0.43941653 0.35413699 0.14372104 10 | 0.56058347 0.35413699 0.35627896 11 | 0.06058347 0.14586301 0.64372104 12 | 0.56058347 0.64586301 0.85627896 13 | 0.06058347 0.85413699 0.14372104 14 | 0.43941653 0.64586301 0.64372104 15 | 0.93941653 0.85413699 0.35627896 16 | 0.93941653 0.14586301 0.85627896 17 | 0.28716220 0.40737164 0.21147903 18 | 0.71283780 0.40737164 0.28852097 19 | 0.21283780 0.09262836 0.71147903 20 | 0.71283780 0.59262836 0.78852097 21 | 0.21283780 0.90737164 0.21147903 22 | 0.28716220 0.59262836 0.71147903 23 | 0.78716220 0.90737164 0.28852097 24 | 0.78716220 0.09262836 0.78852097 25 | -0.00000000 0.13213508 0.25000000 26 | 0.50000000 0.36786492 0.75000000 27 | -0.00000000 0.86786492 0.75000000 28 | 0.50000000 0.63213508 0.25000000 29 | 0.28161480 0.90975240 0.50653017 30 | 0.71838520 0.90975240 0.99346983 31 | 0.21838520 0.59024760 0.00653017 32 | 0.71838520 0.09024760 0.49346983 33 | 0.21838520 0.40975240 0.50653017 34 | 0.28161480 0.09024760 0.00653017 35 | 0.78161480 0.40975240 0.99346983 36 | 0.78161480 0.59024760 0.49346983 37 | 0.46760771 0.05450945 0.24347796 38 | 0.53239229 0.05450945 0.25652204 39 | 0.03239229 0.44549055 0.74347796 40 | 0.53239229 0.94549055 0.75652204 41 | 0.03239229 0.55450945 0.24347796 42 | 0.46760771 0.94549055 0.74347796 43 | 0.96760771 0.55450945 0.25652204 44 | 0.96760771 0.44549055 0.75652204 45 | 0.34446810 0.04626765 0.98351114 46 | 0.65553190 0.04626765 0.51648886 47 | 0.15553190 0.45373235 0.48351114 48 | 0.65553190 0.95373235 0.01648886 49 | 0.15553190 0.54626765 0.98351114 50 | 0.34446810 0.95373235 0.48351114 51 | 0.84446810 0.54626765 0.51648886 52 | 0.84446810 0.45373235 0.01648886 53 | 0.37061771 0.46850177 0.41085414 54 | 0.62938229 0.46850177 0.08914586 55 | 0.12938229 0.03149823 0.91085414 56 | 0.62938229 0.53149823 0.58914586 57 | 0.12938229 0.96850177 0.41085414 58 | 0.37061771 0.53149823 0.91085414 59 | 0.87061771 0.96850177 0.08914586 60 | 0.87061771 0.03149823 0.58914586 61 | 0.09968678 0.26946500 0.18570777 62 | 0.90031322 0.26946500 0.31429223 63 | 0.40031322 0.23053500 0.68570777 64 | 0.90031322 0.73053500 0.81429223 65 | 0.40031322 0.76946500 0.18570777 66 | 0.09968678 0.73053500 0.68570777 67 | 0.59968678 0.76946500 0.31429223 68 | 0.59968678 0.23053500 0.81429223 69 | -------------------------------------------------------------------------------- /tests/test_print_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from aviary.utils import print_metrics_classification, print_metrics_regression 4 | 5 | # generate random data to test print functions 6 | rng = np.random.RandomState(42) 7 | 8 | # Generate reg data 9 | xs = rng.rand(100) 10 | y_pred = xs + 0.1 * rng.normal(size=100) 11 | y_true = xs + 0.1 * rng.normal(size=100) 12 | 13 | # Generate clf data 14 | y_binary = rng.choice([0, 1], (100)) 15 | y_proba = np.clip(y_binary - 0.1 * rng.normal(scale=5, size=(100)), 0.1, 0.9) 16 | 17 | # NOTE binary clf is handled as a multi-class clf problem therefore we need 18 | # to add another prediction dimension to accommodate the negative class 19 | y_probs = np.expand_dims(y_proba, axis=(0, 2)) 20 | y_probs = np.tile(y_probs, (1, 1, 2)) 21 | y_probs[0, :, 1] = 1 - y_proba 22 | 23 | 24 | def test_regression_metrics(capsys): 25 | print_metrics_regression(y_true, y_pred[None, :]) 26 | out, err = capsys.readouterr() 27 | assert err == "" 28 | lines = out.split("\n") 29 | assert len(lines) == 5 30 | assert out.startswith("Model Performance Metrics:\nR2 Score: ") 31 | assert lines[2].startswith("MAE: ") 32 | assert lines[3].startswith("RMSE: ") 33 | 34 | 35 | def test_regression_metrics_ensemble(capsys): 36 | # simulate 2-model ensemble by duplicating predictions along 0-axis 37 | y_preds = np.tile(y_pred, (2, 1)) 38 | print_metrics_regression(y_true, y_preds) 39 | out, err = capsys.readouterr() 40 | assert err == "" 41 | lines = out.split("\n") 42 | assert len(lines) == 10 43 | assert out.startswith("Model Performance Metrics:\nR2 Score: ") 44 | assert lines[2].startswith("MAE: ") 45 | assert "+/-" in lines[2] 46 | assert lines[3].startswith("RMSE: ") 47 | assert "+/-" in lines[3] 48 | assert lines[5].startswith("Ensemble Performance Metrics:") 49 | 50 | 51 | def test_classification_metrics(capsys): 52 | print_metrics_classification(y_binary, y_probs) 53 | out, err = capsys.readouterr() 54 | assert err == "" 55 | lines = out.split("\n") 56 | assert len(lines) == 8 57 | assert out.startswith("\nModel Performance Metrics:\nAccuracy") 58 | assert lines[3].startswith("ROC-AUC") 59 | assert lines[4].startswith("Weighted Precision") 60 | assert lines[5].startswith("Weighted Recall") 61 | assert lines[6].startswith("Weighted F-score") 62 | 63 | 64 | def test_classification_metrics_ensemble(capsys): 65 | y_probs_ens = np.tile(y_probs, (2, 1, 1)) 66 | print_metrics_classification(y_binary, y_probs_ens) 67 | out, err = capsys.readouterr() 68 | assert err == "" 69 | lines = out.split("\n") 70 | assert len(lines) == 15 71 | assert out.startswith("\nModel Performance Metrics:\nAccuracy") 72 | assert lines[3].startswith("ROC-AUC") 73 | assert "+/-" in lines[3] 74 | assert lines[4].startswith("Weighted Precision") 75 | assert "+/-" in lines[4] 76 | assert lines[5].startswith("Weighted Recall") 77 | assert "+/-" in lines[5] 78 | assert lines[6].startswith("Weighted F-score") 79 | assert "+/-" in lines[6] 80 | assert lines[8].startswith("Ensemble Performance Metrics:") 81 | -------------------------------------------------------------------------------- /aviary/scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def scatter_reduce(src, index, dim=-1, dim_size=None, reduce="sum"): 5 | """Performs a scatter-reduce operation on the input tensor. 6 | 7 | This function scatters the elements from the source tensor (src) into a new tensor 8 | of shape determined by dim_size along the specified dimension (dim), using the 9 | given reduction method. It's compatible with autograd for gradient computation. 10 | 11 | NOTE this function was written by Claude 3.5 Sonnet. 12 | 13 | Args: 14 | src (torch.Tensor): The source tensor. 15 | index (torch.Tensor): The indices of elements to scatter. Must be 1D or have 16 | the same number of dimensions as src. 17 | dim (int, optional): The axis along which to index. Defaults to -1. 18 | dim_size (int, optional): The size of the output tensor's dimension `dim`. 19 | If None, it's inferred as index.max().item() + 1. Defaults to None. 20 | reduce (str, optional): The reduction operation to perform. 21 | Options: "sum", "mean", "amax", "max", "amin", "min", "prod". 22 | Defaults to "sum". 23 | 24 | Returns: 25 | torch.Tensor: The output tensor after the scatter-reduce operation. 26 | 27 | Raises: 28 | ValueError: If an unsupported reduction method is specified. 29 | RuntimeError: If index and src tensors are incompatible. 30 | 31 | Example: 32 | >>> src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) 33 | >>> index = torch.tensor([0, 1, 0, 1, 2]) 34 | >>> scatter_reduce(src, index, dim=0, reduce="sum") 35 | tensor([4., 6., 5.]) 36 | """ 37 | if dim_size is None: 38 | dim_size = index.max().item() + 1 39 | 40 | # Prepare the output tensor shape 41 | shape = list(src.shape) 42 | shape[dim] = dim_size 43 | 44 | # Ensure index has the same number of dimensions as src 45 | if index.dim() != src.dim(): 46 | if index.dim() != 1: 47 | raise RuntimeError( 48 | "Index tensor must be 1D or have the same number of dimensions " 49 | f"as src tensor. {index.shape=} != {src.shape=}" 50 | ) 51 | # Expand index to match src dimensions 52 | repeat_shape = [1] * src.dim() 53 | repeat_shape[dim] = src.size(dim) 54 | index = index.view(-1, *[1] * (src.dim() - 1)).expand_as(src) 55 | 56 | # Perform scatter_reduce operation 57 | if reduce in ["sum", "mean"]: 58 | out = torch.zeros(shape, dtype=src.dtype, device=src.device) 59 | out = out.scatter_add(dim, index, src) 60 | if reduce == "mean": 61 | count = torch.zeros(shape, dtype=src.dtype, device=src.device) 62 | count = count.scatter_add(dim, index, torch.ones_like(src)) 63 | out = out / (count + (count == 0).float()) # avoid division by zero 64 | elif reduce in ["amax", "max"]: 65 | out = torch.full(shape, float("-inf"), dtype=src.dtype, device=src.device) 66 | out = torch.max(out, out.scatter(dim, index, src)) 67 | elif reduce in ["amin", "min"]: 68 | out = torch.full(shape, float("inf"), dtype=src.dtype, device=src.device) 69 | out = torch.min(out, out.scatter(dim, index, src)) 70 | elif reduce == "prod": 71 | out = torch.ones(shape, dtype=src.dtype, device=src.device) 72 | out = out.scatter(dim, index, src, reduce="multiply") 73 | else: 74 | raise ValueError(f"Unsupported reduction method: {reduce}") 75 | 76 | return out 77 | -------------------------------------------------------------------------------- /examples/inputs/raw/0069751a4ba575438917a8b4e5a565b686388f6e.poscar: -------------------------------------------------------------------------------- 1 | -.38833339E+03 '072/ht.task.triolith--default.0069751a4ba575438917a8b4e5a565b686388f6e.cleanup.0.unclaimed.2.finished/ht.run.2016-02-05_08.20.36' ['(Tag) original structure: GeNa4Se4', '(Tag) original structure path: structures/all/matches/ternaries/62/13_62_2cA2cB2dC3dB4cC.cif', '(Tag) type: ternary'] 2 | 1 3 | 21.954219 0.000000 0.000000 4 | 0.000000 7.464043 0.000000 5 | 0.000000 -0.000000 5.699575 6 | Zn Hf N 7 | 32 8 32 8 | Direct 9 | 0.15188087 0.25000000 0.37123552 10 | 0.34811913 0.75000000 0.87123552 11 | 0.84811913 0.75000000 0.62876448 12 | 0.65188087 0.25000000 0.12876448 13 | 0.07011749 0.75000000 0.53590908 14 | 0.42988251 0.25000000 0.03590908 15 | 0.92988251 0.25000000 0.46409092 16 | 0.57011749 0.75000000 0.96409092 17 | 0.22545090 0.58089851 0.61111204 18 | 0.27454910 0.08089851 0.11111204 19 | 0.27454910 0.41910149 0.11111204 20 | 0.77454910 0.08089851 0.38888796 21 | 0.77454910 0.41910149 0.38888796 22 | 0.72545090 0.91910149 0.88888796 23 | 0.72545090 0.58089851 0.88888796 24 | 0.22545090 0.91910149 0.61111204 25 | 0.46623328 0.53080619 0.77253392 26 | 0.03376672 0.03080619 0.27253392 27 | 0.03376672 0.46919381 0.27253392 28 | 0.53376672 0.03080619 0.22746608 29 | 0.53376672 0.46919381 0.22746608 30 | 0.96623328 0.96919381 0.72746608 31 | 0.96623328 0.53080619 0.72746608 32 | 0.46623328 0.96919381 0.77253392 33 | 0.36582353 0.48136016 0.38535976 34 | 0.13417647 0.98136016 0.88535976 35 | 0.13417647 0.51863984 0.88535976 36 | 0.63417647 0.98136016 0.61464024 37 | 0.63417647 0.51863984 0.61464024 38 | 0.86582353 0.01863984 0.11464024 39 | 0.86582353 0.48136016 0.11464024 40 | 0.36582353 0.01863984 0.38535976 41 | 0.22946066 0.75000000 0.15437777 42 | 0.27053934 0.25000000 0.65437777 43 | 0.77053934 0.25000000 0.84562223 44 | 0.72946066 0.75000000 0.34562223 45 | 0.44935397 0.75000000 0.30458592 46 | 0.05064603 0.25000000 0.80458592 47 | 0.55064603 0.25000000 0.69541408 48 | 0.94935397 0.75000000 0.19541408 49 | 0.32504633 0.75000000 0.23191919 50 | 0.17495367 0.25000000 0.73191919 51 | 0.67495367 0.25000000 0.76808081 52 | 0.82504633 0.75000000 0.26808081 53 | 0.31700735 0.25000000 0.33328863 54 | 0.18299265 0.75000000 0.83328863 55 | 0.68299265 0.75000000 0.66671137 56 | 0.81700735 0.25000000 0.16671137 57 | 0.29931688 0.49244928 0.78952729 58 | 0.20068312 0.99244928 0.28952729 59 | 0.20068312 0.50755072 0.28952729 60 | 0.70068312 0.99244928 0.21047271 61 | 0.70068312 0.50755072 0.21047271 62 | 0.79931688 0.00755072 0.71047271 63 | 0.79931688 0.49244928 0.71047271 64 | 0.29931688 0.00755072 0.78952729 65 | 0.45714683 0.25000000 0.69923780 66 | 0.04285317 0.75000000 0.19923780 67 | 0.54285317 0.75000000 0.30076220 68 | 0.95714683 0.25000000 0.80076220 69 | 0.42259913 0.75000000 0.66077824 70 | 0.07740087 0.25000000 0.16077824 71 | 0.57740087 0.25000000 0.33922176 72 | 0.92259913 0.75000000 0.83922176 73 | 0.06081497 0.50723160 0.66631613 74 | 0.43918503 0.00723160 0.16631613 75 | 0.43918503 0.49276840 0.16631613 76 | 0.93918503 0.00723160 0.33368387 77 | 0.93918503 0.49276840 0.33368387 78 | 0.56081497 0.99276840 0.83368387 79 | 0.56081497 0.50723160 0.83368387 80 | 0.06081497 0.99276840 0.66631613 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "aviary-models" 3 | version = "1.2.1" 4 | description = "A collection of machine learning models for materials discovery" 5 | authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }] 6 | readme = "README.md" 7 | license = { file = "LICENSE" } 8 | keywords = [ 9 | "Graph Neural Network", 10 | "Machine Learning", 11 | "Materials Discovery", 12 | "Materials Informatics", 13 | "Materials Science", 14 | "Roost", 15 | "Self-Attention", 16 | "Transformer", 17 | "Wren", 18 | "Wyckoff positions", 19 | ] 20 | classifiers = [ 21 | "Intended Audience :: Science/Research", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | "Topic :: Scientific/Engineering :: Chemistry", 28 | "Topic :: Scientific/Engineering :: Physics", 29 | ] 30 | 31 | requires-python = ">=3.10" 32 | dependencies = [ 33 | "numpy>=2,<3", 34 | "pandas", 35 | "pymatgen>=2025.4.10", 36 | "scikit_learn", 37 | "tensorboard", 38 | "torch>=2.3.0", 39 | "tqdm", 40 | "typing-extensions", 41 | "wandb", 42 | ] 43 | 44 | [project.urls] 45 | Repo = "https://github.com/CompRhys/aviary" 46 | 47 | [project.optional-dependencies] 48 | test = ["matminer", "moyopy>=0.3.3", "pytest", "pytest-cov"] 49 | moyopy = ["moyopy>=0.3.3"] 50 | 51 | [build-system] 52 | requires = ["uv_build>=0.7.5"] 53 | build-backend = "uv_build" 54 | 55 | [tool.uv.build-backend] 56 | module-name = "aviary" 57 | module-root = "" 58 | 59 | [tool.pytest.ini_options] 60 | testpaths = ["tests"] 61 | addopts = "-p no:warnings" 62 | 63 | [tool.mypy] 64 | no_implicit_optional = false 65 | 66 | [tool.ruff] 67 | line-length = 90 68 | target-version = "py310" 69 | output-format = "concise" 70 | 71 | [tool.ruff.lint] 72 | select = [ 73 | "B", # flake8-bugbear 74 | "C4", # flake8-comprehensions 75 | "D", # pydocstyle 76 | "E", # pycodestyle error 77 | "EXE", # flake8-executable 78 | "F", # pyflakes 79 | "FA", # flake8-future-annotations 80 | "FLY", # flynt 81 | "I", # isort 82 | "ICN", # flake8-import-conventions 83 | "ISC", # flake8-implicit-str-concat 84 | "PD", # pandas-vet 85 | "PERF", # perflint 86 | "PIE", # flake8-pie 87 | "PL", # pylint 88 | "PT", # flake8-pytest-style 89 | "PYI", # flakes8-pyi 90 | "Q", # flake8-quotes 91 | "RET", # flake8-return 92 | "RSE", # flake8-raise 93 | "RUF", # Ruff-specific rules 94 | "SIM", # flake8-simplify 95 | "SLOT", # flake8-slots 96 | "TCH", # flake8-type-checking 97 | "TID", # tidy imports 98 | "TID", # flake8-tidy-imports 99 | "UP", # pyupgrade 100 | "W", # pycodestyle warning 101 | "YTT", # flake8-2020 102 | ] 103 | ignore = [ 104 | "C408", # Unnecessary dict call - rewrite as a literal 105 | "D100", # Missing docstring in public module 106 | "D104", # Missing docstring in public package 107 | "D105", # Missing docstring in magic method 108 | "D205", # 1 blank line required between summary line and description 109 | "E731", # Do not assign a lambda expression, use a def 110 | "ISC001", 111 | "PD901", # pandas-df-variable-name 112 | "PLC0415", 113 | "PLR", # pylint refactor 114 | "PT006", # pytest-parametrize-names-wrong-type 115 | ] 116 | pydocstyle.convention = "google" 117 | isort.known-third-party = ["wandb"] 118 | 119 | [tool.ruff.lint.per-file-ignores] 120 | "tests/*" = ["D"] 121 | "examples/notebooks/*.py" = ["E402"] 122 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ad3ae37a9fb7eab9189ff1cb384b9d2dd26bbf9.poscar: -------------------------------------------------------------------------------- 1 | -.61080027E+03 '078/ht.task.gamma--default.0ad3ae37a9fb7eab9189ff1cb384b9d2dd26bbf9.cleanup.0.unclaimed.3.finished/ht.run.2016-02-12_22.11.44' ['(Tag) original structure: Ba11O24Os4', '(Tag) original structure path: structures/all/matches/ternaries/88/12_88_1aA1cB1dB1eA2fA6fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.139712 0.000001 0.000000 4 | -0.000001 10.139712 -0.000000 5 | 5.069855 5.069856 8.570492 6 | Ti Zn N 7 | 22 8 48 8 | Direct 9 | 0.87500000 0.12500000 0.25000000 10 | 0.12500000 0.87500000 0.75000000 11 | 0.85706887 0.60706887 0.28586226 12 | 0.60706887 0.85706887 0.78586226 13 | 0.39293113 0.14293113 0.21413774 14 | 0.14293113 0.39293113 0.71413774 15 | 0.36659200 0.17340598 0.69506440 16 | 0.17340598 0.93834360 0.19506440 17 | 0.36847038 0.63340800 0.80493560 18 | 0.06165640 0.36847038 0.30493560 19 | 0.93834360 0.63152962 0.69506440 20 | 0.63340800 0.82659402 0.30493560 21 | 0.82659402 0.06165640 0.80493560 22 | 0.63152962 0.36659200 0.19506440 23 | 0.64923651 0.69033378 0.13171933 24 | 0.69033378 0.21904416 0.63171933 25 | 0.32205310 0.35076349 0.36828067 26 | 0.78095584 0.32205310 0.86828067 27 | 0.21904416 0.67794690 0.13171933 28 | 0.35076349 0.30966622 0.86828067 29 | 0.30966622 0.78095584 0.36828067 30 | 0.67794690 0.64923651 0.63171933 31 | -0.00000000 0.00000000 0.00000000 32 | -0.00000000 0.00000000 0.50000000 33 | 0.50000000 0.00000000 0.50000000 34 | -0.00000000 0.50000000 0.00000000 35 | 0.50000000 0.50000000 0.00000000 36 | 0.50000000 0.50000000 0.50000000 37 | -0.00000000 0.50000000 0.50000000 38 | 0.50000000 0.00000000 0.00000000 39 | 0.53440356 0.79247721 0.60695497 40 | 0.79247721 0.85864147 0.10695497 41 | 0.89943218 0.46559644 0.89304503 42 | 0.14135853 0.89943218 0.39304503 43 | 0.85864147 0.10056782 0.60695497 44 | 0.46559644 0.20752279 0.39304503 45 | 0.20752279 0.14135853 0.89304503 46 | 0.10056782 0.53440356 0.10695497 47 | 0.27999134 0.82335358 0.57883182 48 | 0.82335358 0.14117683 0.07883182 49 | 0.90218540 0.72000866 0.92116818 50 | 0.85882317 0.90218540 0.42116818 51 | 0.14117683 0.09781460 0.57883182 52 | 0.72000866 0.17664642 0.42116818 53 | 0.17664642 0.85882317 0.92116818 54 | 0.09781460 0.27999134 0.07883182 55 | 0.90435227 0.82714494 0.81183548 56 | 0.82714494 0.28381226 0.31183548 57 | 0.13898042 0.09564773 0.68816452 58 | 0.71618774 0.13898042 0.18816452 59 | 0.28381226 0.86101958 0.81183548 60 | 0.09564773 0.17285506 0.18816452 61 | 0.17285506 0.71618774 0.68816452 62 | 0.86101958 0.90435227 0.31183548 63 | 0.18007599 0.49398734 0.32307195 64 | 0.49398734 0.49685207 0.82307195 65 | 0.31705928 0.81992401 0.17692805 66 | 0.50314793 0.31705928 0.67692805 67 | 0.49685207 0.68294072 0.32307195 68 | 0.81992401 0.50601266 0.67692805 69 | 0.50601266 0.50314793 0.17692805 70 | 0.68294072 0.18007599 0.82307195 71 | 0.41501395 0.69820587 0.93773867 72 | 0.69820587 0.64724739 0.43773867 73 | 0.13594453 0.58498605 0.56226133 74 | 0.35275261 0.13594453 0.06226133 75 | 0.64724739 0.86405547 0.93773867 76 | 0.58498605 0.30179413 0.06226133 77 | 0.30179413 0.35275261 0.56226133 78 | 0.86405547 0.41501395 0.43773867 79 | 0.74435821 0.52432675 0.21225406 80 | 0.52432675 0.04338773 0.71225406 81 | 0.23658080 0.25564179 0.28774594 82 | 0.95661227 0.23658080 0.78774594 83 | 0.04338773 0.76341920 0.21225406 84 | 0.25564179 0.47567325 0.78774594 85 | 0.47567325 0.95661227 0.28774594 86 | 0.76341920 0.74435821 0.71225406 87 | -------------------------------------------------------------------------------- /examples/inputs/raw/0a10bbf9e7cd3d8da2469f15e06ceca7d8d2af5d.poscar: -------------------------------------------------------------------------------- 1 | -.59711780E+03 '080/ht.task.gamma--default.0a10bbf9e7cd3d8da2469f15e06ceca7d8d2af5d.cleanup.0.unclaimed.3.finished/ht.run.2016-01-04_11.35.30' ['(Tag) original structure: N11P6Rb3', '(Tag) original structure path: structures/all/matches/ternaries/213/6_213_1aA1cB1cA1dB1eB1eC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.640122 -0.000335 -0.000327 4 | -0.000336 10.647665 0.001083 5 | -0.000326 0.001083 10.648228 6 | Zn Ti N 7 | 12 24 44 8 | Direct 9 | 0.37508445 0.37508887 0.37504493 10 | 0.62494728 0.87508013 0.12491953 11 | 0.12492617 0.62492222 0.87505292 12 | 0.87507172 0.12492622 0.62485794 13 | 0.72704563 0.72702715 0.72698344 14 | 0.27292335 0.22692544 0.77304857 15 | 0.97704783 0.52297581 0.47696343 16 | 0.77294307 0.27297746 0.22699959 17 | 0.02286844 0.02290848 0.02295900 18 | 0.22706647 0.77295062 0.27297118 19 | 0.52287739 0.47709822 0.97699114 20 | 0.47709683 0.97708896 0.52300783 21 | 0.06997315 0.27086890 0.27986617 22 | 0.72011622 0.56998201 0.22902245 23 | 0.31997055 0.97004664 0.02086265 24 | 0.27990750 0.06991229 0.27104539 25 | 0.22872668 0.72045335 0.56996215 26 | 0.46999786 0.47910899 0.68011386 27 | 0.27120699 0.27942688 0.07017936 28 | 0.56998693 0.22908425 0.72016603 29 | 0.97876892 0.82007522 0.52957564 30 | 0.72869531 0.77947918 0.42980911 31 | 0.52129812 0.17987130 0.02962728 32 | 0.67988556 0.47028440 0.47908583 33 | 0.17998432 0.02971812 0.52094208 34 | 0.02994635 0.52092573 0.18011369 35 | 0.02140657 0.32020134 0.97055809 36 | 0.22024769 0.92998691 0.77095846 37 | 0.93001137 0.77095766 0.22030524 38 | 0.43003656 0.72905476 0.77967244 39 | 0.77125825 0.22056725 0.92993063 40 | 0.52993122 0.97910780 0.81993856 41 | 0.97008829 0.02089702 0.31991814 42 | 0.82009873 0.52987022 0.97918264 43 | 0.47866263 0.67994669 0.47039571 44 | 0.77996067 0.43009416 0.72906280 45 | 0.94438375 0.42715485 0.32289350 46 | 0.67710879 0.44437890 0.07278568 47 | 0.19437105 0.92712974 0.17722607 48 | 0.32291109 0.94432523 0.42718949 49 | 0.07271000 0.67721224 0.44424549 50 | 0.42710803 0.32279449 0.80565916 51 | 0.42733405 0.32275592 0.94426201 52 | 0.44439866 0.07285410 0.67712669 53 | 0.82264647 0.69422908 0.57276947 54 | 0.57269570 0.82281734 0.55570719 55 | 0.67729931 0.30576928 0.07274910 56 | 0.80557950 0.42712590 0.32276359 57 | 0.30560411 0.07290852 0.67715989 58 | 0.07289180 0.67716750 0.30568349 59 | 0.17730793 0.19426276 0.92727206 60 | 0.17711398 0.05559526 0.92719094 61 | 0.05562500 0.92718605 0.17712709 62 | 0.55560647 0.57281449 0.82283937 63 | 0.92738006 0.17724268 0.05575621 64 | 0.57291827 0.82288582 0.69428940 65 | 0.92710292 0.17714834 0.19431068 66 | 0.69442518 0.57285312 0.82276475 67 | 0.32278145 0.80580202 0.42709828 68 | 0.82285963 0.55565492 0.57280392 69 | 0.12500800 0.16511079 0.41532897 70 | 0.58457373 0.62518218 0.33495937 71 | 0.37501382 0.83481996 0.91517228 72 | 0.41546897 0.12511502 0.16510698 73 | 0.33450992 0.58487564 0.62474184 74 | 0.16547506 0.41494480 0.12485987 75 | 0.62494007 0.33488535 0.58467423 76 | 0.08454666 0.87501842 0.66510908 77 | 0.83453232 0.91497982 0.37511309 78 | 0.91541597 0.37494171 0.83503988 79 | 0.87503930 0.66517762 0.08477465 80 | 0.66541513 0.08507066 0.87518667 81 | 0.17719115 0.17712798 0.17720488 82 | 0.82277684 0.67702595 0.32291383 83 | 0.42716227 0.07267635 0.92722819 84 | 0.32279849 0.82299536 0.67703616 85 | 0.57272894 0.57281672 0.57292929 86 | 0.67724781 0.32293006 0.82278498 87 | 0.07280887 0.92705742 0.42716620 88 | 0.92711841 0.42728898 0.07290209 89 | -------------------------------------------------------------------------------- /examples/inputs/raw/0ad7c1ba7ec6a37cae0523bbd1c2a75f2d56a3f2.poscar: -------------------------------------------------------------------------------- 1 | -.50650847E+03 '080/ht.task.gamma--default.0ad7c1ba7ec6a37cae0523bbd1c2a75f2d56a3f2.cleanup.0.unclaimed.3.finished/ht.run.2016-01-03_13.19.29' ['(Tag) original structure: Fe5O12Y3', '(Tag) original structure path: structures/all/matches/ternaries/148/16_148_1aA1bA1dA1eA2fA2fB8fC.cif', '(Tag) type: ternary'] 2 | 1 3 | 10.216080 -0.000000 3.537716 4 | -5.108039 8.847384 3.537716 5 | -5.108039 -8.847384 3.537716 6 | Ti Zn N 7 | 12 20 48 8 | Direct 9 | 0.36754946 0.12846507 0.24712369 10 | 0.87153493 0.75287631 0.63245054 11 | 0.75287631 0.63245054 0.87153493 12 | 0.24712369 0.36754946 0.12846507 13 | 0.12846507 0.24712369 0.36754946 14 | 0.63245054 0.87153493 0.75287631 15 | 0.87840662 0.62291357 0.25024206 16 | 0.37708643 0.74975794 0.12159338 17 | 0.74975794 0.12159338 0.37708643 18 | 0.25024206 0.87840662 0.62291357 19 | 0.62291357 0.25024206 0.87840662 20 | 0.12159338 0.37708643 0.74975794 21 | 0.00000000 0.00000000 0.00000000 22 | 0.50000000 0.50000000 0.50000000 23 | 0.50000000 0.00000000 0.00000000 24 | 0.00000000 0.00000000 0.50000000 25 | 0.00000000 0.50000000 0.00000000 26 | 0.50000000 0.50000000 0.00000000 27 | 0.50000000 0.00000000 0.50000000 28 | 0.00000000 0.50000000 0.50000000 29 | 0.62493308 0.37417690 0.24949405 30 | 0.62582310 0.75050595 0.37506692 31 | 0.75050595 0.37506692 0.62582310 32 | 0.24949405 0.62493308 0.37417690 33 | 0.37417690 0.24949405 0.62493308 34 | 0.37506692 0.62582310 0.75050595 35 | 0.87025560 0.12728315 0.74740531 36 | 0.87271685 0.25259469 0.12974440 37 | 0.25259469 0.12974440 0.87271685 38 | 0.74740531 0.87025560 0.12728315 39 | 0.12728315 0.74740531 0.87025560 40 | 0.12974440 0.87271685 0.25259469 41 | 0.27851882 0.17126274 0.07800032 42 | 0.82873726 0.92199968 0.72148118 43 | 0.92199968 0.72148118 0.82873726 44 | 0.07800032 0.27851882 0.17126274 45 | 0.17126274 0.07800032 0.27851882 46 | 0.72148118 0.82873726 0.92199968 47 | 0.61707782 0.20288302 0.29370045 48 | 0.79711698 0.70629955 0.38292218 49 | 0.70629955 0.38292218 0.79711698 50 | 0.29370045 0.61707782 0.20288302 51 | 0.20288302 0.29370045 0.61707782 52 | 0.38292218 0.79711698 0.70629955 53 | 0.17814519 0.59047415 0.89135218 54 | 0.40952585 0.10864782 0.82185481 55 | 0.10864782 0.82185481 0.40952585 56 | 0.89135218 0.17814519 0.59047415 57 | 0.59047415 0.89135218 0.17814519 58 | 0.82185481 0.40952585 0.10864782 59 | 0.91394805 0.09170363 0.30370414 60 | 0.90829637 0.69629586 0.08605195 61 | 0.69629585 0.08605195 0.90829637 62 | 0.30370415 0.91394805 0.09170363 63 | 0.09170363 0.30370414 0.91394805 64 | 0.08605195 0.90829637 0.69629586 65 | 0.31496298 0.20893908 0.40710560 66 | 0.79106092 0.59289440 0.68503702 67 | 0.59289440 0.68503702 0.79106092 68 | 0.40710560 0.31496298 0.20893908 69 | 0.20893908 0.40710560 0.31496298 70 | 0.68503702 0.79106092 0.59289440 71 | 0.30086016 0.88715915 0.20512678 72 | 0.11284085 0.79487322 0.69913984 73 | 0.79487322 0.69913984 0.11284085 74 | 0.20512678 0.30086016 0.88715915 75 | 0.88715915 0.20512678 0.30086016 76 | 0.69913984 0.11284085 0.79487322 77 | 0.90727520 0.31850294 0.60566245 78 | 0.68149706 0.39433755 0.09272480 79 | 0.39433755 0.09272480 0.68149706 80 | 0.60566245 0.90727520 0.31850294 81 | 0.31850294 0.60566245 0.90727520 82 | 0.09272480 0.68149706 0.39433755 83 | 0.40249571 0.58352797 0.19640396 84 | 0.41647203 0.80359604 0.59750429 85 | 0.80359604 0.59750429 0.41647203 86 | 0.19640396 0.40249571 0.58352797 87 | 0.58352797 0.19640396 0.40249571 88 | 0.59750429 0.41647203 0.80359604 89 | -------------------------------------------------------------------------------- /examples/matbench_example/train_wrenformer.py: -------------------------------------------------------------------------------- 1 | """Train a Wrenformer ensemble of size n_folds on collection of Matbench datasets.""" 2 | 3 | # %% 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | # Add the parent directory to system path 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | 11 | from datetime import datetime 12 | from itertools import product 13 | 14 | import wandb 15 | from matbench.metadata import mbv01_metadata 16 | from matbench.task import MatbenchTask 17 | from matbench_example.prepare_matbench_datasets import DATA_PATHS 18 | from matbench_example.trainer import train_wrenformer 19 | from matbench_example.utils import merge_json_on_disk, slurm_submit 20 | from matminer.utils.io import load_dataframe_from_json 21 | 22 | from aviary.core import TaskType 23 | 24 | MODULE_DIR = os.path.dirname(__file__) 25 | 26 | 27 | # %% 28 | epochs = 10 29 | folds = list(range(5)) 30 | timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}" 31 | today = timestamp.split("@")[0] 32 | # job_name unlike run_name doesn't include dataset and fold since not yet known without 33 | # SLURM_ARRAY_TASK_ID 34 | job_name = f"matbench-wrenformer-robust-{epochs=}" 35 | out_dir = f"{os.path.dirname(__file__)}/{today}-{job_name}" 36 | 37 | if "roost" in job_name.lower(): 38 | # deploy Roost on all tasks 39 | datasets = list(DATA_PATHS) 40 | else: 41 | # deploy Wren on structure tasks only 42 | datasets = [ 43 | k 44 | for k, v in mbv01_metadata.items() 45 | if v.input_type == "structure" and k in DATA_PATHS 46 | ] 47 | 48 | # NOTE: this script will run as is if you want to run it locally without slurm. 49 | slurm_submit( 50 | job_name=job_name, 51 | partition="ampere", 52 | account="LEE-SL3-GPU", 53 | time="8:0:0", 54 | array=f"0-{len(datasets) * len(folds) - 1}", 55 | out_dir=out_dir, 56 | slurm_flags=("--nodes", "1", "--gpus-per-node", "1"), 57 | # prepend into sbatch script to source module command and load default env 58 | # for Ampere GPU partition before actual job command 59 | pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;", 60 | ) 61 | 62 | 63 | # %% 64 | slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", "0")) 65 | print(f"Job started running {timestamp}") 66 | 67 | dataset_name, fold = list(product(datasets, folds))[slurm_array_task_id] 68 | print(f"{dataset_name=}") 69 | print(f"{fold=}") 70 | 71 | 72 | data_path = DATA_PATHS[dataset_name] 73 | id_col = "mbid" 74 | df = load_dataframe_from_json(data_path) 75 | df.index.name = id_col 76 | 77 | matbench_task = MatbenchTask(dataset_name, autoload=False) 78 | matbench_task.df = df 79 | 80 | target_col = matbench_task.metadata.target 81 | run_name = f"{job_name}-{dataset_name}-{fold=}-{target_col}" 82 | print(f"{run_name=}") 83 | task_type: TaskType = matbench_task.metadata.task_type 84 | 85 | train_df = matbench_task.get_train_and_val_data(fold, as_type="df") 86 | test_df = matbench_task.get_test_data(fold, as_type="df", include_target=True) 87 | 88 | wandb_path = None 89 | 90 | test_metrics, run_params, test_df = train_wrenformer( 91 | checkpoint=None, # None | 'local' | 'wandb' 92 | run_name=run_name, 93 | train_df=train_df, 94 | test_df=test_df, 95 | target_col=target_col, 96 | task_type=task_type, 97 | id_col=id_col, 98 | # set to None to disable logging 99 | wandb_path=wandb_path, 100 | run_params=dict(dataset=dataset_name, fold=fold), 101 | timestamp=timestamp, 102 | epochs=epochs, 103 | ) 104 | 105 | 106 | # %% 107 | # save model predictions to JSON 108 | preds_path = f"{MODULE_DIR}/model_preds/{timestamp}-{run_name}.json" 109 | os.makedirs(os.path.dirname(preds_path), exist_ok=True) 110 | 111 | # record model predictions 112 | test_df[id_col] = test_df.index 113 | preds_dict = test_df[[id_col, target_col, f"{target_col}_pred_0"]].to_dict(orient="list") 114 | merge_json_on_disk({dataset_name: {f"fold_{fold}": preds_dict}}, preds_path) 115 | 116 | # save model scores to JSON 117 | scores_path = f"{MODULE_DIR}/model_scores/{timestamp}-{run_name}.json" 118 | os.makedirs(os.path.dirname(scores_path), exist_ok=True) 119 | 120 | scores_dict = {dataset_name: {f"fold_{fold}": test_metrics}} 121 | scores_dict["params"] = run_params 122 | if wandb_path is not None: 123 | scores_dict["wandb_run"] = wandb.run.get_url() 124 | merge_json_on_disk(scores_dict, scores_path) 125 | 126 | print(f"scores for {fold = } of task {dataset_name} written to {scores_path}") 127 | -------------------------------------------------------------------------------- /aviary/networks.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | from torch import Tensor, nn 4 | 5 | 6 | class SimpleNetwork(nn.Module): 7 | """Simple Feed Forward Neural Network.""" 8 | 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | output_dim: int, 13 | hidden_layer_dims: Sequence[int], 14 | activation: type[nn.Module] = nn.LeakyReLU, 15 | batch_norm: bool = False, 16 | ) -> None: 17 | """Create a simple feed forward neural network. 18 | 19 | Args: 20 | input_dim (int): Number of input features 21 | output_dim (int): Number of output features 22 | hidden_layer_dims (list[int]): List of hidden layer sizes 23 | activation (type[nn.Module], optional): Which activation function to use. 24 | Defaults to nn.LeakyReLU. 25 | batch_norm (bool, optional): Whether to use batch_norm. Defaults to False. 26 | """ 27 | super().__init__() 28 | 29 | dims = [input_dim, *list(hidden_layer_dims)] 30 | 31 | self.fcs = nn.ModuleList( 32 | nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1) 33 | ) 34 | 35 | if batch_norm: 36 | self.bns = nn.ModuleList( 37 | nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1) 38 | ) 39 | else: 40 | self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1)) 41 | 42 | self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1)) 43 | 44 | self.fc_out = nn.Linear(dims[-1], output_dim) 45 | 46 | def forward(self, x: Tensor) -> Tensor: 47 | """Forward pass through network.""" 48 | for fc, bn, act in zip(self.fcs, self.bns, self.acts, strict=False): 49 | x = act(bn(fc(x))) 50 | 51 | return self.fc_out(x) 52 | 53 | def reset_parameters(self) -> None: 54 | """Reinitialize network weights using PyTorch defaults.""" 55 | for fc in self.fcs: 56 | fc.reset_parameters() 57 | 58 | self.fc_out.reset_parameters() 59 | 60 | def __repr__(self) -> str: 61 | input_dim = self.fcs[0].in_features 62 | output_dim = self.fc_out.out_features 63 | activation = type(self.acts[0]).__name__ 64 | return f"{type(self).__name__}({input_dim=}, {output_dim=}, {activation=})" 65 | 66 | 67 | class ResidualNetwork(nn.Module): 68 | """Feed forward Residual Neural Network.""" 69 | 70 | def __init__( 71 | self, 72 | input_dim: int, 73 | output_dim: int, 74 | hidden_layer_dims: Sequence[int], 75 | activation: type[nn.Module] = nn.ReLU, 76 | batch_norm: bool = False, 77 | ) -> None: 78 | """Create a feed forward neural network with skip connections. 79 | 80 | Args: 81 | input_dim (int): Number of input features 82 | output_dim (int): Number of output features 83 | hidden_layer_dims (list[int]): List of hidden layer sizes 84 | activation (type[nn.Module], optional): Which activation function to use. 85 | Defaults to nn.LeakyReLU. 86 | batch_norm (bool, optional): Whether to use batch_norm. Defaults to False. 87 | """ 88 | super().__init__() 89 | 90 | dims = [input_dim, *list(hidden_layer_dims)] 91 | 92 | self.fcs = nn.ModuleList( 93 | nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1) 94 | ) 95 | 96 | if batch_norm: 97 | self.bns = nn.ModuleList( 98 | nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1) 99 | ) 100 | else: 101 | self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1)) 102 | 103 | self.res_fcs = nn.ModuleList( 104 | nn.Linear(dims[idx], dims[idx + 1], bias=False) 105 | if (dims[idx] != dims[idx + 1]) 106 | else nn.Identity() 107 | for idx in range(len(dims) - 1) 108 | ) 109 | self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1)) 110 | 111 | self.fc_out = nn.Linear(dims[-1], output_dim) 112 | 113 | def forward(self, x: Tensor) -> Tensor: 114 | """Forward pass through network.""" 115 | for fc, bn, res_fc, act in zip( 116 | self.fcs, self.bns, self.res_fcs, self.acts, strict=False 117 | ): 118 | x = act(bn(fc(x))) + res_fc(x) 119 | 120 | return self.fc_out(x) 121 | 122 | def __repr__(self) -> str: 123 | input_dim = self.fcs[0].in_features 124 | output_dim = self.fc_out.out_features 125 | activation = type(self.acts[0]).__name__ 126 | return f"{type(self).__name__}({input_dim=}, {output_dim=}, {activation=})" 127 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | import torch 7 | 8 | from aviary.utils import get_element_embedding, get_metrics, get_sym_embedding 9 | 10 | 11 | @pytest.fixture 12 | def temp_element_embedding(tmp_path): 13 | embedding_data = { 14 | "H": [1.0, 2.0], 15 | "He": [3.0, 4.0], 16 | "Li": [5.0, 6.0], 17 | } 18 | path = tmp_path / "test_elem_embedding.json" 19 | with open(path, "w") as f: 20 | json.dump(embedding_data, f) 21 | return str(path) 22 | 23 | 24 | @pytest.fixture 25 | def temp_sym_embedding(tmp_path): 26 | embedding_data = { 27 | "1": {"a": [1.0, 2.0], "b": [3.0, 4.0]}, 28 | "2": {"c": [5.0, 6.0]}, 29 | } 30 | path = tmp_path / "test_sym_embedding.json" 31 | with open(path, "w") as f: 32 | json.dump(embedding_data, f) 33 | return str(path) 34 | 35 | 36 | def test_get_element_embedding_custom(temp_element_embedding): 37 | embedding = get_element_embedding(temp_element_embedding) 38 | assert isinstance(embedding, torch.nn.Embedding) 39 | assert embedding.weight.shape == (3 + 1, 2) # max_Z + 1, embedding_dim 40 | assert torch.allclose(embedding.weight[1], torch.tensor([1.0, 2.0])) # H 41 | assert torch.allclose(embedding.weight[2], torch.tensor([3.0, 4.0])) # He 42 | 43 | 44 | def test_get_element_embedding_builtin(): 45 | embedding = get_element_embedding("matscholar200") 46 | assert isinstance(embedding, torch.nn.Embedding) 47 | assert embedding.weight.shape[1] == 200 48 | 49 | 50 | def test_get_element_embedding_invalid(): 51 | with pytest.raises(ValueError, match="Invalid element embedding: invalid_embedding"): 52 | get_element_embedding("invalid_embedding") 53 | 54 | 55 | def test_get_sym_embedding_custom(temp_sym_embedding): 56 | embedding = get_sym_embedding(temp_sym_embedding) 57 | assert isinstance(embedding, torch.nn.Embedding) 58 | assert embedding.weight.shape == (3, 2) # total features, embedding_dim 59 | assert torch.allclose(embedding.weight[0], torch.tensor([1.0, 2.0])) 60 | assert torch.allclose(embedding.weight[1], torch.tensor([3.0, 4.0])) 61 | 62 | 63 | def test_get_sym_embedding_builtin(): 64 | embedding = get_sym_embedding("bra-alg-off") 65 | assert isinstance(embedding, torch.nn.Embedding) 66 | assert isinstance(embedding.weight, torch.Tensor) 67 | 68 | 69 | def test_get_sym_embedding_invalid(): 70 | with pytest.raises(ValueError, match="Invalid symmetry embedding: invalid_embedding"): 71 | get_sym_embedding("invalid_embedding") 72 | 73 | 74 | def test_regression_metrics(): 75 | targets = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) 76 | predictions = np.array([1.1, 2.1, 3.1, 4.1, 5.1]) 77 | 78 | metrics = get_metrics(targets, predictions, "regression") 79 | 80 | assert set(metrics.keys()) == {"MAE", "RMSE", "R2"} 81 | assert metrics["MAE"] == pytest.approx(0.1, abs=1e-4) 82 | assert metrics["RMSE"] == pytest.approx(0.1, abs=1e-4) 83 | assert metrics["R2"] == pytest.approx(0.995, abs=1e-4) 84 | 85 | 86 | def test_classification_metrics(): 87 | targets = np.array([0, 1, 0, 1, 0]) 88 | # Probabilities for class 0 and 1 89 | predictions = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]]) 90 | 91 | metrics = get_metrics(targets, predictions, "classification") 92 | 93 | assert set(metrics.keys()) == {"accuracy", "balanced_accuracy", "F1", "ROCAUC"} 94 | assert metrics["accuracy"] == 1.0 95 | assert metrics["balanced_accuracy"] == 1.0 96 | assert metrics["F1"] == 1.0 97 | assert metrics["ROCAUC"] == 1.0 98 | 99 | 100 | def test_nan_handling(): 101 | targets = np.array([1.0, np.nan, 3.0, 4.0]) 102 | predictions = np.array([1.1, 2.1, np.nan, 4.1]) 103 | 104 | metrics = get_metrics(targets, predictions, "regression") 105 | assert not np.isnan(metrics["MAE"]) 106 | assert not np.isnan(metrics["RMSE"]) 107 | assert not np.isnan(metrics["R2"]) 108 | 109 | 110 | def test_pandas_input(): 111 | targets = pd.Series([1.0, 2.0, 3.0]) 112 | predictions = pd.Series([1.1, 2.1, 3.1]) 113 | 114 | metrics = get_metrics(targets, predictions, "regression") 115 | assert set(metrics.keys()) == {"MAE", "RMSE", "R2"} 116 | 117 | 118 | def test_precision(): 119 | targets = np.array([1.0, 2.0, 3.0]) 120 | predictions = np.array([1.12345, 2.12345, 3.12345]) 121 | 122 | metrics = get_metrics(targets, predictions, "regression", prec=2) 123 | assert all(len(str(v).split(".")[-1]) <= 2 for v in metrics.values()) 124 | 125 | 126 | def test_invalid_type(): 127 | targets = np.array([1.0, 2.0]) 128 | predictions = np.array([1.1, 2.1]) 129 | 130 | with pytest.raises(ValueError, match="Invalid task type: invalid_type"): 131 | get_metrics(targets, predictions, "invalid_type") 132 | 133 | 134 | def test_mismatched_shapes(): 135 | targets = np.array([0, 1, 0]) 136 | predictions = np.array([[0.9, 0.1], [0.1, 0.9]]) # Wrong shape 137 | 138 | with pytest.raises(ValueError): # noqa: PT011 139 | get_metrics(targets, predictions, "classification") 140 | -------------------------------------------------------------------------------- /aviary/data.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable, Iterator 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | from typing_extensions import Self 8 | 9 | 10 | class Normalizer: 11 | """Normalize a Tensor and restore it later.""" 12 | 13 | def __init__(self) -> None: 14 | """Initialize Normalizer with mean 0 and std 1.""" 15 | self.mean = torch.tensor(0) 16 | self.std = torch.tensor(1) 17 | 18 | def fit(self, tensor: Tensor, dim: int = 0, keepdim: bool = False) -> None: 19 | """Compute the mean and standard deviation of the given tensor. 20 | 21 | Args: 22 | tensor (Tensor): Tensor to determine the mean and standard deviation over. 23 | dim (int, optional): Which dimension to take mean and standard deviation 24 | over. Defaults to 0. 25 | keepdim (bool, optional): Whether to keep the reduced dimension in Tensor. 26 | Defaults to False. 27 | """ 28 | self.mean = torch.mean(tensor, dim, keepdim) 29 | self.std = torch.std(tensor, dim, keepdim) 30 | 31 | def norm(self, tensor: Tensor) -> Tensor: 32 | """Normalize a Tensor. 33 | 34 | Args: 35 | tensor (Tensor): Tensor to be normalized 36 | 37 | Returns: 38 | Tensor: Normalized Tensor 39 | """ 40 | return (tensor - self.mean) / self.std 41 | 42 | def denorm(self, normed_tensor: Tensor) -> Tensor: 43 | """Restore normalized Tensor to original. 44 | 45 | Args: 46 | normed_tensor (Tensor): Tensor to be restored 47 | 48 | Returns: 49 | Tensor: Restored Tensor 50 | """ 51 | return normed_tensor * self.std + self.mean 52 | 53 | def state_dict(self) -> dict[str, Tensor]: 54 | """Get Normalizer parameters mean and std. 55 | 56 | Returns: 57 | dict[str, Tensor]: Dictionary storing Normalizer parameters. 58 | """ 59 | return {"mean": self.mean, "std": self.std} 60 | 61 | def load_state_dict(self, state_dict: dict[str, Tensor]) -> None: 62 | """Overwrite Normalizer parameters given a new state_dict. 63 | 64 | Args: 65 | state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters. 66 | """ 67 | self.mean = state_dict["mean"].cpu() 68 | self.std = state_dict["std"].cpu() 69 | 70 | @classmethod 71 | def from_state_dict(cls, state_dict: dict[str, Tensor]) -> Self: 72 | """Create a new Normalizer given a state_dict. 73 | 74 | Args: 75 | state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters. 76 | 77 | Returns: 78 | Normalizer 79 | """ 80 | instance = cls() 81 | instance.mean = state_dict["mean"].cpu() 82 | instance.std = state_dict["std"].cpu() 83 | 84 | return instance 85 | 86 | 87 | @dataclass 88 | class InMemoryDataLoader: 89 | """In-memory DataLoader using array/tensor slicing to generate whole batches at 90 | once instead of sample-by-sample. 91 | Source: https://discuss.pytorch.org/t/27014/6. 92 | 93 | Args: 94 | *tensors: List of arrays or tensors. Must all have the same length in 95 | dimension 0. 96 | collate_fn (Callable): Should accept variadic list of tensors and 97 | output a minibatch of data ready for model consumption. 98 | batch_size (int, optional): Usually 64, 128 or 256. Can be larger for test set 99 | loaders to speedup inference. Defaults to 64. 100 | shuffle (bool, optional): If True, shuffle the data *in-place* whenever an 101 | iterator is created from this object. Defaults to False. 102 | """ 103 | 104 | # each item must be indexable (usually torch.tensor, np.array or pd.Series) 105 | tensors: list[Tensor | np.ndarray] 106 | collate_fn: Callable 107 | batch_size: int = 64 108 | shuffle: bool = False 109 | 110 | def __post_init__(self): 111 | self.dataset_len = len(self.tensors[0]) 112 | if not all(len(t) == self.dataset_len for t in self.tensors): 113 | raise ValueError("All tensors must have the same length in dim 0") 114 | 115 | def __iter__(self) -> Iterator[tuple[Tensor, ...]]: 116 | self.indices = np.random.permutation(self.dataset_len) if self.shuffle else None 117 | self.current_idx = 0 118 | return self 119 | 120 | def __next__(self) -> tuple[Tensor, ...]: 121 | start_idx = self.current_idx 122 | if start_idx >= self.dataset_len: 123 | raise StopIteration 124 | 125 | end_idx = start_idx + self.batch_size 126 | 127 | if self.indices is None: # shuffle=False 128 | slices = (t[start_idx:end_idx] for t in self.tensors) 129 | else: 130 | idx = self.indices[start_idx:end_idx] 131 | slices = (t[idx] for t in self.tensors) 132 | 133 | batch = self.collate_fn(*slices) 134 | 135 | self.current_idx += self.batch_size 136 | return batch 137 | 138 | def __len__(self) -> int: 139 | """Get the number of batches in this data loader.""" 140 | n_batches, remainder = divmod(self.dataset_len, self.batch_size) 141 | return n_batches + bool(remainder) 142 | -------------------------------------------------------------------------------- /tests/test_wrenformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from sklearn.model_selection import train_test_split as split 5 | 6 | from aviary.utils import get_metrics, results_multitask, train_ensemble 7 | from aviary.wrenformer.data import df_to_in_mem_dataloader 8 | from aviary.wrenformer.model import Wrenformer 9 | 10 | 11 | @pytest.fixture 12 | def base_config(): 13 | return { 14 | "robust": True, 15 | "ensemble": 2, 16 | "run_id": 1, 17 | "data_seed": 42, 18 | "log": False, 19 | "sample": 1, 20 | "test_size": 0.2, 21 | } 22 | 23 | 24 | @pytest.fixture 25 | def model_architecture(): 26 | return { 27 | "d_model": 128, 28 | "n_attn_layers": 2, 29 | "n_attn_heads": 4, 30 | "trunk_hidden": (1024, 512), 31 | "out_hidden": (256, 128, 64), 32 | "embedding_aggregations": ("mean",), 33 | } 34 | 35 | 36 | @pytest.fixture 37 | def training_config(): 38 | return { 39 | "resume": False, 40 | "fine_tune": None, 41 | "transfer": None, 42 | "optim": "AdamW", 43 | "learning_rate": 3e-4, 44 | "momentum": 0.9, 45 | "weight_decay": 1e-6, 46 | "batch_size": 128, 47 | "workers": 0, 48 | "device": "cuda" if torch.cuda.is_available() else "cpu", 49 | } 50 | 51 | 52 | def test_wrenformer_regression( 53 | df_matbench_phonons_wyckoff, base_config, model_architecture, training_config 54 | ): 55 | target_name = "last phdos peak" 56 | task = "regression" 57 | losses = ["L1"] 58 | epochs = 25 59 | model_name = "wrenformer-reg-test" 60 | input_col = "protostructure" 61 | embedding_type = "protostructure" 62 | 63 | task_dict = dict(zip([target_name], [task], strict=False)) 64 | loss_dict = dict(zip([target_name], losses, strict=False)) 65 | 66 | train_idx = list(range(len(df_matbench_phonons_wyckoff))) 67 | train_idx, test_idx = split( 68 | train_idx, 69 | random_state=base_config["data_seed"], 70 | test_size=base_config["test_size"], 71 | ) 72 | 73 | train_df = df_matbench_phonons_wyckoff.iloc[train_idx[0 :: base_config["sample"]]] 74 | test_df = df_matbench_phonons_wyckoff.iloc[test_idx] 75 | val_df = test_df # Using test set for validation 76 | 77 | data_loader_kwargs = dict( 78 | id_col="material_id", 79 | input_col=input_col, 80 | target_col=target_name, 81 | embedding_type=embedding_type, 82 | device=training_config["device"], 83 | ) 84 | 85 | train_loader = df_to_in_mem_dataloader( 86 | train_df, 87 | batch_size=training_config["batch_size"], 88 | shuffle=True, 89 | **data_loader_kwargs, 90 | ) 91 | 92 | val_loader = df_to_in_mem_dataloader( 93 | val_df, 94 | batch_size=training_config["batch_size"] * 16, 95 | shuffle=False, 96 | **data_loader_kwargs, 97 | ) 98 | 99 | setup_params = { 100 | "optim": training_config["optim"], 101 | "learning_rate": training_config["learning_rate"], 102 | "weight_decay": training_config["weight_decay"], 103 | "momentum": training_config["momentum"], 104 | "device": training_config["device"], 105 | } 106 | 107 | restart_params = { 108 | "resume": training_config["resume"], 109 | "fine_tune": training_config["fine_tune"], 110 | "transfer": training_config["transfer"], 111 | } 112 | 113 | n_targets = [1] # Regression task has 1 target 114 | 115 | model_params = { 116 | "task_dict": task_dict, 117 | "robust": base_config["robust"], 118 | "n_targets": n_targets, 119 | "n_features": train_loader.tensors[0][0].shape[-1], 120 | **model_architecture, 121 | } 122 | 123 | train_ensemble( 124 | model_class=Wrenformer, 125 | model_name=model_name, 126 | run_id=base_config["run_id"], 127 | ensemble_folds=base_config["ensemble"], 128 | epochs=epochs, 129 | train_loader=train_loader, 130 | val_loader=val_loader, 131 | log=base_config["log"], 132 | setup_params=setup_params, 133 | restart_params=restart_params, 134 | model_params=model_params, 135 | loss_dict=loss_dict, 136 | ) 137 | 138 | test_loader = df_to_in_mem_dataloader( 139 | test_df, 140 | batch_size=training_config["batch_size"] * 64, 141 | shuffle=False, 142 | **data_loader_kwargs, 143 | ) 144 | 145 | results_dict = results_multitask( 146 | model_class=Wrenformer, 147 | model_name=model_name, 148 | run_id=base_config["run_id"], 149 | ensemble_folds=base_config["ensemble"], 150 | test_loader=test_loader, 151 | robust=base_config["robust"], 152 | task_dict=task_dict, 153 | device=training_config["device"], 154 | eval_type="checkpoint", 155 | save_results=False, 156 | ) 157 | 158 | preds = results_dict[target_name]["preds"] 159 | targets = results_dict[target_name]["targets"] 160 | 161 | y_ens = np.mean(preds, axis=0) 162 | mae, rmse, r2 = get_metrics(targets, y_ens, task).values() 163 | 164 | assert len(targets) == len(test_df) 165 | assert r2 > 0.7 166 | assert mae < 150 167 | assert rmse < 300 168 | 169 | 170 | if __name__ == "__main__": 171 | pytest.main(["-v", __file__]) 172 | -------------------------------------------------------------------------------- /examples/inputs/poscar_to_df.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import glob 3 | import os 4 | 5 | import pandas as pd 6 | from pymatgen.analysis.prototypes import ( 7 | count_wyckoff_positions, 8 | get_protostructure_label_from_spglib, 9 | ) 10 | from pymatgen.core import Composition, Structure 11 | from tqdm import tqdm 12 | 13 | tqdm.pandas() # prime progress_map functionality 14 | 15 | final_dir = os.path.dirname(os.path.abspath(__file__)) 16 | 17 | idx_list = [] 18 | structs = [] 19 | E_vasp_list = [] 20 | meta_list = [] 21 | ht_paths = [] 22 | 23 | for filepath in glob.glob(final_dir + "/raw/*.poscar", recursive=True): 24 | task_id = filepath.split("/")[-1].split(".")[0] 25 | 26 | with open(filepath) as file_contents: 27 | file_str = file_contents.read() 28 | struct = Structure.from_str(file_str, fmt="poscar") 29 | lines = file_str.splitlines() 30 | 31 | num = lines[6].split() 32 | E_vasp_per_atom = float(lines[0].split()[0]) / sum(int(a) for a in num) 33 | 34 | ht_path = lines[0].split()[1] 35 | meta_data = "[" + lines[0].split("[")[-1] 36 | 37 | idx_list.append(task_id) 38 | structs.append(struct) 39 | E_vasp_list.append(E_vasp_per_atom) 40 | ht_paths.append(ht_path) 41 | meta_list.append(meta_data) 42 | 43 | 44 | df = pd.DataFrame() 45 | df["material_id"] = idx_list 46 | df["final_structure"] = structs 47 | df["E_vasp_per_atom"] = E_vasp_list 48 | df["meta_data"] = meta_list 49 | df["ht_data"] = ht_paths 50 | 51 | df["E_vasp_per_atom"] = df["E_vasp_per_atom"].astype(float) 52 | 53 | print("\n~~~~ LOAD DATA ~~~~") 54 | # Remove duplicated ID's keeping lowest energy 55 | # NOTE this is a bug in TAATA we really shouldn't have to do this 56 | df = df.sort_values(by=["material_id", "E_vasp_per_atom"], ascending=True) 57 | df = df.drop_duplicates(subset="material_id", keep="first") 58 | 59 | 60 | # %% 61 | # Count number of datapoints 62 | print(f"Number of points in dataset: {len(df)}") 63 | 64 | # takes ~ 15mins 65 | df["protostructure"] = df.final_structure.progress_map( 66 | get_protostructure_label_from_spglib 67 | ) 68 | 69 | # lattice, sites = zip(*df.final_structure.progress_map(get_cgcnn_input)) 70 | 71 | df["composition"] = df.final_structure.map(lambda x: x.composition.reduced_formula) 72 | df["nelements"] = df.final_structure.map(lambda x: len(x.composition.elements)) 73 | df["volume"] = df.final_structure.map(lambda x: x.volume) 74 | df["n_sites"] = df.final_structure.map(lambda x: x.num_sites) 75 | 76 | # df["lattice"] = lattice 77 | # df["sites"] = sites 78 | 79 | 80 | # %% 81 | # Calculate Formation Enthalpy 82 | df_el = df[df["nelements"] == 1] 83 | df_el = df_el.sort_values(by=["composition", "E_vasp_per_atom"], ascending=True) 84 | el_refs = { 85 | c.composition.elements[0]: e 86 | for c, e in zip(df_el.final_structure, df_el.E_vasp_per_atom, strict=False) 87 | } 88 | 89 | 90 | def get_formation_energy(comp: str, energy: float, el_refs: dict[str, float]) -> float: 91 | """Compute formation energy per atom for formula/energy pair and elemental 92 | references. 93 | 94 | Args: 95 | comp (str): Formula string 96 | energy (float): energy in eV/atom 97 | el_refs (dict[str, float]): elemental reference energies in eV/atom 98 | 99 | Returns: 100 | float: formation energy per atom in eV 101 | """ 102 | c = Composition(comp) 103 | # NOTE our references use energies_per_atom for energy 104 | ref_e = sum(c[el] * el_refs[el] for el in c.elements) 105 | return energy - ref_e / c.num_atoms 106 | 107 | 108 | df["E_f"] = [ 109 | get_formation_energy(row.composition, row.E_vasp_per_atom, el_refs=el_refs) 110 | for row in df.itertuples() 111 | ] 112 | 113 | 114 | # %% 115 | # Remove invalid Wyckoff Sequences 116 | df["n_wyckoff"] = df["protostructure"].map(count_wyckoff_positions) 117 | 118 | df = df.query("'Invalid' not in wyckoff") 119 | print(f"Valid Wyckoff representation {len(df)}") 120 | 121 | 122 | # %% 123 | # Drop duplicated wyckoff representations 124 | df = df.sort_values(by=["protostructure", "E_vasp_per_atom"], ascending=True) 125 | df_wyk = df.drop_duplicates(["protostructure"], keep="first") 126 | print(f"Lowest energy unique wyckoff sequences: {len(df_wyk)}") 127 | 128 | 129 | # %% 130 | # NOTE searching after having dropped wyckoff duplicates will remove 131 | # some scaled duplicates. This value may still contain duplicates. 132 | df_wyk = df_wyk.sort_values(by=["composition", "E_vasp_per_atom"], ascending=True) 133 | df_comp = df_wyk.drop_duplicates("composition", keep="first") 134 | print(f"Lowest energy polymorphs only: {len(df_comp)}") 135 | 136 | 137 | # %% 138 | # Clean the data 139 | 140 | print("\n~~~~ DATA CLEANING ~~~~") 141 | print(f"Total systems: {len(df_wyk)}") 142 | 143 | wyk_lim = 16 144 | df_wyk = df_wyk[df_wyk["n_wyckoff"] <= wyk_lim] 145 | print(f"Less than {wyk_lim} Wyckoff species in cell: {len(df_wyk)}") 146 | 147 | cell_lim = 64 148 | df_wyk = df_wyk[df_wyk["n_sites"] <= cell_lim] 149 | print(f"Less than {cell_lim} atoms in cell: {len(df_wyk)}") 150 | 151 | vol_lim = 500 152 | df_wyk = df_wyk[df_wyk["volume"] / df_wyk["n_sites"] < vol_lim] 153 | print(f"Less than {vol_lim} A^3 per site: {len(df_wyk)}") 154 | 155 | fields = ["material_id", "composition", "E_f", "protostructure"] # , "lattice", "sites"] 156 | 157 | df_wyk[["material_id", "composition", "E_f", "protostructure"]].to_csv( 158 | final_dir + "/examples.csv", index=False 159 | ) 160 | 161 | df_wyk["structure"] = df_wyk["final_structure"].map(lambda x: x.as_dict()) 162 | 163 | df_wyk[["material_id", "composition", "E_f", "protostructure", "structure"]].to_json( 164 | final_dir + "/examples.json" 165 | ) 166 | -------------------------------------------------------------------------------- /aviary/segments.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | from torch import LongTensor, Tensor, nn 5 | 6 | from aviary.networks import SimpleNetwork 7 | from aviary.scatter import scatter_reduce 8 | 9 | 10 | class AttentionPooling(nn.Module): 11 | """Softmax attention layer. Currently unused.""" 12 | 13 | def __init__(self, gate_nn: nn.Module, message_nn: nn.Module) -> None: 14 | """Initialize softmax attention layer. 15 | 16 | Args: 17 | gate_nn (nn.Module): Neural network to calculate attention scalars 18 | message_nn (nn.Module): Neural network to evaluate message updates 19 | """ 20 | super().__init__() 21 | self.gate_nn = gate_nn 22 | self.message_nn = message_nn 23 | 24 | def forward(self, x: Tensor, index: Tensor) -> Tensor: 25 | """Forward pass. 26 | 27 | Args: 28 | x (Tensor): Input features for nodes 29 | index (Tensor): The indices for scatter operation over nodes 30 | 31 | Returns: 32 | Tensor: Output features for nodes 33 | """ 34 | gate = self.gate_nn(x) 35 | 36 | gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index] 37 | gate = gate.exp() 38 | gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10 39 | 40 | x = self.message_nn(x) 41 | return scatter_reduce(gate * x, index, dim=0, reduce="sum") 42 | 43 | def __repr__(self) -> str: 44 | gate_nn, message_nn = self.gate_nn, self.message_nn 45 | return f"{type(self).__name__}({gate_nn=}, {message_nn=})" 46 | 47 | 48 | class WeightedAttentionPooling(nn.Module): 49 | """Weighted softmax attention layer.""" 50 | 51 | def __init__(self, gate_nn: nn.Module, message_nn: nn.Module) -> None: 52 | """Initialize softmax attention layer. 53 | 54 | Args: 55 | gate_nn (nn.Module): Neural network to calculate attention scalars 56 | message_nn (nn.Module): Neural network to evaluate message updates 57 | """ 58 | super().__init__() 59 | self.gate_nn = gate_nn 60 | self.message_nn = message_nn 61 | self.pow = torch.nn.Parameter(torch.randn(1)) 62 | 63 | def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor: 64 | """Forward pass. 65 | 66 | Args: 67 | x (Tensor): Input features for nodes 68 | index (Tensor): The indices for scatter operation over nodes 69 | weights (Tensor): The weights to assign to nodes 70 | 71 | Returns: 72 | Tensor: Output features for nodes 73 | """ 74 | gate = self.gate_nn(x) 75 | 76 | gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index] 77 | gate = (weights**self.pow) * gate.exp() 78 | gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10 79 | 80 | x = self.message_nn(x) 81 | return scatter_reduce(gate * x, index, dim=0, reduce="sum") 82 | 83 | def __repr__(self) -> str: 84 | pow, gate_nn, message_nn = float(self.pow), self.gate_nn, self.message_nn 85 | return f"{type(self).__name__}({pow=:.3}, {gate_nn=}, {message_nn=})" 86 | 87 | 88 | class MessageLayer(nn.Module): 89 | """MessageLayer to propagate information between nodes in graph.""" 90 | 91 | def __init__( 92 | self, 93 | msg_fea_len: int, 94 | num_msg_heads: int, 95 | msg_gate_layers: Sequence[int], 96 | msg_net_layers: Sequence[int], 97 | ) -> None: 98 | """Initialise MessageLayer. 99 | 100 | Args: 101 | msg_fea_len (int): Number of input features 102 | num_msg_heads (int): Number of attention heads 103 | msg_gate_layers (list[int]): List of hidden layer sizes for gate network 104 | msg_net_layers (list[int]): List of hidden layer sizes for message network 105 | """ 106 | super().__init__() 107 | 108 | self._repr = ( 109 | f"{self._get_name()}({msg_fea_len=}, {num_msg_heads=}, {msg_gate_layers=}, " 110 | f"{msg_net_layers=})" 111 | ) 112 | 113 | # Pooling and Output 114 | self.pooling = nn.ModuleList( 115 | WeightedAttentionPooling( 116 | gate_nn=SimpleNetwork(2 * msg_fea_len, 1, msg_gate_layers), 117 | message_nn=SimpleNetwork(2 * msg_fea_len, msg_fea_len, msg_net_layers), 118 | ) 119 | for _ in range(num_msg_heads) 120 | ) 121 | 122 | def forward( 123 | self, 124 | node_weights: Tensor, 125 | node_prev_features: Tensor, 126 | self_idx: LongTensor, 127 | neighbor_idx: LongTensor, 128 | ) -> Tensor: 129 | """Forward pass. 130 | 131 | Args: 132 | node_weights (Tensor): The fractional weights of elements in their materials 133 | node_prev_features (Tensor): Node hidden features before message passing 134 | self_idx (LongTensor): Indices of the 1st element in each of the node pairs 135 | neighbor_idx (LongTensor): Indices of the 2nd element in each of the node 136 | pairs 137 | 138 | Returns: 139 | Tensor: node hidden features after message passing 140 | """ 141 | # construct the total features for passing 142 | node_nbr_weights = node_weights[neighbor_idx, :] 143 | msg_nbr_fea = node_prev_features[neighbor_idx, :] 144 | msg_self_fea = node_prev_features[self_idx, :] 145 | message = torch.cat([msg_self_fea, msg_nbr_fea], dim=1) 146 | 147 | # sum selectivity over the neighbors to get node updates 148 | head_features = [] 149 | for attn_head in self.pooling: 150 | out_msg = attn_head(message, index=self_idx, weights=node_nbr_weights) 151 | head_features.append(out_msg) 152 | 153 | # average the attention heads 154 | node_update = torch.stack(head_features).mean(dim=0) 155 | 156 | return node_update + node_prev_features 157 | 158 | def __repr__(self) -> str: 159 | return self._repr 160 | -------------------------------------------------------------------------------- /examples/inputs/examples.csv: -------------------------------------------------------------------------------- 1 | material_id,composition,E_f,protostructure 2 | c07b988125b3a0adc329b2449eeaac24c9e7def8,Hf,0.0,A_hP2_194_c:Hf 3 | 0060794138df3ba93c68007949b0a72d19cb4c62,Hf(ZnN2)3,0.6322286350000006,AB6C3_hR30_160_a_2b_b:Hf-N-Zn 4 | 004c70f544dca7f8e1c74c6c4ff87a1e355e7cd5,Hf(ZnN2)3,0.8708339350000012,AB6C3_hR30_148_a_f_bc:Hf-N-Zn 5 | 0059fd6dd1d17138c42bb5430fa0bc074053f0de,Hf(ZnN3)2,0.1256185166666688,AB6C2_mC18_12_a_ij_i:Hf-N-Zn 6 | 00a8d0f6f55191596770605314ff347995a3540d,Hf2Zn4N,0.2936619535714291,A2BC4_cF112_227_e_c_df:Hf-N-Zn 7 | 006197ce31efdef23c82f0f6cec08ac0b0febdce,Hf2Zn4N9,-0.037019788333333636,A2B9C4_oP60_60_d_c4d_2d:Hf-N-Zn 8 | 0019407b910a29730999dab65491b4e50624c219,Hf2Zn4N9,0.5619530450000001,A2B9C4_hP30_176_f_hi_bh:Hf-N-Zn 9 | 0039f3c461ce4a2cbf4c48b09bfc3c6b13a5c769,Hf2ZnN4,-0.600798328571428,A2B4C_oP28_61_c_2c_a:Hf-N-Zn 10 | 0079b9e3ab0e210ea4afe037a9faee8cd6577922,Hf2ZnN5,-0.09228186562499907,A2B5C_oP16_31_2a_5a_a:Hf-N-Zn 11 | 00021364446a637881257fd9ee912a422a6b1753,Hf2ZnN9,0.028860297916666333,A2B9C_oP48_62_d_c4d_c:Hf-N-Zn 12 | 00cca7ad273875ba80f1d437af797b6a01cc6ebd,Hf3ZnN13,-0.08990458382352706,A3B13C_mP34_13_fg_e6g_a:Hf-N-Zn 13 | 00ceac1a8da8e554218a84def07e5a6da9f106f3,Hf3ZnN4,-1.0282588499999985,A3B4C_mP16_6_3a3b_3a5b_2a:Hf-N-Zn 14 | 00ecadb66a8baab030675a817d2ee15f0002a4dc,Hf5(ZnN8)2,-0.0069458521739118595,A5B16C2_oC92_68_agh_4i_h:Hf-N-Zn 15 | 00fa7d75b7b2e1442df0a68071210c8d64b93ebe,Hf5ZnN3,-0.3175585194444448,A5B3C_tP18_125_am_cg_d:Hf-N-Zn 16 | 009ee0451531276f598976f46848fe7d590262ca,Hf6Zn2N13,-0.2216884392857139,A6B13C2_mC42_12_3i_a6i_i:Hf-N-Zn 17 | 00a3259f00f8a9aa503caae2045970d2f7b22e13,Hf7Zn5N19,0.3828095201612909,A7B19C5_oP31_47_ek2l_ai2j3x_cij:Hf-N-Zn 18 | 00237cf70206a5de77e5154ea72633514cdd4445,HfZn2N7,0.1671952525000009,AB7C2_mP20_11_e_3e2f_f:Hf-N-Zn 19 | 006a8e78e4bb5c1570bf70b47b24c0b855a5b5b5,HfZn3N2,0.9949698083333343,AB2C3_tI12_139_a_e_be:Hf-N-Zn 20 | 00bbba1d86168727008d4c1542b39d45d93d693c,HfZn3N4,0.11430601250000105,AB4C3_cP16_223_a_e_c:Hf-N-Zn 21 | 006047d5cfe39157f2904d2533bad74b2ddd3170,HfZnN,-0.36961197499999887,ABC_hP6_194_a_b_c:Hf-N-Zn 22 | 001a938f944fa33e899b81e3d0a909876587e7b0,HfZnN12,0.8577765142857157,AB12C_mC56_15_e_6f_e:Hf-N-Zn 23 | 000fe0f2a4730347df688bd0d3c75f4494a28dca,HfZnN2,-0.6694122624999999,AB2C_hP8_194_c_f_a:Hf-N-Zn 24 | 00af9b9350147b8c20498eb9cd65b63a3132c741,HfZnN2,-0.3108887624999994,AB2C_mC16_12_i_2i_g:Hf-N-Zn 25 | 6f7d6a970fde2a4243091aba720fc14ceec2854b,N2,0.0,A_cP8_205_c:N 26 | b428ccf0f8cd5488a3af5d72b74e117dc71aa575,Ti,0.0,A_hP2_194_c:Ti 27 | 0ac3201405a380e382377a71216b33cde892b161,Ti(ZnN2)2,0.44363870000000016,A4BC2_mC28_5_4c_c_abc:N-Ti-Zn 28 | 0abac62147275aa8199465764751eaa67c94d16c,Ti(ZnN2)2,0.4993669142857149,A4BC2_oP28_62_2cd_c_ac:N-Ti-Zn 29 | 0a0ae9c18de23e72bb764e10ef542de693321542,Ti2Zn2N5,-0.4770085805555553,A5B2C2_oI36_46_b2c_ab_c:N-Ti-Zn 30 | 0ab8cc7e5cc3d48fcc786f9060da5a8882374f44,Ti2Zn3N4,-0.4294422111111116,A4B2C3_mP36_14_4e_2e_3e:N-Ti-Zn 31 | 0ad7d4c22ba4414fbe161f27f42c112c6cb83b29,Ti2Zn3N7,0.6236076854166672,A7B2C3_tI24_139_aeg_e_be:N-Ti-Zn 32 | 0abb6669aa25624d8371765c5afa6b80ae20f30b,Ti2ZnN3,-0.8424399458333331,A3B2C_mC12_12_ai_i_b:N-Ti-Zn 33 | 0a2c4708d3c853ea3c4b8630027afe9b788ee2ce,Ti2ZnN5,-0.049384740624999424,A5B2C_oP32_58_3gh_eg_g:N-Ti-Zn 34 | 0a946dc02260d0d4b36ef2087e60b415a5fff800,Ti3ZnN5,-0.701213825,A5B3C_tP36_76_5a_3a_a:N-Ti-Zn 35 | 0a130b596d1ce7bfe317d9e546bbeacc703fef65,Ti3ZnN8,-0.23821160833333277,A8B3C_oI48_72_2jk_aj_b:N-Ti-Zn 36 | 0a646e174268b493bd67f58fdf52fead02f74a96,Ti4(ZnN3)3,-0.3996681890624991,A9B4C3_hR48_146_3b_4a_3a:N-Ti-Zn 37 | 0ab8ebf80d94a1ba826857a9e81cea8c7124d86e,Ti4Zn3N8,-0.2971007933333327,A8B4C3_mC60_12_8i_4i_3i:N-Ti-Zn 38 | 0ac8a65af75e8223b07dcd98c6958db7af6504cf,Ti4ZnN7,-0.5551503479166673,A7B4C_mC48_15_e3f_2f_e:N-Ti-Zn 39 | 0a5af69c21e6c2d9b004c2f51939790b10b19814,Ti5(ZnN3)3,-0.22798647205882272,A9B5C3_hP17_157_cd_bc_c:N-Ti-Zn 40 | 0a58ef8a37185d8abcbd6e1bad7eaf7e3d3d6016,Ti6ZnN8,-1.1345180866666658,A8B6C_hR45_148_cf_f_a:N-Ti-Zn 41 | 0a4a721fb8f0996350edd4fb224d89ce213f7f93,Ti7(ZnN)3,-0.47197025192307684,A3B7C3_oC52_36_3a_7a_3a:N-Ti-Zn 42 | 0ad8980cb8eb0693647f411732b2165b6a0fce0e,Ti7ZnN12,-0.34730283999999934,A12B7C_cI40_204_g_bc_a:N-Ti-Zn 43 | 0a458b5f2eeebeb0e4e55e1d1a17988ec87b7123,TiZn2N3,-0.04297332916666541,A3BC2_oC48_64_e2f_f_2f:N-Ti-Zn 44 | 0ac7f92ee7225e47dbdc25f6d60a10b116b0813b,TiZn2N3,0.5085712541666672,A3BC2_oC24_63_cg_c_e:N-Ti-Zn 45 | 0a3eb2074380f4abdec94d4a6171a099c9ecb30a,TiZn2N5,0.020599596875000792,A5BC2_oC32_63_c2f_c_f:N-Ti-Zn 46 | 0a2c3a235656a97d76f731b269982f92de9b37fd,TiZnN,-0.6217663083333331,ABC_hP6_164_d_c_d:N-Ti-Zn 47 | 0a2e1f340255f62a617974d778714e5fdd28a007,TiZnN2,2.0046756125,A2BC_cF32_227_c_b_a:N-Ti-Zn 48 | 0adeaf7a9c16772b66815cce19917ef0460affe8,TiZnN3,-0.014901634999999303,A3BC_cP40_198_2b_2a_2a:N-Ti-Zn 49 | 0a74297fb11027ba2a958da5ee31980b8876e1d4,TiZnN4,1.1665261166666676,A4BC_oC24_63_fg_a_c:N-Ti-Zn 50 | b3603cef0816d0e115adeafa4eee65c056dc7a26,Zn,0.0,A_hP2_194_c:Zn 51 | 0fd347fb29c23cad30d42852d7aa4ed906102b8a,Zr,0.0,A_hP2_194_c:Zr 52 | 0cd142efcf69ba63dc62760ecd72fb4f1e930115,Zr(ZnN)2,-0.2309529700000006,A2B2C_mC20_12_2i_2i_i:N-Zn-Zr 53 | 0c710ee221d7567cc0f1a06fba8231ae8261a181,Zr(ZnN)6,0.00016695769230778978,A6B6C_mC26_12_3i_3i_a:N-Zn-Zr 54 | 0c5264c980964b88ae8a7c2c18d5a61daff481b1,Zr(ZnN2)2,0.7521056285714289,A4B2C_hP14_186_bc_ab_b:N-Zn-Zr 55 | 0c7147c05a270b9b2744576b0bfb69c8391ee1e4,Zr(ZnN3)2,0.44643896111111125,A6B2C_hP9_162_k_c_b:N-Zn-Zr 56 | 0c429f98b79c9522a11d51fbfae484004f899422,Zr(ZnN5)2,0.478535934615385,A10B2C_mP26_11_2e4f_f_e:N-Zn-Zr 57 | 0c27147781fc9c64cbcf45acc1575a0bf8070841,Zr2Zn2N3,-0.5270690678571421,A3B2C2_mP14_14_ae_e_e:N-Zn-Zr 58 | 0c450e6f9c256e33ef452c98dd20917fb6484317,Zr2Zn3N4,-0.3812128222222215,A4B3C2_mC36_15_2f_ef_f:N-Zn-Zr 59 | 0c9cbed1e5d755a9666db36f742004c14bf780bd,Zr2ZnN5,-0.5414344906249999,A5BC2_mC32_9_5a_a_2a:N-Zn-Zr 60 | 0cd1f1670176b03cc3d61935d9a6882c9508e33a,Zr3ZnN6,-0.7395423050000005,A6BC3_mP20_11_6e_e_3e:N-Zn-Zr 61 | 0c882a0621189d18ff1350e3bed4114868663acf,Zr4ZnN5,-0.9191486925000003,A5BC4_oP20_59_a2e_a_2be:N-Zn-Zr 62 | 0c5ffa8e2b3126eba02901c37b31e8faf3b5377c,Zr6Zn4N13,-0.21249625326086985,A13B4C6_cI46_217_ag_c_d:N-Zn-Zr 63 | 0c9c74b6f48facd507c4e41d9df797a2017634e7,ZrZn2N3,0.1968950041666666,A3B2C_oC24_63_cf_2c_c:N-Zn-Zr 64 | 0c13902f8200732555ad1603e954a726d4d24d85,ZrZn3N2,-0.11572319166666656,A2B3C_oI24_72_j_ce_b:N-Zn-Zr 65 | 0c9cf664171d7e10527033cda1551694141207ed,ZrZnN2,-0.2964831375000001,A2BC_oP4_47_k_a_d:N-Zn-Zr 66 | 0c0b4fd84c8f4e075436d7ec34bfb3fad5e610be,ZrZnN3,-0.16417793499999878,A3BC_hP30_185_ab2c_ab_c:N-Zn-Zr 67 | 0c2fc1da358bcddfb4baf41b5dfd7285a13b766a,ZrZnN3,0.23257766500000088,A3BC_mC20_15_ef_c_e:N-Zn-Zr 68 | 0c5ef0e6b20d6e34ae388d640797f032a15df918,ZrZnN3,0.7506590650000007,A3BC_cP5_221_c_b_a:N-Zn-Zr 69 | 0c5e41fdc8fba8e90b48cbe4752dc2c4a2805f21,ZrZnN3,0.774344365000001,A3BC_tI20_140_bh_d_c:N-Zn-Zr 70 | -------------------------------------------------------------------------------- /aviary/roost/data.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from functools import cache 3 | from typing import Any 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from pymatgen.core import Composition 9 | from torch import LongTensor, Tensor 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class CompositionData(Dataset): 14 | """Dataset class for the Roost composition model.""" 15 | 16 | def __init__( 17 | self, 18 | df: pd.DataFrame, 19 | task_dict: dict[str, str], 20 | inputs: str = "composition", 21 | identifiers: Sequence[str] = ("material_id", "composition"), 22 | ): 23 | """Data class for Roost models. 24 | 25 | Args: 26 | df (pd.DataFrame): Pandas dataframe holding input and target values. 27 | task_dict (dict[str, "regression" | "classification"]): Map from target 28 | names to task type. 29 | elem_embedding (str, optional): One of "matscholar200", "cgcnn92", 30 | "megnet16", "onehot112" or path to a file with custom element 31 | embeddings. Defaults to "matscholar200". 32 | inputs (str, optional): df column name holding material compositions. 33 | Defaults to "composition". 34 | identifiers (list, optional): df columns for distinguishing data points. 35 | Will be copied over into the model's output CSV. Defaults to 36 | ["material_id", "composition"]. 37 | """ 38 | if len(identifiers) != 2: 39 | raise AssertionError("Two identifiers are required") 40 | 41 | self.inputs = inputs 42 | self.task_dict = task_dict 43 | self.identifiers = list(identifiers) 44 | self.df = df 45 | 46 | self.n_targets = [] 47 | for target, task in self.task_dict.items(): 48 | if task == "regression": 49 | self.n_targets.append(1) 50 | elif task == "classification": 51 | n_classes = np.max(self.df[target].values) + 1 52 | self.n_targets.append(n_classes) 53 | 54 | def __len__(self) -> int: 55 | return len(self.df) 56 | 57 | def __repr__(self) -> str: 58 | df_repr = f"cols=[{', '.join(self.df.columns)}], len={len(self.df)}" 59 | return f"{type(self).__name__}({df_repr}, task_dict={self.task_dict})" 60 | 61 | # Cache data for faster training 62 | @cache # noqa: B019 63 | def __getitem__(self, idx: int): 64 | """Get an entry out of the Dataset. 65 | 66 | Args: 67 | idx (int): index of entry in Dataset 68 | 69 | Returns: 70 | tuple: containing 71 | - tuple[Tensor, Tensor, LongTensor, LongTensor]: Roost model inputs 72 | - list[Tensor | LongTensor]: regression or classification targets 73 | - list[str | int]: identifiers like material_id, composition 74 | """ 75 | row = self.df.iloc[idx] 76 | composition = row[self.inputs] 77 | material_ids = row[self.identifiers].to_list() 78 | 79 | comp_dict = Composition(composition).fractional_composition 80 | weights = list(comp_dict.values()) 81 | weights = np.atleast_2d(weights).T / np.sum(weights) 82 | elem_fea = [elem.Z for elem in comp_dict] 83 | 84 | n_elems = len(comp_dict) 85 | self_idx = [] 86 | nbr_idx = [] 87 | for elem_idx in range(n_elems): 88 | self_idx += [elem_idx] * n_elems 89 | nbr_idx += list(range(n_elems)) 90 | 91 | # convert all data to tensors 92 | elem_weights = Tensor(weights) 93 | elem_fea = LongTensor(elem_fea) 94 | self_idx = LongTensor(self_idx) 95 | nbr_idx = LongTensor(nbr_idx) 96 | 97 | targets = [] 98 | for target in self.task_dict: 99 | if self.task_dict[target] == "regression": 100 | targets.append(Tensor([row[target]])) 101 | elif self.task_dict[target] == "classification": 102 | targets.append(LongTensor([row[target]])) 103 | 104 | return ( 105 | (elem_weights, elem_fea, self_idx, nbr_idx), 106 | targets, 107 | *material_ids, 108 | ) 109 | 110 | 111 | def collate_batch( 112 | samples: tuple[ 113 | tuple[Tensor, Tensor, LongTensor, LongTensor], 114 | list[Tensor | LongTensor], 115 | list[str | int], 116 | ], 117 | ) -> tuple[Any, ...]: 118 | """Collate a list of data and return a batch for predicting crystal properties. 119 | 120 | Args: 121 | samples (list): list of tuples for each data point where each tuple contains: 122 | (elem_fea, nbr_fea, nbr_idx, target) 123 | - elem_fea (Tensor): Atom hidden features before convolution 124 | - self_idx (LongTensor): Indices of the atom's self 125 | - nbr_idx (LongTensor): Indices of M neighbors of each atom 126 | - target (Tensor | LongTensor): target values containing floats for 127 | regression or integers as class labels for classification 128 | - cif_id: str or int 129 | 130 | Returns: 131 | tuple[ 132 | tuple[Tensor, Tensor, LongTensor, LongTensor, LongTensor]: batched Roost 133 | model inputs, 134 | tuple[Tensor | LongTensor]: Target values for different tasks, 135 | # TODO this last tuple is unpacked how to do type hint? 136 | *tuple[str | int]: Identifiers like material_id, composition 137 | ] 138 | """ 139 | # define the lists 140 | batch_elem_weights = [] 141 | batch_elem_fea = [] 142 | batch_self_idx = [] 143 | batch_nbr_idx = [] 144 | crystal_elem_idx = [] 145 | batch_targets = [] 146 | batch_cry_ids = [] 147 | 148 | cry_base_idx = 0 149 | for idx, (inputs, target, *cry_ids) in enumerate(samples): 150 | elem_weights, elem_fea, self_idx, nbr_idx = inputs 151 | 152 | n_sites = elem_fea.shape[0] # number of atoms for this crystal 153 | 154 | # batch the features together 155 | batch_elem_weights.append(elem_weights) 156 | batch_elem_fea.append(elem_fea) 157 | 158 | # mappings from bonds to atoms 159 | batch_self_idx.append(self_idx + cry_base_idx) 160 | batch_nbr_idx.append(nbr_idx + cry_base_idx) 161 | 162 | # mapping from atoms to crystals 163 | crystal_elem_idx.append(torch.tensor([idx] * n_sites)) 164 | 165 | # batch the targets and ids 166 | batch_targets.append(target) 167 | batch_cry_ids.append(cry_ids) 168 | 169 | # increment the id counter 170 | cry_base_idx += n_sites 171 | 172 | return ( 173 | ( 174 | torch.cat(batch_elem_weights, dim=0), 175 | torch.cat(batch_elem_fea, dim=0), 176 | torch.cat(batch_self_idx, dim=0), 177 | torch.cat(batch_nbr_idx, dim=0), 178 | torch.cat(crystal_elem_idx), 179 | ), 180 | tuple( 181 | torch.stack(b_target, dim=0) for b_target in zip(*batch_targets, strict=False) 182 | ), 183 | *zip(*batch_cry_ids, strict=False), 184 | ) 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |