├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── LICENSE ├── README.md ├── data ├── datasets │ └── roost │ │ ├── README.md │ │ ├── expt-non-metals.csv │ │ ├── mp-non-metals.csv │ │ └── oqmd-form-enthalpy.csv ├── el-embeddings │ ├── README.md │ ├── cgcnn-embedding.json │ ├── matscholar-embedding.json │ ├── megnet16-embedding.json │ └── onehot-embedding.json └── wp-embeddings │ ├── README.md │ ├── bra-alg-off.json │ └── spg-alg-off.json ├── environment-gpu-cu111.yml ├── examples ├── cgcnn-example.py ├── roost-example.py └── wren-example.py ├── requirements.txt ├── roost ├── cgcnn │ ├── data.py │ ├── model.py │ └── utils.py ├── core.py ├── pretrain │ ├── dist_data.py │ ├── dist_model.py │ ├── ele_data.py │ ├── ele_model.py │ ├── ener_data.py │ └── ener_model.py ├── roost │ ├── data.py │ └── model.py ├── segments.py ├── utils.py └── wren │ ├── allowed-wp-mult.json │ ├── data.py │ ├── model.py │ ├── relab.json │ ├── utils.py │ └── wp-params.json ├── setup.cfg ├── setup.py └── tests ├── data ├── roost-classification.csv └── roost-regression.csv ├── test_single_roost_classification.py └── test_single_roost_regression.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | paths: 6 | - '**.py' 7 | - '.github/workflows/ci.yml' 8 | pull_request: 9 | paths: 10 | - '**.py' 11 | - '.github/workflows/ci.yml' 12 | 13 | jobs: 14 | tests: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout repo 18 | uses: actions/checkout@v2 19 | 20 | - name: Setup Python 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: 3.8 24 | 25 | - name: Install dependencies 26 | run: | 27 | # install torch first because torch_scatter needs it 28 | pip install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 29 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cpu.html 30 | pip install -r requirements.txt 31 | pip install . 32 | 33 | - name: Run Tests 34 | run: | 35 | python -m pytest 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.pyc 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # vscode 108 | .vscode/ 109 | 110 | # notebooks 111 | .ipynb_checkpoints/ 112 | 113 | # exclude model outputs from github 114 | runs/ 115 | results/ 116 | models/ 117 | 118 | # exclude local slurm and plotting scripts 119 | process/ 120 | papers/ 121 | cds3/ 122 | 123 | # exclude local data and cached graphs 124 | *.csv 125 | *.pkl 126 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install these hooks with `pre-commit install`. 2 | 3 | ci: 4 | autoupdate_schedule: quarterly 5 | 6 | default_stages: [commit] 7 | 8 | repos: 9 | - repo: https://github.com/PyCQA/isort 10 | rev: 5.10.1 11 | hooks: 12 | - id: isort 13 | 14 | - repo: https://github.com/PyCQA/flake8 15 | rev: 5.0.4 16 | hooks: 17 | - id: flake8 18 | 19 | - repo: https://github.com/asottile/pyupgrade 20 | rev: v2.38.2 21 | hooks: 22 | - id: pyupgrade 23 | args: [--py37-plus] 24 | 25 | - repo: https://github.com/pre-commit/pre-commit-hooks 26 | rev: v4.3.0 27 | hooks: 28 | - id: end-of-file-fixer 29 | - id: forbid-new-submodules 30 | - id: mixed-line-ending 31 | - id: trailing-whitespace 32 | 33 | - repo: https://github.com/codespell-project/codespell 34 | rev: v2.2.1 35 | hooks: 36 | - id: codespell 37 | stages: [commit, commit-msg] 38 | exclude_types: [csv, json] 39 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Goodall" 5 | given-names: "Rhys" 6 | orcid: "https://orcid.org/0000-0002-6589-1700" 7 | title: "roost" 8 | version: 0.0.2 9 | date-released: 2021-07-28 10 | url: "https://github.com/CompRhys/roost" 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019-2020 Rhys Goodall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Roost

2 |

Representation Learning from Stoichiometry

3 | 4 |

5 | 6 | ![License: MIT](https://img.shields.io/badge/License-MIT-green.svg) 7 | [![GitHub Repo Size](https://img.shields.io/github/repo-size/comprhys/roost?label=Repo+Size)](https://github.com/comprhys/roost/graphs/contributors) 8 | [![GitHub last commit](https://img.shields.io/github/last-commit/comprhys/roost?label=Last+Commit)](https://github.com/comprhys/roost/commits) 9 | [![Tests](https://github.com/CompRhys/roost/workflows/Tests/badge.svg)](https://github.com/CompRhys/roost/actions) 10 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/CompRhys/roost/main.svg)](https://results.pre-commit.ci/latest/github/CompRhys/roost/main) 11 | 12 |

13 | 14 | 15 |

:warning: UNSUPPORTED NOTICE :warning:

16 | Development of the roost has been moved from this repository to the https://github.com/CompRhys/aviary repository. 17 | 18 | ## Premise 19 | 20 | In materials discovery applications often we know the composition of trial materials but have little knowledge about the structure. 21 | 22 | Many current SOTA results within the field of machine learning for materials discovery are reliant on knowledge of the structure of the material. This means that such models can only be applied to systems that have undergone structural characterisation. As structural characterisation is a time-consuming process whether done experimentally or via the use of ab-initio methods the use of structures as our model inputs is a prohibitive bottleneck to many materials screening applications we would like to pursue. 23 | 24 | One approach for avoiding the structure bottleneck is to develop models that learn from the stoichiometry alone. In this work, we show that via a novel recasting of how we view the stoichiometry of a material we can leverage a message-passing neural network to learn materials properties whilst remaining agnostic to the structure. The proposed model exhibits increased sample efficiency compared to more widely used descriptor-based approaches. This work draws inspiration from recent progress in using graph-based methods for the study of small molecules and crystalline materials. 25 | 26 | ## Environment Setup 27 | 28 | To use `roost` you need to create an environment with the correct dependencies. The easiest way to get up and running it to use `Anaconda`. 29 | A `cudatoolkit=11.1` environment file is provided `environment-gpu-cu111.yml` allowing a working environment to be created with: 30 | 31 | ```bash 32 | conda env create -f environment-gpu-cu111.yml 33 | ``` 34 | 35 | If you are not using `cudatoolkit=11.1` or do not have access to a GPU this setup will not work for you. If so please check the following pages [PyTorch](https://pytorch.org/get-started/locally/), [PyTorch-Scatter](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) for how install the core packages and then install the remaining requirements as detailed in `requirements.txt`. 36 | 37 | The was developed and tested on Linux Mint 19.1 Tessa. The code should work on with other Operating Systems but it has not been tested for such use. 38 | 39 | ## Roost Setup 40 | 41 | Once you have setup an environment with the correct dependencies you can install `roost` using the following commands: 42 | 43 | ```bash 44 | conda activate roost 45 | git clone https://github.com/CompRhys/roost 46 | cd roost 47 | python setup.py sdist 48 | pip install -e . 49 | ``` 50 | 51 | This will install the library in an editable state allowing for advanced users to make changes as desired. 52 | 53 | ## Example Use 54 | 55 | In order to test your installation you can do so by running the following example from the top of your `roost` directory: 56 | 57 | ```sh 58 | cd /path/to/roost/ 59 | python examples/roost-example.py --train --evaluate --epochs 10 --tasks regression --targets Eg --losses L2 60 | ``` 61 | 62 | This command runs a default task for 10 epochs -- experimental band gap regression using the data from Zhou et al. (See `data/` folder for reference). This default task has been set up to work out of the box without any changes and to give a flavour of how the model can be used. The demo task should take less than a minute when a GPU is available are give a test set MAE of 0.42-0.45 eV after 10 epochs. 63 | 64 | If you want to use your own data set on a regression task this can be done with: 65 | 66 | ```sh 67 | python examples/roost-example.py --data-path /path/to/your/data/data.csv --train --tasks [regression/classification, ...] --targets [
, ..] --losses [L1/L2, ..] 68 | ``` 69 | 70 | You can then test your model with: 71 | 72 | ```sh 73 | python examples/roost-example.py --test-path /path/to/testset.csv --evaluate --tasks [regression/classification, ...] --targets [
, ..] --losses [L1/L2, ..] 74 | ``` 75 | 76 | The model takes input in the form csv files with materials-ids, composition strings and target values as the columns. 77 | 78 | | material-id | composition | target | 79 | | ----------- | ----------- | ------ | 80 | | foo-1 | Fe2O3 | 2.3 | 81 | | foo-2 | La2CuO4 | 4.3 | 82 | 83 | Basic hints about more advanced use of the model (i.e. classification, robust losses, ensembles, tensorboard logging etc..) 84 | are available via the command: 85 | 86 | ```sh 87 | python examples/roost-example.py --help 88 | ``` 89 | 90 | This will output the various command-line flags that can be used to control the code. 91 | 92 | ## Cite This Work 93 | 94 | If you use this code please cite our work for which this model was built: 95 | 96 | Predicting materials properties without crystal structure: Deep representation learning from stoichiometry. [[Paper]](https://doi.org/10.1038/s41467-020-19964-7) [[arXiv](https://arxiv.org/abs/1910.00617)] 97 | 98 | ```tex 99 | @article{goodall2020predicting, 100 | title={Predicting materials properties without crystal structure: Deep representation learning from stoichiometry}, 101 | author={Goodall, Rhys EA and Lee, Alpha A}, 102 | journal={Nature Communications}, 103 | volume={11}, 104 | number={1}, 105 | pages={1--9}, 106 | year={2020}, 107 | publisher={Nature Publishing Group} 108 | } 109 | ``` 110 | 111 | ## Work Featuring Roost 112 | 113 | Work using Roost as presented: 114 | 115 | * A critical examination of compound stability predictions from machine-learned formation energies [[Paper]](https://www.nature.com/articles/s41524-020-00362-y) [[arXiv]](https://arxiv.org/abs/2001.10591) 116 | 117 | * Active learning based generative design for the discovery of wide bandgap materials. [[arXiv]](https://arxiv.org/abs/2103.00608) 118 | 119 | * MaterialsAtlas formation energy WebApp - 120 | 121 | Work building-on/using-parts-of the code shared here: 122 | 123 | * Predicting the Outcomes of Material Syntheses with Deep Learning [[Paper]](https://pubs.acs.org/doi/abs/10.1021/acs.chemmater.0c03885) 124 | 125 | * Compositionally restricted attention-based network for materials property predictions [[Paper]](https://www.nature.com/articles/s41524-021-00545-1) 126 | 127 | * Materials Representation and Transfer Learning for Multi-Property Prediction [[arXiv]](https://arxiv.org/abs/2106.02225) 128 | 129 | If you have used Roost in your work please contact me and I will add your paper here. 130 | 131 | ## Disclaimer 132 | 133 | This is research code shared without support or any guarantee on its quality. However, please do raise an issue or submit a pull request if you spot something wrong or that could be improved and I will try my best to solve it. 134 | -------------------------------------------------------------------------------- /data/datasets/roost/README.md: -------------------------------------------------------------------------------- 1 | # Dataset sources 2 | 3 | ## OQMD 4 | `oqmd-form-enthalpy.csv` - [The Open Quantum Materials Database (OQMD)](http://oqmd.org/) 5 | 6 | ## Materials Project 7 | `mp-non-metals.csv` - [The Materials Project (MP)](https://materialsproject.org/) 8 | 9 | ## Experimental Bandgaps 10 | `expt-non-metals.csv` - [Predicting the Band Gaps of Inorganic Solids by Machine Learning](https://doi.org/10.1021/acs.jpclett.8b00124) 11 | -------------------------------------------------------------------------------- /data/el-embeddings/README.md: -------------------------------------------------------------------------------- 1 | # Element - Embeddings 2 | 3 | ## CGCNN 4 | 5 | The following paper describes the details of the CGCNN framework, information about how the one-hot atom embedding was constructed is available in the supplementary materials: 6 | 7 | [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://link.aps.org/doi/10.1103/PhysRevLett.120.145301) 8 | 9 | ## MEGnet 10 | 11 | The following paper describes the details of the MEGnet framework, the embedding is generated from the atomic weights through one MEGnet layer on a training task to predict the computed formation energies of ∼69,000 materials from the [Materials Project](https://materialsproject.org/) Database: 12 | 13 | [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/abs/1812.05055) 14 | 15 | ## MatScholar 16 | 17 | This is an experimental NLP embedding based on the data mining of definite compositions and structure prototypes. 18 | 19 | [Unsupervised word embeddings capture latent knowledge from materials science literature](https://www.nature.com/articles/s41586-019-1335-8) 20 | 21 | ## One-hot 22 | 23 | This is a simple one-hot encoding for the elements 24 | -------------------------------------------------------------------------------- /data/wp-embeddings/README.md: -------------------------------------------------------------------------------- 1 | # Wyckoff - Embeddings 2 | 3 | ## wyk-ohe 4 | 5 | A 1731 dimension OHE of possible wyckoff positions. 6 | 7 | ## spg-letter 8 | 9 | A 230 + 27 dimension embedding consisting of a OHE of the spacegroup and a OHE of the wyckoff position letter. 10 | 11 | ## alg 12 | 13 | A 185 dimension embedding of the sum of whether sites in a wyckoff position sit on lines, planes or in volumes. 14 | 15 | ## alg-off 16 | 17 | A 185 + 248 dimension embedding of the sum of whether sites in a wyckoff position sit on lines, planes or in 18 | volumes and the specific offset of that wyckoff position within the unit cell. 19 | 20 | ## bra-alg-off 21 | 22 | A 6 + 5 + 185 + 248 dimension embedding of the crystal system, unit cell centering, the sum of whether sites 23 | in a wyckoff position sit on lines, planes or in volumes, and the specific offset of that wyckoff position within 24 | the unit cell. 25 | 26 | ## spg-alg-off 27 | 28 | A 230 + 185 + 248 dimension embedding consisting of a OHE of the spacegroup, the sum of whether sites 29 | in a wyckoff position sit on lines, planes or in volumes, and the specific offset of that wyckoff position within 30 | the unit cell. 31 | -------------------------------------------------------------------------------- /environment-gpu-cu111.yml: -------------------------------------------------------------------------------- 1 | name: roost 2 | channels: 3 | - pytorch 4 | - pyg 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=4.5=1_gnu 11 | - absl-py=0.13.0=py37h06a4308_0 12 | - aiohttp=3.7.4=py37h27cfd23_1 13 | - async-timeout=3.0.1=py37h06a4308_0 14 | - attrs=21.2.0=pyhd3eb1b0_0 15 | - blas=1.0=mkl 16 | - blinker=1.4=py37h06a4308_0 17 | - brotlipy=0.7.0=py37h5e8e339_1001 18 | - c-ares=1.17.1=h27cfd23_0 19 | - ca-certificates=2021.7.5=h06a4308_1 20 | - cachetools=4.2.2=pyhd3eb1b0_0 21 | - certifi=2021.5.30=py37h06a4308_0 22 | - cffi=1.14.6=py37hc58025e_0 23 | - chardet=3.0.4=py37h06a4308_1003 24 | - charset-normalizer=2.0.0=pyhd8ed1ab_0 25 | - click=8.0.1=pyhd3eb1b0_0 26 | - colorama=0.4.4=pyh9f0ad1d_0 27 | - coverage=5.5=py37h27cfd23_2 28 | - cryptography=3.4.7=py37h5d9358c_0 29 | - cudatoolkit=11.1.74=h6bb024c_0 30 | - cython=0.29.24=py37h295c915_0 31 | - decorator=4.4.2=py_0 32 | - future=0.18.2=py37h89c1867_3 33 | - google-auth=1.33.0=pyhd3eb1b0_0 34 | - google-auth-oauthlib=0.4.1=py_2 35 | - googledrivedownloader=0.4=pyhd3deb0d_1 36 | - grpcio=1.36.1=py37h2157cd5_1 37 | - idna=3.1=pyhd3deb0d_0 38 | - importlib-metadata=4.6.4=py37h06a4308_0 39 | - intel-openmp=2021.3.0=h06a4308_3350 40 | - jinja2=3.0.1=pyhd8ed1ab_0 41 | - joblib=1.0.1=pyhd8ed1ab_0 42 | - ld_impl_linux-64=2.35.1=h7274673_9 43 | - libblas=3.9.0=11_linux64_mkl 44 | - libcblas=3.9.0=11_linux64_mkl 45 | - libffi=3.3=he6710b0_2 46 | - libgcc-ng=9.3.0=h5101ec6_17 47 | - libgfortran-ng=11.1.0=h69a702a_8 48 | - libgfortran5=11.1.0=h6c583b3_8 49 | - libgomp=9.3.0=h5101ec6_17 50 | - liblapack=3.9.0=11_linux64_mkl 51 | - libprotobuf=3.17.2=h4ff587b_1 52 | - libstdcxx-ng=9.3.0=hd4cf53a_17 53 | - libuv=1.40.0=h7b6447c_0 54 | - markdown=3.3.4=py37h06a4308_0 55 | - markupsafe=2.0.1=py37h5e8e339_0 56 | - mkl=2021.3.0=h06a4308_520 57 | - multidict=5.1.0=py37h27cfd23_2 58 | - ncurses=6.2=he6710b0_1 59 | - networkx=2.5.1=pyhd8ed1ab_0 60 | - ninja=1.10.2=h4bd325d_0 61 | - numpy=1.20.3=py37h038b26d_1 62 | - oauthlib=3.1.1=pyhd3eb1b0_0 63 | - openssl=1.1.1l=h7f8727e_0 64 | - pandas=1.3.0=py37h219a48f_0 65 | - pip=21.0.1=py37h06a4308_0 66 | - protobuf=3.17.2=py37h295c915_0 67 | - pyasn1=0.4.8=pyhd3eb1b0_0 68 | - pyasn1-modules=0.2.8=py_0 69 | - pycparser=2.20=pyh9f0ad1d_2 70 | - pyg=2.0.0=py37_torch_1.9.0_cu111 71 | - pyjwt=2.1.0=py37h06a4308_0 72 | - pyopenssl=20.0.1=pyhd8ed1ab_0 73 | - pyparsing=2.4.7=pyh9f0ad1d_0 74 | - pysocks=1.7.1=py37h89c1867_3 75 | - python=3.7.10=h12debd9_4 76 | - python-dateutil=2.8.2=pyhd8ed1ab_0 77 | - python-louvain=0.15=pyhd3deb0d_0 78 | - python_abi=3.7=2_cp37m 79 | - pytorch=1.9.0=py3.7_cuda11.1_cudnn8.0.5_0 80 | - pytorch-cluster=1.5.9=py37_torch_1.9.0_cu111 81 | - pytorch-scatter=2.0.8=py37_torch_1.9.0_cu111 82 | - pytorch-sparse=0.6.12=py37_torch_1.9.0_cu111 83 | - pytorch-spline-conv=1.2.1=py37_torch_1.9.0_cu111 84 | - pytz=2021.1=pyhd8ed1ab_0 85 | - pyyaml=5.4.1=py37h5e8e339_0 86 | - readline=8.1=h27cfd23_0 87 | - requests=2.26.0=pyhd8ed1ab_0 88 | - requests-oauthlib=1.3.0=py_0 89 | - rsa=4.7.2=pyhd3eb1b0_1 90 | - scikit-learn=0.24.2=py37h18a542f_0 91 | - scipy=1.6.3=py37h29e03ee_0 92 | - setuptools=52.0.0=py37h06a4308_0 93 | - six=1.16.0=pyh6c4a22f_0 94 | - sleef=3.5.1=h7f98852_1 95 | - sqlite=3.36.0=hc218d9a_0 96 | - tensorboard=2.5.0=py_0 97 | - tensorboard-plugin-wit=1.6.0=py_0 98 | - threadpoolctl=2.2.0=pyh8a188c0_0 99 | - tk=8.6.10=hbc83047_0 100 | - tqdm=4.62.2=pyhd8ed1ab_0 101 | - typing-extensions=3.10.0.0=hd3eb1b0_0 102 | - typing_extensions=3.10.0.0=pyh06a4308_0 103 | - urllib3=1.26.6=pyhd8ed1ab_0 104 | - werkzeug=1.0.1=pyhd3eb1b0_0 105 | - wheel=0.37.0=pyhd3eb1b0_1 106 | - xz=5.2.5=h7b6447c_0 107 | - yacs=0.1.6=py_0 108 | - yaml=0.2.5=h516909a_0 109 | - yarl=1.6.3=py37h27cfd23_0 110 | - zipp=3.5.0=pyhd3eb1b0_0 111 | - zlib=1.2.11=h7b6447c_3 112 | - pip: 113 | - cycler==0.10.0 114 | - kiwisolver==1.3.2 115 | - matplotlib==3.4.3 116 | - monty==2021.8.17 117 | - mpmath==1.2.1 118 | - palettable==3.3.0 119 | - pillow==8.3.2 120 | - plotly==5.3.1 121 | - pymatgen==2022.0.14 122 | - ruamel-yaml==0.17.16 123 | - ruamel-yaml-clib==0.2.6 124 | - spglib==1.16.2 125 | - sympy==1.8 126 | - tabulate==0.8.9 127 | - tenacity==8.0.1 128 | - uncertainties==3.1.6 129 | prefix: /home/reag2/miniconda3/envs/roost 130 | -------------------------------------------------------------------------------- /examples/roost-example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import torch 6 | from sklearn.model_selection import train_test_split as split 7 | 8 | from roost.roost.data import CompositionData, collate_batch 9 | from roost.roost.model import Roost 10 | from roost.utils import results_multitask, train_ensemble 11 | 12 | 13 | def main( 14 | data_path, 15 | fea_path, 16 | targets, 17 | tasks, 18 | losses, 19 | robust, 20 | model_name="roost", 21 | elem_fea_len=64, 22 | n_graph=3, 23 | ensemble=1, 24 | run_id=1, 25 | data_seed=42, 26 | epochs=100, 27 | patience=None, 28 | log=True, 29 | sample=1, 30 | test_size=0.2, 31 | test_path=None, 32 | val_size=0.0, 33 | val_path=None, 34 | resume=None, 35 | fine_tune=None, 36 | transfer=None, 37 | train=True, 38 | evaluate=True, 39 | optim="AdamW", 40 | learning_rate=3e-4, 41 | momentum=0.9, 42 | weight_decay=1e-6, 43 | batch_size=128, 44 | workers=0, 45 | device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), 46 | **kwargs, 47 | ): 48 | assert len(targets) == len(tasks) == len(losses) 49 | 50 | assert ( 51 | evaluate or train 52 | ), "No action given - At least one of 'train' or 'evaluate' cli flags required" 53 | 54 | if test_path: 55 | test_size = 0.0 56 | 57 | if not (test_path and val_path): 58 | assert test_size + val_size < 1.0, ( 59 | f"'test_size'({test_size}) " 60 | f"plus 'val_size'({val_size}) must be less than 1" 61 | ) 62 | 63 | if ensemble > 1 and (fine_tune or transfer): 64 | raise NotImplementedError( 65 | "If training an ensemble with fine tuning or transferring" 66 | " options the models must be trained one by one using the" 67 | " run-id flag." 68 | ) 69 | 70 | assert not (fine_tune and transfer), ( 71 | "Cannot fine-tune and" " transfer checkpoint(s) at the same time." 72 | ) 73 | 74 | task_dict = {k: v for k, v in zip(targets, tasks)} 75 | loss_dict = {k: v for k, v in zip(targets, losses)} 76 | 77 | dataset = CompositionData( 78 | data_path=data_path, fea_path=fea_path, task_dict=task_dict 79 | ) 80 | n_targets = dataset.n_targets 81 | elem_emb_len = dataset.elem_emb_len 82 | 83 | train_idx = list(range(len(dataset))) 84 | 85 | if evaluate: 86 | if test_path: 87 | print(f"using independent test set: {test_path}") 88 | test_set = CompositionData( 89 | data_path=test_path, fea_path=fea_path, task_dict=task_dict 90 | ) 91 | test_set = torch.utils.data.Subset(test_set, range(len(test_set))) 92 | elif test_size == 0.0: 93 | raise ValueError("test-size must be non-zero to evaluate model") 94 | else: 95 | print(f"using {test_size} of training set as test set") 96 | train_idx, test_idx = split( 97 | train_idx, random_state=data_seed, test_size=test_size 98 | ) 99 | test_set = torch.utils.data.Subset(dataset, test_idx) 100 | 101 | if train: 102 | if val_path: 103 | print(f"using independent validation set: {val_path}") 104 | val_set = CompositionData( 105 | data_path=val_path, fea_path=fea_path, task_dict=task_dict 106 | ) 107 | val_set = torch.utils.data.Subset(val_set, range(len(val_set))) 108 | else: 109 | if val_size == 0.0 and evaluate: 110 | print("No validation set used, using test set for evaluation purposes") 111 | # NOTE that when using this option care must be taken not to 112 | # peak at the test-set. The only valid model to use is the one 113 | # obtained after the final epoch where the epoch count is 114 | # decided in advance of the experiment. 115 | val_set = test_set 116 | elif val_size == 0.0: 117 | val_set = None 118 | else: 119 | print(f"using {val_size} of training set as validation set") 120 | train_idx, val_idx = split( 121 | train_idx, random_state=data_seed, test_size=val_size / (1 - test_size), 122 | ) 123 | val_set = torch.utils.data.Subset(dataset, val_idx) 124 | 125 | train_set = torch.utils.data.Subset(dataset, train_idx[0::sample]) 126 | 127 | data_params = { 128 | "batch_size": batch_size, 129 | "num_workers": workers, 130 | "pin_memory": False, 131 | "shuffle": True, 132 | "collate_fn": collate_batch, 133 | } 134 | 135 | setup_params = { 136 | "optim": optim, 137 | "learning_rate": learning_rate, 138 | "weight_decay": weight_decay, 139 | "momentum": momentum, 140 | "device": device, 141 | } 142 | 143 | if resume: 144 | resume = f"models/{model_name}/checkpoint-r{run_id}.pth.tar" 145 | 146 | restart_params = { 147 | "resume": resume, 148 | "fine_tune": fine_tune, 149 | "transfer": transfer, 150 | } 151 | 152 | model_params = { 153 | "task_dict": task_dict, 154 | "robust": robust, 155 | "n_targets": n_targets, 156 | "elem_emb_len": elem_emb_len, 157 | "elem_fea_len": elem_fea_len, 158 | "n_graph": n_graph, 159 | "elem_heads": 3, 160 | "elem_gate": [256], 161 | "elem_msg": [256], 162 | "cry_heads": 3, 163 | "cry_gate": [256], 164 | "cry_msg": [256], 165 | "trunk_hidden": [1024, 512], 166 | "out_hidden": [256, 128, 64], 167 | } 168 | 169 | os.makedirs(f"models/{model_name}/", exist_ok=True) 170 | 171 | if log: 172 | os.makedirs("runs/", exist_ok=True) 173 | 174 | os.makedirs("results/", exist_ok=True) 175 | 176 | # TODO dump all args/kwargs to a file for reproducibility. 177 | 178 | if train: 179 | train_ensemble( 180 | model_class=Roost, 181 | model_name=model_name, 182 | run_id=run_id, 183 | ensemble_folds=ensemble, 184 | epochs=epochs, 185 | patience=patience, 186 | train_set=train_set, 187 | val_set=val_set, 188 | log=log, 189 | data_params=data_params, 190 | setup_params=setup_params, 191 | restart_params=restart_params, 192 | model_params=model_params, 193 | loss_dict=loss_dict, 194 | ) 195 | 196 | if evaluate: 197 | 198 | data_reset = { 199 | "batch_size": 16 * batch_size, # faster model inference 200 | "shuffle": False, # need fixed data order due to ensembling 201 | } 202 | data_params.update(data_reset) 203 | 204 | results_multitask( 205 | model_class=Roost, 206 | model_name=model_name, 207 | run_id=run_id, 208 | ensemble_folds=ensemble, 209 | test_set=test_set, 210 | data_params=data_params, 211 | robust=robust, 212 | task_dict=task_dict, 213 | device=device, 214 | eval_type="checkpoint", 215 | ) 216 | 217 | 218 | def input_parser(): 219 | """ 220 | parse input 221 | """ 222 | parser = argparse.ArgumentParser( 223 | description=( 224 | "Roost - a Structure Agnostic Message Passing " 225 | "Neural Network for Inorganic Materials" 226 | ) 227 | ) 228 | 229 | # data inputs 230 | parser.add_argument( 231 | "--data-path", 232 | type=str, 233 | default="data/datasets/roost/expt-non-metals.csv", 234 | metavar="PATH", 235 | help="Path to main data set/training set", 236 | ) 237 | valid_group = parser.add_mutually_exclusive_group() 238 | valid_group.add_argument( 239 | "--val-path", 240 | type=str, 241 | metavar="PATH", 242 | help="Path to independent validation set", 243 | ) 244 | valid_group.add_argument( 245 | "--val-size", 246 | default=0.0, 247 | type=float, 248 | metavar="FLOAT", 249 | help="Proportion of data used for validation", 250 | ) 251 | test_group = parser.add_mutually_exclusive_group() 252 | test_group.add_argument( 253 | "--test-path", type=str, metavar="PATH", help="Path to independent test set" 254 | ) 255 | test_group.add_argument( 256 | "--test-size", 257 | default=0.2, 258 | type=float, 259 | metavar="FLOAT", 260 | help="Proportion of data set for testing", 261 | ) 262 | 263 | # data embeddings 264 | parser.add_argument( 265 | "--fea-path", 266 | type=str, 267 | default="data/el-embeddings/matscholar-embedding.json", 268 | metavar="PATH", 269 | help="Element embedding feature path", 270 | ) 271 | 272 | # dataloader inputs 273 | parser.add_argument( 274 | "--workers", 275 | default=0, 276 | type=int, 277 | metavar="INT", 278 | help="Number of data loading workers (default: 0)", 279 | ) 280 | parser.add_argument( 281 | "--batch-size", 282 | "--bsize", 283 | default=128, 284 | type=int, 285 | metavar="INT", 286 | help="Mini-batch size (default: 128)", 287 | ) 288 | parser.add_argument( 289 | "--data-seed", 290 | default=0, 291 | type=int, 292 | metavar="INT", 293 | help="Seed used when splitting data sets (default: 0)", 294 | ) 295 | parser.add_argument( 296 | "--sample", 297 | default=1, 298 | type=int, 299 | metavar="INT", 300 | help="Sub-sample the training set for learning curves", 301 | ) 302 | 303 | # task inputs 304 | parser.add_argument( 305 | "--targets", 306 | nargs="*", 307 | type=str, 308 | metavar="STR", 309 | help="Task types for targets", 310 | ) 311 | 312 | parser.add_argument( 313 | "--tasks", 314 | nargs="*", 315 | default=["regression"], 316 | type=str, 317 | metavar="STR", 318 | help="Task types for targets", 319 | ) 320 | 321 | parser.add_argument( 322 | "--losses", 323 | nargs="*", 324 | default=["L1"], 325 | type=str, 326 | metavar="STR", 327 | help="Loss function if regression (default: 'L1')", 328 | ) 329 | 330 | # optimiser inputs 331 | parser.add_argument( 332 | "--epochs", 333 | default=100, 334 | type=int, 335 | metavar="INT", 336 | help="Number of training epochs to run (default: 100)", 337 | ) 338 | parser.add_argument( 339 | "--robust", 340 | action="store_true", 341 | help="Specifies whether to use hetroskedastic loss variants", 342 | ) 343 | parser.add_argument( 344 | "--optim", 345 | default="AdamW", 346 | type=str, 347 | metavar="STR", 348 | help="Optimizer used for training (default: 'AdamW')", 349 | ) 350 | parser.add_argument( 351 | "--learning-rate", 352 | "--lr", 353 | default=3e-4, 354 | type=float, 355 | metavar="FLOAT", 356 | help="Initial learning rate (default: 3e-4)", 357 | ) 358 | parser.add_argument( 359 | "--momentum", 360 | default=0.9, 361 | type=float, 362 | metavar="FLOAT [0,1]", 363 | help="Optimizer momentum (default: 0.9)", 364 | ) 365 | parser.add_argument( 366 | "--weight-decay", 367 | default=1e-6, 368 | type=float, 369 | metavar="FLOAT [0,1]", 370 | help="Optimizer weight decay (default: 1e-6)", 371 | ) 372 | 373 | # graph inputs 374 | parser.add_argument( 375 | "--elem-fea-len", 376 | default=64, 377 | type=int, 378 | metavar="INT", 379 | help="Number of hidden features for elements (default: 64)", 380 | ) 381 | parser.add_argument( 382 | "--n-graph", 383 | default=3, 384 | type=int, 385 | metavar="INT", 386 | help="Number of message passing layers (default: 3)", 387 | ) 388 | 389 | # ensemble inputs 390 | parser.add_argument( 391 | "--ensemble", 392 | default=1, 393 | type=int, 394 | metavar="INT", 395 | help="Number models to ensemble", 396 | ) 397 | name_group = parser.add_mutually_exclusive_group() 398 | name_group.add_argument( 399 | "--model-name", 400 | type=str, 401 | default=None, 402 | metavar="STR", 403 | help="Name for sub-directory where models will be stored", 404 | ) 405 | name_group.add_argument( 406 | "--data-id", 407 | default="roost", 408 | type=str, 409 | metavar="STR", 410 | help="Partial identifier for sub-directory where models will be stored", 411 | ) 412 | parser.add_argument( 413 | "--run-id", 414 | default=0, 415 | type=int, 416 | metavar="INT", 417 | help="Index for model in an ensemble of models", 418 | ) 419 | 420 | # restart inputs 421 | use_group = parser.add_mutually_exclusive_group() 422 | use_group.add_argument( 423 | "--fine-tune", type=str, metavar="PATH", help="Checkpoint path for fine tuning" 424 | ) 425 | use_group.add_argument( 426 | "--transfer", 427 | type=str, 428 | metavar="PATH", 429 | help="Checkpoint path for transfer learning", 430 | ) 431 | use_group.add_argument( 432 | "--resume", action="store_true", help="Resume from previous checkpoint" 433 | ) 434 | 435 | # task type 436 | parser.add_argument( 437 | "--evaluate", 438 | action="store_true", 439 | help="Evaluate the model/ensemble", 440 | ) 441 | parser.add_argument("--train", action="store_true", help="Train the model/ensemble") 442 | 443 | # misc 444 | parser.add_argument("--disable-cuda", action="store_true", help="Disable CUDA") 445 | parser.add_argument( 446 | "--log", action="store_true", help="Log training metrics to tensorboard" 447 | ) 448 | 449 | args = parser.parse_args(sys.argv[1:]) 450 | 451 | assert all( 452 | [i in ["regression", "classification"] for i in args.tasks] 453 | ), "Only `regression` and `classification` are allowed as tasks" 454 | 455 | if args.model_name is None: 456 | args.model_name = f"{args.data_id}_s-{args.data_seed}_t-{args.sample}" 457 | 458 | args.device = ( 459 | torch.device("cuda") 460 | if (not args.disable_cuda) and torch.cuda.is_available() 461 | else torch.device("cpu") 462 | ) 463 | 464 | return args 465 | 466 | 467 | if __name__ == "__main__": 468 | args = input_parser() 469 | 470 | print(f"The model will run on the {args.device} device") 471 | 472 | main(**vars(args)) 473 | -------------------------------------------------------------------------------- /examples/wren-example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import torch 6 | from sklearn.model_selection import train_test_split as split 7 | 8 | from roost.utils import results_multitask, train_ensemble 9 | from roost.wren.data import WyckoffData, collate_batch 10 | from roost.wren.model import Wren 11 | 12 | 13 | def main( 14 | data_path, 15 | fea_path, 16 | sym_path, 17 | targets, 18 | tasks, 19 | losses, 20 | robust, 21 | model_name="wren", 22 | sym_fea_len=32, 23 | elem_fea_len=32, 24 | n_graph=3, 25 | ensemble=1, 26 | run_id=1, 27 | data_seed=42, 28 | epochs=100, 29 | patience=None, 30 | log=True, 31 | sample=1, 32 | test_size=0.2, 33 | test_path=None, 34 | val_size=0.0, 35 | val_path=None, 36 | resume=None, 37 | fine_tune=None, 38 | transfer=None, 39 | train=True, 40 | evaluate=True, 41 | optim="AdamW", 42 | learning_rate=3e-4, 43 | momentum=0.9, 44 | weight_decay=1e-6, 45 | batch_size=128, 46 | workers=0, 47 | device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), 48 | **kwargs, 49 | ): 50 | 51 | assert len(targets) == len(tasks) == len(losses) 52 | 53 | assert ( 54 | evaluate or train 55 | ), "No action given - At least one of 'train' or 'evaluate' cli flags required" 56 | 57 | if test_path: 58 | test_size = 0.0 59 | 60 | if not (test_path and val_path): 61 | assert test_size + val_size < 1.0, ( 62 | f"'test_size'({test_size}) " 63 | f"plus 'val_size'({val_size}) must be less than 1" 64 | ) 65 | 66 | if ensemble > 1 and (fine_tune or transfer): 67 | raise NotImplementedError( 68 | "If training an ensemble with fine tuning or transferring" 69 | " options the models must be trained one by one using the" 70 | " run-id flag." 71 | ) 72 | 73 | assert not (fine_tune and transfer), ( 74 | "Cannot fine-tune and" " transfer checkpoint(s) at the same time." 75 | ) 76 | 77 | # TODO CLI controls for loss dict. 78 | 79 | task_dict = {k: v for k, v in zip(targets, tasks)} 80 | loss_dict = {k: v for k, v in zip(targets, losses)} 81 | 82 | dataset = WyckoffData( 83 | data_path=data_path, fea_path=fea_path, sym_path=sym_path, task_dict=task_dict 84 | ) 85 | n_targets = dataset.n_targets 86 | sym_emb_len = dataset.sym_fea_dim 87 | elem_emb_len = dataset.elem_fea_dim 88 | 89 | train_idx = list(range(len(dataset))) 90 | 91 | if evaluate: 92 | if test_path: 93 | print(f"using independent test set: {test_path}") 94 | test_set = WyckoffData( 95 | data_path=test_path, 96 | fea_path=fea_path, 97 | sym_path=sym_path, 98 | task_dict=task_dict, 99 | ) 100 | test_set = torch.utils.data.Subset(test_set, range(len(test_set))) 101 | elif test_size == 0.0: 102 | raise ValueError("test-size must be non-zero to evaluate model") 103 | else: 104 | print(f"using {test_size} of training set as test set") 105 | train_idx, test_idx = split( 106 | train_idx, random_state=data_seed, test_size=test_size 107 | ) 108 | test_set = torch.utils.data.Subset(dataset, test_idx) 109 | 110 | if train: 111 | if val_path: 112 | print(f"using independent validation set: {val_path}") 113 | val_set = WyckoffData( 114 | data_path=val_path, 115 | fea_path=fea_path, 116 | sym_path=sym_path, 117 | task_dict=task_dict, 118 | ) 119 | val_set = torch.utils.data.Subset(val_set, range(len(val_set))) 120 | else: 121 | if val_size == 0.0 and evaluate: 122 | print("No validation set used, using test set for evaluation purposes") 123 | # NOTE that when using this option care must be taken not to 124 | # peak at the test-set. The only valid model to use is the one 125 | # obtained after the final epoch where the epoch count is 126 | # decided in advance of the experiment. 127 | val_set = test_set 128 | elif val_size == 0.0: 129 | val_set = None 130 | else: 131 | print(f"using {val_size} of training set as validation set") 132 | train_idx, val_idx = split( 133 | train_idx, 134 | random_state=data_seed, 135 | test_size=val_size / (1 - test_size), 136 | ) 137 | val_set = torch.utils.data.Subset(dataset, val_idx) 138 | 139 | train_set = torch.utils.data.Subset(dataset, train_idx[0::sample]) 140 | 141 | data_params = { 142 | "batch_size": batch_size, 143 | "num_workers": workers, 144 | "pin_memory": False, 145 | "shuffle": True, 146 | "collate_fn": collate_batch, 147 | } 148 | 149 | setup_params = { 150 | "optim": optim, 151 | "learning_rate": learning_rate, 152 | "weight_decay": weight_decay, 153 | "momentum": momentum, 154 | "device": device, 155 | } 156 | 157 | if resume: 158 | resume = f"models/{model_name}/checkpoint-r{run_id}.pth.tar" 159 | 160 | restart_params = { 161 | "resume": resume, 162 | "fine_tune": fine_tune, 163 | "transfer": transfer, 164 | } 165 | 166 | model_params = { 167 | "task_dict": task_dict, 168 | "robust": robust, 169 | "n_targets": n_targets, 170 | "elem_emb_len": elem_emb_len, 171 | "sym_emb_len": sym_emb_len, 172 | "elem_fea_len": elem_fea_len, 173 | "sym_fea_len": sym_fea_len, 174 | "n_graph": n_graph, 175 | "elem_heads": 1, 176 | "elem_gate": [256], 177 | "elem_msg": [256], 178 | "cry_heads": 1, 179 | "cry_gate": [256], 180 | "cry_msg": [256], 181 | # "out_hidden": [256] * 6, 182 | # "out_hidden": [1024, 512, 256, 128, 64], 183 | "out_hidden": [256, 256], 184 | "trunk_hidden": [128, 64], 185 | } 186 | 187 | os.makedirs(f"models/{model_name}/", exist_ok=True) 188 | 189 | if log: 190 | os.makedirs("runs/", exist_ok=True) 191 | 192 | os.makedirs("results/", exist_ok=True) 193 | 194 | if train: 195 | train_ensemble( 196 | model_class=Wren, 197 | model_name=model_name, 198 | run_id=run_id, 199 | ensemble_folds=ensemble, 200 | epochs=epochs, 201 | patience=patience, 202 | train_set=train_set, 203 | val_set=val_set, 204 | log=log, 205 | data_params=data_params, 206 | setup_params=setup_params, 207 | restart_params=restart_params, 208 | model_params=model_params, 209 | loss_dict=loss_dict, 210 | ) 211 | 212 | if evaluate: 213 | 214 | data_reset = { 215 | "batch_size": 16 * batch_size, # faster model inference 216 | "shuffle": False, # need fixed data order due to ensembling 217 | } 218 | data_params.update(data_reset) 219 | 220 | results_multitask( 221 | model_class=Wren, 222 | model_name=model_name, 223 | run_id=run_id, 224 | ensemble_folds=ensemble, 225 | test_set=test_set, 226 | data_params=data_params, 227 | robust=robust, 228 | task_dict=task_dict, 229 | device=device, 230 | eval_type="checkpoint", 231 | ) 232 | 233 | 234 | def input_parser(): 235 | """ 236 | parse input 237 | """ 238 | parser = argparse.ArgumentParser(description=("Wren")) 239 | 240 | # data inputs 241 | parser.add_argument( 242 | "--data-path", 243 | type=str, 244 | default="/home/reag2/PhD/roost/data/datasets/wren/taata-c-spglib-test.csv", 245 | metavar="PATH", 246 | help="Path to main data set/training set", 247 | ) 248 | valid_group = parser.add_mutually_exclusive_group() 249 | valid_group.add_argument( 250 | "--val-path", 251 | type=str, 252 | metavar="PATH", 253 | help="Path to independent validation set", 254 | ) 255 | valid_group.add_argument( 256 | "--val-size", 257 | default=0.0, 258 | type=float, 259 | metavar="FLOAT", 260 | help="Proportion of data used for validation", 261 | ) 262 | test_group = parser.add_mutually_exclusive_group() 263 | test_group.add_argument( 264 | "--test-path", type=str, metavar="PATH", help="Path to independent test set" 265 | ) 266 | test_group.add_argument( 267 | "--test-size", 268 | default=0.2, 269 | type=float, 270 | metavar="FLOAT", 271 | help="Proportion of data set for testing", 272 | ) 273 | 274 | # data embeddings 275 | parser.add_argument( 276 | "--fea-path", 277 | type=str, 278 | default="data/el-embeddings/matscholar-embedding.json", 279 | metavar="PATH", 280 | help="Element embedding feature path", 281 | ) 282 | parser.add_argument( 283 | "--sym-path", 284 | type=str, 285 | default="data/wp-embeddings/bra-alg-off.json", 286 | metavar="PATH", 287 | help="Element embedding feature path", 288 | ) 289 | 290 | # dataloader inputs 291 | parser.add_argument( 292 | "--workers", 293 | default=0, 294 | type=int, 295 | metavar="INT", 296 | help="Number of data loading workers (default: 0)", 297 | ) 298 | parser.add_argument( 299 | "--batch-size", 300 | "--bsize", 301 | default=128, 302 | type=int, 303 | metavar="INT", 304 | help="Mini-batch size (default: 128)", 305 | ) 306 | parser.add_argument( 307 | "--data-seed", 308 | default=0, 309 | type=int, 310 | metavar="INT", 311 | help="Seed used when splitting data sets (default: 0)", 312 | ) 313 | parser.add_argument( 314 | "--sample", 315 | default=1, 316 | type=int, 317 | metavar="INT", 318 | help="Sub-sample the training set for learning curves", 319 | ) 320 | 321 | # task inputs 322 | parser.add_argument( 323 | "--targets", 324 | nargs="*", 325 | type=str, 326 | metavar="STR", 327 | help="Task types for targets", 328 | ) 329 | parser.add_argument( 330 | "--tasks", 331 | nargs="*", 332 | default=["regression"], 333 | type=str, 334 | metavar="STR", 335 | help="Task types for targets", 336 | ) 337 | parser.add_argument( 338 | "--losses", 339 | nargs="*", 340 | default=["L1"], 341 | type=str, 342 | metavar="STR", 343 | help="Loss function if regression (default: 'L1')", 344 | ) 345 | 346 | # optimiser inputs 347 | parser.add_argument( 348 | "--epochs", 349 | default=100, 350 | type=int, 351 | metavar="INT", 352 | help="Number of training epochs to run (default: 100)", 353 | ) 354 | parser.add_argument( 355 | "--robust", 356 | action="store_true", 357 | help="Specifies whether to use hetroskedastic loss variants", 358 | ) 359 | parser.add_argument( 360 | "--optim", 361 | default="AdamW", 362 | type=str, 363 | metavar="STR", 364 | help="Optimizer used for training (default: 'AdamW')", 365 | ) 366 | parser.add_argument( 367 | "--learning-rate", 368 | "--lr", 369 | default=3e-4, 370 | type=float, 371 | metavar="FLOAT", 372 | help="Initial learning rate (default: 3e-4)", 373 | ) 374 | parser.add_argument( 375 | "--momentum", 376 | default=0.9, 377 | type=float, 378 | metavar="FLOAT [0,1]", 379 | help="Optimizer momentum (default: 0.9)", 380 | ) 381 | parser.add_argument( 382 | "--weight-decay", 383 | default=1e-6, 384 | type=float, 385 | metavar="FLOAT [0,1]", 386 | help="Optimizer weight decay (default: 1e-6)", 387 | ) 388 | 389 | # graph inputs 390 | parser.add_argument( 391 | "--elem-fea-len", 392 | default=32, 393 | type=int, 394 | metavar="INT", 395 | help="Number of hidden features for elements (default: 64)", 396 | ) 397 | parser.add_argument( 398 | "--sym-fea-len", 399 | default=32, 400 | type=int, 401 | metavar="INT", 402 | help="Number of hidden features for elements (default: 64)", 403 | ) 404 | parser.add_argument( 405 | "--n-graph", 406 | default=3, 407 | type=int, 408 | metavar="INT", 409 | help="Number of message passing layers (default: 3)", 410 | ) 411 | 412 | # ensemble inputs 413 | parser.add_argument( 414 | "--ensemble", 415 | default=1, 416 | type=int, 417 | metavar="INT", 418 | help="Number models to ensemble", 419 | ) 420 | name_group = parser.add_mutually_exclusive_group() 421 | name_group.add_argument( 422 | "--model-name", 423 | type=str, 424 | default=None, 425 | metavar="STR", 426 | help="Name for sub-directory where models will be stored", 427 | ) 428 | name_group.add_argument( 429 | "--data-id", 430 | default="wren", 431 | type=str, 432 | metavar="STR", 433 | help="Partial identifier for sub-directory where models will be stored", 434 | ) 435 | parser.add_argument( 436 | "--run-id", 437 | default=0, 438 | type=int, 439 | metavar="INT", 440 | help="Index for model in an ensemble of models", 441 | ) 442 | 443 | # restart inputs 444 | use_group = parser.add_mutually_exclusive_group() 445 | use_group.add_argument( 446 | "--fine-tune", type=str, metavar="PATH", help="Checkpoint path for fine tuning" 447 | ) 448 | use_group.add_argument( 449 | "--transfer", 450 | type=str, 451 | metavar="PATH", 452 | help="Checkpoint path for transfer learning", 453 | ) 454 | use_group.add_argument( 455 | "--resume", action="store_true", help="Resume from previous checkpoint" 456 | ) 457 | 458 | # task type 459 | parser.add_argument( 460 | "--evaluate", 461 | action="store_true", 462 | help="Evaluate the model/ensemble", 463 | ) 464 | parser.add_argument("--train", action="store_true", help="Train the model/ensemble") 465 | 466 | # misc 467 | parser.add_argument("--disable-cuda", action="store_true", help="Disable CUDA") 468 | parser.add_argument( 469 | "--log", action="store_true", help="Log training metrics to tensorboard" 470 | ) 471 | 472 | args = parser.parse_args(sys.argv[1:]) 473 | 474 | assert all( 475 | [i in ["regression", "classification"] for i in args.tasks] 476 | ), "Only `regression` and `classification` are allowed as tasks" 477 | 478 | if args.model_name is None: 479 | args.model_name = f"{args.data_id}_s-{args.data_seed}_t-{args.sample}" 480 | 481 | args.device = ( 482 | torch.device("cuda") 483 | if (not args.disable_cuda) and torch.cuda.is_available() 484 | else torch.device("cpu") 485 | ) 486 | 487 | return args 488 | 489 | 490 | if __name__ == "__main__": 491 | args = input_parser() 492 | 493 | print(f"The model will run on the {args.device} device") 494 | 495 | main(**vars(args)) 496 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | tqdm 3 | torch 4 | numpy 5 | pymatgen 6 | torch_scatter 7 | pandas 8 | scikit_learn 9 | pytest 10 | tensorboard 11 | -------------------------------------------------------------------------------- /roost/cgcnn/data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import functools 3 | import os 4 | from itertools import groupby 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from pymatgen.core.structure import Structure 10 | from torch.utils.data import Dataset 11 | 12 | from roost.core import Featurizer 13 | 14 | 15 | class CrystalGraphData(Dataset): 16 | def __init__( 17 | self, 18 | data_path, 19 | fea_path, 20 | task_dict, 21 | inputs=["lattice", "sites"], 22 | identifiers=["material_id", "composition"], 23 | radius=5, 24 | max_num_nbr=12, 25 | dmin=0, 26 | step=0.2, 27 | ): 28 | """CrystalGraphData returns neighbourhood graphs 29 | 30 | Args: 31 | data_path (str): The path to the dataset 32 | fea_path (str): The path to the element embedding 33 | task_dict ({target: task}): task dict for multi-task learning 34 | inputs (list, optional): df columns for lattice and sites. 35 | Defaults to ["lattice", "sites"]. 36 | identifiers (list, optional): df columns for distinguishing data points. 37 | Defaults to ["material_id", "composition"]. 38 | radius (int, optional): cut-off radius for neighbourhood. 39 | Defaults to 5. 40 | max_num_nbr (int, optional): maximum number of neighbours to consider. 41 | Defaults to 12. 42 | dmin (int, optional): minimum distance in gaussian basis. 43 | Defaults to 0. 44 | step (float, optional): increment size of gaussian basis. 45 | Defaults to 0.2. 46 | """ 47 | assert len(identifiers) == 2, "Two identifiers are required" 48 | assert len(inputs) == 2, "One input column required are required" 49 | 50 | self.inputs = inputs 51 | self.task_dict = task_dict 52 | self.identifiers = identifiers 53 | 54 | self.radius = radius 55 | self.max_num_nbr = max_num_nbr 56 | 57 | assert os.path.exists(fea_path), f"{fea_path} does not exist!" 58 | self.ari = Featurizer.from_json(fea_path) 59 | self.elem_fea_dim = self.ari.embedding_size 60 | 61 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step) 62 | self.nbr_fea_dim = self.gdf.embedding_size 63 | 64 | assert os.path.exists(data_path), f"{data_path} does not exist!" 65 | # NOTE make sure to use dense datasets, here do not use the default na 66 | # as they can clash with "NaN" which is a valid material 67 | self.df = pd.read_csv( 68 | data_path, keep_default_na=False, na_values=[], comment="#" 69 | ) 70 | 71 | self.df["Structure_obj"] = self.df[inputs].apply(get_structure, axis=1) 72 | 73 | self._pre_check() 74 | 75 | self.n_targets = [] 76 | for target in self.task_dict: 77 | if self.task_dict[target] == "regression": 78 | self.n_targets.append(1) 79 | elif self.task == "classification": 80 | n_classes = np.max(self.df[target].values) + 1 81 | self.n_targets.append(n_classes) 82 | 83 | def __len__(self): 84 | return len(self.df) 85 | 86 | def _get_nbr_data(self, crystal): 87 | """get neighbours for every site 88 | 89 | Args: 90 | crystal ([Structure]): pymatgen structure to get neighbours for 91 | """ 92 | self_idx, nbr_idx, _, nbr_dist = crystal.get_neighbor_list( 93 | self.radius, 94 | numerical_tol=1e-8, 95 | ) 96 | 97 | if self.max_num_nbr is not None: 98 | _self_idx, _nbr_idx, _nbr_dist = [], [], [] 99 | 100 | for i, g in groupby(zip(self_idx, nbr_idx, nbr_dist), key=lambda x: x[0]): 101 | s, n, d = zip(*sorted(g, key=lambda x: x[2])) 102 | _self_idx.extend(s[: self.max_num_nbr]) 103 | _nbr_idx.extend(n[: self.max_num_nbr]) 104 | _nbr_dist.extend(d[: self.max_num_nbr]) 105 | 106 | self_idx = np.array(_self_idx) 107 | nbr_idx = np.array(_nbr_idx) 108 | nbr_dist = np.array(_nbr_dist) 109 | 110 | return self_idx, nbr_idx, nbr_dist 111 | 112 | def _pre_check(self): 113 | """Check that none of the structures have isolated atoms. 114 | 115 | Raises: 116 | ValueError: if isolated structures are present 117 | """ 118 | print("Precheck that all structures are valid") 119 | all_iso = [] 120 | some_iso = [] 121 | 122 | for cif_id, crystal in zip(self.df["material_id"], self.df["Structure_obj"]): 123 | self_idx, nbr_idx, _ = self._get_nbr_data(crystal) 124 | 125 | if len(self_idx) == 0: 126 | all_iso.append(cif_id) 127 | elif len(nbr_idx) == 0: 128 | all_iso.append(cif_id) 129 | elif set(self_idx) != set(range(crystal.num_sites)): 130 | some_iso.append(cif_id) 131 | 132 | if (len(all_iso) > 0) or (len(some_iso) > 0): 133 | # drop the data points that do not give rise to dense crystal graphs 134 | self.df = self.df.drop(self.df[self.df["material_id"].isin(all_iso)].index) 135 | self.df = self.df.drop(self.df[self.df["material_id"].isin(some_iso)].index) 136 | 137 | print(all_iso) 138 | print(some_iso) 139 | 140 | @functools.lru_cache(maxsize=None) # Cache loaded structures 141 | def __getitem__(self, idx): 142 | # NOTE sites must be given in fractional coordinates 143 | df_idx = self.df.iloc[idx] 144 | crystal = df_idx["Structure_obj"] 145 | cif_id, comp = df_idx[self.identifiers] 146 | 147 | # atom features for disordered sites 148 | site_atoms = [atom.species.as_dict() for atom in crystal] 149 | atom_fea = np.vstack( 150 | [ 151 | np.sum([self.ari.get_fea(el) * amt for el, amt in site.items()], axis=0) 152 | for site in site_atoms 153 | ] 154 | ) 155 | 156 | # # # neighbours 157 | self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal) 158 | 159 | assert len(self_idx), f"All atoms in {cif_id} are isolated" 160 | assert len(nbr_idx), f"This should not be triggered but was for {cif_id}" 161 | assert set(self_idx) == set( 162 | range(crystal.num_sites) 163 | ), f"At least one atom in {cif_id} is isolated" 164 | 165 | nbr_dist = self.gdf.expand(nbr_dist) 166 | 167 | atom_fea = torch.Tensor(atom_fea) 168 | nbr_dist = torch.Tensor(nbr_dist) 169 | self_idx = torch.LongTensor(self_idx) 170 | nbr_idx = torch.LongTensor(nbr_idx) 171 | 172 | targets = [] 173 | for target in self.task_dict: 174 | if self.task_dict[target] == "regression": 175 | targets.append(torch.Tensor([df_idx[target]])) 176 | elif self.task_dict[target] == "classification": 177 | targets.append(torch.LongTensor([df_idx[target]])) 178 | 179 | return ((atom_fea, nbr_dist, self_idx, nbr_idx), targets, comp, cif_id) 180 | 181 | 182 | class GaussianDistance: 183 | """ 184 | Expands the distance by Gaussian basis. 185 | 186 | Unit: angstrom 187 | """ 188 | 189 | def __init__(self, dmin, dmax, step, var=None): 190 | """ 191 | Args: 192 | dmin (float): Minimum interatomic distance 193 | dmax (float): Maximum interatomic distance 194 | step (float): Step size for the Gaussian filter 195 | var (float, optional): Variance of Gaussian basis. Defaults to step if not given 196 | """ 197 | assert dmin < dmax 198 | assert dmax - dmin > step 199 | 200 | self.filter = np.arange(dmin, dmax + step, step) 201 | self.embedding_size = len(self.filter) 202 | 203 | if var is None: 204 | var = step 205 | 206 | self.var = var 207 | 208 | def expand(self, distances): 209 | """Apply Gaussian distance filter to a numpy distance array 210 | 211 | Args: 212 | distances (ArrayLike): A distance matrix of any shape 213 | 214 | Returns: 215 | Expanded distance matrix with the last dimension of length 216 | len(self.filter) 217 | """ 218 | distances = np.array(distances) 219 | 220 | return np.exp( 221 | -((distances[..., np.newaxis] - self.filter) ** 2) / self.var ** 2 222 | ) 223 | 224 | 225 | def collate_batch(dataset_list): 226 | """ 227 | Collate a list of data and return a batch for predicting crystal 228 | properties. 229 | 230 | Parameters 231 | ---------- 232 | 233 | dataset_list: list of tuples for each data point. 234 | (atom_fea, nbr_dist, nbr_idx, target) 235 | 236 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 237 | nbr_dist: torch.Tensor shape (n_i, M, nbr_dist_len) 238 | nbr_idx: torch.LongTensor shape (n_i, M) 239 | target: torch.Tensor shape (1, ) 240 | cif_id: str or int 241 | 242 | Returns 243 | ------- 244 | N = sum(n_i); N0 = sum(i) 245 | 246 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 247 | Atom features from atom type 248 | batch_nbr_dist: torch.Tensor shape (N, M, nbr_dist_len) 249 | Bond features of each atom's M neighbors 250 | batch_nbr_idx: torch.LongTensor shape (N, M) 251 | Indices of M neighbors of each atom 252 | crystal_atom_idx: list of torch.LongTensor of length N0 253 | Mapping from the crystal idx to atom idx 254 | target: torch.Tensor shape (N, 1) 255 | Target value for prediction 256 | batch_cif_ids: list 257 | """ 258 | batch_atom_fea = [] 259 | batch_nbr_dist = [] 260 | batch_self_idx = [] 261 | batch_nbr_idx = [] 262 | crystal_atom_idx = [] 263 | batch_targets = [] 264 | batch_comps = [] 265 | batch_cif_ids = [] 266 | base_idx = 0 267 | 268 | for i, (inputs, target, comp, cif_id) in enumerate(dataset_list): 269 | atom_fea, nbr_dist, self_idx, nbr_idx = inputs 270 | n_i = atom_fea.shape[0] # number of atoms for this crystal 271 | 272 | batch_atom_fea.append(atom_fea) 273 | batch_nbr_dist.append(nbr_dist) 274 | batch_self_idx.append(self_idx + base_idx) 275 | batch_nbr_idx.append(nbr_idx + base_idx) 276 | 277 | crystal_atom_idx.extend([i] * n_i) 278 | batch_targets.append(target) 279 | batch_comps.append(comp) 280 | batch_cif_ids.append(cif_id) 281 | base_idx += n_i 282 | 283 | atom_fea = torch.cat(batch_atom_fea, dim=0) 284 | nbr_dist = torch.cat(batch_nbr_dist, dim=0) 285 | self_idx = torch.cat(batch_self_idx, dim=0) 286 | nbr_idx = torch.cat(batch_nbr_idx, dim=0) 287 | cry_idx = torch.LongTensor(crystal_atom_idx) 288 | 289 | return ( 290 | (atom_fea, nbr_dist, self_idx, nbr_idx, cry_idx), 291 | tuple(torch.stack(b_target, dim=0) for b_target in zip(*batch_targets)), 292 | batch_comps, 293 | batch_cif_ids, 294 | ) 295 | 296 | 297 | def get_structure(cols): 298 | """Return pymatgen structure from lattice and sites cols""" 299 | cell, sites = cols 300 | cell, elems, coords = parse_cgcnn(cell, sites) 301 | # NOTE getting primitive structure before constructing graph 302 | # significantly harms the performance of this model. 303 | return Structure(lattice=cell, species=elems, coords=coords, to_unit_cell=True) 304 | 305 | 306 | def parse_cgcnn(cell, sites): 307 | """Parse str representation into lists""" 308 | cell = np.array(ast.literal_eval(cell), dtype=float) 309 | elems = [] 310 | coords = [] 311 | for site in ast.literal_eval(sites): 312 | ele, pos = site.split(" @ ") 313 | elems.append(ele) 314 | coords.append(pos.split(" ")) 315 | 316 | coords = np.array(coords, dtype=float) 317 | return cell, elems, coords 318 | -------------------------------------------------------------------------------- /roost/cgcnn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from roost.core import BaseModelClass 6 | from roost.segments import MeanPooling, SimpleNetwork, SumPooling 7 | 8 | 9 | class CrystalGraphConvNet(BaseModelClass): 10 | """ 11 | Create a crystal graph convolutional neural network for predicting total 12 | material properties. 13 | 14 | This model is based on: https://github.com/txie-93/cgcnn [MIT License]. 15 | Changes to the code were made to allow for the removal of zero-padding 16 | and to benefit from the BaseModelClass functionality. The architectural 17 | choices of the model remain unchanged. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | robust, 23 | n_targets, 24 | elem_emb_len, 25 | nbr_fea_len, 26 | elem_fea_len=64, 27 | n_graph=4, 28 | h_fea_len=128, 29 | n_trunk=1, 30 | n_hidden=1, 31 | **kwargs, 32 | ): 33 | """ 34 | Initialize CrystalGraphConvNet. 35 | 36 | Parameters 37 | ---------- 38 | 39 | orig_elem_fea_len: int 40 | Number of atom features in the input. 41 | nbr_fea_len: int 42 | Number of bond features. 43 | elem_fea_len: int 44 | Number of hidden atom features in the convolutional layers 45 | n_graph: int 46 | Number of convolutional layers 47 | h_fea_len: int 48 | Number of hidden features after pooling 49 | n_hidden: int 50 | Number of hidden layers after pooling 51 | """ 52 | super().__init__(robust=robust, **kwargs) 53 | 54 | desc_dict = { 55 | "elem_emb_len": elem_emb_len, 56 | "nbr_fea_len": nbr_fea_len, 57 | "elem_fea_len": elem_fea_len, 58 | "n_graph": n_graph, 59 | } 60 | 61 | self.node_nn = DescriptorNetwork(**desc_dict) 62 | 63 | self.model_params.update( 64 | { 65 | "robust": robust, 66 | "n_targets": n_targets, 67 | "h_fea_len": h_fea_len, 68 | "n_hidden": n_hidden, 69 | } 70 | ) 71 | 72 | self.model_params.update(desc_dict) 73 | 74 | self.pooling = MeanPooling() 75 | 76 | # define an output neural network 77 | if self.robust: 78 | n_targets = [2 * n for n in n_targets] 79 | 80 | out_hidden = [h_fea_len] * n_hidden 81 | trunk_hidden = [h_fea_len] * n_trunk 82 | self.trunk_nn = SimpleNetwork(elem_fea_len, h_fea_len, trunk_hidden) 83 | 84 | self.output_nns = nn.ModuleList( 85 | SimpleNetwork(h_fea_len, n, out_hidden) for n in n_targets 86 | ) 87 | 88 | def forward(self, atom_fea, nbr_fea, self_idx, nbr_idx, crystal_atom_idx): 89 | """ 90 | Forward pass 91 | 92 | N: Total number of atoms in the batch 93 | M: Max number of neighbors 94 | N0: Total number of crystals in the batch 95 | 96 | Parameters 97 | ---------- 98 | 99 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len) 100 | Atom features from atom type 101 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) 102 | Bond features of each atom's M neighbors 103 | nbr_fea_idx: torch.LongTensor shape (N, M) 104 | Indices of M neighbors of each atom 105 | crystal_atom_idx: list of torch.LongTensor of length N0 106 | Mapping from the crystal idx to atom idx 107 | 108 | Returns 109 | ------- 110 | 111 | prediction: nn.Variable shape (N, ) 112 | Atom hidden features after convolution 113 | 114 | """ 115 | atom_fea = self.node_nn( 116 | atom_fea, nbr_fea, self_idx, nbr_idx 117 | ) 118 | 119 | crys_fea = self.pooling(atom_fea, crystal_atom_idx) 120 | 121 | # NOTE required to match the reference implementation 122 | crys_fea = nn.functional.softplus(crys_fea) 123 | 124 | crys_fea = F.relu(self.trunk_nn(crys_fea)) 125 | 126 | # apply neural network to map from learned features to target 127 | return (output_nn(crys_fea) for output_nn in self.output_nns) 128 | 129 | 130 | class DescriptorNetwork(nn.Module): 131 | """ 132 | The Descriptor Network is the message passing section of the 133 | CrystalGraphConvNet Model. 134 | """ 135 | 136 | def __init__( 137 | self, elem_emb_len, nbr_fea_len, elem_fea_len=64, n_graph=4, 138 | ): 139 | """ 140 | """ 141 | super().__init__() 142 | 143 | self.embedding = nn.Linear(elem_emb_len, elem_fea_len) 144 | 145 | self.convs = nn.ModuleList( 146 | [CGCNNConv( 147 | elem_fea_len=elem_fea_len, 148 | nbr_fea_len=nbr_fea_len 149 | ) for _ in range(n_graph)] 150 | ) 151 | 152 | def forward(self, atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx): 153 | """ 154 | Forward pass 155 | 156 | N: Total number of atoms in the batch 157 | M: Max number of neighbors 158 | N0: Total number of crystals in the batch 159 | 160 | Parameters 161 | ---------- 162 | 163 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len) 164 | Atom features from atom type 165 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) 166 | Bond features of each atom's M neighbors 167 | nbr_fea_idx: torch.LongTensor shape (N, M) 168 | Indices of M neighbors of each atom 169 | crystal_atom_idx: list of torch.LongTensor of length N0 170 | Mapping from the crystal idx to atom idx 171 | 172 | Returns 173 | ------- 174 | 175 | prediction: nn.Variable shape (N, ) 176 | Atom hidden features after convolution 177 | 178 | """ 179 | atom_fea = self.embedding(atom_fea) 180 | 181 | for conv_func in self.convs: 182 | atom_fea = conv_func(atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx) 183 | 184 | return atom_fea 185 | 186 | 187 | class CGCNNConv(nn.Module): 188 | """ 189 | Convolutional operation on graphs 190 | """ 191 | 192 | def __init__(self, elem_fea_len, nbr_fea_len): 193 | """ 194 | Initialize CGCNNConv. 195 | 196 | Parameters 197 | ---------- 198 | 199 | elem_fea_len: int 200 | Number of atom hidden features. 201 | nbr_fea_len: int 202 | Number of bond features. 203 | """ 204 | super().__init__() 205 | self.elem_fea_len = elem_fea_len 206 | self.nbr_fea_len = nbr_fea_len 207 | self.fc_full = nn.Linear( 208 | 2 * self.elem_fea_len + self.nbr_fea_len, 2 * self.elem_fea_len 209 | ) 210 | self.sigmoid = nn.Sigmoid() 211 | self.softplus1 = nn.Softplus() 212 | self.bn1 = nn.BatchNorm1d(2 * self.elem_fea_len) 213 | self.bn2 = nn.BatchNorm1d(self.elem_fea_len) 214 | self.softplus2 = nn.Softplus() 215 | self.pooling = SumPooling() 216 | 217 | def forward(self, atom_in_fea, nbr_fea, self_fea_idx, nbr_fea_idx): 218 | """ 219 | Forward pass 220 | 221 | N: Total number of atoms in the batch 222 | M: Max number of neighbors 223 | 224 | Parameters 225 | ---------- 226 | 227 | atom_in_fea: Variable(torch.Tensor) shape (N, elem_fea_len) 228 | Atom hidden features before convolution 229 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) 230 | Bond features of each atom's M neighbors 231 | nbr_fea_idx: torch.LongTensor shape (N, M) 232 | Indices of M neighbors of each atom 233 | 234 | Returns 235 | ------- 236 | 237 | atom_out_fea: nn.Variable shape (N, elem_fea_len) 238 | Atom hidden features after convolution 239 | 240 | """ 241 | # convolution 242 | atom_nbr_fea = atom_in_fea[nbr_fea_idx, :] 243 | atom_self_fea = atom_in_fea[self_fea_idx, :] 244 | 245 | total_fea = torch.cat([atom_self_fea, atom_nbr_fea, nbr_fea], dim=1) 246 | 247 | total_fea = self.fc_full(total_fea) 248 | total_fea = self.bn1(total_fea) 249 | 250 | filter_fea, core_fea = total_fea.chunk(2, dim=1) 251 | filter_fea = self.sigmoid(filter_fea) 252 | core_fea = self.softplus1(core_fea) 253 | 254 | # take the elementwise product of the filter and core 255 | nbr_msg = filter_fea * core_fea 256 | nbr_sumed = self.pooling(nbr_msg, self_fea_idx) 257 | 258 | nbr_sumed = self.bn2(nbr_sumed) 259 | out = self.softplus2(atom_in_fea + nbr_sumed) 260 | 261 | return out 262 | -------------------------------------------------------------------------------- /roost/cgcnn/utils.py: -------------------------------------------------------------------------------- 1 | def get_cgcnn_input(struct): 2 | """return the wren input string 3 | 4 | TODO update wren to use a more standard convention 5 | 6 | Args: 7 | struct (Structure): input structure to get inputs for 8 | 9 | Returns: 10 | cgcnn inputs as a lattice matrix and list of sites of the form [f"{el} @ {x,y,z}", ] 11 | """ 12 | elems = [atom.specie.symbol for atom in struct] 13 | cell = struct.lattice.matrix.tolist() 14 | coords = struct.frac_coords 15 | sites = [" @ ".join((el, " ".join(map(str, x)))) for el, x in zip(elems, coords)] 16 | 17 | return cell, sites 18 | -------------------------------------------------------------------------------- /roost/pretrain/dist_data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | from itertools import groupby 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from pymatgen.core.structure import Structure 9 | from torch.utils.data import Dataset 10 | 11 | from roost.core import Featurizer 12 | 13 | 14 | class CrystalGraphData(Dataset): 15 | def __init__( 16 | self, 17 | data_path, 18 | fea_path, 19 | task_dict, 20 | inputs=["lattice", "sites"], 21 | identifiers=["material_id", "composition"], 22 | radius=5, 23 | max_num_nbr=12, 24 | dmin=0, 25 | step=0.2, 26 | p_mask=0.15, 27 | p_zero=0.8, 28 | ): 29 | """CrystalGraphData returns neighbourhood graphs 30 | 31 | Args: 32 | data_path (str): The path to the dataset 33 | fea_path (str): The path to the element embedding 34 | task_dict ({target: task}): task dict for multi-task learning 35 | inputs (list, optional): df columns for lattice and sites. 36 | Defaults to ["lattice", "sites"]. 37 | identifiers (list, optional): df columns for distinguishing data points. 38 | Defaults to ["material_id", "composition"]. 39 | radius (int, optional): cut-off radius for neighbourhood. Defaults to 5. 40 | max_num_nbr (int, optional): maximum number of neighbours to consider. Defaults to 12. 41 | dmin (int, optional): minimum distance in gaussian basis. Defaults to 0. 42 | step (float, optional): increment size of gaussian basis. Defaults to 0.2. 43 | """ 44 | assert len(identifiers) == 2, "Two identifiers are required" 45 | assert len(inputs) == 2, "One input column required are required" 46 | 47 | self.inputs = inputs 48 | self.task_dict = task_dict 49 | self.identifiers = identifiers 50 | self.radius = radius 51 | self.max_num_nbr = max_num_nbr 52 | self.p_mask = p_mask 53 | self.p_zero = p_zero 54 | 55 | self.graph = ["self", "nbr", "dist"] 56 | 57 | assert os.path.exists(fea_path), f"{fea_path} does not exist!" 58 | self.ari = Featurizer.from_json(fea_path) 59 | self.elem_fea_dim = self.ari.embedding_size 60 | 61 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step) 62 | self.nbr_fea_dim = self.gdf.embedding_size 63 | 64 | assert os.path.exists(data_path), f"{data_path} does not exist!" 65 | 66 | # NOTE make sure to use dense datasets, here do not use the default na 67 | # as they can clash with "NaN" which is a valid material 68 | self.df = pd.read_csv( 69 | data_path, keep_default_na=False, na_values=[], comment="#" 70 | )[:1000] 71 | 72 | self.df["Structure_obj"] = self.df[inputs].apply(get_structure, axis=1) 73 | 74 | self._pre_check() 75 | 76 | self.n_targets = [] 77 | for target in self.task_dict: 78 | if self.task_dict[target] == "dist": 79 | self.n_targets.append(1) 80 | elif self.task_dict[target] == "regression": 81 | self.n_targets.append(1) 82 | else: 83 | raise NotImplementedError("bad user") 84 | 85 | def __len__(self): 86 | return len(self.df) 87 | 88 | # @functools.lru_cache(maxsize=None) # Cache loaded structures 89 | def _get_nbr_data(self, crystal): 90 | """get neighbours for every site 91 | 92 | Args: 93 | crystal ([Structure]): pymatgen structure to get neighbours for 94 | """ 95 | # # # neighbours 96 | self_idx, nbr_idx, _, nbr_dist = crystal.get_neighbor_list( 97 | self.radius, 98 | numerical_tol=1e-8, 99 | ) 100 | 101 | if self.max_num_nbr is not None: 102 | _self_idx, _nbr_idx, _nbr_dist = [], [], [] 103 | 104 | for _, g in groupby(zip(self_idx, nbr_idx, nbr_dist), key=lambda x: x[0]): 105 | s, n, d = zip(*sorted(g, key=lambda x: x[2])) 106 | _self_idx.extend(s[: self.max_num_nbr]) 107 | _nbr_idx.extend(n[: self.max_num_nbr]) 108 | _nbr_dist.extend(d[: self.max_num_nbr]) 109 | 110 | self_idx = np.array(_self_idx) 111 | nbr_idx = np.array(_nbr_idx) 112 | nbr_dist = np.array(_nbr_dist) 113 | 114 | return self_idx, nbr_idx, nbr_dist 115 | 116 | def _pre_check(self): 117 | """Check that none of the structures have isolated atoms. 118 | 119 | Raises: 120 | ValueError: if isolated structures are present 121 | """ 122 | print("checking all structures valid") 123 | all_iso = [] 124 | some_iso = [] 125 | 126 | # initialise empty columns of objects to insert the lists 127 | self.df[self.graph[0]] = [[] for _ in range(len(self.df))] 128 | self.df[self.graph[1]] = [[] for _ in range(len(self.df))] 129 | self.df[self.graph[2]] = [[] for _ in range(len(self.df))] 130 | 131 | for index, row in self.df.iterrows(): 132 | # for cif_id, crystal in zip(self.df["material_id"], self.df["Structure_obj"]): 133 | image = 0 134 | 135 | cif_id = row["material_id"] 136 | crystal = row["Structure_obj"] 137 | 138 | while image == 0: 139 | self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal) 140 | 141 | if np.any(self_idx == nbr_idx): 142 | # TODO only double along the shortest dimension 143 | crystal.make_supercell([2, 2, 2]) 144 | else: 145 | image = 1 146 | 147 | if len(self_idx) == 0: 148 | all_iso.append(cif_id) 149 | elif len(nbr_idx) == 0: 150 | all_iso.append(cif_id) 151 | elif set(self_idx) != set(range(crystal.num_sites)): 152 | some_iso.append(cif_id) 153 | 154 | # nbr_dist = self.gdf.expand(nbr_dist) 155 | 156 | # nbr_dist = torch.Tensor(nbr_dist) 157 | self_idx = torch.LongTensor(self_idx) 158 | nbr_idx = torch.LongTensor(nbr_idx) 159 | 160 | self.df.at[index, self.graph[0]] = self_idx 161 | self.df.at[index, self.graph[1]] = nbr_idx 162 | self.df.at[index, self.graph[2]] = nbr_dist 163 | 164 | # TODO have the option for the pre-check to delete non-valid entries from the df? 165 | 166 | if (len(all_iso) > 0) or (len(some_iso) > 0): 167 | self.df = self.df.drop(self.df[self.df["material_id"].isin(all_iso)].index) 168 | self.df = self.df.drop(self.df[self.df["material_id"].isin(some_iso)].index) 169 | 170 | print(all_iso) 171 | print(some_iso) 172 | # raise ValueError("isolated structures contained in dataframe") 173 | 174 | # NOTE do not cache the pre-training structures as we want to see new sets of 175 | # masked structures each epoch as this effectively expands the training set 176 | # @functools.lru_cache(maxsize=None) # Cache loaded structures 177 | def __getitem__(self, idx): 178 | # NOTE sites must be given in fractional coordinates 179 | df_idx = self.df.iloc[idx] 180 | crystal = df_idx["Structure_obj"] 181 | cif_id, comp = df_idx[self.identifiers] 182 | self_idx, nbr_idx, nbr_dist = df_idx[self.graph] 183 | 184 | # atom features 185 | # TODO can this be vectorised with numpy? 186 | 187 | # handle disordered structures (multiple fractional elements per site) 188 | site_atoms = [atom.species.as_dict() for atom in crystal] 189 | atom_fea = np.vstack( 190 | [ 191 | np.sum([self.ari.get_fea(el) * amt for el, amt in site.items()], axis=0) 192 | for site in site_atoms 193 | ] 194 | ) 195 | 196 | # mask distances 197 | mask_ids = np.sort( 198 | np.random.choice( 199 | np.arange(len(self_idx), dtype=int), 200 | max(1, int(self.p_mask * len(self_idx))), 201 | ) 202 | ) 203 | 204 | # mask_labels = np.atleast_2d(1/nbr_dist[mask_ids]).T 205 | mask_labels = np.atleast_2d(nbr_dist[mask_ids]).T 206 | 207 | # TODO currently 20% no mask -> 10% no mask, 10% random mask 208 | mask_filter = np.random.rand(len(mask_ids)) 209 | nbr_dist[mask_ids[np.where(mask_filter < self.p_zero)]] = 0 210 | 211 | nbr_dist = self.gdf.expand(nbr_dist) 212 | 213 | atom_fea = torch.Tensor(atom_fea) 214 | mask_ids = torch.LongTensor(mask_ids) 215 | nbr_dist = torch.Tensor(nbr_dist) 216 | 217 | targets = [] 218 | for target in self.task_dict: 219 | if self.task_dict[target] == "dist": 220 | targets.append(torch.Tensor(mask_labels)) 221 | elif self.task_dict[target] == "regression": 222 | targets.append( 223 | torch.Tensor( 224 | [ 225 | [ 226 | df_idx["e_form"], 227 | ], 228 | ] 229 | ) 230 | ) 231 | else: 232 | raise NotImplementedError("bad user") 233 | 234 | return ( 235 | (atom_fea, nbr_dist, self_idx, nbr_idx, mask_ids), 236 | targets, 237 | comp, 238 | cif_id, 239 | ) 240 | 241 | 242 | class GaussianDistance: 243 | """ 244 | Expands the distance by Gaussian basis. 245 | 246 | Unit: angstrom 247 | """ 248 | 249 | def __init__(self, dmin, dmax, step, var=None): 250 | """ 251 | Args: 252 | dmin (float): Minimum interatomic distance 253 | dmax (float): Maximum interatomic distance 254 | step (float): Step size for the Gaussian filter 255 | var (float, optional): Variance of Gaussian basis. Defaults to step if not given 256 | """ 257 | assert dmin < dmax 258 | assert dmax - dmin > step 259 | 260 | self.filter = np.arange(dmin, dmax + step, step) 261 | self.embedding_size = len(self.filter) 262 | 263 | if var is None: 264 | var = step 265 | 266 | self.var = var 267 | 268 | def expand(self, distances): 269 | """Apply Gaussian distance filter to a numpy distance array 270 | 271 | Args: 272 | distances (ArrayLike): A distance matrix of any shape 273 | 274 | Returns: 275 | Expanded distance matrix with the last dimension of length 276 | len(self.filter) 277 | """ 278 | distances = np.array(distances) 279 | 280 | return np.exp( 281 | -((distances[..., np.newaxis] - self.filter) ** 2) / self.var ** 2 282 | ) 283 | 284 | 285 | def collate_batch(dataset_list): 286 | """ 287 | Collate a list of data and return a batch for predicting crystal 288 | properties. 289 | 290 | Parameters 291 | ---------- 292 | 293 | dataset_list: list of tuples for each data point. 294 | (atom_fea, nbr_dist, nbr_idx, target) 295 | 296 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 297 | nbr_dist: torch.Tensor shape (n_i, M, nbr_dist_len) 298 | nbr_idx: torch.LongTensor shape (n_i, M) 299 | target: torch.Tensor shape (1, ) 300 | cif_id: str or int 301 | 302 | Returns 303 | ------- 304 | N = sum(n_i); N0 = sum(i) 305 | 306 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 307 | Atom features from atom type 308 | batch_nbr_dist: torch.Tensor shape (N, M, nbr_dist_len) 309 | Bond features of each atom's M neighbors 310 | batch_nbr_idx: torch.LongTensor shape (N, M) 311 | Indices of M neighbors of each atom 312 | crystal_atom_idx: list of torch.LongTensor of length N0 313 | Mapping from the crystal idx to atom idx 314 | target: torch.Tensor shape (N, 1) 315 | Target value for prediction 316 | batch_cif_ids: list 317 | """ 318 | batch_atom_fea = [] 319 | batch_nbr_dist = [] 320 | batch_self_idx = [] 321 | batch_nbr_idx = [] 322 | crystal_atom_idx = [] 323 | batch_mask_idx = [] 324 | batch_targets = [] 325 | batch_comps = [] 326 | batch_cif_ids = [] 327 | base_idx = 0 328 | base_bond_idx = 0 329 | 330 | for i, (inputs, target, comp, cif_id) in enumerate(dataset_list): 331 | atom_fea, nbr_dist, self_idx, nbr_idx, mask_idx = inputs 332 | n_atoms = atom_fea.shape[0] # number of atoms for this crystal 333 | n_bonds = self_idx.shape[0] 334 | 335 | batch_atom_fea.append(atom_fea) 336 | batch_nbr_dist.append(nbr_dist) 337 | batch_self_idx.append(self_idx + base_idx) 338 | batch_nbr_idx.append(nbr_idx + base_idx) 339 | # batch_mask_idx.append(mask_idx + base_idx) 340 | 341 | batch_mask_idx.append(mask_idx + base_bond_idx) 342 | 343 | crystal_atom_idx.extend([i] * n_atoms) 344 | 345 | batch_targets.append(target) 346 | batch_comps.append(comp) 347 | batch_cif_ids.append(cif_id) 348 | 349 | base_idx += n_atoms 350 | base_bond_idx += n_bonds 351 | 352 | atom_fea = torch.cat(batch_atom_fea, dim=0) 353 | nbr_dist = torch.cat(batch_nbr_dist, dim=0) 354 | self_idx = torch.cat(batch_self_idx, dim=0) 355 | nbr_idx = torch.cat(batch_nbr_idx, dim=0) 356 | mask_idx = torch.cat(batch_mask_idx, dim=0) 357 | 358 | cry_idx = torch.LongTensor(crystal_atom_idx) 359 | 360 | return ( 361 | (atom_fea, nbr_dist, self_idx, nbr_idx, mask_idx, cry_idx), 362 | tuple(torch.cat(b_target, dim=0) for b_target in zip(*batch_targets)), 363 | batch_comps, 364 | batch_cif_ids, 365 | ) 366 | 367 | 368 | def get_structure(cols): 369 | """Return pymatgen structure from lattice and sites cols""" 370 | cell, sites = cols 371 | cell, elems, coords = parse_cgcnn(cell, sites) 372 | # NOTE getting primitive structure before constructing graph 373 | # significantly harms the performance of this model. 374 | 375 | crystal = Structure(lattice=cell, species=elems, coords=coords, to_unit_cell=True) 376 | 377 | # In place modification of structures that only contain a few sites 378 | # this is to allow us to mask ~15% of sites without having a 379 | # disproportionate impact on small unit cell structures. 380 | if crystal.num_sites < 7: 381 | crystal.make_supercell([2, 2, 2]) 382 | 383 | return crystal 384 | 385 | 386 | def parse_cgcnn(cell, sites): 387 | """Parse str representation into lists""" 388 | cell = np.array(ast.literal_eval(cell), dtype=float) 389 | elems = [] 390 | coords = [] 391 | for site in ast.literal_eval(sites): 392 | ele, pos = site.split(" @ ") 393 | elems.append(ele) 394 | coords.append(pos.split(" ")) 395 | 396 | coords = np.array(coords, dtype=float) 397 | return cell, elems, coords 398 | -------------------------------------------------------------------------------- /roost/pretrain/dist_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from roost.cgcnn.model import DescriptorNetwork 5 | from roost.core import BaseModelClass 6 | from roost.segments import MeanPooling 7 | 8 | 9 | class CrystalGraphPreNet(BaseModelClass): 10 | """ 11 | Create a crystal graph convolutional neural network for predicting total 12 | material properties. 13 | 14 | This model is based on: https://github.com/txie-93/cgcnn [MIT License]. 15 | Changes to the code were made to allow for the removal of zero-padding 16 | and to benefit from the BaseModelClass functionality. The architectural 17 | choices of the model remain unchanged. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | robust, 23 | n_targets, 24 | elem_emb_len, 25 | nbr_fea_len, 26 | elem_fea_len=64, 27 | n_graph=4, 28 | **kwargs, 29 | ): 30 | """ 31 | Initialize CrystalGraphConvNet. 32 | 33 | Parameters 34 | ---------- 35 | 36 | orig_elem_fea_len: int 37 | Number of atom features in the input. 38 | nbr_fea_len: int 39 | Number of bond features. 40 | elem_fea_len: int 41 | Number of hidden atom features in the convolutional layers 42 | n_graph: int 43 | Number of convolutional layers 44 | h_fea_len: int 45 | Number of hidden features after pooling 46 | n_hidden: int 47 | Number of hidden layers after pooling 48 | """ 49 | super().__init__(robust=robust, **kwargs) 50 | 51 | desc_dict = { 52 | "elem_emb_len": elem_emb_len, 53 | "nbr_fea_len": nbr_fea_len, 54 | "elem_fea_len": elem_fea_len, 55 | "n_graph": n_graph, 56 | } 57 | 58 | self.node_nn = DescriptorNetwork(**desc_dict) 59 | 60 | self.model_params.update( 61 | { 62 | "robust": robust, 63 | "n_targets": n_targets, 64 | } 65 | ) 66 | 67 | self.node_linear = nn.Linear(2*elem_fea_len, 1) 68 | 69 | self.global_pool = MeanPooling() 70 | self.global_linear = nn.Linear(elem_fea_len, 1) 71 | 72 | self.model_params.update(desc_dict) 73 | 74 | def forward(self, atom_fea, nbr_fea, self_idx, nbr_idx, mask_idx, cry_idx): 75 | """ 76 | Forward pass 77 | 78 | N: Total number of atoms in the batch 79 | M: Max number of neighbors 80 | N0: Total number of crystals in the batch 81 | 82 | Parameters 83 | ---------- 84 | 85 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len) 86 | Atom features from atom type 87 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) 88 | Bond features of each atom's M neighbors 89 | nbr_idx: torch.LongTensor shape (N, M) 90 | Indices of M neighbors of each atom 91 | crystal_atom_idx: list of torch.LongTensor of length N0 92 | Mapping from the crystal idx to atom idx 93 | 94 | Returns 95 | ------- 96 | 97 | prediction: nn.Variable shape (N, ) 98 | Atom hidden features after convolution 99 | 100 | """ 101 | crys_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx) 102 | 103 | atom_nbr_fea = crys_fea[nbr_idx[mask_idx], :] 104 | atom_self_fea = crys_fea[self_idx[mask_idx], :] 105 | 106 | total_fea = torch.cat([atom_self_fea, atom_nbr_fea], dim=1) 107 | 108 | nodes = self.node_linear(total_fea) 109 | 110 | glob = self.global_linear(self.global_pool(crys_fea, cry_idx)) 111 | 112 | return [nodes, glob] 113 | -------------------------------------------------------------------------------- /roost/pretrain/ele_data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | from itertools import groupby 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from pymatgen.core.structure import Structure 9 | from torch.utils.data import Dataset 10 | 11 | from roost.core import Featurizer 12 | 13 | 14 | class CrystalGraphData(Dataset): 15 | def __init__( 16 | self, 17 | data_path, 18 | fea_path, 19 | task_dict, 20 | inputs=["lattice", "sites"], 21 | identifiers=["material_id", "composition"], 22 | radius=5, 23 | max_num_nbr=12, 24 | dmin=0, 25 | step=0.2, 26 | p_mask=0.15, 27 | p_zero=0.8, 28 | ): 29 | """CrystalGraphData returns neighbourhood graphs 30 | 31 | Args: 32 | data_path (str): The path to the dataset 33 | fea_path (str): The path to the element embedding 34 | task_dict ({target: task}): task dict for multi-task learning 35 | inputs (list, optional): df columns for lattice and sites. 36 | Defaults to ["lattice", "sites"]. 37 | identifiers (list, optional): df columns for distinguishing data points. 38 | Defaults to ["material_id", "composition"]. 39 | radius (int, optional): cut-off radius for neighbourhood. Defaults to 5. 40 | max_num_nbr (int, optional): maximum number of neighbours to consider. Defaults to 12. 41 | dmin (int, optional): minimum distance in gaussian basis. Defaults to 0. 42 | step (float, optional): increment size of gaussian basis. Defaults to 0.2. 43 | """ 44 | assert len(identifiers) == 2, "Two identifiers are required" 45 | assert len(inputs) == 2, "One input column required are required" 46 | 47 | self.inputs = inputs 48 | self.task_dict = task_dict 49 | self.identifiers = identifiers 50 | self.radius = radius 51 | self.max_num_nbr = max_num_nbr 52 | self.p_mask = p_mask 53 | self.p_zero = p_zero 54 | 55 | self.graph = ["self", "nbr", "dist"] 56 | 57 | assert os.path.exists(fea_path), f"{fea_path} does not exist!" 58 | self.ari = Featurizer.from_json(fea_path) 59 | self.ohe = Featurizer.from_json("data/el-embeddings/onehot-embedding.json") 60 | self.elem_fea_dim = self.ari.embedding_size 61 | 62 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step) 63 | self.nbr_fea_dim = self.gdf.embedding_size 64 | 65 | assert os.path.exists(data_path), f"{data_path} does not exist!" 66 | 67 | # NOTE make sure to use dense datasets, here do not use the default na 68 | # as they can clash with "NaN" which is a valid material 69 | self.df = pd.read_csv( 70 | data_path, keep_default_na=False, na_values=[], comment="#" 71 | )[:1000] 72 | 73 | self.df["Structure_obj"] = self.df[inputs].apply(get_structure, axis=1) 74 | 75 | self._pre_check() 76 | 77 | self.n_targets = [] 78 | for target in self.task_dict: 79 | if self.task_dict[target] == "mask": 80 | self.n_targets.append(self.ohe.embedding_size) 81 | elif self.task_dict[target] == "regression": 82 | self.n_targets.append(1) 83 | else: 84 | raise NotImplementedError("bad user") 85 | 86 | def __len__(self): 87 | return len(self.df) 88 | 89 | # @functools.lru_cache(maxsize=None) # Cache loaded structures 90 | def _get_nbr_data(self, crystal): 91 | """get neighbours for every site 92 | 93 | Args: 94 | crystal ([Structure]): pymatgen structure to get neighbours for 95 | """ 96 | # # # neighbours 97 | self_idx, nbr_idx, _, nbr_dist = crystal.get_neighbor_list( 98 | self.radius, 99 | numerical_tol=1e-8, 100 | ) 101 | 102 | if self.max_num_nbr is not None: 103 | _self_idx, _nbr_idx, _nbr_dist = [], [], [] 104 | 105 | for _, g in groupby(zip(self_idx, nbr_idx, nbr_dist), key=lambda x: x[0]): 106 | s, n, d = zip(*sorted(g, key=lambda x: x[2])) 107 | _self_idx.extend(s[: self.max_num_nbr]) 108 | _nbr_idx.extend(n[: self.max_num_nbr]) 109 | _nbr_dist.extend(d[: self.max_num_nbr]) 110 | 111 | self_idx = np.array(_self_idx) 112 | nbr_idx = np.array(_nbr_idx) 113 | nbr_dist = np.array(_nbr_dist) 114 | 115 | return self_idx, nbr_idx, nbr_dist 116 | 117 | def _pre_check(self): 118 | """Check that none of the structures have isolated atoms. 119 | 120 | Raises: 121 | ValueError: if isolated structures are present 122 | """ 123 | print("checking all structures valid") 124 | all_iso = [] 125 | some_iso = [] 126 | 127 | # initialise empty columns of objects to insert the lists 128 | self.df[self.graph[0]] = [[] for _ in range(len(self.df))] 129 | self.df[self.graph[1]] = [[] for _ in range(len(self.df))] 130 | self.df[self.graph[2]] = [[] for _ in range(len(self.df))] 131 | 132 | for index, row in self.df.iterrows(): 133 | # for cif_id, crystal in zip(self.df["material_id"], self.df["Structure_obj"]): 134 | image = 0 135 | 136 | cif_id = row["material_id"] 137 | crystal = row["Structure_obj"] 138 | 139 | while image == 0: 140 | self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal) 141 | 142 | if np.any(self_idx == nbr_idx): 143 | # TODO only double along the shortest dimension 144 | crystal.make_supercell([2, 2, 2]) 145 | else: 146 | image = 1 147 | 148 | if len(self_idx) == 0: 149 | all_iso.append(cif_id) 150 | elif len(nbr_idx) == 0: 151 | all_iso.append(cif_id) 152 | elif set(self_idx) != set(range(crystal.num_sites)): 153 | some_iso.append(cif_id) 154 | 155 | nbr_dist = self.gdf.expand(nbr_dist) 156 | 157 | nbr_dist = torch.Tensor(nbr_dist) 158 | self_idx = torch.LongTensor(self_idx) 159 | nbr_idx = torch.LongTensor(nbr_idx) 160 | 161 | self.df.at[index, self.graph[0]] = self_idx 162 | self.df.at[index, self.graph[1]] = nbr_idx 163 | self.df.at[index, self.graph[2]] = nbr_dist 164 | 165 | # TODO have the option for the pre-check to delete non-valid entries from the df? 166 | 167 | if (len(all_iso) > 0) or (len(some_iso) > 0): 168 | self.df = self.df.drop(self.df[self.df["material_id"].isin(all_iso)].index) 169 | self.df = self.df.drop(self.df[self.df["material_id"].isin(some_iso)].index) 170 | 171 | print(all_iso) 172 | print(some_iso) 173 | # raise ValueError("isolated structures contained in dataframe") 174 | 175 | # NOTE do not cache the pre-training structures as we want to see new sets of 176 | # masked structures each epoch as this effectively expands the training set 177 | # @functools.lru_cache(maxsize=None) # Cache loaded structures 178 | def __getitem__(self, idx): 179 | # NOTE sites must be given in fractional coordinates 180 | df_idx = self.df.iloc[idx] 181 | crystal = df_idx["Structure_obj"] 182 | cif_id, comp = df_idx[self.identifiers] 183 | self_idx, nbr_idx, nbr_dist = df_idx[self.graph] 184 | 185 | # atom features 186 | # TODO can this be vectorised with numpy? 187 | 188 | # handle disordered structures (multiple fractional elements per site) 189 | site_atoms = [atom.species.as_dict() for atom in crystal] 190 | atom_fea = np.vstack( 191 | [ 192 | np.sum([self.ari.get_fea(el) * amt for el, amt in site.items()], axis=0) 193 | for site in site_atoms 194 | ] 195 | ) 196 | 197 | # select at least one site in the crystal to mask 198 | mask_ids = np.sort( 199 | np.random.choice( 200 | np.arange(crystal.num_sites, dtype=int), 201 | max(1, int(self.p_mask * crystal.num_sites)), 202 | ) 203 | ) 204 | 205 | # get the mask labels via use of a OHE of elements to handle disordered structures 206 | mask_labels = np.vstack( 207 | [ 208 | np.sum( 209 | [self.ohe.get_fea(el) * amt for el, amt in site_atoms[idx].items()], 210 | axis=0, 211 | ) 212 | for idx in mask_ids 213 | ] 214 | ) 215 | 216 | # TODO currently 20% no mask -> 10% no mask, 10% random mask 217 | mask_filter = np.random.rand(len(mask_ids)) 218 | atom_fea[mask_ids[np.where(mask_filter < self.p_zero)], :] = 0 219 | 220 | atom_fea = torch.Tensor(atom_fea) 221 | mask_ids = torch.LongTensor(mask_ids) 222 | 223 | targets = [] 224 | for target in self.task_dict: 225 | if self.task_dict[target] == "mask": 226 | targets.append(torch.Tensor(mask_labels)) 227 | elif self.task_dict[target] == "regression": 228 | targets.append( 229 | torch.Tensor( 230 | [ 231 | [ 232 | df_idx["e_form"], 233 | ], 234 | ] 235 | ) 236 | ) 237 | else: 238 | raise NotImplementedError("bad user") 239 | 240 | return ( 241 | (atom_fea, nbr_dist, self_idx, nbr_idx, mask_ids), 242 | targets, 243 | comp, 244 | cif_id, 245 | ) 246 | 247 | 248 | class GaussianDistance: 249 | """ 250 | Expands the distance by Gaussian basis. 251 | 252 | Unit: angstrom 253 | """ 254 | 255 | def __init__(self, dmin, dmax, step, var=None): 256 | """ 257 | Args: 258 | dmin (float): Minimum interatomic distance 259 | dmax (float): Maximum interatomic distance 260 | step (float): Step size for the Gaussian filter 261 | var (float, optional): Variance of Gaussian basis. Defaults to step if not given 262 | """ 263 | assert dmin < dmax 264 | assert dmax - dmin > step 265 | 266 | self.filter = np.arange(dmin, dmax + step, step) 267 | self.embedding_size = len(self.filter) 268 | 269 | if var is None: 270 | var = step 271 | 272 | self.var = var 273 | 274 | def expand(self, distances): 275 | """Apply Gaussian distance filter to a numpy distance array 276 | 277 | Args: 278 | distances (ArrayLike): A distance matrix of any shape 279 | 280 | Returns: 281 | Expanded distance matrix with the last dimension of length 282 | len(self.filter) 283 | """ 284 | distances = np.array(distances) 285 | 286 | return np.exp( 287 | -((distances[..., np.newaxis] - self.filter) ** 2) / self.var ** 2 288 | ) 289 | 290 | 291 | def collate_batch(dataset_list): 292 | """ 293 | Collate a list of data and return a batch for predicting crystal 294 | properties. 295 | 296 | Parameters 297 | ---------- 298 | 299 | dataset_list: list of tuples for each data point. 300 | (atom_fea, nbr_dist, nbr_idx, target) 301 | 302 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 303 | nbr_dist: torch.Tensor shape (n_i, M, nbr_dist_len) 304 | nbr_idx: torch.LongTensor shape (n_i, M) 305 | target: torch.Tensor shape (1, ) 306 | cif_id: str or int 307 | 308 | Returns 309 | ------- 310 | N = sum(n_i); N0 = sum(i) 311 | 312 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 313 | Atom features from atom type 314 | batch_nbr_dist: torch.Tensor shape (N, M, nbr_dist_len) 315 | Bond features of each atom's M neighbors 316 | batch_nbr_idx: torch.LongTensor shape (N, M) 317 | Indices of M neighbors of each atom 318 | crystal_atom_idx: list of torch.LongTensor of length N0 319 | Mapping from the crystal idx to atom idx 320 | target: torch.Tensor shape (N, 1) 321 | Target value for prediction 322 | batch_cif_ids: list 323 | """ 324 | batch_atom_fea = [] 325 | batch_nbr_dist = [] 326 | batch_self_idx = [] 327 | batch_nbr_idx = [] 328 | crystal_atom_idx = [] 329 | batch_mask_idx = [] 330 | batch_targets = [] 331 | batch_comps = [] 332 | batch_cif_ids = [] 333 | base_idx = 0 334 | 335 | for i, (inputs, target, comp, cif_id) in enumerate(dataset_list): 336 | atom_fea, nbr_dist, self_idx, nbr_idx, mask_idx = inputs 337 | n_atoms = atom_fea.shape[0] # number of atoms for this crystal 338 | 339 | batch_atom_fea.append(atom_fea) 340 | batch_nbr_dist.append(nbr_dist) 341 | batch_self_idx.append(self_idx + base_idx) 342 | batch_nbr_idx.append(nbr_idx + base_idx) 343 | batch_mask_idx.append(mask_idx + base_idx) 344 | 345 | crystal_atom_idx.extend([i] * n_atoms) 346 | 347 | batch_targets.append(target) 348 | batch_comps.append(comp) 349 | batch_cif_ids.append(cif_id) 350 | 351 | base_idx += n_atoms 352 | 353 | atom_fea = torch.cat(batch_atom_fea, dim=0) 354 | nbr_dist = torch.cat(batch_nbr_dist, dim=0) 355 | self_idx = torch.cat(batch_self_idx, dim=0) 356 | nbr_idx = torch.cat(batch_nbr_idx, dim=0) 357 | mask_idx = torch.cat(batch_mask_idx, dim=0) 358 | 359 | cry_idx = torch.LongTensor(crystal_atom_idx) 360 | 361 | return ( 362 | (atom_fea, nbr_dist, self_idx, nbr_idx, mask_idx, cry_idx), 363 | tuple(torch.cat(b_target, dim=0) for b_target in zip(*batch_targets)), 364 | batch_comps, 365 | batch_cif_ids, 366 | ) 367 | 368 | 369 | def get_structure(cols): 370 | """Return pymatgen structure from lattice and sites cols""" 371 | cell, sites = cols 372 | cell, elems, coords = parse_cgcnn(cell, sites) 373 | # NOTE getting primitive structure before constructing graph 374 | # significantly harms the performance of this model. 375 | 376 | crystal = Structure(lattice=cell, species=elems, coords=coords, to_unit_cell=True) 377 | 378 | # In place modification of structures that only contain a few sites 379 | # this is to allow us to mask ~15% of sites without having a 380 | # disproportionate impact on small unit cell structures. 381 | if crystal.num_sites < 7: 382 | crystal.make_supercell([2, 2, 2]) 383 | 384 | return crystal 385 | 386 | 387 | def parse_cgcnn(cell, sites): 388 | """Parse str representation into lists""" 389 | cell = np.array(ast.literal_eval(cell), dtype=float) 390 | elems = [] 391 | coords = [] 392 | for site in ast.literal_eval(sites): 393 | ele, pos = site.split(" @ ") 394 | elems.append(ele) 395 | coords.append(pos.split(" ")) 396 | 397 | coords = np.array(coords, dtype=float) 398 | return cell, elems, coords 399 | -------------------------------------------------------------------------------- /roost/pretrain/ele_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from roost.cgcnn.model import DescriptorNetwork 4 | from roost.core import BaseModelClass 5 | from roost.segments import MeanPooling 6 | 7 | 8 | class CrystalGraphPreNet(BaseModelClass): 9 | """ 10 | Create a crystal graph convolutional neural network for predicting total 11 | material properties. 12 | 13 | This model is based on: https://github.com/txie-93/cgcnn [MIT License]. 14 | Changes to the code were made to allow for the removal of zero-padding 15 | and to benefit from the BaseModelClass functionality. The architectural 16 | choices of the model remain unchanged. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | robust, 22 | n_targets, 23 | elem_emb_len, 24 | nbr_fea_len, 25 | elem_fea_len=64, 26 | n_graph=4, 27 | **kwargs, 28 | ): 29 | """ 30 | Initialize CrystalGraphConvNet. 31 | 32 | Parameters 33 | ---------- 34 | 35 | orig_elem_fea_len: int 36 | Number of atom features in the input. 37 | nbr_fea_len: int 38 | Number of bond features. 39 | elem_fea_len: int 40 | Number of hidden atom features in the convolutional layers 41 | n_graph: int 42 | Number of convolutional layers 43 | h_fea_len: int 44 | Number of hidden features after pooling 45 | n_hidden: int 46 | Number of hidden layers after pooling 47 | """ 48 | super().__init__(robust=robust, **kwargs) 49 | 50 | desc_dict = { 51 | "elem_emb_len": elem_emb_len, 52 | "nbr_fea_len": nbr_fea_len, 53 | "elem_fea_len": elem_fea_len, 54 | "n_graph": n_graph, 55 | } 56 | 57 | self.node_nn = DescriptorNetwork(**desc_dict) 58 | 59 | self.model_params.update( 60 | { 61 | "robust": robust, 62 | "n_targets": n_targets, 63 | } 64 | ) 65 | 66 | self.node_linear = nn.Linear(elem_fea_len, n_targets[0]) 67 | 68 | self.global_pool = MeanPooling() 69 | self.global_linear = nn.Linear(elem_fea_len, 1) 70 | 71 | self.model_params.update(desc_dict) 72 | 73 | def forward(self, atom_fea, nbr_fea, self_idx, nbr_idx, mask_idx, cry_idx): 74 | """ 75 | Forward pass 76 | 77 | N: Total number of atoms in the batch 78 | M: Max number of neighbors 79 | N0: Total number of crystals in the batch 80 | 81 | Parameters 82 | ---------- 83 | 84 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len) 85 | Atom features from atom type 86 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) 87 | Bond features of each atom's M neighbors 88 | nbr_idx: torch.LongTensor shape (N, M) 89 | Indices of M neighbors of each atom 90 | crystal_atom_idx: list of torch.LongTensor of length N0 91 | Mapping from the crystal idx to atom idx 92 | 93 | Returns 94 | ------- 95 | 96 | prediction: nn.Variable shape (N, ) 97 | Atom hidden features after convolution 98 | 99 | """ 100 | crys_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx) 101 | 102 | nodes = self.node_linear(crys_fea[mask_idx, :]) 103 | 104 | glob = self.global_linear(self.global_pool(crys_fea, cry_idx)) 105 | 106 | return [nodes, glob] 107 | -------------------------------------------------------------------------------- /roost/pretrain/ener_data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | from itertools import groupby 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from pymatgen.core.structure import Structure 9 | from torch.utils.data import Dataset 10 | 11 | from roost.core import Featurizer 12 | 13 | 14 | class CrystalGraphData(Dataset): 15 | def __init__( 16 | self, 17 | data_path, 18 | fea_path, 19 | task_dict, 20 | inputs=["lattice", "sites"], 21 | identifiers=["material_id", "composition"], 22 | radius=5, 23 | max_num_nbr=12, 24 | dmin=0, 25 | step=0.2, 26 | p_mask=0.15, 27 | p_zero=0.8, 28 | ): 29 | """CrystalGraphData returns neighbourhood graphs 30 | 31 | Args: 32 | data_path (str): The path to the dataset 33 | fea_path (str): The path to the element embedding 34 | task_dict ({target: task}): task dict for multi-task learning 35 | inputs (list, optional): df columns for lattice and sites. 36 | Defaults to ["lattice", "sites"]. 37 | identifiers (list, optional): df columns for distinguishing data points. 38 | Defaults to ["material_id", "composition"]. 39 | radius (int, optional): cut-off radius for neighbourhood. Defaults to 5. 40 | max_num_nbr (int, optional): maximum number of neighbours to consider. Defaults to 12. 41 | dmin (int, optional): minimum distance in gaussian basis. Defaults to 0. 42 | step (float, optional): increment size of gaussian basis. Defaults to 0.2. 43 | """ 44 | assert len(identifiers) == 2, "Two identifiers are required" 45 | assert len(inputs) == 2, "One input column required are required" 46 | 47 | self.inputs = inputs 48 | self.task_dict = task_dict 49 | self.identifiers = identifiers 50 | self.radius = radius 51 | self.max_num_nbr = max_num_nbr 52 | self.p_mask = p_mask 53 | self.p_zero = p_zero 54 | 55 | self.graph = ["self", "nbr", "dist"] 56 | 57 | assert os.path.exists(fea_path), f"{fea_path} does not exist!" 58 | self.ari = Featurizer.from_json(fea_path) 59 | self.ohe = Featurizer.from_json("data/el-embeddings/onehot-embedding.json") 60 | self.elem_fea_dim = self.ari.embedding_size 61 | 62 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step) 63 | self.nbr_fea_dim = self.gdf.embedding_size 64 | 65 | assert os.path.exists(data_path), f"{data_path} does not exist!" 66 | 67 | # NOTE make sure to use dense datasets, here do not use the default na 68 | # as they can clash with "NaN" which is a valid material 69 | self.df = pd.read_csv( 70 | data_path, keep_default_na=False, na_values=[], comment="#" 71 | )[:1000] 72 | 73 | self.df["Structure_obj"] = self.df[inputs].apply(get_structure, axis=1) 74 | 75 | self._pre_check() 76 | 77 | self.n_targets = [] 78 | for target in self.task_dict: 79 | if self.task_dict[target] == "mask": 80 | self.n_targets.append(self.ohe.embedding_size) 81 | elif self.task_dict[target] == "regression": 82 | self.n_targets.append(1) 83 | else: 84 | raise NotImplementedError("bad user") 85 | 86 | def __len__(self): 87 | return len(self.df) 88 | 89 | # @functools.lru_cache(maxsize=None) # Cache loaded structures 90 | def _get_nbr_data(self, crystal): 91 | """get neighbours for every site 92 | 93 | Args: 94 | crystal ([Structure]): pymatgen structure to get neighbours for 95 | """ 96 | # # # neighbours 97 | self_idx, nbr_idx, _, nbr_dist = crystal.get_neighbor_list( 98 | self.radius, 99 | numerical_tol=1e-8, 100 | ) 101 | 102 | if self.max_num_nbr is not None: 103 | _self_idx, _nbr_idx, _nbr_dist = [], [], [] 104 | 105 | for _, g in groupby(zip(self_idx, nbr_idx, nbr_dist), key=lambda x: x[0]): 106 | s, n, d = zip(*sorted(g, key=lambda x: x[2])) 107 | _self_idx.extend(s[: self.max_num_nbr]) 108 | _nbr_idx.extend(n[: self.max_num_nbr]) 109 | _nbr_dist.extend(d[: self.max_num_nbr]) 110 | 111 | self_idx = np.array(_self_idx) 112 | nbr_idx = np.array(_nbr_idx) 113 | nbr_dist = np.array(_nbr_dist) 114 | 115 | return self_idx, nbr_idx, nbr_dist 116 | 117 | def _pre_check(self): 118 | """Check that none of the structures have isolated atoms. 119 | 120 | Raises: 121 | ValueError: if isolated structures are present 122 | """ 123 | print("checking all structures valid") 124 | all_iso = [] 125 | some_iso = [] 126 | 127 | # initialise empty columns of objects to insert the lists 128 | self.df[self.graph[0]] = [[] for _ in range(len(self.df))] 129 | self.df[self.graph[1]] = [[] for _ in range(len(self.df))] 130 | self.df[self.graph[2]] = [[] for _ in range(len(self.df))] 131 | 132 | for index, row in self.df.iterrows(): 133 | # for cif_id, crystal in zip(self.df["material_id"], self.df["Structure_obj"]): 134 | image = 0 135 | 136 | cif_id = row["material_id"] 137 | crystal = row["Structure_obj"] 138 | 139 | while image == 0: 140 | self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal) 141 | 142 | if np.any(self_idx == nbr_idx): 143 | # TODO only double along the shortest dimension 144 | crystal.make_supercell([2, 2, 2]) 145 | else: 146 | image = 1 147 | 148 | if len(self_idx) == 0: 149 | all_iso.append(cif_id) 150 | elif len(nbr_idx) == 0: 151 | all_iso.append(cif_id) 152 | elif set(self_idx) != set(range(crystal.num_sites)): 153 | some_iso.append(cif_id) 154 | 155 | nbr_dist = self.gdf.expand(nbr_dist) 156 | 157 | nbr_dist = torch.Tensor(nbr_dist) 158 | self_idx = torch.LongTensor(self_idx) 159 | nbr_idx = torch.LongTensor(nbr_idx) 160 | 161 | self.df.at[index, self.graph[0]] = self_idx 162 | self.df.at[index, self.graph[1]] = nbr_idx 163 | self.df.at[index, self.graph[2]] = nbr_dist 164 | 165 | # TODO have the option for the pre-check to delete non-valid entries from the df? 166 | 167 | if (len(all_iso) > 0) or (len(some_iso) > 0): 168 | self.df = self.df.drop(self.df[self.df["material_id"].isin(all_iso)].index) 169 | self.df = self.df.drop(self.df[self.df["material_id"].isin(some_iso)].index) 170 | 171 | print(all_iso) 172 | print(some_iso) 173 | # raise ValueError("isolated structures contained in dataframe") 174 | 175 | # NOTE do not cache the pre-training structures as we want to see new sets of 176 | # masked structures each epoch as this effectively expands the training set 177 | # @functools.lru_cache(maxsize=None) # Cache loaded structures 178 | def __getitem__(self, idx): 179 | # NOTE sites must be given in fractional coordinates 180 | df_idx = self.df.iloc[idx] 181 | crystal = df_idx["Structure_obj"] 182 | cif_id, comp = df_idx[self.identifiers] 183 | self_idx, nbr_idx, nbr_dist = df_idx[self.graph] 184 | 185 | # atom features 186 | # TODO can this be vectorised with numpy? 187 | 188 | # handle disordered structures (multiple fractional elements per site) 189 | site_atoms = [atom.species.as_dict() for atom in crystal] 190 | atom_fea = np.vstack( 191 | [ 192 | np.sum([self.ari.get_fea(el) * amt for el, amt in site.items()], axis=0) 193 | for site in site_atoms 194 | ] 195 | ) 196 | 197 | atom_fea = torch.Tensor(atom_fea) 198 | 199 | targets = [] 200 | for target in self.task_dict: 201 | if self.task_dict[target] == "regression": 202 | targets.append( 203 | torch.Tensor( 204 | [ 205 | [ 206 | df_idx[target], 207 | ], 208 | ] 209 | ) 210 | ) 211 | else: 212 | raise NotImplementedError("bad user") 213 | 214 | return ((atom_fea, nbr_dist, self_idx, nbr_idx), targets, comp, cif_id) 215 | 216 | 217 | class GaussianDistance: 218 | """ 219 | Expands the distance by Gaussian basis. 220 | 221 | Unit: angstrom 222 | """ 223 | 224 | def __init__(self, dmin, dmax, step, var=None): 225 | """ 226 | Args: 227 | dmin (float): Minimum interatomic distance 228 | dmax (float): Maximum interatomic distance 229 | step (float): Step size for the Gaussian filter 230 | var (float, optional): Variance of Gaussian basis. Defaults to step if not given 231 | """ 232 | assert dmin < dmax 233 | assert dmax - dmin > step 234 | 235 | self.filter = np.arange(dmin, dmax + step, step) 236 | self.embedding_size = len(self.filter) 237 | 238 | if var is None: 239 | var = step 240 | 241 | self.var = var 242 | 243 | def expand(self, distances): 244 | """Apply Gaussian distance filter to a numpy distance array 245 | 246 | Args: 247 | distances (ArrayLike): A distance matrix of any shape 248 | 249 | Returns: 250 | Expanded distance matrix with the last dimension of length 251 | len(self.filter) 252 | """ 253 | distances = np.array(distances) 254 | 255 | return np.exp( 256 | -((distances[..., np.newaxis] - self.filter) ** 2) / self.var ** 2 257 | ) 258 | 259 | 260 | def collate_batch(dataset_list): 261 | """ 262 | Collate a list of data and return a batch for predicting crystal 263 | properties. 264 | 265 | Parameters 266 | ---------- 267 | 268 | dataset_list: list of tuples for each data point. 269 | (atom_fea, nbr_dist, nbr_idx, target) 270 | 271 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 272 | nbr_dist: torch.Tensor shape (n_i, M, nbr_dist_len) 273 | nbr_idx: torch.LongTensor shape (n_i, M) 274 | target: torch.Tensor shape (1, ) 275 | cif_id: str or int 276 | 277 | Returns 278 | ------- 279 | N = sum(n_i); N0 = sum(i) 280 | 281 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 282 | Atom features from atom type 283 | batch_nbr_dist: torch.Tensor shape (N, M, nbr_dist_len) 284 | Bond features of each atom's M neighbors 285 | batch_nbr_idx: torch.LongTensor shape (N, M) 286 | Indices of M neighbors of each atom 287 | crystal_atom_idx: list of torch.LongTensor of length N0 288 | Mapping from the crystal idx to atom idx 289 | target: torch.Tensor shape (N, 1) 290 | Target value for prediction 291 | batch_cif_ids: list 292 | """ 293 | batch_atom_fea = [] 294 | batch_nbr_dist = [] 295 | batch_self_idx = [] 296 | batch_nbr_idx = [] 297 | crystal_atom_idx = [] 298 | batch_targets = [] 299 | batch_comps = [] 300 | batch_cif_ids = [] 301 | base_idx = 0 302 | 303 | for i, (inputs, target, comp, cif_id) in enumerate(dataset_list): 304 | atom_fea, nbr_dist, self_idx, nbr_idx = inputs 305 | n_atoms = atom_fea.shape[0] # number of atoms for this crystal 306 | 307 | batch_atom_fea.append(atom_fea) 308 | batch_nbr_dist.append(nbr_dist) 309 | batch_self_idx.append(self_idx + base_idx) 310 | batch_nbr_idx.append(nbr_idx + base_idx) 311 | 312 | crystal_atom_idx.extend([i] * n_atoms) 313 | 314 | batch_targets.append(target) 315 | batch_comps.append(comp) 316 | batch_cif_ids.append(cif_id) 317 | 318 | base_idx += n_atoms 319 | 320 | atom_fea = torch.cat(batch_atom_fea, dim=0) 321 | nbr_dist = torch.cat(batch_nbr_dist, dim=0) 322 | self_idx = torch.cat(batch_self_idx, dim=0) 323 | nbr_idx = torch.cat(batch_nbr_idx, dim=0) 324 | 325 | cry_idx = torch.LongTensor(crystal_atom_idx) 326 | 327 | return ( 328 | (atom_fea, nbr_dist, self_idx, nbr_idx, cry_idx), 329 | tuple(torch.cat(b_target, dim=0) for b_target in zip(*batch_targets)), 330 | batch_comps, 331 | batch_cif_ids, 332 | ) 333 | 334 | 335 | def get_structure(cols): 336 | """Return pymatgen structure from lattice and sites cols""" 337 | cell, sites = cols 338 | cell, elems, coords = parse_cgcnn(cell, sites) 339 | # NOTE getting primitive structure before constructing graph 340 | # significantly harms the performance of this model. 341 | 342 | crystal = Structure(lattice=cell, species=elems, coords=coords, to_unit_cell=True) 343 | 344 | # In place modification of structures that only contain a few sites 345 | # this is to allow us to mask ~15% of sites without having a 346 | # disproportionate impact on small unit cell structures. 347 | if crystal.num_sites < 7: 348 | crystal.make_supercell([2, 2, 2]) 349 | 350 | return crystal 351 | 352 | 353 | def parse_cgcnn(cell, sites): 354 | """Parse str representation into lists""" 355 | cell = np.array(ast.literal_eval(cell), dtype=float) 356 | elems = [] 357 | coords = [] 358 | for site in ast.literal_eval(sites): 359 | ele, pos = site.split(" @ ") 360 | elems.append(ele) 361 | coords.append(pos.split(" ")) 362 | 363 | coords = np.array(coords, dtype=float) 364 | return cell, elems, coords 365 | -------------------------------------------------------------------------------- /roost/pretrain/ener_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from roost.cgcnn.model import DescriptorNetwork 4 | from roost.core import BaseModelClass 5 | from roost.segments import MeanPooling 6 | 7 | 8 | class CrystalGraphPreNet(BaseModelClass): 9 | """ 10 | Create a crystal graph convolutional neural network for predicting total 11 | material properties. 12 | 13 | This model is based on: https://github.com/txie-93/cgcnn [MIT License]. 14 | Changes to the code were made to allow for the removal of zero-padding 15 | and to benefit from the BaseModelClass functionality. The architectural 16 | choices of the model remain unchanged. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | robust, 22 | n_targets, 23 | elem_emb_len, 24 | nbr_fea_len, 25 | elem_fea_len=64, 26 | n_graph=4, 27 | **kwargs, 28 | ): 29 | """ 30 | Initialize CrystalGraphConvNet. 31 | 32 | Parameters 33 | ---------- 34 | 35 | orig_elem_fea_len: int 36 | Number of atom features in the input. 37 | nbr_fea_len: int 38 | Number of bond features. 39 | elem_fea_len: int 40 | Number of hidden atom features in the convolutional layers 41 | n_graph: int 42 | Number of convolutional layers 43 | h_fea_len: int 44 | Number of hidden features after pooling 45 | n_hidden: int 46 | Number of hidden layers after pooling 47 | """ 48 | super().__init__(robust=robust, **kwargs) 49 | 50 | desc_dict = { 51 | "elem_emb_len": elem_emb_len, 52 | "nbr_fea_len": nbr_fea_len, 53 | "elem_fea_len": elem_fea_len, 54 | "n_graph": n_graph, 55 | } 56 | 57 | self.node_nn = DescriptorNetwork(**desc_dict) 58 | 59 | self.model_params.update( 60 | { 61 | "robust": robust, 62 | "n_targets": n_targets, 63 | } 64 | ) 65 | 66 | self.global_pool = MeanPooling() 67 | self.global_linear = nn.Linear(elem_fea_len, 1) 68 | 69 | self.model_params.update(desc_dict) 70 | 71 | def forward(self, atom_fea, nbr_fea, self_idx, nbr_idx, cry_idx): 72 | """ 73 | Forward pass 74 | 75 | N: Total number of atoms in the batch 76 | M: Max number of neighbors 77 | N0: Total number of crystals in the batch 78 | 79 | Parameters 80 | ---------- 81 | 82 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len) 83 | Atom features from atom type 84 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) 85 | Bond features of each atom's M neighbors 86 | nbr_idx: torch.LongTensor shape (N, M) 87 | Indices of M neighbors of each atom 88 | crystal_atom_idx: list of torch.LongTensor of length N0 89 | Mapping from the crystal idx to atom idx 90 | 91 | Returns 92 | ------- 93 | 94 | prediction: nn.Variable shape (N, ) 95 | Atom hidden features after convolution 96 | 97 | """ 98 | crys_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx) 99 | 100 | glob = self.global_linear(self.global_pool(crys_fea, cry_idx)) 101 | 102 | return [glob] 103 | -------------------------------------------------------------------------------- /roost/roost/data.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from pymatgen.core.composition import Composition 8 | from torch.utils.data import Dataset 9 | 10 | from roost.core import Featurizer 11 | 12 | 13 | class CompositionData(Dataset): 14 | """ 15 | The CompositionData dataset is a wrapper for a dataset data points are 16 | automatically constructed from composition strings. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | data_path, 22 | fea_path, 23 | task_dict, 24 | inputs=["composition"], 25 | identifiers=["material_id", "composition"], 26 | # identifiers=["material_id", "composition"], 27 | ): 28 | """[summary] 29 | 30 | Args: 31 | data_path (str): [description] 32 | fea_path (str): [description] 33 | task_dict ({name: task}): list of tasks 34 | inputs (list, optional): column name for compositions. 35 | Defaults to ["composition"]. 36 | identifiers (list, optional): column names for unique identifier 37 | and pretty name. Defaults to ["id", "composition"]. 38 | """ 39 | 40 | assert len(identifiers) == 2, "Two identifiers are required" 41 | assert len(inputs) == 1, "One input column required are required" 42 | 43 | self.inputs = inputs 44 | self.task_dict = task_dict 45 | self.identifiers = identifiers 46 | 47 | assert os.path.exists(data_path), f"{data_path} does not exist!" 48 | # NOTE make sure to use dense datasets, 49 | # NOTE do not use default_na as "NaN" is a valid material 50 | self.df = pd.read_csv(data_path, keep_default_na=False, na_values=[]) 51 | 52 | assert os.path.exists(fea_path), f"{fea_path} does not exist!" 53 | self.elem_features = Featurizer.from_json(fea_path) 54 | self.elem_emb_len = self.elem_features.embedding_size 55 | 56 | self.n_targets = [] 57 | for target, task in self.task_dict.items(): 58 | if task == "regression": 59 | self.n_targets.append(1) 60 | elif task == "classification": 61 | n_classes = np.max(self.df[target].values) + 1 62 | self.n_targets.append(n_classes) 63 | 64 | def __len__(self): 65 | return len(self.df) 66 | 67 | @functools.lru_cache(maxsize=None) # Cache data for faster training 68 | def __getitem__(self, idx): 69 | """[summary] 70 | 71 | Args: 72 | idx (int): dataset index 73 | 74 | Raises: 75 | AssertionError: [description] 76 | ValueError: [description] 77 | 78 | Returns: 79 | atom_weights: torch.Tensor shape (M, 1) 80 | weights of atoms in the material 81 | atom_fea: torch.Tensor shape (M, n_fea) 82 | features of atoms in the material 83 | self_fea_idx: torch.Tensor shape (M*M, 1) 84 | list of self indices 85 | nbr_fea_idx: torch.Tensor shape (M*M, 1) 86 | list of neighbor indices 87 | target: torch.Tensor shape (1,) 88 | target value for material 89 | cry_id: torch.Tensor shape (1,) 90 | input id for the material 91 | 92 | """ 93 | df_idx = self.df.iloc[idx] 94 | composition = df_idx[self.inputs][0] 95 | cry_ids = df_idx[self.identifiers].values 96 | 97 | comp_dict = Composition(composition).get_el_amt_dict() 98 | elements = list(comp_dict.keys()) 99 | 100 | weights = list(comp_dict.values()) 101 | weights = np.atleast_2d(weights).T / np.sum(weights) 102 | 103 | try: 104 | atom_fea = np.vstack( 105 | [self.elem_features.get_fea(element) for element in elements] 106 | ) 107 | except AssertionError: 108 | raise AssertionError( 109 | f"cry-id {cry_ids[0]} [{composition}] contains element types not in embedding" 110 | ) 111 | except ValueError: 112 | raise ValueError( 113 | f"cry-id {cry_ids[0]} [{composition}] composition cannot be parsed into elements" 114 | ) 115 | 116 | nele = len(elements) 117 | self_fea_idx = [] 118 | nbr_fea_idx = [] 119 | for i, _ in enumerate(elements): 120 | self_fea_idx += [i] * nele 121 | nbr_fea_idx += list(range(nele)) 122 | 123 | # convert all data to tensors 124 | atom_weights = torch.Tensor(weights) 125 | atom_fea = torch.Tensor(atom_fea) 126 | self_fea_idx = torch.LongTensor(self_fea_idx) 127 | nbr_fea_idx = torch.LongTensor(nbr_fea_idx) 128 | 129 | targets = [] 130 | for target in self.task_dict: 131 | if self.task_dict[target] == "regression": 132 | targets.append(torch.Tensor([df_idx[target]])) 133 | elif self.task_dict[target] == "classification": 134 | targets.append(torch.LongTensor([df_idx[target]])) 135 | 136 | return ( 137 | (atom_weights, atom_fea, self_fea_idx, nbr_fea_idx), 138 | targets, 139 | *cry_ids, 140 | ) 141 | 142 | 143 | def collate_batch(dataset_list): 144 | """ 145 | Collate a list of data and return a batch for predicting crystal 146 | properties. 147 | 148 | Parameters 149 | ---------- 150 | 151 | dataset_list: list of tuples for each data point. 152 | (atom_fea, nbr_fea, nbr_fea_idx, target) 153 | 154 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 155 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 156 | self_fea_idx: torch.LongTensor shape (n_i, M) 157 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 158 | target: torch.Tensor shape (1, ) 159 | cif_id: str or int 160 | 161 | Returns 162 | ------- 163 | N = sum(n_i); N0 = sum(i) 164 | 165 | batch_atom_weights: torch.Tensor shape (N, 1) 166 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 167 | Atom features from atom type 168 | batch_self_fea_idx: torch.LongTensor shape (N, M) 169 | Indices of mapping atom to copies of itself 170 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 171 | Indices of M neighbors of each atom 172 | crystal_atom_idx: list of torch.LongTensor of length N0 173 | Mapping from the crystal idx to atom idx 174 | target: torch.Tensor shape (N, 1) 175 | Target value for prediction 176 | batch_comps: list 177 | batch_ids: list 178 | """ 179 | # define the lists 180 | batch_atom_weights = [] 181 | batch_atom_fea = [] 182 | batch_self_fea_idx = [] 183 | batch_nbr_fea_idx = [] 184 | crystal_atom_idx = [] 185 | batch_targets = [] 186 | batch_cry_ids = [] 187 | 188 | cry_base_idx = 0 189 | for i, (inputs, target, *cry_ids) in enumerate(dataset_list): 190 | atom_weights, atom_fea, self_fea_idx, nbr_fea_idx = inputs 191 | 192 | # number of atoms for this crystal 193 | n_i = atom_fea.shape[0] 194 | 195 | # batch the features together 196 | batch_atom_weights.append(atom_weights) 197 | batch_atom_fea.append(atom_fea) 198 | 199 | # mappings from bonds to atoms 200 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx) 201 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx) 202 | 203 | # mapping from atoms to crystals 204 | crystal_atom_idx.append(torch.tensor([i] * n_i)) 205 | 206 | # batch the targets and ids 207 | batch_targets.append(target) 208 | batch_cry_ids.append(cry_ids) 209 | 210 | # increment the id counter 211 | cry_base_idx += n_i 212 | 213 | return ( 214 | ( 215 | torch.cat(batch_atom_weights, dim=0), 216 | torch.cat(batch_atom_fea, dim=0), 217 | torch.cat(batch_self_fea_idx, dim=0), 218 | torch.cat(batch_nbr_fea_idx, dim=0), 219 | torch.cat(crystal_atom_idx), 220 | ), 221 | tuple(torch.stack(b_target, dim=0) for b_target in zip(*batch_targets)), 222 | *zip(*batch_cry_ids), 223 | ) 224 | -------------------------------------------------------------------------------- /roost/roost/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from roost.core import BaseModelClass 6 | from roost.segments import ResidualNetwork, SimpleNetwork, WeightedAttentionPooling 7 | 8 | 9 | class Roost(BaseModelClass): 10 | """ 11 | The Roost model is comprised of a fully connected network 12 | and message passing graph layers. 13 | 14 | The message passing layers are used to determine a descriptor set 15 | for the fully connected network. The graphs are used to represent 16 | the stoichiometry of inorganic materials in a trainable manner. 17 | This makes them systematically improvable with more data. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | robust, 23 | n_targets, 24 | elem_emb_len, 25 | elem_fea_len=64, 26 | n_graph=3, 27 | elem_heads=3, 28 | elem_gate=[256], 29 | elem_msg=[256], 30 | cry_heads=3, 31 | cry_gate=[256], 32 | cry_msg=[256], 33 | trunk_hidden=[1024, 512], 34 | out_hidden=[256, 128, 64], 35 | **kwargs 36 | ): 37 | if isinstance(out_hidden[0], list): 38 | raise ValueError("boo hiss bad user") 39 | # assert all([isinstance(x, list) for x in out_hidden]), 40 | # 'all elements of out_hidden must be ints or all lists' 41 | # assert len(out_hidden) == len(n_targets), 42 | # 'out_hidden-n_targets length mismatch' 43 | 44 | super().__init__(robust=robust, **kwargs) 45 | 46 | desc_dict = { 47 | "elem_emb_len": elem_emb_len, 48 | "elem_fea_len": elem_fea_len, 49 | "n_graph": n_graph, 50 | "elem_heads": elem_heads, 51 | "elem_gate": elem_gate, 52 | "elem_msg": elem_msg, 53 | "cry_heads": cry_heads, 54 | "cry_gate": cry_gate, 55 | "cry_msg": cry_msg, 56 | } 57 | 58 | self.material_nn = DescriptorNetwork(**desc_dict) 59 | 60 | self.model_params.update( 61 | { 62 | "robust": robust, 63 | "n_targets": n_targets, 64 | "out_hidden": out_hidden, 65 | "trunk_hidden": trunk_hidden, 66 | } 67 | ) 68 | 69 | self.model_params.update(desc_dict) 70 | 71 | # define an output neural network 72 | if self.robust: 73 | n_targets = [2 * n for n in n_targets] 74 | 75 | self.trunk_nn = ResidualNetwork(elem_fea_len, out_hidden[0], trunk_hidden) 76 | 77 | self.output_nns = nn.ModuleList([ 78 | ResidualNetwork(out_hidden[0], n, out_hidden[1:]) for n in n_targets 79 | ]) 80 | 81 | def forward(self, elem_weights, elem_fea, self_fea_idx, nbr_fea_idx, cry_elem_idx): 82 | """ 83 | Forward pass through the material_nn and output_nn 84 | """ 85 | crys_fea = self.material_nn( 86 | elem_weights, elem_fea, self_fea_idx, nbr_fea_idx, cry_elem_idx 87 | ) 88 | 89 | crys_fea = F.relu(self.trunk_nn(crys_fea)) 90 | 91 | # apply neural network to map from learned features to target 92 | return (output_nn(crys_fea) for output_nn in self.output_nns) 93 | 94 | def __repr__(self): 95 | return self.__class__.__name__ 96 | 97 | 98 | class DescriptorNetwork(nn.Module): 99 | """ 100 | The Descriptor Network is the message passing section of the 101 | Roost Model. 102 | """ 103 | 104 | def __init__( 105 | self, 106 | elem_emb_len, 107 | elem_fea_len=64, 108 | n_graph=3, 109 | elem_heads=3, 110 | elem_gate=[256], 111 | elem_msg=[256], 112 | cry_heads=3, 113 | cry_gate=[256], 114 | cry_msg=[256], 115 | ): 116 | """ 117 | """ 118 | super().__init__() 119 | 120 | # apply linear transform to the input to get a trainable embedding 121 | # NOTE -1 here so we can add the weights as a node feature 122 | self.embedding = nn.Linear(elem_emb_len, elem_fea_len - 1) 123 | 124 | # create a list of Message passing layers 125 | self.graphs = nn.ModuleList( 126 | [ 127 | MessageLayer( 128 | elem_fea_len=elem_fea_len, 129 | elem_heads=elem_heads, 130 | elem_gate=elem_gate, 131 | elem_msg=elem_msg, 132 | ) 133 | for i in range(n_graph) 134 | ] 135 | ) 136 | 137 | # define a global pooling function for materials 138 | self.cry_pool = nn.ModuleList( 139 | [ 140 | WeightedAttentionPooling( 141 | gate_nn=SimpleNetwork(elem_fea_len, 1, cry_gate), 142 | message_nn=SimpleNetwork(elem_fea_len, elem_fea_len, cry_msg), 143 | ) 144 | for _ in range(cry_heads) 145 | ] 146 | ) 147 | 148 | def forward(self, elem_weights, elem_fea, self_fea_idx, nbr_fea_idx, cry_elem_idx): 149 | """ 150 | Forward pass 151 | 152 | Parameters 153 | ---------- 154 | N: Total number of elements (nodes) in the batch 155 | M: Total number of pairs (edges) in the batch 156 | C: Total number of crystals (graphs) in the batch 157 | 158 | Inputs 159 | ---------- 160 | elem_weights: Variable(torch.Tensor) shape (N) 161 | Fractional weight of each Element in its stoichiometry 162 | elem_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len) 163 | Element features of each of the N elems in the batch 164 | self_fea_idx: torch.Tensor shape (M,) 165 | Indices of the first element in each of the M pairs 166 | nbr_fea_idx: torch.Tensor shape (M,) 167 | Indices of the second element in each of the M pairs 168 | cry_elem_idx: list of torch.LongTensor of length C 169 | Mapping from the elem idx to crystal idx 170 | 171 | Returns 172 | ------- 173 | cry_fea: nn.Variable shape (C,) 174 | Material representation after message passing 175 | """ 176 | 177 | # embed the original features into a trainable embedding space 178 | elem_fea = self.embedding(elem_fea) 179 | 180 | # add weights as a node feature 181 | elem_fea = torch.cat([elem_fea, elem_weights], dim=1) 182 | 183 | # apply the message passing functions 184 | for graph_func in self.graphs: 185 | elem_fea = graph_func(elem_weights, elem_fea, self_fea_idx, nbr_fea_idx) 186 | 187 | # generate crystal features by pooling the elemental features 188 | head_fea = [] 189 | for attnhead in self.cry_pool: 190 | head_fea.append( 191 | attnhead(elem_fea, index=cry_elem_idx, weights=elem_weights) 192 | ) 193 | 194 | # head_fea = [ 195 | # head(elem_fea, index=cry_elem_idx, weights=elem_weights) 196 | # for head in self.cry_pool 197 | # ] 198 | 199 | return torch.mean(torch.stack(head_fea), dim=0) 200 | 201 | def __repr__(self): 202 | return self.__class__.__name__ 203 | 204 | 205 | class MessageLayer(nn.Module): 206 | """ 207 | Massage Layers are used to propagate information between nodes in 208 | the stoichiometry graph. 209 | """ 210 | 211 | def __init__(self, elem_fea_len, elem_heads, elem_gate, elem_msg): 212 | """ 213 | """ 214 | super().__init__() 215 | 216 | # Pooling and Output 217 | self.pooling = nn.ModuleList( 218 | [ 219 | WeightedAttentionPooling( 220 | gate_nn=SimpleNetwork(2 * elem_fea_len, 1, elem_gate), 221 | message_nn=SimpleNetwork(2 * elem_fea_len, elem_fea_len, elem_msg), 222 | ) 223 | for _ in range(elem_heads) 224 | ] 225 | ) 226 | 227 | def forward(self, elem_weights, elem_in_fea, self_fea_idx, nbr_fea_idx): 228 | """ 229 | Forward pass 230 | 231 | Parameters 232 | ---------- 233 | N: Total number of elements (nodes) in the batch 234 | M: Total number of pairs (edges) in the batch 235 | C: Total number of crystals (graphs) in the batch 236 | 237 | Inputs 238 | ---------- 239 | elem_weights: Variable(torch.Tensor) shape (N,) 240 | The fractional weights of elems in their materials 241 | elem_in_fea: Variable(torch.Tensor) shape (N, elem_fea_len) 242 | Element hidden features before message passing 243 | self_fea_idx: torch.Tensor shape (M,) 244 | Indices of the first element in each of the M pairs 245 | nbr_fea_idx: torch.Tensor shape (M,) 246 | Indices of the second element in each of the M pairs 247 | 248 | Returns 249 | ------- 250 | elem_out_fea: nn.Variable shape (N, elem_fea_len) 251 | Element hidden features after message passing 252 | """ 253 | # construct the total features for passing 254 | elem_nbr_weights = elem_weights[nbr_fea_idx, :] 255 | elem_nbr_fea = elem_in_fea[nbr_fea_idx, :] 256 | elem_self_fea = elem_in_fea[self_fea_idx, :] 257 | fea = torch.cat([elem_self_fea, elem_nbr_fea], dim=1) 258 | 259 | # sum selectivity over the neighbours to get elems 260 | head_fea = [] 261 | for attnhead in self.pooling: 262 | head_fea.append( 263 | attnhead(fea, index=self_fea_idx, weights=elem_nbr_weights) 264 | ) 265 | 266 | # average the attention heads 267 | fea = torch.mean(torch.stack(head_fea), dim=0) 268 | 269 | return fea + elem_in_fea 270 | 271 | def __repr__(self): 272 | return self.__class__.__name__ 273 | -------------------------------------------------------------------------------- /roost/segments.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import scatter_add, scatter_max, scatter_mean 4 | 5 | 6 | class MeanPooling(nn.Module): 7 | """Mean pooling""" 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, index): 13 | return scatter_mean(x, index, dim=0) 14 | 15 | def __repr__(self): 16 | return self.__class__.__name__ 17 | 18 | 19 | class SumPooling(nn.Module): 20 | """Sum pooling""" 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, x, index): 26 | return scatter_add(x, index, dim=0) 27 | 28 | def __repr__(self): 29 | return self.__class__.__name__ 30 | 31 | 32 | class AttentionPooling(nn.Module): 33 | """ 34 | softmax attention layer 35 | """ 36 | 37 | def __init__(self, gate_nn, message_nn): 38 | """ 39 | Args: 40 | gate_nn: Variable(nn.Module) 41 | message_nn 42 | """ 43 | super().__init__() 44 | self.gate_nn = gate_nn 45 | self.message_nn = message_nn 46 | 47 | def forward(self, x, index): 48 | gate = self.gate_nn(x) 49 | 50 | gate = gate - scatter_max(gate, index, dim=0)[0][index] 51 | gate = gate.exp() 52 | gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10) 53 | 54 | x = self.message_nn(x) 55 | out = scatter_add(gate * x, index, dim=0) 56 | 57 | return out 58 | 59 | def __repr__(self): 60 | return self.__class__.__name__ 61 | 62 | 63 | class WeightedAttentionPooling(nn.Module): 64 | """ 65 | Weighted softmax attention layer 66 | """ 67 | 68 | def __init__(self, gate_nn, message_nn): 69 | """ 70 | Inputs 71 | ---------- 72 | gate_nn: Variable(nn.Module) 73 | """ 74 | super().__init__() 75 | self.gate_nn = gate_nn 76 | self.message_nn = message_nn 77 | self.pow = torch.nn.Parameter(torch.randn(1)) 78 | 79 | def forward(self, x, index, weights): 80 | gate = self.gate_nn(x) 81 | 82 | gate = gate - scatter_max(gate, index, dim=0)[0][index] 83 | gate = (weights ** self.pow) * gate.exp() 84 | # gate = weights * gate.exp() 85 | # gate = gate.exp() 86 | gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10) 87 | 88 | x = self.message_nn(x) 89 | out = scatter_add(gate * x, index, dim=0) 90 | 91 | return out 92 | 93 | def __repr__(self): 94 | return self.__class__.__name__ 95 | 96 | 97 | class SimpleNetwork(nn.Module): 98 | """ 99 | Simple Feed Forward Neural Network 100 | """ 101 | 102 | def __init__( 103 | self, 104 | input_dim, 105 | output_dim, 106 | hidden_layer_dims, 107 | activation=nn.LeakyReLU, 108 | batchnorm=False, 109 | ): 110 | """ 111 | Inputs 112 | ---------- 113 | input_dim: int 114 | output_dim: int 115 | hidden_layer_dims: list(int) 116 | 117 | """ 118 | super().__init__() 119 | 120 | dims = [input_dim] + hidden_layer_dims 121 | 122 | self.fcs = nn.ModuleList( 123 | [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)] 124 | ) 125 | 126 | if batchnorm: 127 | self.bns = nn.ModuleList( 128 | [nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1)] 129 | ) 130 | else: 131 | self.bns = nn.ModuleList([nn.Identity() for i in range(len(dims) - 1)]) 132 | 133 | self.acts = nn.ModuleList([activation() for _ in range(len(dims) - 1)]) 134 | 135 | self.fc_out = nn.Linear(dims[-1], output_dim) 136 | 137 | def forward(self, x): 138 | for fc, bn, act in zip(self.fcs, self.bns, self.acts): 139 | x = act(bn(fc(x))) 140 | 141 | return self.fc_out(x) 142 | 143 | def __repr__(self): 144 | return self.__class__.__name__ 145 | 146 | def reset_parameters(self): 147 | for fc in self.fcs: 148 | fc.reset_parameters() 149 | 150 | self.fc_out.reset_parameters() 151 | 152 | 153 | class ResidualNetwork(nn.Module): 154 | """ 155 | Feed forward Residual Neural Network 156 | """ 157 | 158 | def __init__( 159 | self, 160 | input_dim, 161 | output_dim, 162 | hidden_layer_dims, 163 | activation=nn.ReLU, 164 | batchnorm=False, 165 | return_features=False, 166 | ): 167 | """ 168 | Inputs 169 | ---------- 170 | input_dim: int 171 | output_dim: int 172 | hidden_layer_dims: list(int) 173 | 174 | """ 175 | super().__init__() 176 | 177 | dims = [input_dim] + hidden_layer_dims 178 | 179 | self.fcs = nn.ModuleList( 180 | [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)] 181 | ) 182 | 183 | if batchnorm: 184 | self.bns = nn.ModuleList( 185 | [nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1)] 186 | ) 187 | else: 188 | self.bns = nn.ModuleList([nn.Identity() for i in range(len(dims) - 1)]) 189 | 190 | self.res_fcs = nn.ModuleList( 191 | [ 192 | nn.Linear(dims[i], dims[i + 1], bias=False) 193 | if (dims[i] != dims[i + 1]) 194 | else nn.Identity() 195 | for i in range(len(dims) - 1) 196 | ] 197 | ) 198 | self.acts = nn.ModuleList([activation() for _ in range(len(dims) - 1)]) 199 | 200 | self.return_features = return_features 201 | if not self.return_features: 202 | self.fc_out = nn.Linear(dims[-1], output_dim) 203 | 204 | def forward(self, x): 205 | for fc, bn, res_fc, act in zip(self.fcs, self.bns, self.res_fcs, self.acts): 206 | x = act(bn(fc(x))) + res_fc(x) 207 | 208 | if self.return_features: 209 | return x 210 | else: 211 | return self.fc_out(x) 212 | 213 | def __repr__(self): 214 | return self.__class__.__name__ 215 | -------------------------------------------------------------------------------- /roost/wren/data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import functools 3 | import json 4 | import os 5 | import re 6 | from itertools import groupby 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | from roost.wren.utils import mult_dict, relab_dict 14 | 15 | 16 | class WyckoffData(Dataset): 17 | """ 18 | The WrenData dataset is a wrapper for a dataset data points are 19 | automatically constructed from composition strings. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | data_path, 25 | sym_path, 26 | fea_path, 27 | task_dict, 28 | inputs=["wyckoff"], 29 | identifiers=["material_id", "composition", "wyckoff"], 30 | ): 31 | """[summary] 32 | 33 | Args: 34 | data_path ([type]): [description] 35 | sym_path ([type]): [description] 36 | fea_path ([type]): [description] 37 | task_dict ([type]): [description] 38 | inputs (list, optional): [description]. Defaults to ["composition"]. 39 | identifiers (list, optional): [description]. Defaults to ["material_id", "composition"]. 40 | """ 41 | assert len(identifiers) >= 2, "Two identifiers are required" 42 | assert len(inputs) == 1, "One input column required" 43 | 44 | self.inputs = inputs 45 | self.task_dict = task_dict 46 | self.identifiers = identifiers 47 | 48 | assert os.path.exists(data_path), f"{data_path} does not exist!" 49 | # NOTE make sure to use dense datasets, 50 | # NOTE do not use default_na as "NaN" is a valid material composition 51 | self.df = pd.read_csv(data_path, keep_default_na=False, na_values=[]) 52 | 53 | assert os.path.exists(fea_path), f"{fea_path} does not exist!" 54 | 55 | # TODO now using 2 level dicts so can't use featuriser, can this be standardised? 56 | # self.atom_features = Featurizer.from_json(fea_path) 57 | with open(fea_path) as f: 58 | self.atom_features = json.load(f) 59 | 60 | assert os.path.exists(sym_path), f"{sym_path} does not exist!" 61 | # self.sym_features = Featurizer.from_json(sym_path) 62 | with open(sym_path) as f: 63 | self.sym_features = json.load(f) 64 | 65 | # self.elem_fea_dim = self.atom_features.embedding_size 66 | # self.sym_fea_dim = self.sym_features.embedding_size 67 | 68 | self.elem_fea_dim = len(list(self.atom_features.values())[0]) 69 | self.sym_fea_dim = len(list(list(self.sym_features.values())[0].values())[0]) 70 | 71 | self.n_targets = [] 72 | for target, task in self.task_dict.items(): 73 | if task == "regression": 74 | self.n_targets.append(1) 75 | elif task == "classification": 76 | n_classes = np.max(self.df[target].values) + 1 77 | self.n_targets.append(n_classes) 78 | 79 | def __len__(self): 80 | return len(self.df) 81 | 82 | @functools.lru_cache(maxsize=None) # Cache loaded structures 83 | def __getitem__(self, idx): 84 | """[summary] 85 | 86 | Args: 87 | idx ([type]): [description] 88 | 89 | Raises: 90 | AssertionError: [description] 91 | 92 | Returns: 93 | atom_weights: torch.Tensor shape (M, 1) 94 | weights of atoms in the material 95 | atom_fea: torch.Tensor shape (M, n_fea) 96 | features of atoms in the material 97 | self_fea_idx: torch.Tensor shape (M*M, 1) 98 | list of self indices 99 | nbr_fea_idx: torch.Tensor shape (M*M, 1) 100 | list of neighbour indices 101 | target: torch.Tensor shape (1,) 102 | target value for material 103 | cry_id: torch.Tensor shape (1,) 104 | input id for the material 105 | """ 106 | df_idx = self.df.iloc[idx] 107 | swyks = df_idx[self.inputs][0] 108 | cry_ids = df_idx[self.identifiers].values 109 | 110 | # print(cry_id, composition, swyks) 111 | 112 | spg_no, weights, elements, aug_wyks = parse_aflow(swyks) 113 | # spg_no, weights, elements, aug_wyks = parse_wren(swyks) 114 | weights = np.atleast_2d(weights).T / np.sum(weights) 115 | 116 | try: 117 | atom_fea = np.vstack([self.atom_features[el] for el in elements]) 118 | sym_fea = np.vstack( 119 | [self.sym_features[spg_no][wyk] for wyks in aug_wyks for wyk in wyks] 120 | ) 121 | except AssertionError: 122 | print(f"failed to process {cry_ids[0]}: {cry_ids[1]}-{swyks}") 123 | raise 124 | 125 | n_wyks = len(elements) 126 | self_fea_idx = [] 127 | nbr_fea_idx = [] 128 | for i in range(n_wyks): 129 | self_fea_idx += [i] * n_wyks 130 | nbr_fea_idx += list(range(n_wyks)) 131 | 132 | self_aug_fea_idx = [] 133 | nbr_aug_fea_idx = [] 134 | n_aug = len(aug_wyks) 135 | for i in range(n_aug): 136 | self_aug_fea_idx += [x + i * n_wyks for x in self_fea_idx] 137 | nbr_aug_fea_idx += [x + i * n_wyks for x in nbr_fea_idx] 138 | 139 | # convert all data to tensors 140 | atom_weights = torch.Tensor(weights) 141 | atom_fea = torch.Tensor(atom_fea) 142 | sym_fea = torch.Tensor(sym_fea) 143 | self_fea_idx = torch.LongTensor(self_aug_fea_idx) 144 | nbr_fea_idx = torch.LongTensor(nbr_aug_fea_idx) 145 | 146 | targets = [] 147 | for name in self.task_dict: 148 | if self.task_dict[name] == "regression": 149 | targets.append(torch.Tensor([float(self.df.iloc[idx][name])])) 150 | elif self.task_dict[name] == "classification": 151 | targets.append(torch.LongTensor([int(self.df.iloc[idx][name])])) 152 | 153 | return ( 154 | (atom_weights, atom_fea, sym_fea, self_fea_idx, nbr_fea_idx), 155 | targets, 156 | *cry_ids, 157 | ) 158 | 159 | 160 | def collate_batch(dataset_list): 161 | """Collate a list of data and return a batch for predicting 162 | crystal properties. 163 | 164 | N = sum(n_i); N0 = sum(i) 165 | 166 | Args: 167 | dataset_list ([tuple]): list of tuples for each data point. 168 | (atom_fea, nbr_fea, nbr_fea_idx, target) 169 | 170 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 171 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 172 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 173 | target: torch.Tensor shape (1, ) 174 | cif_id: str or int 175 | 176 | Returns: 177 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 178 | Atom features from atom type 179 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) 180 | Bond features of each atom"s M neighbors 181 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 182 | Indices of M neighbors of each atom 183 | crystal_atom_idx: list of torch.LongTensor of length N0 184 | Mapping from the crystal idx to atom idx 185 | target: torch.Tensor shape (N, 1) 186 | Target value for prediction 187 | batch_cif_ids: list 188 | """ 189 | # define the lists 190 | batch_atom_weights = [] 191 | batch_atom_fea = [] 192 | batch_sym_fea = [] 193 | batch_self_fea_idx = [] 194 | batch_nbr_fea_idx = [] 195 | crystal_atom_idx = [] 196 | aug_cry_idx = [] 197 | batch_targets = [] 198 | batch_cry_ids = [] 199 | 200 | aug_count = 0 201 | cry_base_idx = 0 202 | for i, (inputs, target, *cry_ids) in enumerate(dataset_list): 203 | atom_weights, atom_fea, sym_fea, self_fea_idx, nbr_fea_idx = inputs 204 | 205 | # number of atoms for this crystal 206 | n_el = atom_fea.shape[0] 207 | n_i = sym_fea.shape[0] 208 | n_aug = int(float(n_i) / float(n_el)) 209 | 210 | # batch the features together 211 | batch_atom_weights.append(atom_weights.repeat((n_aug, 1))) 212 | batch_atom_fea.append(atom_fea.repeat((n_aug, 1))) 213 | batch_sym_fea.append(sym_fea) 214 | 215 | # mappings from bonds to atoms 216 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx) 217 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx) 218 | 219 | # mapping from atoms to crystals 220 | # print(torch.tensor(range(i, i+n_aug)).size()) 221 | crystal_atom_idx.append( 222 | torch.tensor(range(aug_count, aug_count + n_aug)).repeat_interleave(n_el) 223 | ) 224 | aug_cry_idx.append(torch.tensor([i] * n_aug)) 225 | 226 | # batch the targets and ids 227 | batch_targets.append(target) 228 | batch_cry_ids.append(cry_ids) 229 | 230 | # increment the id counter 231 | aug_count += n_aug 232 | cry_base_idx += n_i 233 | 234 | return ( 235 | ( 236 | torch.cat(batch_atom_weights, dim=0), 237 | torch.cat(batch_atom_fea, dim=0), 238 | torch.cat(batch_sym_fea, dim=0), 239 | torch.cat(batch_self_fea_idx, dim=0), 240 | torch.cat(batch_nbr_fea_idx, dim=0), 241 | torch.cat(crystal_atom_idx), 242 | torch.cat(aug_cry_idx), 243 | ), 244 | tuple(torch.stack(b_target, dim=0) for b_target in zip(*batch_targets)), 245 | *zip(*batch_cry_ids), 246 | ) 247 | 248 | 249 | def parse_wren(swyk_list): 250 | """parse the wyckoff format 251 | 252 | Args: 253 | swyk_list ([type]): [description] 254 | 255 | Returns: 256 | mult_list, ele_list, aug_wyks 257 | """ 258 | swyk_list = ast.literal_eval(swyk_list) 259 | 260 | mult_list = [] 261 | ele_list = [] 262 | wyk_list = [] 263 | 264 | spg_no = swyk_list[0].split("-")[-1] 265 | 266 | for swyk in swyk_list: 267 | ele_mult, wyk = swyk.split(" @ ") 268 | ele, mult = ele_mult.split("-") 269 | let, _ = wyk.split("-") 270 | 271 | # ele, wyk = swyk.split(" @ ") 272 | # mult, let, _ = wyk.split("-") 273 | 274 | mult_list.append(float(mult)) 275 | ele_list.append(ele) 276 | wyk_list.append(let) 277 | 278 | aug_wyks = [] 279 | for trans in relab_dict[spg_no]: 280 | t = str.maketrans(trans) 281 | aug_wyks.append(tuple(",".join(wyk_list).translate(t).split(","))) 282 | 283 | aug_wyks = list(set(aug_wyks)) 284 | # print(len(aug_wyks)) 285 | # print(aug_wyks) 286 | # exit() 287 | 288 | return spg_no, mult_list, ele_list, aug_wyks 289 | 290 | 291 | def parse_aflow(aflow_label): 292 | """parse the wyckoff format 293 | 294 | Args: 295 | swyk_list ([type]): [description] 296 | relab_dict ([type]): [description] 297 | 298 | Returns: 299 | mult_list, ele_list, aug_wyks 300 | """ 301 | proto, chemsys = aflow_label.split(":") 302 | elems = chemsys.split("-") 303 | _, _, spg_no, *wyks = proto.split("_") 304 | 305 | mult_list = [] 306 | ele_list = [] 307 | wyk_list = [] 308 | 309 | subst = r"1\g<1>" 310 | for el, wyk in zip(elems, wyks): 311 | wyk = re.sub(r"((? str: 58 | """get aflow prototype label for pymatgen structure 59 | 60 | args: 61 | struct (Structure): pymatgen structure object 62 | 63 | returns: 64 | aflow prototype labels 65 | """ 66 | 67 | poscar = Poscar(struct) 68 | 69 | cmd = f"{aflow_executable} --prototype --print=json cat" 70 | 71 | output = subprocess.run( 72 | cmd, input=poscar.get_string(), text=True, capture_output=True, shell=True 73 | ) 74 | 75 | aflow_proto = json.loads(output.stdout) 76 | 77 | aflow_label = aflow_proto["aflow_label"] 78 | 79 | aflow_label = aflow_label.replace( 80 | "alpha", "A" 81 | ) # to be consistent with spglib and wren embeddings 82 | 83 | # check that multiplicities satisfy original composition 84 | symm = aflow_label.split("_") 85 | spg_no = symm[2] 86 | wyks = symm[3:] 87 | elems = poscar.site_symbols 88 | elem_dict = {} 89 | subst = r"1\g<1>" 90 | for el, wyk in zip(elems, wyks): 91 | wyk = re.sub(r"((? str: 108 | """get aflow prototype label for pymatgen structure 109 | 110 | args: 111 | struct (Structure): pymatgen structure object 112 | 113 | returns: 114 | aflow prototype labels 115 | """ 116 | spga = SpacegroupAnalyzer(struct, symprec=0.1, angle_tolerance=5) 117 | aflow = get_aflow_label_from_spga(spga) 118 | 119 | # try again with refined structure if it initially fails 120 | # NOTE structures with magmoms fail unless all have same magmom 121 | if "Invalid" in aflow: 122 | spga = SpacegroupAnalyzer( 123 | spga.get_refined_structure(), symprec=1e-5, angle_tolerance=-1 124 | ) 125 | aflow = get_aflow_label_from_spga(spga) 126 | 127 | return aflow 128 | 129 | 130 | def get_aflow_label_from_spga(spga): 131 | spg_no = spga.get_space_group_number() 132 | sym_struct = spga.get_symmetrized_structure() 133 | 134 | equivs = [ 135 | (len(s), s[0].species_string, f"{wyk.translate(remove_digits)}") 136 | for s, wyk in zip(sym_struct.equivalent_sites, sym_struct.wyckoff_symbols) 137 | ] 138 | equivs = sorted(equivs, key=lambda x: (x[1], x[2])) 139 | 140 | # check that multiplicities satisfy original composition 141 | elem_dict = {} 142 | elem_wyks = [] 143 | for el, g in groupby(equivs, key=lambda x: x[1]): # sort alphabetically by element 144 | g = list(g) 145 | elem_dict[el] = sum(float(mult_dict[str(spg_no)][e[2]]) for e in g) 146 | wyks = "" 147 | for wyk, w in groupby( 148 | g, key=lambda x: x[2] 149 | ): # sort alphabetically by wyckoff letter 150 | w = list(w) 151 | wyks += f"{len(w)}{wyk}" 152 | elem_wyks.append(wyks) 153 | 154 | # cannonicalise the possible wyckoff letter sequences 155 | elem_wyks = "_".join(elem_wyks) 156 | canonical = canonicalise_elem_wyks(elem_wyks, spg_no) 157 | 158 | # get pearson symbol 159 | cry_sys = spga.get_crystal_system() 160 | spg_sym = spga.get_space_group_symbol() 161 | centering = "C" if spg_sym[0] in ("A", "B", "C", "S") else spg_sym[0] 162 | n_conv = len(spga._space_group_data["std_types"]) 163 | pearson = f"{cry_sys_dict[cry_sys]}{centering}{n_conv}" 164 | 165 | prototype_form = prototype_formula(spga._structure.composition) 166 | 167 | aflow_label = ( 168 | f"{prototype_form}_{pearson}_{spg_no}_{canonical}:" 169 | f"{spga._structure.composition.chemical_system}" 170 | ) 171 | 172 | eqi_comp = Composition(elem_dict) 173 | if not eqi_comp.reduced_formula == spga._structure.composition.reduced_formula: 174 | return f"Invalid WP Multiplicities - {aflow_label}" 175 | 176 | return aflow_label 177 | 178 | 179 | def canonicalise_elem_wyks(elem_wyks, spg_no): 180 | """ 181 | Given an element ordering canonicalise the associated wyckoff positions 182 | based on the alphabetical weight of equivalent choices of origin. 183 | """ 184 | 185 | isopointial = [] 186 | 187 | for trans in relab_dict[str(spg_no)]: 188 | t = str.maketrans(trans) 189 | isopointial.append(elem_wyks.translate(t)) 190 | 191 | isopointial = list(set(isopointial)) 192 | 193 | scores = [] 194 | sorted_iso = [] 195 | for wyks in isopointial: 196 | score = 0 197 | sorted_el_wyks = [] 198 | for el_wyks in wyks.split("_"): 199 | sep_el_wyks = ["".join(g) for _, g in groupby(el_wyks, str.isalpha)] 200 | sep_el_wyks = ["" if i == "1" else i for i in sep_el_wyks] 201 | sorted_el_wyks.append( 202 | "".join( 203 | [ 204 | f"{n}{w}" 205 | for n, w in sorted( 206 | zip(sep_el_wyks[0::2], sep_el_wyks[1::2]), 207 | key=lambda x: x[1], 208 | ) 209 | ] 210 | ) 211 | ) 212 | score += sum(0 if el == "A" else ord(el) - 96 for el in sep_el_wyks[1::2]) 213 | 214 | scores.append(score) 215 | sorted_iso.append("_".join(sorted_el_wyks)) 216 | 217 | canonical = sorted(zip(scores, sorted_iso), key=lambda x: (x[0], x[1]))[0][1] 218 | 219 | return canonical 220 | 221 | 222 | def prototype_formula(composition) -> str: 223 | """ 224 | An anonymized formula. Unique species are arranged in alphabetical order 225 | and assigned ascending alphabets. This format is used in the aflow structure 226 | prototype labelling scheme. 227 | """ 228 | reduced = composition.element_composition 229 | if all(x == int(x) for x in composition.values()): 230 | reduced /= gcd(*(int(i) for i in composition.values())) 231 | 232 | amts = [amt for _, amt in sorted(reduced.items(), key=lambda x: str(x[0]))] 233 | 234 | anon = "" 235 | for e, amt in zip(ascii_uppercase, amts): 236 | if amt == 1: 237 | amt_str = "" 238 | elif abs(amt % 1) < 1e-8: 239 | amt_str = str(int(amt)) 240 | else: 241 | amt_str = str(amt) 242 | anon += f"{e}{amt_str}" 243 | return anon 244 | 245 | 246 | def count_wyks(aflow_label): 247 | num_wyk = 0 248 | 249 | aflow_label, _ = aflow_label.split(":") 250 | wyks = aflow_label.split("_")[3:] 251 | 252 | subst = r"1\g<1>" 253 | for wyk in wyks: 254 | wyk = re.sub(r"((?" 274 | for wyk in wyks: 275 | wyk = re.sub(r"((? 0.75 153 | assert ens_roc_auc > 0.9 154 | 155 | 156 | if __name__ == "__main__": 157 | test_single_roost_clf() 158 | -------------------------------------------------------------------------------- /tests/test_single_roost_regression.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from sklearn.metrics import r2_score 6 | from sklearn.model_selection import train_test_split as split 7 | 8 | from roost.roost.data import CompositionData, collate_batch 9 | from roost.roost.model import Roost 10 | from roost.utils import results_multitask, train_ensemble 11 | 12 | torch.manual_seed(0) # ensure reproducible results 13 | 14 | 15 | def test_single_roost(): 16 | 17 | data_path = "tests/data/roost-regression.csv" 18 | fea_path = "data/el-embeddings/matscholar-embedding.json" 19 | targets = ["Eg"] 20 | tasks = ["regression"] 21 | losses = ["L1"] 22 | robust = True 23 | model_name = "roost" 24 | elem_fea_len = 64 25 | n_graph = 3 26 | ensemble = 2 27 | run_id = 1 28 | data_seed = 42 29 | epochs = 25 30 | log = False 31 | sample = 1 32 | test_size = 0.2 33 | resume = False 34 | fine_tune = None 35 | transfer = None 36 | optim = "AdamW" 37 | learning_rate = 3e-4 38 | momentum = 0.9 39 | weight_decay = 1e-6 40 | batch_size = 128 41 | workers = 0 42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 43 | 44 | task_dict = {k: v for k, v in zip(targets, tasks)} 45 | loss_dict = {k: v for k, v in zip(targets, losses)} 46 | 47 | dataset = CompositionData(data_path=data_path, fea_path=fea_path, task_dict=task_dict) 48 | n_targets = dataset.n_targets 49 | elem_emb_len = dataset.elem_emb_len 50 | 51 | train_idx = list(range(len(dataset))) 52 | 53 | print(f"using {test_size} of training set as test set") 54 | train_idx, test_idx = split(train_idx, random_state=data_seed, test_size=test_size) 55 | test_set = torch.utils.data.Subset(dataset, test_idx) 56 | 57 | print("No validation set used, using test set for evaluation purposes") 58 | # NOTE that when using this option care must be taken not to 59 | # peak at the test-set. The only valid model to use is the one 60 | # obtained after the final epoch where the epoch count is 61 | # decided in advance of the experiment. 62 | val_set = test_set 63 | 64 | train_set = torch.utils.data.Subset(dataset, train_idx[0::sample]) 65 | 66 | data_params = { 67 | "batch_size": batch_size, 68 | "num_workers": workers, 69 | "pin_memory": False, 70 | "shuffle": True, 71 | "collate_fn": collate_batch, 72 | } 73 | 74 | setup_params = { 75 | "optim": optim, 76 | "learning_rate": learning_rate, 77 | "weight_decay": weight_decay, 78 | "momentum": momentum, 79 | "device": device, 80 | } 81 | 82 | restart_params = { 83 | "resume": resume, 84 | "fine_tune": fine_tune, 85 | "transfer": transfer, 86 | } 87 | 88 | model_params = { 89 | "task_dict": task_dict, 90 | "robust": robust, 91 | "n_targets": n_targets, 92 | "elem_emb_len": elem_emb_len, 93 | "elem_fea_len": elem_fea_len, 94 | "n_graph": n_graph, 95 | "elem_heads": 3, 96 | "elem_gate": [256], 97 | "elem_msg": [256], 98 | "cry_heads": 3, 99 | "cry_gate": [256], 100 | "cry_msg": [256], 101 | "trunk_hidden": [1024, 512], 102 | "out_hidden": [256, 128, 64], 103 | } 104 | 105 | os.makedirs(f"models/{model_name}", exist_ok=True) 106 | os.makedirs(f"results/{model_name}", exist_ok=True) 107 | 108 | train_ensemble( 109 | model_class=Roost, 110 | model_name=model_name, 111 | run_id=run_id, 112 | ensemble_folds=ensemble, 113 | epochs=epochs, 114 | train_set=train_set, 115 | val_set=val_set, 116 | log=log, 117 | data_params=data_params, 118 | setup_params=setup_params, 119 | restart_params=restart_params, 120 | model_params=model_params, 121 | loss_dict=loss_dict, 122 | ) 123 | 124 | data_params["batch_size"] = 64 * batch_size # faster model inference 125 | data_params["shuffle"] = False # need fixed data order due to ensembling 126 | 127 | results_dict = results_multitask( 128 | model_class=Roost, 129 | model_name=model_name, 130 | run_id=run_id, 131 | ensemble_folds=ensemble, 132 | test_set=test_set, 133 | data_params=data_params, 134 | robust=robust, 135 | task_dict=task_dict, 136 | device=device, 137 | eval_type="checkpoint", 138 | ) 139 | 140 | pred = results_dict["Eg"]["pred"] 141 | target = results_dict["Eg"]["target"] 142 | 143 | y_ens = np.mean(pred, axis=0) 144 | 145 | mae = np.abs(target - y_ens).mean() 146 | mse = np.square(target - y_ens).mean() 147 | rmse = np.sqrt(mse) 148 | r2 = r2_score(target, y_ens) 149 | 150 | assert r2 > 0.7 151 | assert mae < 0.6 152 | assert rmse < 0.8 153 | 154 | 155 | if __name__ == "__main__": 156 | test_single_roost() 157 | --------------------------------------------------------------------------------