├── .github
└── workflows
│ └── ci.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CITATION.cff
├── LICENSE
├── README.md
├── data
├── datasets
│ └── roost
│ │ ├── README.md
│ │ ├── expt-non-metals.csv
│ │ ├── mp-non-metals.csv
│ │ └── oqmd-form-enthalpy.csv
├── el-embeddings
│ ├── README.md
│ ├── cgcnn-embedding.json
│ ├── matscholar-embedding.json
│ ├── megnet16-embedding.json
│ └── onehot-embedding.json
└── wp-embeddings
│ ├── README.md
│ ├── bra-alg-off.json
│ └── spg-alg-off.json
├── environment-gpu-cu111.yml
├── examples
├── cgcnn-example.py
├── roost-example.py
└── wren-example.py
├── requirements.txt
├── roost
├── cgcnn
│ ├── data.py
│ ├── model.py
│ └── utils.py
├── core.py
├── pretrain
│ ├── dist_data.py
│ ├── dist_model.py
│ ├── ele_data.py
│ ├── ele_model.py
│ ├── ener_data.py
│ └── ener_model.py
├── roost
│ ├── data.py
│ └── model.py
├── segments.py
├── utils.py
└── wren
│ ├── allowed-wp-mult.json
│ ├── data.py
│ ├── model.py
│ ├── relab.json
│ ├── utils.py
│ └── wp-params.json
├── setup.cfg
├── setup.py
└── tests
├── data
├── roost-classification.csv
└── roost-regression.csv
├── test_single_roost_classification.py
└── test_single_roost_regression.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | paths:
6 | - '**.py'
7 | - '.github/workflows/ci.yml'
8 | pull_request:
9 | paths:
10 | - '**.py'
11 | - '.github/workflows/ci.yml'
12 |
13 | jobs:
14 | tests:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: Checkout repo
18 | uses: actions/checkout@v2
19 |
20 | - name: Setup Python
21 | uses: actions/setup-python@v2
22 | with:
23 | python-version: 3.8
24 |
25 | - name: Install dependencies
26 | run: |
27 | # install torch first because torch_scatter needs it
28 | pip install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
29 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
30 | pip install -r requirements.txt
31 | pip install .
32 |
33 | - name: Run Tests
34 | run: |
35 | python -m pytest
36 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.pyc
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | # vscode
108 | .vscode/
109 |
110 | # notebooks
111 | .ipynb_checkpoints/
112 |
113 | # exclude model outputs from github
114 | runs/
115 | results/
116 | models/
117 |
118 | # exclude local slurm and plotting scripts
119 | process/
120 | papers/
121 | cds3/
122 |
123 | # exclude local data and cached graphs
124 | *.csv
125 | *.pkl
126 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Install these hooks with `pre-commit install`.
2 |
3 | ci:
4 | autoupdate_schedule: quarterly
5 |
6 | default_stages: [commit]
7 |
8 | repos:
9 | - repo: https://github.com/PyCQA/isort
10 | rev: 5.10.1
11 | hooks:
12 | - id: isort
13 |
14 | - repo: https://github.com/PyCQA/flake8
15 | rev: 5.0.4
16 | hooks:
17 | - id: flake8
18 |
19 | - repo: https://github.com/asottile/pyupgrade
20 | rev: v2.38.2
21 | hooks:
22 | - id: pyupgrade
23 | args: [--py37-plus]
24 |
25 | - repo: https://github.com/pre-commit/pre-commit-hooks
26 | rev: v4.3.0
27 | hooks:
28 | - id: end-of-file-fixer
29 | - id: forbid-new-submodules
30 | - id: mixed-line-ending
31 | - id: trailing-whitespace
32 |
33 | - repo: https://github.com/codespell-project/codespell
34 | rev: v2.2.1
35 | hooks:
36 | - id: codespell
37 | stages: [commit, commit-msg]
38 | exclude_types: [csv, json]
39 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Goodall"
5 | given-names: "Rhys"
6 | orcid: "https://orcid.org/0000-0002-6589-1700"
7 | title: "roost"
8 | version: 0.0.2
9 | date-released: 2021-07-28
10 | url: "https://github.com/CompRhys/roost"
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019-2020 Rhys Goodall
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Roost
2 | Representation Learning from Stoichiometry
3 |
4 |
5 |
6 | 
7 | [](https://github.com/comprhys/roost/graphs/contributors)
8 | [](https://github.com/comprhys/roost/commits)
9 | [](https://github.com/CompRhys/roost/actions)
10 | [](https://results.pre-commit.ci/latest/github/CompRhys/roost/main)
11 |
12 |
13 |
14 |
15 | :warning: UNSUPPORTED NOTICE :warning:
16 | Development of the roost has been moved from this repository to the https://github.com/CompRhys/aviary repository.
17 |
18 | ## Premise
19 |
20 | In materials discovery applications often we know the composition of trial materials but have little knowledge about the structure.
21 |
22 | Many current SOTA results within the field of machine learning for materials discovery are reliant on knowledge of the structure of the material. This means that such models can only be applied to systems that have undergone structural characterisation. As structural characterisation is a time-consuming process whether done experimentally or via the use of ab-initio methods the use of structures as our model inputs is a prohibitive bottleneck to many materials screening applications we would like to pursue.
23 |
24 | One approach for avoiding the structure bottleneck is to develop models that learn from the stoichiometry alone. In this work, we show that via a novel recasting of how we view the stoichiometry of a material we can leverage a message-passing neural network to learn materials properties whilst remaining agnostic to the structure. The proposed model exhibits increased sample efficiency compared to more widely used descriptor-based approaches. This work draws inspiration from recent progress in using graph-based methods for the study of small molecules and crystalline materials.
25 |
26 | ## Environment Setup
27 |
28 | To use `roost` you need to create an environment with the correct dependencies. The easiest way to get up and running it to use `Anaconda`.
29 | A `cudatoolkit=11.1` environment file is provided `environment-gpu-cu111.yml` allowing a working environment to be created with:
30 |
31 | ```bash
32 | conda env create -f environment-gpu-cu111.yml
33 | ```
34 |
35 | If you are not using `cudatoolkit=11.1` or do not have access to a GPU this setup will not work for you. If so please check the following pages [PyTorch](https://pytorch.org/get-started/locally/), [PyTorch-Scatter](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) for how install the core packages and then install the remaining requirements as detailed in `requirements.txt`.
36 |
37 | The was developed and tested on Linux Mint 19.1 Tessa. The code should work on with other Operating Systems but it has not been tested for such use.
38 |
39 | ## Roost Setup
40 |
41 | Once you have setup an environment with the correct dependencies you can install `roost` using the following commands:
42 |
43 | ```bash
44 | conda activate roost
45 | git clone https://github.com/CompRhys/roost
46 | cd roost
47 | python setup.py sdist
48 | pip install -e .
49 | ```
50 |
51 | This will install the library in an editable state allowing for advanced users to make changes as desired.
52 |
53 | ## Example Use
54 |
55 | In order to test your installation you can do so by running the following example from the top of your `roost` directory:
56 |
57 | ```sh
58 | cd /path/to/roost/
59 | python examples/roost-example.py --train --evaluate --epochs 10 --tasks regression --targets Eg --losses L2
60 | ```
61 |
62 | This command runs a default task for 10 epochs -- experimental band gap regression using the data from Zhou et al. (See `data/` folder for reference). This default task has been set up to work out of the box without any changes and to give a flavour of how the model can be used. The demo task should take less than a minute when a GPU is available are give a test set MAE of 0.42-0.45 eV after 10 epochs.
63 |
64 | If you want to use your own data set on a regression task this can be done with:
65 |
66 | ```sh
67 | python examples/roost-example.py --data-path /path/to/your/data/data.csv --train --tasks [regression/classification, ...] --targets [, ..] --losses [L1/L2, ..]
68 | ```
69 |
70 | You can then test your model with:
71 |
72 | ```sh
73 | python examples/roost-example.py --test-path /path/to/testset.csv --evaluate --tasks [regression/classification, ...] --targets [, ..] --losses [L1/L2, ..]
74 | ```
75 |
76 | The model takes input in the form csv files with materials-ids, composition strings and target values as the columns.
77 |
78 | | material-id | composition | target |
79 | | ----------- | ----------- | ------ |
80 | | foo-1 | Fe2O3 | 2.3 |
81 | | foo-2 | La2CuO4 | 4.3 |
82 |
83 | Basic hints about more advanced use of the model (i.e. classification, robust losses, ensembles, tensorboard logging etc..)
84 | are available via the command:
85 |
86 | ```sh
87 | python examples/roost-example.py --help
88 | ```
89 |
90 | This will output the various command-line flags that can be used to control the code.
91 |
92 | ## Cite This Work
93 |
94 | If you use this code please cite our work for which this model was built:
95 |
96 | Predicting materials properties without crystal structure: Deep representation learning from stoichiometry. [[Paper]](https://doi.org/10.1038/s41467-020-19964-7) [[arXiv](https://arxiv.org/abs/1910.00617)]
97 |
98 | ```tex
99 | @article{goodall2020predicting,
100 | title={Predicting materials properties without crystal structure: Deep representation learning from stoichiometry},
101 | author={Goodall, Rhys EA and Lee, Alpha A},
102 | journal={Nature Communications},
103 | volume={11},
104 | number={1},
105 | pages={1--9},
106 | year={2020},
107 | publisher={Nature Publishing Group}
108 | }
109 | ```
110 |
111 | ## Work Featuring Roost
112 |
113 | Work using Roost as presented:
114 |
115 | * A critical examination of compound stability predictions from machine-learned formation energies [[Paper]](https://www.nature.com/articles/s41524-020-00362-y) [[arXiv]](https://arxiv.org/abs/2001.10591)
116 |
117 | * Active learning based generative design for the discovery of wide bandgap materials. [[arXiv]](https://arxiv.org/abs/2103.00608)
118 |
119 | * MaterialsAtlas formation energy WebApp -
120 |
121 | Work building-on/using-parts-of the code shared here:
122 |
123 | * Predicting the Outcomes of Material Syntheses with Deep Learning [[Paper]](https://pubs.acs.org/doi/abs/10.1021/acs.chemmater.0c03885)
124 |
125 | * Compositionally restricted attention-based network for materials property predictions [[Paper]](https://www.nature.com/articles/s41524-021-00545-1)
126 |
127 | * Materials Representation and Transfer Learning for Multi-Property Prediction [[arXiv]](https://arxiv.org/abs/2106.02225)
128 |
129 | If you have used Roost in your work please contact me and I will add your paper here.
130 |
131 | ## Disclaimer
132 |
133 | This is research code shared without support or any guarantee on its quality. However, please do raise an issue or submit a pull request if you spot something wrong or that could be improved and I will try my best to solve it.
134 |
--------------------------------------------------------------------------------
/data/datasets/roost/README.md:
--------------------------------------------------------------------------------
1 | # Dataset sources
2 |
3 | ## OQMD
4 | `oqmd-form-enthalpy.csv` - [The Open Quantum Materials Database (OQMD)](http://oqmd.org/)
5 |
6 | ## Materials Project
7 | `mp-non-metals.csv` - [The Materials Project (MP)](https://materialsproject.org/)
8 |
9 | ## Experimental Bandgaps
10 | `expt-non-metals.csv` - [Predicting the Band Gaps of Inorganic Solids by Machine Learning](https://doi.org/10.1021/acs.jpclett.8b00124)
11 |
--------------------------------------------------------------------------------
/data/el-embeddings/README.md:
--------------------------------------------------------------------------------
1 | # Element - Embeddings
2 |
3 | ## CGCNN
4 |
5 | The following paper describes the details of the CGCNN framework, information about how the one-hot atom embedding was constructed is available in the supplementary materials:
6 |
7 | [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://link.aps.org/doi/10.1103/PhysRevLett.120.145301)
8 |
9 | ## MEGnet
10 |
11 | The following paper describes the details of the MEGnet framework, the embedding is generated from the atomic weights through one MEGnet layer on a training task to predict the computed formation energies of ∼69,000 materials from the [Materials Project](https://materialsproject.org/) Database:
12 |
13 | [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/abs/1812.05055)
14 |
15 | ## MatScholar
16 |
17 | This is an experimental NLP embedding based on the data mining of definite compositions and structure prototypes.
18 |
19 | [Unsupervised word embeddings capture latent knowledge from materials science literature](https://www.nature.com/articles/s41586-019-1335-8)
20 |
21 | ## One-hot
22 |
23 | This is a simple one-hot encoding for the elements
24 |
--------------------------------------------------------------------------------
/data/wp-embeddings/README.md:
--------------------------------------------------------------------------------
1 | # Wyckoff - Embeddings
2 |
3 | ## wyk-ohe
4 |
5 | A 1731 dimension OHE of possible wyckoff positions.
6 |
7 | ## spg-letter
8 |
9 | A 230 + 27 dimension embedding consisting of a OHE of the spacegroup and a OHE of the wyckoff position letter.
10 |
11 | ## alg
12 |
13 | A 185 dimension embedding of the sum of whether sites in a wyckoff position sit on lines, planes or in volumes.
14 |
15 | ## alg-off
16 |
17 | A 185 + 248 dimension embedding of the sum of whether sites in a wyckoff position sit on lines, planes or in
18 | volumes and the specific offset of that wyckoff position within the unit cell.
19 |
20 | ## bra-alg-off
21 |
22 | A 6 + 5 + 185 + 248 dimension embedding of the crystal system, unit cell centering, the sum of whether sites
23 | in a wyckoff position sit on lines, planes or in volumes, and the specific offset of that wyckoff position within
24 | the unit cell.
25 |
26 | ## spg-alg-off
27 |
28 | A 230 + 185 + 248 dimension embedding consisting of a OHE of the spacegroup, the sum of whether sites
29 | in a wyckoff position sit on lines, planes or in volumes, and the specific offset of that wyckoff position within
30 | the unit cell.
31 |
--------------------------------------------------------------------------------
/environment-gpu-cu111.yml:
--------------------------------------------------------------------------------
1 | name: roost
2 | channels:
3 | - pytorch
4 | - pyg
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - _openmp_mutex=4.5=1_gnu
11 | - absl-py=0.13.0=py37h06a4308_0
12 | - aiohttp=3.7.4=py37h27cfd23_1
13 | - async-timeout=3.0.1=py37h06a4308_0
14 | - attrs=21.2.0=pyhd3eb1b0_0
15 | - blas=1.0=mkl
16 | - blinker=1.4=py37h06a4308_0
17 | - brotlipy=0.7.0=py37h5e8e339_1001
18 | - c-ares=1.17.1=h27cfd23_0
19 | - ca-certificates=2021.7.5=h06a4308_1
20 | - cachetools=4.2.2=pyhd3eb1b0_0
21 | - certifi=2021.5.30=py37h06a4308_0
22 | - cffi=1.14.6=py37hc58025e_0
23 | - chardet=3.0.4=py37h06a4308_1003
24 | - charset-normalizer=2.0.0=pyhd8ed1ab_0
25 | - click=8.0.1=pyhd3eb1b0_0
26 | - colorama=0.4.4=pyh9f0ad1d_0
27 | - coverage=5.5=py37h27cfd23_2
28 | - cryptography=3.4.7=py37h5d9358c_0
29 | - cudatoolkit=11.1.74=h6bb024c_0
30 | - cython=0.29.24=py37h295c915_0
31 | - decorator=4.4.2=py_0
32 | - future=0.18.2=py37h89c1867_3
33 | - google-auth=1.33.0=pyhd3eb1b0_0
34 | - google-auth-oauthlib=0.4.1=py_2
35 | - googledrivedownloader=0.4=pyhd3deb0d_1
36 | - grpcio=1.36.1=py37h2157cd5_1
37 | - idna=3.1=pyhd3deb0d_0
38 | - importlib-metadata=4.6.4=py37h06a4308_0
39 | - intel-openmp=2021.3.0=h06a4308_3350
40 | - jinja2=3.0.1=pyhd8ed1ab_0
41 | - joblib=1.0.1=pyhd8ed1ab_0
42 | - ld_impl_linux-64=2.35.1=h7274673_9
43 | - libblas=3.9.0=11_linux64_mkl
44 | - libcblas=3.9.0=11_linux64_mkl
45 | - libffi=3.3=he6710b0_2
46 | - libgcc-ng=9.3.0=h5101ec6_17
47 | - libgfortran-ng=11.1.0=h69a702a_8
48 | - libgfortran5=11.1.0=h6c583b3_8
49 | - libgomp=9.3.0=h5101ec6_17
50 | - liblapack=3.9.0=11_linux64_mkl
51 | - libprotobuf=3.17.2=h4ff587b_1
52 | - libstdcxx-ng=9.3.0=hd4cf53a_17
53 | - libuv=1.40.0=h7b6447c_0
54 | - markdown=3.3.4=py37h06a4308_0
55 | - markupsafe=2.0.1=py37h5e8e339_0
56 | - mkl=2021.3.0=h06a4308_520
57 | - multidict=5.1.0=py37h27cfd23_2
58 | - ncurses=6.2=he6710b0_1
59 | - networkx=2.5.1=pyhd8ed1ab_0
60 | - ninja=1.10.2=h4bd325d_0
61 | - numpy=1.20.3=py37h038b26d_1
62 | - oauthlib=3.1.1=pyhd3eb1b0_0
63 | - openssl=1.1.1l=h7f8727e_0
64 | - pandas=1.3.0=py37h219a48f_0
65 | - pip=21.0.1=py37h06a4308_0
66 | - protobuf=3.17.2=py37h295c915_0
67 | - pyasn1=0.4.8=pyhd3eb1b0_0
68 | - pyasn1-modules=0.2.8=py_0
69 | - pycparser=2.20=pyh9f0ad1d_2
70 | - pyg=2.0.0=py37_torch_1.9.0_cu111
71 | - pyjwt=2.1.0=py37h06a4308_0
72 | - pyopenssl=20.0.1=pyhd8ed1ab_0
73 | - pyparsing=2.4.7=pyh9f0ad1d_0
74 | - pysocks=1.7.1=py37h89c1867_3
75 | - python=3.7.10=h12debd9_4
76 | - python-dateutil=2.8.2=pyhd8ed1ab_0
77 | - python-louvain=0.15=pyhd3deb0d_0
78 | - python_abi=3.7=2_cp37m
79 | - pytorch=1.9.0=py3.7_cuda11.1_cudnn8.0.5_0
80 | - pytorch-cluster=1.5.9=py37_torch_1.9.0_cu111
81 | - pytorch-scatter=2.0.8=py37_torch_1.9.0_cu111
82 | - pytorch-sparse=0.6.12=py37_torch_1.9.0_cu111
83 | - pytorch-spline-conv=1.2.1=py37_torch_1.9.0_cu111
84 | - pytz=2021.1=pyhd8ed1ab_0
85 | - pyyaml=5.4.1=py37h5e8e339_0
86 | - readline=8.1=h27cfd23_0
87 | - requests=2.26.0=pyhd8ed1ab_0
88 | - requests-oauthlib=1.3.0=py_0
89 | - rsa=4.7.2=pyhd3eb1b0_1
90 | - scikit-learn=0.24.2=py37h18a542f_0
91 | - scipy=1.6.3=py37h29e03ee_0
92 | - setuptools=52.0.0=py37h06a4308_0
93 | - six=1.16.0=pyh6c4a22f_0
94 | - sleef=3.5.1=h7f98852_1
95 | - sqlite=3.36.0=hc218d9a_0
96 | - tensorboard=2.5.0=py_0
97 | - tensorboard-plugin-wit=1.6.0=py_0
98 | - threadpoolctl=2.2.0=pyh8a188c0_0
99 | - tk=8.6.10=hbc83047_0
100 | - tqdm=4.62.2=pyhd8ed1ab_0
101 | - typing-extensions=3.10.0.0=hd3eb1b0_0
102 | - typing_extensions=3.10.0.0=pyh06a4308_0
103 | - urllib3=1.26.6=pyhd8ed1ab_0
104 | - werkzeug=1.0.1=pyhd3eb1b0_0
105 | - wheel=0.37.0=pyhd3eb1b0_1
106 | - xz=5.2.5=h7b6447c_0
107 | - yacs=0.1.6=py_0
108 | - yaml=0.2.5=h516909a_0
109 | - yarl=1.6.3=py37h27cfd23_0
110 | - zipp=3.5.0=pyhd3eb1b0_0
111 | - zlib=1.2.11=h7b6447c_3
112 | - pip:
113 | - cycler==0.10.0
114 | - kiwisolver==1.3.2
115 | - matplotlib==3.4.3
116 | - monty==2021.8.17
117 | - mpmath==1.2.1
118 | - palettable==3.3.0
119 | - pillow==8.3.2
120 | - plotly==5.3.1
121 | - pymatgen==2022.0.14
122 | - ruamel-yaml==0.17.16
123 | - ruamel-yaml-clib==0.2.6
124 | - spglib==1.16.2
125 | - sympy==1.8
126 | - tabulate==0.8.9
127 | - tenacity==8.0.1
128 | - uncertainties==3.1.6
129 | prefix: /home/reag2/miniconda3/envs/roost
130 |
--------------------------------------------------------------------------------
/examples/roost-example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | import torch
6 | from sklearn.model_selection import train_test_split as split
7 |
8 | from roost.roost.data import CompositionData, collate_batch
9 | from roost.roost.model import Roost
10 | from roost.utils import results_multitask, train_ensemble
11 |
12 |
13 | def main(
14 | data_path,
15 | fea_path,
16 | targets,
17 | tasks,
18 | losses,
19 | robust,
20 | model_name="roost",
21 | elem_fea_len=64,
22 | n_graph=3,
23 | ensemble=1,
24 | run_id=1,
25 | data_seed=42,
26 | epochs=100,
27 | patience=None,
28 | log=True,
29 | sample=1,
30 | test_size=0.2,
31 | test_path=None,
32 | val_size=0.0,
33 | val_path=None,
34 | resume=None,
35 | fine_tune=None,
36 | transfer=None,
37 | train=True,
38 | evaluate=True,
39 | optim="AdamW",
40 | learning_rate=3e-4,
41 | momentum=0.9,
42 | weight_decay=1e-6,
43 | batch_size=128,
44 | workers=0,
45 | device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
46 | **kwargs,
47 | ):
48 | assert len(targets) == len(tasks) == len(losses)
49 |
50 | assert (
51 | evaluate or train
52 | ), "No action given - At least one of 'train' or 'evaluate' cli flags required"
53 |
54 | if test_path:
55 | test_size = 0.0
56 |
57 | if not (test_path and val_path):
58 | assert test_size + val_size < 1.0, (
59 | f"'test_size'({test_size}) "
60 | f"plus 'val_size'({val_size}) must be less than 1"
61 | )
62 |
63 | if ensemble > 1 and (fine_tune or transfer):
64 | raise NotImplementedError(
65 | "If training an ensemble with fine tuning or transferring"
66 | " options the models must be trained one by one using the"
67 | " run-id flag."
68 | )
69 |
70 | assert not (fine_tune and transfer), (
71 | "Cannot fine-tune and" " transfer checkpoint(s) at the same time."
72 | )
73 |
74 | task_dict = {k: v for k, v in zip(targets, tasks)}
75 | loss_dict = {k: v for k, v in zip(targets, losses)}
76 |
77 | dataset = CompositionData(
78 | data_path=data_path, fea_path=fea_path, task_dict=task_dict
79 | )
80 | n_targets = dataset.n_targets
81 | elem_emb_len = dataset.elem_emb_len
82 |
83 | train_idx = list(range(len(dataset)))
84 |
85 | if evaluate:
86 | if test_path:
87 | print(f"using independent test set: {test_path}")
88 | test_set = CompositionData(
89 | data_path=test_path, fea_path=fea_path, task_dict=task_dict
90 | )
91 | test_set = torch.utils.data.Subset(test_set, range(len(test_set)))
92 | elif test_size == 0.0:
93 | raise ValueError("test-size must be non-zero to evaluate model")
94 | else:
95 | print(f"using {test_size} of training set as test set")
96 | train_idx, test_idx = split(
97 | train_idx, random_state=data_seed, test_size=test_size
98 | )
99 | test_set = torch.utils.data.Subset(dataset, test_idx)
100 |
101 | if train:
102 | if val_path:
103 | print(f"using independent validation set: {val_path}")
104 | val_set = CompositionData(
105 | data_path=val_path, fea_path=fea_path, task_dict=task_dict
106 | )
107 | val_set = torch.utils.data.Subset(val_set, range(len(val_set)))
108 | else:
109 | if val_size == 0.0 and evaluate:
110 | print("No validation set used, using test set for evaluation purposes")
111 | # NOTE that when using this option care must be taken not to
112 | # peak at the test-set. The only valid model to use is the one
113 | # obtained after the final epoch where the epoch count is
114 | # decided in advance of the experiment.
115 | val_set = test_set
116 | elif val_size == 0.0:
117 | val_set = None
118 | else:
119 | print(f"using {val_size} of training set as validation set")
120 | train_idx, val_idx = split(
121 | train_idx, random_state=data_seed, test_size=val_size / (1 - test_size),
122 | )
123 | val_set = torch.utils.data.Subset(dataset, val_idx)
124 |
125 | train_set = torch.utils.data.Subset(dataset, train_idx[0::sample])
126 |
127 | data_params = {
128 | "batch_size": batch_size,
129 | "num_workers": workers,
130 | "pin_memory": False,
131 | "shuffle": True,
132 | "collate_fn": collate_batch,
133 | }
134 |
135 | setup_params = {
136 | "optim": optim,
137 | "learning_rate": learning_rate,
138 | "weight_decay": weight_decay,
139 | "momentum": momentum,
140 | "device": device,
141 | }
142 |
143 | if resume:
144 | resume = f"models/{model_name}/checkpoint-r{run_id}.pth.tar"
145 |
146 | restart_params = {
147 | "resume": resume,
148 | "fine_tune": fine_tune,
149 | "transfer": transfer,
150 | }
151 |
152 | model_params = {
153 | "task_dict": task_dict,
154 | "robust": robust,
155 | "n_targets": n_targets,
156 | "elem_emb_len": elem_emb_len,
157 | "elem_fea_len": elem_fea_len,
158 | "n_graph": n_graph,
159 | "elem_heads": 3,
160 | "elem_gate": [256],
161 | "elem_msg": [256],
162 | "cry_heads": 3,
163 | "cry_gate": [256],
164 | "cry_msg": [256],
165 | "trunk_hidden": [1024, 512],
166 | "out_hidden": [256, 128, 64],
167 | }
168 |
169 | os.makedirs(f"models/{model_name}/", exist_ok=True)
170 |
171 | if log:
172 | os.makedirs("runs/", exist_ok=True)
173 |
174 | os.makedirs("results/", exist_ok=True)
175 |
176 | # TODO dump all args/kwargs to a file for reproducibility.
177 |
178 | if train:
179 | train_ensemble(
180 | model_class=Roost,
181 | model_name=model_name,
182 | run_id=run_id,
183 | ensemble_folds=ensemble,
184 | epochs=epochs,
185 | patience=patience,
186 | train_set=train_set,
187 | val_set=val_set,
188 | log=log,
189 | data_params=data_params,
190 | setup_params=setup_params,
191 | restart_params=restart_params,
192 | model_params=model_params,
193 | loss_dict=loss_dict,
194 | )
195 |
196 | if evaluate:
197 |
198 | data_reset = {
199 | "batch_size": 16 * batch_size, # faster model inference
200 | "shuffle": False, # need fixed data order due to ensembling
201 | }
202 | data_params.update(data_reset)
203 |
204 | results_multitask(
205 | model_class=Roost,
206 | model_name=model_name,
207 | run_id=run_id,
208 | ensemble_folds=ensemble,
209 | test_set=test_set,
210 | data_params=data_params,
211 | robust=robust,
212 | task_dict=task_dict,
213 | device=device,
214 | eval_type="checkpoint",
215 | )
216 |
217 |
218 | def input_parser():
219 | """
220 | parse input
221 | """
222 | parser = argparse.ArgumentParser(
223 | description=(
224 | "Roost - a Structure Agnostic Message Passing "
225 | "Neural Network for Inorganic Materials"
226 | )
227 | )
228 |
229 | # data inputs
230 | parser.add_argument(
231 | "--data-path",
232 | type=str,
233 | default="data/datasets/roost/expt-non-metals.csv",
234 | metavar="PATH",
235 | help="Path to main data set/training set",
236 | )
237 | valid_group = parser.add_mutually_exclusive_group()
238 | valid_group.add_argument(
239 | "--val-path",
240 | type=str,
241 | metavar="PATH",
242 | help="Path to independent validation set",
243 | )
244 | valid_group.add_argument(
245 | "--val-size",
246 | default=0.0,
247 | type=float,
248 | metavar="FLOAT",
249 | help="Proportion of data used for validation",
250 | )
251 | test_group = parser.add_mutually_exclusive_group()
252 | test_group.add_argument(
253 | "--test-path", type=str, metavar="PATH", help="Path to independent test set"
254 | )
255 | test_group.add_argument(
256 | "--test-size",
257 | default=0.2,
258 | type=float,
259 | metavar="FLOAT",
260 | help="Proportion of data set for testing",
261 | )
262 |
263 | # data embeddings
264 | parser.add_argument(
265 | "--fea-path",
266 | type=str,
267 | default="data/el-embeddings/matscholar-embedding.json",
268 | metavar="PATH",
269 | help="Element embedding feature path",
270 | )
271 |
272 | # dataloader inputs
273 | parser.add_argument(
274 | "--workers",
275 | default=0,
276 | type=int,
277 | metavar="INT",
278 | help="Number of data loading workers (default: 0)",
279 | )
280 | parser.add_argument(
281 | "--batch-size",
282 | "--bsize",
283 | default=128,
284 | type=int,
285 | metavar="INT",
286 | help="Mini-batch size (default: 128)",
287 | )
288 | parser.add_argument(
289 | "--data-seed",
290 | default=0,
291 | type=int,
292 | metavar="INT",
293 | help="Seed used when splitting data sets (default: 0)",
294 | )
295 | parser.add_argument(
296 | "--sample",
297 | default=1,
298 | type=int,
299 | metavar="INT",
300 | help="Sub-sample the training set for learning curves",
301 | )
302 |
303 | # task inputs
304 | parser.add_argument(
305 | "--targets",
306 | nargs="*",
307 | type=str,
308 | metavar="STR",
309 | help="Task types for targets",
310 | )
311 |
312 | parser.add_argument(
313 | "--tasks",
314 | nargs="*",
315 | default=["regression"],
316 | type=str,
317 | metavar="STR",
318 | help="Task types for targets",
319 | )
320 |
321 | parser.add_argument(
322 | "--losses",
323 | nargs="*",
324 | default=["L1"],
325 | type=str,
326 | metavar="STR",
327 | help="Loss function if regression (default: 'L1')",
328 | )
329 |
330 | # optimiser inputs
331 | parser.add_argument(
332 | "--epochs",
333 | default=100,
334 | type=int,
335 | metavar="INT",
336 | help="Number of training epochs to run (default: 100)",
337 | )
338 | parser.add_argument(
339 | "--robust",
340 | action="store_true",
341 | help="Specifies whether to use hetroskedastic loss variants",
342 | )
343 | parser.add_argument(
344 | "--optim",
345 | default="AdamW",
346 | type=str,
347 | metavar="STR",
348 | help="Optimizer used for training (default: 'AdamW')",
349 | )
350 | parser.add_argument(
351 | "--learning-rate",
352 | "--lr",
353 | default=3e-4,
354 | type=float,
355 | metavar="FLOAT",
356 | help="Initial learning rate (default: 3e-4)",
357 | )
358 | parser.add_argument(
359 | "--momentum",
360 | default=0.9,
361 | type=float,
362 | metavar="FLOAT [0,1]",
363 | help="Optimizer momentum (default: 0.9)",
364 | )
365 | parser.add_argument(
366 | "--weight-decay",
367 | default=1e-6,
368 | type=float,
369 | metavar="FLOAT [0,1]",
370 | help="Optimizer weight decay (default: 1e-6)",
371 | )
372 |
373 | # graph inputs
374 | parser.add_argument(
375 | "--elem-fea-len",
376 | default=64,
377 | type=int,
378 | metavar="INT",
379 | help="Number of hidden features for elements (default: 64)",
380 | )
381 | parser.add_argument(
382 | "--n-graph",
383 | default=3,
384 | type=int,
385 | metavar="INT",
386 | help="Number of message passing layers (default: 3)",
387 | )
388 |
389 | # ensemble inputs
390 | parser.add_argument(
391 | "--ensemble",
392 | default=1,
393 | type=int,
394 | metavar="INT",
395 | help="Number models to ensemble",
396 | )
397 | name_group = parser.add_mutually_exclusive_group()
398 | name_group.add_argument(
399 | "--model-name",
400 | type=str,
401 | default=None,
402 | metavar="STR",
403 | help="Name for sub-directory where models will be stored",
404 | )
405 | name_group.add_argument(
406 | "--data-id",
407 | default="roost",
408 | type=str,
409 | metavar="STR",
410 | help="Partial identifier for sub-directory where models will be stored",
411 | )
412 | parser.add_argument(
413 | "--run-id",
414 | default=0,
415 | type=int,
416 | metavar="INT",
417 | help="Index for model in an ensemble of models",
418 | )
419 |
420 | # restart inputs
421 | use_group = parser.add_mutually_exclusive_group()
422 | use_group.add_argument(
423 | "--fine-tune", type=str, metavar="PATH", help="Checkpoint path for fine tuning"
424 | )
425 | use_group.add_argument(
426 | "--transfer",
427 | type=str,
428 | metavar="PATH",
429 | help="Checkpoint path for transfer learning",
430 | )
431 | use_group.add_argument(
432 | "--resume", action="store_true", help="Resume from previous checkpoint"
433 | )
434 |
435 | # task type
436 | parser.add_argument(
437 | "--evaluate",
438 | action="store_true",
439 | help="Evaluate the model/ensemble",
440 | )
441 | parser.add_argument("--train", action="store_true", help="Train the model/ensemble")
442 |
443 | # misc
444 | parser.add_argument("--disable-cuda", action="store_true", help="Disable CUDA")
445 | parser.add_argument(
446 | "--log", action="store_true", help="Log training metrics to tensorboard"
447 | )
448 |
449 | args = parser.parse_args(sys.argv[1:])
450 |
451 | assert all(
452 | [i in ["regression", "classification"] for i in args.tasks]
453 | ), "Only `regression` and `classification` are allowed as tasks"
454 |
455 | if args.model_name is None:
456 | args.model_name = f"{args.data_id}_s-{args.data_seed}_t-{args.sample}"
457 |
458 | args.device = (
459 | torch.device("cuda")
460 | if (not args.disable_cuda) and torch.cuda.is_available()
461 | else torch.device("cpu")
462 | )
463 |
464 | return args
465 |
466 |
467 | if __name__ == "__main__":
468 | args = input_parser()
469 |
470 | print(f"The model will run on the {args.device} device")
471 |
472 | main(**vars(args))
473 |
--------------------------------------------------------------------------------
/examples/wren-example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | import torch
6 | from sklearn.model_selection import train_test_split as split
7 |
8 | from roost.utils import results_multitask, train_ensemble
9 | from roost.wren.data import WyckoffData, collate_batch
10 | from roost.wren.model import Wren
11 |
12 |
13 | def main(
14 | data_path,
15 | fea_path,
16 | sym_path,
17 | targets,
18 | tasks,
19 | losses,
20 | robust,
21 | model_name="wren",
22 | sym_fea_len=32,
23 | elem_fea_len=32,
24 | n_graph=3,
25 | ensemble=1,
26 | run_id=1,
27 | data_seed=42,
28 | epochs=100,
29 | patience=None,
30 | log=True,
31 | sample=1,
32 | test_size=0.2,
33 | test_path=None,
34 | val_size=0.0,
35 | val_path=None,
36 | resume=None,
37 | fine_tune=None,
38 | transfer=None,
39 | train=True,
40 | evaluate=True,
41 | optim="AdamW",
42 | learning_rate=3e-4,
43 | momentum=0.9,
44 | weight_decay=1e-6,
45 | batch_size=128,
46 | workers=0,
47 | device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
48 | **kwargs,
49 | ):
50 |
51 | assert len(targets) == len(tasks) == len(losses)
52 |
53 | assert (
54 | evaluate or train
55 | ), "No action given - At least one of 'train' or 'evaluate' cli flags required"
56 |
57 | if test_path:
58 | test_size = 0.0
59 |
60 | if not (test_path and val_path):
61 | assert test_size + val_size < 1.0, (
62 | f"'test_size'({test_size}) "
63 | f"plus 'val_size'({val_size}) must be less than 1"
64 | )
65 |
66 | if ensemble > 1 and (fine_tune or transfer):
67 | raise NotImplementedError(
68 | "If training an ensemble with fine tuning or transferring"
69 | " options the models must be trained one by one using the"
70 | " run-id flag."
71 | )
72 |
73 | assert not (fine_tune and transfer), (
74 | "Cannot fine-tune and" " transfer checkpoint(s) at the same time."
75 | )
76 |
77 | # TODO CLI controls for loss dict.
78 |
79 | task_dict = {k: v for k, v in zip(targets, tasks)}
80 | loss_dict = {k: v for k, v in zip(targets, losses)}
81 |
82 | dataset = WyckoffData(
83 | data_path=data_path, fea_path=fea_path, sym_path=sym_path, task_dict=task_dict
84 | )
85 | n_targets = dataset.n_targets
86 | sym_emb_len = dataset.sym_fea_dim
87 | elem_emb_len = dataset.elem_fea_dim
88 |
89 | train_idx = list(range(len(dataset)))
90 |
91 | if evaluate:
92 | if test_path:
93 | print(f"using independent test set: {test_path}")
94 | test_set = WyckoffData(
95 | data_path=test_path,
96 | fea_path=fea_path,
97 | sym_path=sym_path,
98 | task_dict=task_dict,
99 | )
100 | test_set = torch.utils.data.Subset(test_set, range(len(test_set)))
101 | elif test_size == 0.0:
102 | raise ValueError("test-size must be non-zero to evaluate model")
103 | else:
104 | print(f"using {test_size} of training set as test set")
105 | train_idx, test_idx = split(
106 | train_idx, random_state=data_seed, test_size=test_size
107 | )
108 | test_set = torch.utils.data.Subset(dataset, test_idx)
109 |
110 | if train:
111 | if val_path:
112 | print(f"using independent validation set: {val_path}")
113 | val_set = WyckoffData(
114 | data_path=val_path,
115 | fea_path=fea_path,
116 | sym_path=sym_path,
117 | task_dict=task_dict,
118 | )
119 | val_set = torch.utils.data.Subset(val_set, range(len(val_set)))
120 | else:
121 | if val_size == 0.0 and evaluate:
122 | print("No validation set used, using test set for evaluation purposes")
123 | # NOTE that when using this option care must be taken not to
124 | # peak at the test-set. The only valid model to use is the one
125 | # obtained after the final epoch where the epoch count is
126 | # decided in advance of the experiment.
127 | val_set = test_set
128 | elif val_size == 0.0:
129 | val_set = None
130 | else:
131 | print(f"using {val_size} of training set as validation set")
132 | train_idx, val_idx = split(
133 | train_idx,
134 | random_state=data_seed,
135 | test_size=val_size / (1 - test_size),
136 | )
137 | val_set = torch.utils.data.Subset(dataset, val_idx)
138 |
139 | train_set = torch.utils.data.Subset(dataset, train_idx[0::sample])
140 |
141 | data_params = {
142 | "batch_size": batch_size,
143 | "num_workers": workers,
144 | "pin_memory": False,
145 | "shuffle": True,
146 | "collate_fn": collate_batch,
147 | }
148 |
149 | setup_params = {
150 | "optim": optim,
151 | "learning_rate": learning_rate,
152 | "weight_decay": weight_decay,
153 | "momentum": momentum,
154 | "device": device,
155 | }
156 |
157 | if resume:
158 | resume = f"models/{model_name}/checkpoint-r{run_id}.pth.tar"
159 |
160 | restart_params = {
161 | "resume": resume,
162 | "fine_tune": fine_tune,
163 | "transfer": transfer,
164 | }
165 |
166 | model_params = {
167 | "task_dict": task_dict,
168 | "robust": robust,
169 | "n_targets": n_targets,
170 | "elem_emb_len": elem_emb_len,
171 | "sym_emb_len": sym_emb_len,
172 | "elem_fea_len": elem_fea_len,
173 | "sym_fea_len": sym_fea_len,
174 | "n_graph": n_graph,
175 | "elem_heads": 1,
176 | "elem_gate": [256],
177 | "elem_msg": [256],
178 | "cry_heads": 1,
179 | "cry_gate": [256],
180 | "cry_msg": [256],
181 | # "out_hidden": [256] * 6,
182 | # "out_hidden": [1024, 512, 256, 128, 64],
183 | "out_hidden": [256, 256],
184 | "trunk_hidden": [128, 64],
185 | }
186 |
187 | os.makedirs(f"models/{model_name}/", exist_ok=True)
188 |
189 | if log:
190 | os.makedirs("runs/", exist_ok=True)
191 |
192 | os.makedirs("results/", exist_ok=True)
193 |
194 | if train:
195 | train_ensemble(
196 | model_class=Wren,
197 | model_name=model_name,
198 | run_id=run_id,
199 | ensemble_folds=ensemble,
200 | epochs=epochs,
201 | patience=patience,
202 | train_set=train_set,
203 | val_set=val_set,
204 | log=log,
205 | data_params=data_params,
206 | setup_params=setup_params,
207 | restart_params=restart_params,
208 | model_params=model_params,
209 | loss_dict=loss_dict,
210 | )
211 |
212 | if evaluate:
213 |
214 | data_reset = {
215 | "batch_size": 16 * batch_size, # faster model inference
216 | "shuffle": False, # need fixed data order due to ensembling
217 | }
218 | data_params.update(data_reset)
219 |
220 | results_multitask(
221 | model_class=Wren,
222 | model_name=model_name,
223 | run_id=run_id,
224 | ensemble_folds=ensemble,
225 | test_set=test_set,
226 | data_params=data_params,
227 | robust=robust,
228 | task_dict=task_dict,
229 | device=device,
230 | eval_type="checkpoint",
231 | )
232 |
233 |
234 | def input_parser():
235 | """
236 | parse input
237 | """
238 | parser = argparse.ArgumentParser(description=("Wren"))
239 |
240 | # data inputs
241 | parser.add_argument(
242 | "--data-path",
243 | type=str,
244 | default="/home/reag2/PhD/roost/data/datasets/wren/taata-c-spglib-test.csv",
245 | metavar="PATH",
246 | help="Path to main data set/training set",
247 | )
248 | valid_group = parser.add_mutually_exclusive_group()
249 | valid_group.add_argument(
250 | "--val-path",
251 | type=str,
252 | metavar="PATH",
253 | help="Path to independent validation set",
254 | )
255 | valid_group.add_argument(
256 | "--val-size",
257 | default=0.0,
258 | type=float,
259 | metavar="FLOAT",
260 | help="Proportion of data used for validation",
261 | )
262 | test_group = parser.add_mutually_exclusive_group()
263 | test_group.add_argument(
264 | "--test-path", type=str, metavar="PATH", help="Path to independent test set"
265 | )
266 | test_group.add_argument(
267 | "--test-size",
268 | default=0.2,
269 | type=float,
270 | metavar="FLOAT",
271 | help="Proportion of data set for testing",
272 | )
273 |
274 | # data embeddings
275 | parser.add_argument(
276 | "--fea-path",
277 | type=str,
278 | default="data/el-embeddings/matscholar-embedding.json",
279 | metavar="PATH",
280 | help="Element embedding feature path",
281 | )
282 | parser.add_argument(
283 | "--sym-path",
284 | type=str,
285 | default="data/wp-embeddings/bra-alg-off.json",
286 | metavar="PATH",
287 | help="Element embedding feature path",
288 | )
289 |
290 | # dataloader inputs
291 | parser.add_argument(
292 | "--workers",
293 | default=0,
294 | type=int,
295 | metavar="INT",
296 | help="Number of data loading workers (default: 0)",
297 | )
298 | parser.add_argument(
299 | "--batch-size",
300 | "--bsize",
301 | default=128,
302 | type=int,
303 | metavar="INT",
304 | help="Mini-batch size (default: 128)",
305 | )
306 | parser.add_argument(
307 | "--data-seed",
308 | default=0,
309 | type=int,
310 | metavar="INT",
311 | help="Seed used when splitting data sets (default: 0)",
312 | )
313 | parser.add_argument(
314 | "--sample",
315 | default=1,
316 | type=int,
317 | metavar="INT",
318 | help="Sub-sample the training set for learning curves",
319 | )
320 |
321 | # task inputs
322 | parser.add_argument(
323 | "--targets",
324 | nargs="*",
325 | type=str,
326 | metavar="STR",
327 | help="Task types for targets",
328 | )
329 | parser.add_argument(
330 | "--tasks",
331 | nargs="*",
332 | default=["regression"],
333 | type=str,
334 | metavar="STR",
335 | help="Task types for targets",
336 | )
337 | parser.add_argument(
338 | "--losses",
339 | nargs="*",
340 | default=["L1"],
341 | type=str,
342 | metavar="STR",
343 | help="Loss function if regression (default: 'L1')",
344 | )
345 |
346 | # optimiser inputs
347 | parser.add_argument(
348 | "--epochs",
349 | default=100,
350 | type=int,
351 | metavar="INT",
352 | help="Number of training epochs to run (default: 100)",
353 | )
354 | parser.add_argument(
355 | "--robust",
356 | action="store_true",
357 | help="Specifies whether to use hetroskedastic loss variants",
358 | )
359 | parser.add_argument(
360 | "--optim",
361 | default="AdamW",
362 | type=str,
363 | metavar="STR",
364 | help="Optimizer used for training (default: 'AdamW')",
365 | )
366 | parser.add_argument(
367 | "--learning-rate",
368 | "--lr",
369 | default=3e-4,
370 | type=float,
371 | metavar="FLOAT",
372 | help="Initial learning rate (default: 3e-4)",
373 | )
374 | parser.add_argument(
375 | "--momentum",
376 | default=0.9,
377 | type=float,
378 | metavar="FLOAT [0,1]",
379 | help="Optimizer momentum (default: 0.9)",
380 | )
381 | parser.add_argument(
382 | "--weight-decay",
383 | default=1e-6,
384 | type=float,
385 | metavar="FLOAT [0,1]",
386 | help="Optimizer weight decay (default: 1e-6)",
387 | )
388 |
389 | # graph inputs
390 | parser.add_argument(
391 | "--elem-fea-len",
392 | default=32,
393 | type=int,
394 | metavar="INT",
395 | help="Number of hidden features for elements (default: 64)",
396 | )
397 | parser.add_argument(
398 | "--sym-fea-len",
399 | default=32,
400 | type=int,
401 | metavar="INT",
402 | help="Number of hidden features for elements (default: 64)",
403 | )
404 | parser.add_argument(
405 | "--n-graph",
406 | default=3,
407 | type=int,
408 | metavar="INT",
409 | help="Number of message passing layers (default: 3)",
410 | )
411 |
412 | # ensemble inputs
413 | parser.add_argument(
414 | "--ensemble",
415 | default=1,
416 | type=int,
417 | metavar="INT",
418 | help="Number models to ensemble",
419 | )
420 | name_group = parser.add_mutually_exclusive_group()
421 | name_group.add_argument(
422 | "--model-name",
423 | type=str,
424 | default=None,
425 | metavar="STR",
426 | help="Name for sub-directory where models will be stored",
427 | )
428 | name_group.add_argument(
429 | "--data-id",
430 | default="wren",
431 | type=str,
432 | metavar="STR",
433 | help="Partial identifier for sub-directory where models will be stored",
434 | )
435 | parser.add_argument(
436 | "--run-id",
437 | default=0,
438 | type=int,
439 | metavar="INT",
440 | help="Index for model in an ensemble of models",
441 | )
442 |
443 | # restart inputs
444 | use_group = parser.add_mutually_exclusive_group()
445 | use_group.add_argument(
446 | "--fine-tune", type=str, metavar="PATH", help="Checkpoint path for fine tuning"
447 | )
448 | use_group.add_argument(
449 | "--transfer",
450 | type=str,
451 | metavar="PATH",
452 | help="Checkpoint path for transfer learning",
453 | )
454 | use_group.add_argument(
455 | "--resume", action="store_true", help="Resume from previous checkpoint"
456 | )
457 |
458 | # task type
459 | parser.add_argument(
460 | "--evaluate",
461 | action="store_true",
462 | help="Evaluate the model/ensemble",
463 | )
464 | parser.add_argument("--train", action="store_true", help="Train the model/ensemble")
465 |
466 | # misc
467 | parser.add_argument("--disable-cuda", action="store_true", help="Disable CUDA")
468 | parser.add_argument(
469 | "--log", action="store_true", help="Log training metrics to tensorboard"
470 | )
471 |
472 | args = parser.parse_args(sys.argv[1:])
473 |
474 | assert all(
475 | [i in ["regression", "classification"] for i in args.tasks]
476 | ), "Only `regression` and `classification` are allowed as tasks"
477 |
478 | if args.model_name is None:
479 | args.model_name = f"{args.data_id}_s-{args.data_seed}_t-{args.sample}"
480 |
481 | args.device = (
482 | torch.device("cuda")
483 | if (not args.disable_cuda) and torch.cuda.is_available()
484 | else torch.device("cpu")
485 | )
486 |
487 | return args
488 |
489 |
490 | if __name__ == "__main__":
491 | args = input_parser()
492 |
493 | print(f"The model will run on the {args.device} device")
494 |
495 | main(**vars(args))
496 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy
2 | tqdm
3 | torch
4 | numpy
5 | pymatgen
6 | torch_scatter
7 | pandas
8 | scikit_learn
9 | pytest
10 | tensorboard
11 |
--------------------------------------------------------------------------------
/roost/cgcnn/data.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import functools
3 | import os
4 | from itertools import groupby
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | from pymatgen.core.structure import Structure
10 | from torch.utils.data import Dataset
11 |
12 | from roost.core import Featurizer
13 |
14 |
15 | class CrystalGraphData(Dataset):
16 | def __init__(
17 | self,
18 | data_path,
19 | fea_path,
20 | task_dict,
21 | inputs=["lattice", "sites"],
22 | identifiers=["material_id", "composition"],
23 | radius=5,
24 | max_num_nbr=12,
25 | dmin=0,
26 | step=0.2,
27 | ):
28 | """CrystalGraphData returns neighbourhood graphs
29 |
30 | Args:
31 | data_path (str): The path to the dataset
32 | fea_path (str): The path to the element embedding
33 | task_dict ({target: task}): task dict for multi-task learning
34 | inputs (list, optional): df columns for lattice and sites.
35 | Defaults to ["lattice", "sites"].
36 | identifiers (list, optional): df columns for distinguishing data points.
37 | Defaults to ["material_id", "composition"].
38 | radius (int, optional): cut-off radius for neighbourhood.
39 | Defaults to 5.
40 | max_num_nbr (int, optional): maximum number of neighbours to consider.
41 | Defaults to 12.
42 | dmin (int, optional): minimum distance in gaussian basis.
43 | Defaults to 0.
44 | step (float, optional): increment size of gaussian basis.
45 | Defaults to 0.2.
46 | """
47 | assert len(identifiers) == 2, "Two identifiers are required"
48 | assert len(inputs) == 2, "One input column required are required"
49 |
50 | self.inputs = inputs
51 | self.task_dict = task_dict
52 | self.identifiers = identifiers
53 |
54 | self.radius = radius
55 | self.max_num_nbr = max_num_nbr
56 |
57 | assert os.path.exists(fea_path), f"{fea_path} does not exist!"
58 | self.ari = Featurizer.from_json(fea_path)
59 | self.elem_fea_dim = self.ari.embedding_size
60 |
61 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
62 | self.nbr_fea_dim = self.gdf.embedding_size
63 |
64 | assert os.path.exists(data_path), f"{data_path} does not exist!"
65 | # NOTE make sure to use dense datasets, here do not use the default na
66 | # as they can clash with "NaN" which is a valid material
67 | self.df = pd.read_csv(
68 | data_path, keep_default_na=False, na_values=[], comment="#"
69 | )
70 |
71 | self.df["Structure_obj"] = self.df[inputs].apply(get_structure, axis=1)
72 |
73 | self._pre_check()
74 |
75 | self.n_targets = []
76 | for target in self.task_dict:
77 | if self.task_dict[target] == "regression":
78 | self.n_targets.append(1)
79 | elif self.task == "classification":
80 | n_classes = np.max(self.df[target].values) + 1
81 | self.n_targets.append(n_classes)
82 |
83 | def __len__(self):
84 | return len(self.df)
85 |
86 | def _get_nbr_data(self, crystal):
87 | """get neighbours for every site
88 |
89 | Args:
90 | crystal ([Structure]): pymatgen structure to get neighbours for
91 | """
92 | self_idx, nbr_idx, _, nbr_dist = crystal.get_neighbor_list(
93 | self.radius,
94 | numerical_tol=1e-8,
95 | )
96 |
97 | if self.max_num_nbr is not None:
98 | _self_idx, _nbr_idx, _nbr_dist = [], [], []
99 |
100 | for i, g in groupby(zip(self_idx, nbr_idx, nbr_dist), key=lambda x: x[0]):
101 | s, n, d = zip(*sorted(g, key=lambda x: x[2]))
102 | _self_idx.extend(s[: self.max_num_nbr])
103 | _nbr_idx.extend(n[: self.max_num_nbr])
104 | _nbr_dist.extend(d[: self.max_num_nbr])
105 |
106 | self_idx = np.array(_self_idx)
107 | nbr_idx = np.array(_nbr_idx)
108 | nbr_dist = np.array(_nbr_dist)
109 |
110 | return self_idx, nbr_idx, nbr_dist
111 |
112 | def _pre_check(self):
113 | """Check that none of the structures have isolated atoms.
114 |
115 | Raises:
116 | ValueError: if isolated structures are present
117 | """
118 | print("Precheck that all structures are valid")
119 | all_iso = []
120 | some_iso = []
121 |
122 | for cif_id, crystal in zip(self.df["material_id"], self.df["Structure_obj"]):
123 | self_idx, nbr_idx, _ = self._get_nbr_data(crystal)
124 |
125 | if len(self_idx) == 0:
126 | all_iso.append(cif_id)
127 | elif len(nbr_idx) == 0:
128 | all_iso.append(cif_id)
129 | elif set(self_idx) != set(range(crystal.num_sites)):
130 | some_iso.append(cif_id)
131 |
132 | if (len(all_iso) > 0) or (len(some_iso) > 0):
133 | # drop the data points that do not give rise to dense crystal graphs
134 | self.df = self.df.drop(self.df[self.df["material_id"].isin(all_iso)].index)
135 | self.df = self.df.drop(self.df[self.df["material_id"].isin(some_iso)].index)
136 |
137 | print(all_iso)
138 | print(some_iso)
139 |
140 | @functools.lru_cache(maxsize=None) # Cache loaded structures
141 | def __getitem__(self, idx):
142 | # NOTE sites must be given in fractional coordinates
143 | df_idx = self.df.iloc[idx]
144 | crystal = df_idx["Structure_obj"]
145 | cif_id, comp = df_idx[self.identifiers]
146 |
147 | # atom features for disordered sites
148 | site_atoms = [atom.species.as_dict() for atom in crystal]
149 | atom_fea = np.vstack(
150 | [
151 | np.sum([self.ari.get_fea(el) * amt for el, amt in site.items()], axis=0)
152 | for site in site_atoms
153 | ]
154 | )
155 |
156 | # # # neighbours
157 | self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal)
158 |
159 | assert len(self_idx), f"All atoms in {cif_id} are isolated"
160 | assert len(nbr_idx), f"This should not be triggered but was for {cif_id}"
161 | assert set(self_idx) == set(
162 | range(crystal.num_sites)
163 | ), f"At least one atom in {cif_id} is isolated"
164 |
165 | nbr_dist = self.gdf.expand(nbr_dist)
166 |
167 | atom_fea = torch.Tensor(atom_fea)
168 | nbr_dist = torch.Tensor(nbr_dist)
169 | self_idx = torch.LongTensor(self_idx)
170 | nbr_idx = torch.LongTensor(nbr_idx)
171 |
172 | targets = []
173 | for target in self.task_dict:
174 | if self.task_dict[target] == "regression":
175 | targets.append(torch.Tensor([df_idx[target]]))
176 | elif self.task_dict[target] == "classification":
177 | targets.append(torch.LongTensor([df_idx[target]]))
178 |
179 | return ((atom_fea, nbr_dist, self_idx, nbr_idx), targets, comp, cif_id)
180 |
181 |
182 | class GaussianDistance:
183 | """
184 | Expands the distance by Gaussian basis.
185 |
186 | Unit: angstrom
187 | """
188 |
189 | def __init__(self, dmin, dmax, step, var=None):
190 | """
191 | Args:
192 | dmin (float): Minimum interatomic distance
193 | dmax (float): Maximum interatomic distance
194 | step (float): Step size for the Gaussian filter
195 | var (float, optional): Variance of Gaussian basis. Defaults to step if not given
196 | """
197 | assert dmin < dmax
198 | assert dmax - dmin > step
199 |
200 | self.filter = np.arange(dmin, dmax + step, step)
201 | self.embedding_size = len(self.filter)
202 |
203 | if var is None:
204 | var = step
205 |
206 | self.var = var
207 |
208 | def expand(self, distances):
209 | """Apply Gaussian distance filter to a numpy distance array
210 |
211 | Args:
212 | distances (ArrayLike): A distance matrix of any shape
213 |
214 | Returns:
215 | Expanded distance matrix with the last dimension of length
216 | len(self.filter)
217 | """
218 | distances = np.array(distances)
219 |
220 | return np.exp(
221 | -((distances[..., np.newaxis] - self.filter) ** 2) / self.var ** 2
222 | )
223 |
224 |
225 | def collate_batch(dataset_list):
226 | """
227 | Collate a list of data and return a batch for predicting crystal
228 | properties.
229 |
230 | Parameters
231 | ----------
232 |
233 | dataset_list: list of tuples for each data point.
234 | (atom_fea, nbr_dist, nbr_idx, target)
235 |
236 | atom_fea: torch.Tensor shape (n_i, atom_fea_len)
237 | nbr_dist: torch.Tensor shape (n_i, M, nbr_dist_len)
238 | nbr_idx: torch.LongTensor shape (n_i, M)
239 | target: torch.Tensor shape (1, )
240 | cif_id: str or int
241 |
242 | Returns
243 | -------
244 | N = sum(n_i); N0 = sum(i)
245 |
246 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len)
247 | Atom features from atom type
248 | batch_nbr_dist: torch.Tensor shape (N, M, nbr_dist_len)
249 | Bond features of each atom's M neighbors
250 | batch_nbr_idx: torch.LongTensor shape (N, M)
251 | Indices of M neighbors of each atom
252 | crystal_atom_idx: list of torch.LongTensor of length N0
253 | Mapping from the crystal idx to atom idx
254 | target: torch.Tensor shape (N, 1)
255 | Target value for prediction
256 | batch_cif_ids: list
257 | """
258 | batch_atom_fea = []
259 | batch_nbr_dist = []
260 | batch_self_idx = []
261 | batch_nbr_idx = []
262 | crystal_atom_idx = []
263 | batch_targets = []
264 | batch_comps = []
265 | batch_cif_ids = []
266 | base_idx = 0
267 |
268 | for i, (inputs, target, comp, cif_id) in enumerate(dataset_list):
269 | atom_fea, nbr_dist, self_idx, nbr_idx = inputs
270 | n_i = atom_fea.shape[0] # number of atoms for this crystal
271 |
272 | batch_atom_fea.append(atom_fea)
273 | batch_nbr_dist.append(nbr_dist)
274 | batch_self_idx.append(self_idx + base_idx)
275 | batch_nbr_idx.append(nbr_idx + base_idx)
276 |
277 | crystal_atom_idx.extend([i] * n_i)
278 | batch_targets.append(target)
279 | batch_comps.append(comp)
280 | batch_cif_ids.append(cif_id)
281 | base_idx += n_i
282 |
283 | atom_fea = torch.cat(batch_atom_fea, dim=0)
284 | nbr_dist = torch.cat(batch_nbr_dist, dim=0)
285 | self_idx = torch.cat(batch_self_idx, dim=0)
286 | nbr_idx = torch.cat(batch_nbr_idx, dim=0)
287 | cry_idx = torch.LongTensor(crystal_atom_idx)
288 |
289 | return (
290 | (atom_fea, nbr_dist, self_idx, nbr_idx, cry_idx),
291 | tuple(torch.stack(b_target, dim=0) for b_target in zip(*batch_targets)),
292 | batch_comps,
293 | batch_cif_ids,
294 | )
295 |
296 |
297 | def get_structure(cols):
298 | """Return pymatgen structure from lattice and sites cols"""
299 | cell, sites = cols
300 | cell, elems, coords = parse_cgcnn(cell, sites)
301 | # NOTE getting primitive structure before constructing graph
302 | # significantly harms the performance of this model.
303 | return Structure(lattice=cell, species=elems, coords=coords, to_unit_cell=True)
304 |
305 |
306 | def parse_cgcnn(cell, sites):
307 | """Parse str representation into lists"""
308 | cell = np.array(ast.literal_eval(cell), dtype=float)
309 | elems = []
310 | coords = []
311 | for site in ast.literal_eval(sites):
312 | ele, pos = site.split(" @ ")
313 | elems.append(ele)
314 | coords.append(pos.split(" "))
315 |
316 | coords = np.array(coords, dtype=float)
317 | return cell, elems, coords
318 |
--------------------------------------------------------------------------------
/roost/cgcnn/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from roost.core import BaseModelClass
6 | from roost.segments import MeanPooling, SimpleNetwork, SumPooling
7 |
8 |
9 | class CrystalGraphConvNet(BaseModelClass):
10 | """
11 | Create a crystal graph convolutional neural network for predicting total
12 | material properties.
13 |
14 | This model is based on: https://github.com/txie-93/cgcnn [MIT License].
15 | Changes to the code were made to allow for the removal of zero-padding
16 | and to benefit from the BaseModelClass functionality. The architectural
17 | choices of the model remain unchanged.
18 | """
19 |
20 | def __init__(
21 | self,
22 | robust,
23 | n_targets,
24 | elem_emb_len,
25 | nbr_fea_len,
26 | elem_fea_len=64,
27 | n_graph=4,
28 | h_fea_len=128,
29 | n_trunk=1,
30 | n_hidden=1,
31 | **kwargs,
32 | ):
33 | """
34 | Initialize CrystalGraphConvNet.
35 |
36 | Parameters
37 | ----------
38 |
39 | orig_elem_fea_len: int
40 | Number of atom features in the input.
41 | nbr_fea_len: int
42 | Number of bond features.
43 | elem_fea_len: int
44 | Number of hidden atom features in the convolutional layers
45 | n_graph: int
46 | Number of convolutional layers
47 | h_fea_len: int
48 | Number of hidden features after pooling
49 | n_hidden: int
50 | Number of hidden layers after pooling
51 | """
52 | super().__init__(robust=robust, **kwargs)
53 |
54 | desc_dict = {
55 | "elem_emb_len": elem_emb_len,
56 | "nbr_fea_len": nbr_fea_len,
57 | "elem_fea_len": elem_fea_len,
58 | "n_graph": n_graph,
59 | }
60 |
61 | self.node_nn = DescriptorNetwork(**desc_dict)
62 |
63 | self.model_params.update(
64 | {
65 | "robust": robust,
66 | "n_targets": n_targets,
67 | "h_fea_len": h_fea_len,
68 | "n_hidden": n_hidden,
69 | }
70 | )
71 |
72 | self.model_params.update(desc_dict)
73 |
74 | self.pooling = MeanPooling()
75 |
76 | # define an output neural network
77 | if self.robust:
78 | n_targets = [2 * n for n in n_targets]
79 |
80 | out_hidden = [h_fea_len] * n_hidden
81 | trunk_hidden = [h_fea_len] * n_trunk
82 | self.trunk_nn = SimpleNetwork(elem_fea_len, h_fea_len, trunk_hidden)
83 |
84 | self.output_nns = nn.ModuleList(
85 | SimpleNetwork(h_fea_len, n, out_hidden) for n in n_targets
86 | )
87 |
88 | def forward(self, atom_fea, nbr_fea, self_idx, nbr_idx, crystal_atom_idx):
89 | """
90 | Forward pass
91 |
92 | N: Total number of atoms in the batch
93 | M: Max number of neighbors
94 | N0: Total number of crystals in the batch
95 |
96 | Parameters
97 | ----------
98 |
99 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len)
100 | Atom features from atom type
101 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
102 | Bond features of each atom's M neighbors
103 | nbr_fea_idx: torch.LongTensor shape (N, M)
104 | Indices of M neighbors of each atom
105 | crystal_atom_idx: list of torch.LongTensor of length N0
106 | Mapping from the crystal idx to atom idx
107 |
108 | Returns
109 | -------
110 |
111 | prediction: nn.Variable shape (N, )
112 | Atom hidden features after convolution
113 |
114 | """
115 | atom_fea = self.node_nn(
116 | atom_fea, nbr_fea, self_idx, nbr_idx
117 | )
118 |
119 | crys_fea = self.pooling(atom_fea, crystal_atom_idx)
120 |
121 | # NOTE required to match the reference implementation
122 | crys_fea = nn.functional.softplus(crys_fea)
123 |
124 | crys_fea = F.relu(self.trunk_nn(crys_fea))
125 |
126 | # apply neural network to map from learned features to target
127 | return (output_nn(crys_fea) for output_nn in self.output_nns)
128 |
129 |
130 | class DescriptorNetwork(nn.Module):
131 | """
132 | The Descriptor Network is the message passing section of the
133 | CrystalGraphConvNet Model.
134 | """
135 |
136 | def __init__(
137 | self, elem_emb_len, nbr_fea_len, elem_fea_len=64, n_graph=4,
138 | ):
139 | """
140 | """
141 | super().__init__()
142 |
143 | self.embedding = nn.Linear(elem_emb_len, elem_fea_len)
144 |
145 | self.convs = nn.ModuleList(
146 | [CGCNNConv(
147 | elem_fea_len=elem_fea_len,
148 | nbr_fea_len=nbr_fea_len
149 | ) for _ in range(n_graph)]
150 | )
151 |
152 | def forward(self, atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx):
153 | """
154 | Forward pass
155 |
156 | N: Total number of atoms in the batch
157 | M: Max number of neighbors
158 | N0: Total number of crystals in the batch
159 |
160 | Parameters
161 | ----------
162 |
163 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len)
164 | Atom features from atom type
165 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
166 | Bond features of each atom's M neighbors
167 | nbr_fea_idx: torch.LongTensor shape (N, M)
168 | Indices of M neighbors of each atom
169 | crystal_atom_idx: list of torch.LongTensor of length N0
170 | Mapping from the crystal idx to atom idx
171 |
172 | Returns
173 | -------
174 |
175 | prediction: nn.Variable shape (N, )
176 | Atom hidden features after convolution
177 |
178 | """
179 | atom_fea = self.embedding(atom_fea)
180 |
181 | for conv_func in self.convs:
182 | atom_fea = conv_func(atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx)
183 |
184 | return atom_fea
185 |
186 |
187 | class CGCNNConv(nn.Module):
188 | """
189 | Convolutional operation on graphs
190 | """
191 |
192 | def __init__(self, elem_fea_len, nbr_fea_len):
193 | """
194 | Initialize CGCNNConv.
195 |
196 | Parameters
197 | ----------
198 |
199 | elem_fea_len: int
200 | Number of atom hidden features.
201 | nbr_fea_len: int
202 | Number of bond features.
203 | """
204 | super().__init__()
205 | self.elem_fea_len = elem_fea_len
206 | self.nbr_fea_len = nbr_fea_len
207 | self.fc_full = nn.Linear(
208 | 2 * self.elem_fea_len + self.nbr_fea_len, 2 * self.elem_fea_len
209 | )
210 | self.sigmoid = nn.Sigmoid()
211 | self.softplus1 = nn.Softplus()
212 | self.bn1 = nn.BatchNorm1d(2 * self.elem_fea_len)
213 | self.bn2 = nn.BatchNorm1d(self.elem_fea_len)
214 | self.softplus2 = nn.Softplus()
215 | self.pooling = SumPooling()
216 |
217 | def forward(self, atom_in_fea, nbr_fea, self_fea_idx, nbr_fea_idx):
218 | """
219 | Forward pass
220 |
221 | N: Total number of atoms in the batch
222 | M: Max number of neighbors
223 |
224 | Parameters
225 | ----------
226 |
227 | atom_in_fea: Variable(torch.Tensor) shape (N, elem_fea_len)
228 | Atom hidden features before convolution
229 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
230 | Bond features of each atom's M neighbors
231 | nbr_fea_idx: torch.LongTensor shape (N, M)
232 | Indices of M neighbors of each atom
233 |
234 | Returns
235 | -------
236 |
237 | atom_out_fea: nn.Variable shape (N, elem_fea_len)
238 | Atom hidden features after convolution
239 |
240 | """
241 | # convolution
242 | atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
243 | atom_self_fea = atom_in_fea[self_fea_idx, :]
244 |
245 | total_fea = torch.cat([atom_self_fea, atom_nbr_fea, nbr_fea], dim=1)
246 |
247 | total_fea = self.fc_full(total_fea)
248 | total_fea = self.bn1(total_fea)
249 |
250 | filter_fea, core_fea = total_fea.chunk(2, dim=1)
251 | filter_fea = self.sigmoid(filter_fea)
252 | core_fea = self.softplus1(core_fea)
253 |
254 | # take the elementwise product of the filter and core
255 | nbr_msg = filter_fea * core_fea
256 | nbr_sumed = self.pooling(nbr_msg, self_fea_idx)
257 |
258 | nbr_sumed = self.bn2(nbr_sumed)
259 | out = self.softplus2(atom_in_fea + nbr_sumed)
260 |
261 | return out
262 |
--------------------------------------------------------------------------------
/roost/cgcnn/utils.py:
--------------------------------------------------------------------------------
1 | def get_cgcnn_input(struct):
2 | """return the wren input string
3 |
4 | TODO update wren to use a more standard convention
5 |
6 | Args:
7 | struct (Structure): input structure to get inputs for
8 |
9 | Returns:
10 | cgcnn inputs as a lattice matrix and list of sites of the form [f"{el} @ {x,y,z}", ]
11 | """
12 | elems = [atom.specie.symbol for atom in struct]
13 | cell = struct.lattice.matrix.tolist()
14 | coords = struct.frac_coords
15 | sites = [" @ ".join((el, " ".join(map(str, x)))) for el, x in zip(elems, coords)]
16 |
17 | return cell, sites
18 |
--------------------------------------------------------------------------------
/roost/pretrain/dist_data.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | from itertools import groupby
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from pymatgen.core.structure import Structure
9 | from torch.utils.data import Dataset
10 |
11 | from roost.core import Featurizer
12 |
13 |
14 | class CrystalGraphData(Dataset):
15 | def __init__(
16 | self,
17 | data_path,
18 | fea_path,
19 | task_dict,
20 | inputs=["lattice", "sites"],
21 | identifiers=["material_id", "composition"],
22 | radius=5,
23 | max_num_nbr=12,
24 | dmin=0,
25 | step=0.2,
26 | p_mask=0.15,
27 | p_zero=0.8,
28 | ):
29 | """CrystalGraphData returns neighbourhood graphs
30 |
31 | Args:
32 | data_path (str): The path to the dataset
33 | fea_path (str): The path to the element embedding
34 | task_dict ({target: task}): task dict for multi-task learning
35 | inputs (list, optional): df columns for lattice and sites.
36 | Defaults to ["lattice", "sites"].
37 | identifiers (list, optional): df columns for distinguishing data points.
38 | Defaults to ["material_id", "composition"].
39 | radius (int, optional): cut-off radius for neighbourhood. Defaults to 5.
40 | max_num_nbr (int, optional): maximum number of neighbours to consider. Defaults to 12.
41 | dmin (int, optional): minimum distance in gaussian basis. Defaults to 0.
42 | step (float, optional): increment size of gaussian basis. Defaults to 0.2.
43 | """
44 | assert len(identifiers) == 2, "Two identifiers are required"
45 | assert len(inputs) == 2, "One input column required are required"
46 |
47 | self.inputs = inputs
48 | self.task_dict = task_dict
49 | self.identifiers = identifiers
50 | self.radius = radius
51 | self.max_num_nbr = max_num_nbr
52 | self.p_mask = p_mask
53 | self.p_zero = p_zero
54 |
55 | self.graph = ["self", "nbr", "dist"]
56 |
57 | assert os.path.exists(fea_path), f"{fea_path} does not exist!"
58 | self.ari = Featurizer.from_json(fea_path)
59 | self.elem_fea_dim = self.ari.embedding_size
60 |
61 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
62 | self.nbr_fea_dim = self.gdf.embedding_size
63 |
64 | assert os.path.exists(data_path), f"{data_path} does not exist!"
65 |
66 | # NOTE make sure to use dense datasets, here do not use the default na
67 | # as they can clash with "NaN" which is a valid material
68 | self.df = pd.read_csv(
69 | data_path, keep_default_na=False, na_values=[], comment="#"
70 | )[:1000]
71 |
72 | self.df["Structure_obj"] = self.df[inputs].apply(get_structure, axis=1)
73 |
74 | self._pre_check()
75 |
76 | self.n_targets = []
77 | for target in self.task_dict:
78 | if self.task_dict[target] == "dist":
79 | self.n_targets.append(1)
80 | elif self.task_dict[target] == "regression":
81 | self.n_targets.append(1)
82 | else:
83 | raise NotImplementedError("bad user")
84 |
85 | def __len__(self):
86 | return len(self.df)
87 |
88 | # @functools.lru_cache(maxsize=None) # Cache loaded structures
89 | def _get_nbr_data(self, crystal):
90 | """get neighbours for every site
91 |
92 | Args:
93 | crystal ([Structure]): pymatgen structure to get neighbours for
94 | """
95 | # # # neighbours
96 | self_idx, nbr_idx, _, nbr_dist = crystal.get_neighbor_list(
97 | self.radius,
98 | numerical_tol=1e-8,
99 | )
100 |
101 | if self.max_num_nbr is not None:
102 | _self_idx, _nbr_idx, _nbr_dist = [], [], []
103 |
104 | for _, g in groupby(zip(self_idx, nbr_idx, nbr_dist), key=lambda x: x[0]):
105 | s, n, d = zip(*sorted(g, key=lambda x: x[2]))
106 | _self_idx.extend(s[: self.max_num_nbr])
107 | _nbr_idx.extend(n[: self.max_num_nbr])
108 | _nbr_dist.extend(d[: self.max_num_nbr])
109 |
110 | self_idx = np.array(_self_idx)
111 | nbr_idx = np.array(_nbr_idx)
112 | nbr_dist = np.array(_nbr_dist)
113 |
114 | return self_idx, nbr_idx, nbr_dist
115 |
116 | def _pre_check(self):
117 | """Check that none of the structures have isolated atoms.
118 |
119 | Raises:
120 | ValueError: if isolated structures are present
121 | """
122 | print("checking all structures valid")
123 | all_iso = []
124 | some_iso = []
125 |
126 | # initialise empty columns of objects to insert the lists
127 | self.df[self.graph[0]] = [[] for _ in range(len(self.df))]
128 | self.df[self.graph[1]] = [[] for _ in range(len(self.df))]
129 | self.df[self.graph[2]] = [[] for _ in range(len(self.df))]
130 |
131 | for index, row in self.df.iterrows():
132 | # for cif_id, crystal in zip(self.df["material_id"], self.df["Structure_obj"]):
133 | image = 0
134 |
135 | cif_id = row["material_id"]
136 | crystal = row["Structure_obj"]
137 |
138 | while image == 0:
139 | self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal)
140 |
141 | if np.any(self_idx == nbr_idx):
142 | # TODO only double along the shortest dimension
143 | crystal.make_supercell([2, 2, 2])
144 | else:
145 | image = 1
146 |
147 | if len(self_idx) == 0:
148 | all_iso.append(cif_id)
149 | elif len(nbr_idx) == 0:
150 | all_iso.append(cif_id)
151 | elif set(self_idx) != set(range(crystal.num_sites)):
152 | some_iso.append(cif_id)
153 |
154 | # nbr_dist = self.gdf.expand(nbr_dist)
155 |
156 | # nbr_dist = torch.Tensor(nbr_dist)
157 | self_idx = torch.LongTensor(self_idx)
158 | nbr_idx = torch.LongTensor(nbr_idx)
159 |
160 | self.df.at[index, self.graph[0]] = self_idx
161 | self.df.at[index, self.graph[1]] = nbr_idx
162 | self.df.at[index, self.graph[2]] = nbr_dist
163 |
164 | # TODO have the option for the pre-check to delete non-valid entries from the df?
165 |
166 | if (len(all_iso) > 0) or (len(some_iso) > 0):
167 | self.df = self.df.drop(self.df[self.df["material_id"].isin(all_iso)].index)
168 | self.df = self.df.drop(self.df[self.df["material_id"].isin(some_iso)].index)
169 |
170 | print(all_iso)
171 | print(some_iso)
172 | # raise ValueError("isolated structures contained in dataframe")
173 |
174 | # NOTE do not cache the pre-training structures as we want to see new sets of
175 | # masked structures each epoch as this effectively expands the training set
176 | # @functools.lru_cache(maxsize=None) # Cache loaded structures
177 | def __getitem__(self, idx):
178 | # NOTE sites must be given in fractional coordinates
179 | df_idx = self.df.iloc[idx]
180 | crystal = df_idx["Structure_obj"]
181 | cif_id, comp = df_idx[self.identifiers]
182 | self_idx, nbr_idx, nbr_dist = df_idx[self.graph]
183 |
184 | # atom features
185 | # TODO can this be vectorised with numpy?
186 |
187 | # handle disordered structures (multiple fractional elements per site)
188 | site_atoms = [atom.species.as_dict() for atom in crystal]
189 | atom_fea = np.vstack(
190 | [
191 | np.sum([self.ari.get_fea(el) * amt for el, amt in site.items()], axis=0)
192 | for site in site_atoms
193 | ]
194 | )
195 |
196 | # mask distances
197 | mask_ids = np.sort(
198 | np.random.choice(
199 | np.arange(len(self_idx), dtype=int),
200 | max(1, int(self.p_mask * len(self_idx))),
201 | )
202 | )
203 |
204 | # mask_labels = np.atleast_2d(1/nbr_dist[mask_ids]).T
205 | mask_labels = np.atleast_2d(nbr_dist[mask_ids]).T
206 |
207 | # TODO currently 20% no mask -> 10% no mask, 10% random mask
208 | mask_filter = np.random.rand(len(mask_ids))
209 | nbr_dist[mask_ids[np.where(mask_filter < self.p_zero)]] = 0
210 |
211 | nbr_dist = self.gdf.expand(nbr_dist)
212 |
213 | atom_fea = torch.Tensor(atom_fea)
214 | mask_ids = torch.LongTensor(mask_ids)
215 | nbr_dist = torch.Tensor(nbr_dist)
216 |
217 | targets = []
218 | for target in self.task_dict:
219 | if self.task_dict[target] == "dist":
220 | targets.append(torch.Tensor(mask_labels))
221 | elif self.task_dict[target] == "regression":
222 | targets.append(
223 | torch.Tensor(
224 | [
225 | [
226 | df_idx["e_form"],
227 | ],
228 | ]
229 | )
230 | )
231 | else:
232 | raise NotImplementedError("bad user")
233 |
234 | return (
235 | (atom_fea, nbr_dist, self_idx, nbr_idx, mask_ids),
236 | targets,
237 | comp,
238 | cif_id,
239 | )
240 |
241 |
242 | class GaussianDistance:
243 | """
244 | Expands the distance by Gaussian basis.
245 |
246 | Unit: angstrom
247 | """
248 |
249 | def __init__(self, dmin, dmax, step, var=None):
250 | """
251 | Args:
252 | dmin (float): Minimum interatomic distance
253 | dmax (float): Maximum interatomic distance
254 | step (float): Step size for the Gaussian filter
255 | var (float, optional): Variance of Gaussian basis. Defaults to step if not given
256 | """
257 | assert dmin < dmax
258 | assert dmax - dmin > step
259 |
260 | self.filter = np.arange(dmin, dmax + step, step)
261 | self.embedding_size = len(self.filter)
262 |
263 | if var is None:
264 | var = step
265 |
266 | self.var = var
267 |
268 | def expand(self, distances):
269 | """Apply Gaussian distance filter to a numpy distance array
270 |
271 | Args:
272 | distances (ArrayLike): A distance matrix of any shape
273 |
274 | Returns:
275 | Expanded distance matrix with the last dimension of length
276 | len(self.filter)
277 | """
278 | distances = np.array(distances)
279 |
280 | return np.exp(
281 | -((distances[..., np.newaxis] - self.filter) ** 2) / self.var ** 2
282 | )
283 |
284 |
285 | def collate_batch(dataset_list):
286 | """
287 | Collate a list of data and return a batch for predicting crystal
288 | properties.
289 |
290 | Parameters
291 | ----------
292 |
293 | dataset_list: list of tuples for each data point.
294 | (atom_fea, nbr_dist, nbr_idx, target)
295 |
296 | atom_fea: torch.Tensor shape (n_i, atom_fea_len)
297 | nbr_dist: torch.Tensor shape (n_i, M, nbr_dist_len)
298 | nbr_idx: torch.LongTensor shape (n_i, M)
299 | target: torch.Tensor shape (1, )
300 | cif_id: str or int
301 |
302 | Returns
303 | -------
304 | N = sum(n_i); N0 = sum(i)
305 |
306 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len)
307 | Atom features from atom type
308 | batch_nbr_dist: torch.Tensor shape (N, M, nbr_dist_len)
309 | Bond features of each atom's M neighbors
310 | batch_nbr_idx: torch.LongTensor shape (N, M)
311 | Indices of M neighbors of each atom
312 | crystal_atom_idx: list of torch.LongTensor of length N0
313 | Mapping from the crystal idx to atom idx
314 | target: torch.Tensor shape (N, 1)
315 | Target value for prediction
316 | batch_cif_ids: list
317 | """
318 | batch_atom_fea = []
319 | batch_nbr_dist = []
320 | batch_self_idx = []
321 | batch_nbr_idx = []
322 | crystal_atom_idx = []
323 | batch_mask_idx = []
324 | batch_targets = []
325 | batch_comps = []
326 | batch_cif_ids = []
327 | base_idx = 0
328 | base_bond_idx = 0
329 |
330 | for i, (inputs, target, comp, cif_id) in enumerate(dataset_list):
331 | atom_fea, nbr_dist, self_idx, nbr_idx, mask_idx = inputs
332 | n_atoms = atom_fea.shape[0] # number of atoms for this crystal
333 | n_bonds = self_idx.shape[0]
334 |
335 | batch_atom_fea.append(atom_fea)
336 | batch_nbr_dist.append(nbr_dist)
337 | batch_self_idx.append(self_idx + base_idx)
338 | batch_nbr_idx.append(nbr_idx + base_idx)
339 | # batch_mask_idx.append(mask_idx + base_idx)
340 |
341 | batch_mask_idx.append(mask_idx + base_bond_idx)
342 |
343 | crystal_atom_idx.extend([i] * n_atoms)
344 |
345 | batch_targets.append(target)
346 | batch_comps.append(comp)
347 | batch_cif_ids.append(cif_id)
348 |
349 | base_idx += n_atoms
350 | base_bond_idx += n_bonds
351 |
352 | atom_fea = torch.cat(batch_atom_fea, dim=0)
353 | nbr_dist = torch.cat(batch_nbr_dist, dim=0)
354 | self_idx = torch.cat(batch_self_idx, dim=0)
355 | nbr_idx = torch.cat(batch_nbr_idx, dim=0)
356 | mask_idx = torch.cat(batch_mask_idx, dim=0)
357 |
358 | cry_idx = torch.LongTensor(crystal_atom_idx)
359 |
360 | return (
361 | (atom_fea, nbr_dist, self_idx, nbr_idx, mask_idx, cry_idx),
362 | tuple(torch.cat(b_target, dim=0) for b_target in zip(*batch_targets)),
363 | batch_comps,
364 | batch_cif_ids,
365 | )
366 |
367 |
368 | def get_structure(cols):
369 | """Return pymatgen structure from lattice and sites cols"""
370 | cell, sites = cols
371 | cell, elems, coords = parse_cgcnn(cell, sites)
372 | # NOTE getting primitive structure before constructing graph
373 | # significantly harms the performance of this model.
374 |
375 | crystal = Structure(lattice=cell, species=elems, coords=coords, to_unit_cell=True)
376 |
377 | # In place modification of structures that only contain a few sites
378 | # this is to allow us to mask ~15% of sites without having a
379 | # disproportionate impact on small unit cell structures.
380 | if crystal.num_sites < 7:
381 | crystal.make_supercell([2, 2, 2])
382 |
383 | return crystal
384 |
385 |
386 | def parse_cgcnn(cell, sites):
387 | """Parse str representation into lists"""
388 | cell = np.array(ast.literal_eval(cell), dtype=float)
389 | elems = []
390 | coords = []
391 | for site in ast.literal_eval(sites):
392 | ele, pos = site.split(" @ ")
393 | elems.append(ele)
394 | coords.append(pos.split(" "))
395 |
396 | coords = np.array(coords, dtype=float)
397 | return cell, elems, coords
398 |
--------------------------------------------------------------------------------
/roost/pretrain/dist_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from roost.cgcnn.model import DescriptorNetwork
5 | from roost.core import BaseModelClass
6 | from roost.segments import MeanPooling
7 |
8 |
9 | class CrystalGraphPreNet(BaseModelClass):
10 | """
11 | Create a crystal graph convolutional neural network for predicting total
12 | material properties.
13 |
14 | This model is based on: https://github.com/txie-93/cgcnn [MIT License].
15 | Changes to the code were made to allow for the removal of zero-padding
16 | and to benefit from the BaseModelClass functionality. The architectural
17 | choices of the model remain unchanged.
18 | """
19 |
20 | def __init__(
21 | self,
22 | robust,
23 | n_targets,
24 | elem_emb_len,
25 | nbr_fea_len,
26 | elem_fea_len=64,
27 | n_graph=4,
28 | **kwargs,
29 | ):
30 | """
31 | Initialize CrystalGraphConvNet.
32 |
33 | Parameters
34 | ----------
35 |
36 | orig_elem_fea_len: int
37 | Number of atom features in the input.
38 | nbr_fea_len: int
39 | Number of bond features.
40 | elem_fea_len: int
41 | Number of hidden atom features in the convolutional layers
42 | n_graph: int
43 | Number of convolutional layers
44 | h_fea_len: int
45 | Number of hidden features after pooling
46 | n_hidden: int
47 | Number of hidden layers after pooling
48 | """
49 | super().__init__(robust=robust, **kwargs)
50 |
51 | desc_dict = {
52 | "elem_emb_len": elem_emb_len,
53 | "nbr_fea_len": nbr_fea_len,
54 | "elem_fea_len": elem_fea_len,
55 | "n_graph": n_graph,
56 | }
57 |
58 | self.node_nn = DescriptorNetwork(**desc_dict)
59 |
60 | self.model_params.update(
61 | {
62 | "robust": robust,
63 | "n_targets": n_targets,
64 | }
65 | )
66 |
67 | self.node_linear = nn.Linear(2*elem_fea_len, 1)
68 |
69 | self.global_pool = MeanPooling()
70 | self.global_linear = nn.Linear(elem_fea_len, 1)
71 |
72 | self.model_params.update(desc_dict)
73 |
74 | def forward(self, atom_fea, nbr_fea, self_idx, nbr_idx, mask_idx, cry_idx):
75 | """
76 | Forward pass
77 |
78 | N: Total number of atoms in the batch
79 | M: Max number of neighbors
80 | N0: Total number of crystals in the batch
81 |
82 | Parameters
83 | ----------
84 |
85 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len)
86 | Atom features from atom type
87 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
88 | Bond features of each atom's M neighbors
89 | nbr_idx: torch.LongTensor shape (N, M)
90 | Indices of M neighbors of each atom
91 | crystal_atom_idx: list of torch.LongTensor of length N0
92 | Mapping from the crystal idx to atom idx
93 |
94 | Returns
95 | -------
96 |
97 | prediction: nn.Variable shape (N, )
98 | Atom hidden features after convolution
99 |
100 | """
101 | crys_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx)
102 |
103 | atom_nbr_fea = crys_fea[nbr_idx[mask_idx], :]
104 | atom_self_fea = crys_fea[self_idx[mask_idx], :]
105 |
106 | total_fea = torch.cat([atom_self_fea, atom_nbr_fea], dim=1)
107 |
108 | nodes = self.node_linear(total_fea)
109 |
110 | glob = self.global_linear(self.global_pool(crys_fea, cry_idx))
111 |
112 | return [nodes, glob]
113 |
--------------------------------------------------------------------------------
/roost/pretrain/ele_data.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | from itertools import groupby
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from pymatgen.core.structure import Structure
9 | from torch.utils.data import Dataset
10 |
11 | from roost.core import Featurizer
12 |
13 |
14 | class CrystalGraphData(Dataset):
15 | def __init__(
16 | self,
17 | data_path,
18 | fea_path,
19 | task_dict,
20 | inputs=["lattice", "sites"],
21 | identifiers=["material_id", "composition"],
22 | radius=5,
23 | max_num_nbr=12,
24 | dmin=0,
25 | step=0.2,
26 | p_mask=0.15,
27 | p_zero=0.8,
28 | ):
29 | """CrystalGraphData returns neighbourhood graphs
30 |
31 | Args:
32 | data_path (str): The path to the dataset
33 | fea_path (str): The path to the element embedding
34 | task_dict ({target: task}): task dict for multi-task learning
35 | inputs (list, optional): df columns for lattice and sites.
36 | Defaults to ["lattice", "sites"].
37 | identifiers (list, optional): df columns for distinguishing data points.
38 | Defaults to ["material_id", "composition"].
39 | radius (int, optional): cut-off radius for neighbourhood. Defaults to 5.
40 | max_num_nbr (int, optional): maximum number of neighbours to consider. Defaults to 12.
41 | dmin (int, optional): minimum distance in gaussian basis. Defaults to 0.
42 | step (float, optional): increment size of gaussian basis. Defaults to 0.2.
43 | """
44 | assert len(identifiers) == 2, "Two identifiers are required"
45 | assert len(inputs) == 2, "One input column required are required"
46 |
47 | self.inputs = inputs
48 | self.task_dict = task_dict
49 | self.identifiers = identifiers
50 | self.radius = radius
51 | self.max_num_nbr = max_num_nbr
52 | self.p_mask = p_mask
53 | self.p_zero = p_zero
54 |
55 | self.graph = ["self", "nbr", "dist"]
56 |
57 | assert os.path.exists(fea_path), f"{fea_path} does not exist!"
58 | self.ari = Featurizer.from_json(fea_path)
59 | self.ohe = Featurizer.from_json("data/el-embeddings/onehot-embedding.json")
60 | self.elem_fea_dim = self.ari.embedding_size
61 |
62 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
63 | self.nbr_fea_dim = self.gdf.embedding_size
64 |
65 | assert os.path.exists(data_path), f"{data_path} does not exist!"
66 |
67 | # NOTE make sure to use dense datasets, here do not use the default na
68 | # as they can clash with "NaN" which is a valid material
69 | self.df = pd.read_csv(
70 | data_path, keep_default_na=False, na_values=[], comment="#"
71 | )[:1000]
72 |
73 | self.df["Structure_obj"] = self.df[inputs].apply(get_structure, axis=1)
74 |
75 | self._pre_check()
76 |
77 | self.n_targets = []
78 | for target in self.task_dict:
79 | if self.task_dict[target] == "mask":
80 | self.n_targets.append(self.ohe.embedding_size)
81 | elif self.task_dict[target] == "regression":
82 | self.n_targets.append(1)
83 | else:
84 | raise NotImplementedError("bad user")
85 |
86 | def __len__(self):
87 | return len(self.df)
88 |
89 | # @functools.lru_cache(maxsize=None) # Cache loaded structures
90 | def _get_nbr_data(self, crystal):
91 | """get neighbours for every site
92 |
93 | Args:
94 | crystal ([Structure]): pymatgen structure to get neighbours for
95 | """
96 | # # # neighbours
97 | self_idx, nbr_idx, _, nbr_dist = crystal.get_neighbor_list(
98 | self.radius,
99 | numerical_tol=1e-8,
100 | )
101 |
102 | if self.max_num_nbr is not None:
103 | _self_idx, _nbr_idx, _nbr_dist = [], [], []
104 |
105 | for _, g in groupby(zip(self_idx, nbr_idx, nbr_dist), key=lambda x: x[0]):
106 | s, n, d = zip(*sorted(g, key=lambda x: x[2]))
107 | _self_idx.extend(s[: self.max_num_nbr])
108 | _nbr_idx.extend(n[: self.max_num_nbr])
109 | _nbr_dist.extend(d[: self.max_num_nbr])
110 |
111 | self_idx = np.array(_self_idx)
112 | nbr_idx = np.array(_nbr_idx)
113 | nbr_dist = np.array(_nbr_dist)
114 |
115 | return self_idx, nbr_idx, nbr_dist
116 |
117 | def _pre_check(self):
118 | """Check that none of the structures have isolated atoms.
119 |
120 | Raises:
121 | ValueError: if isolated structures are present
122 | """
123 | print("checking all structures valid")
124 | all_iso = []
125 | some_iso = []
126 |
127 | # initialise empty columns of objects to insert the lists
128 | self.df[self.graph[0]] = [[] for _ in range(len(self.df))]
129 | self.df[self.graph[1]] = [[] for _ in range(len(self.df))]
130 | self.df[self.graph[2]] = [[] for _ in range(len(self.df))]
131 |
132 | for index, row in self.df.iterrows():
133 | # for cif_id, crystal in zip(self.df["material_id"], self.df["Structure_obj"]):
134 | image = 0
135 |
136 | cif_id = row["material_id"]
137 | crystal = row["Structure_obj"]
138 |
139 | while image == 0:
140 | self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal)
141 |
142 | if np.any(self_idx == nbr_idx):
143 | # TODO only double along the shortest dimension
144 | crystal.make_supercell([2, 2, 2])
145 | else:
146 | image = 1
147 |
148 | if len(self_idx) == 0:
149 | all_iso.append(cif_id)
150 | elif len(nbr_idx) == 0:
151 | all_iso.append(cif_id)
152 | elif set(self_idx) != set(range(crystal.num_sites)):
153 | some_iso.append(cif_id)
154 |
155 | nbr_dist = self.gdf.expand(nbr_dist)
156 |
157 | nbr_dist = torch.Tensor(nbr_dist)
158 | self_idx = torch.LongTensor(self_idx)
159 | nbr_idx = torch.LongTensor(nbr_idx)
160 |
161 | self.df.at[index, self.graph[0]] = self_idx
162 | self.df.at[index, self.graph[1]] = nbr_idx
163 | self.df.at[index, self.graph[2]] = nbr_dist
164 |
165 | # TODO have the option for the pre-check to delete non-valid entries from the df?
166 |
167 | if (len(all_iso) > 0) or (len(some_iso) > 0):
168 | self.df = self.df.drop(self.df[self.df["material_id"].isin(all_iso)].index)
169 | self.df = self.df.drop(self.df[self.df["material_id"].isin(some_iso)].index)
170 |
171 | print(all_iso)
172 | print(some_iso)
173 | # raise ValueError("isolated structures contained in dataframe")
174 |
175 | # NOTE do not cache the pre-training structures as we want to see new sets of
176 | # masked structures each epoch as this effectively expands the training set
177 | # @functools.lru_cache(maxsize=None) # Cache loaded structures
178 | def __getitem__(self, idx):
179 | # NOTE sites must be given in fractional coordinates
180 | df_idx = self.df.iloc[idx]
181 | crystal = df_idx["Structure_obj"]
182 | cif_id, comp = df_idx[self.identifiers]
183 | self_idx, nbr_idx, nbr_dist = df_idx[self.graph]
184 |
185 | # atom features
186 | # TODO can this be vectorised with numpy?
187 |
188 | # handle disordered structures (multiple fractional elements per site)
189 | site_atoms = [atom.species.as_dict() for atom in crystal]
190 | atom_fea = np.vstack(
191 | [
192 | np.sum([self.ari.get_fea(el) * amt for el, amt in site.items()], axis=0)
193 | for site in site_atoms
194 | ]
195 | )
196 |
197 | # select at least one site in the crystal to mask
198 | mask_ids = np.sort(
199 | np.random.choice(
200 | np.arange(crystal.num_sites, dtype=int),
201 | max(1, int(self.p_mask * crystal.num_sites)),
202 | )
203 | )
204 |
205 | # get the mask labels via use of a OHE of elements to handle disordered structures
206 | mask_labels = np.vstack(
207 | [
208 | np.sum(
209 | [self.ohe.get_fea(el) * amt for el, amt in site_atoms[idx].items()],
210 | axis=0,
211 | )
212 | for idx in mask_ids
213 | ]
214 | )
215 |
216 | # TODO currently 20% no mask -> 10% no mask, 10% random mask
217 | mask_filter = np.random.rand(len(mask_ids))
218 | atom_fea[mask_ids[np.where(mask_filter < self.p_zero)], :] = 0
219 |
220 | atom_fea = torch.Tensor(atom_fea)
221 | mask_ids = torch.LongTensor(mask_ids)
222 |
223 | targets = []
224 | for target in self.task_dict:
225 | if self.task_dict[target] == "mask":
226 | targets.append(torch.Tensor(mask_labels))
227 | elif self.task_dict[target] == "regression":
228 | targets.append(
229 | torch.Tensor(
230 | [
231 | [
232 | df_idx["e_form"],
233 | ],
234 | ]
235 | )
236 | )
237 | else:
238 | raise NotImplementedError("bad user")
239 |
240 | return (
241 | (atom_fea, nbr_dist, self_idx, nbr_idx, mask_ids),
242 | targets,
243 | comp,
244 | cif_id,
245 | )
246 |
247 |
248 | class GaussianDistance:
249 | """
250 | Expands the distance by Gaussian basis.
251 |
252 | Unit: angstrom
253 | """
254 |
255 | def __init__(self, dmin, dmax, step, var=None):
256 | """
257 | Args:
258 | dmin (float): Minimum interatomic distance
259 | dmax (float): Maximum interatomic distance
260 | step (float): Step size for the Gaussian filter
261 | var (float, optional): Variance of Gaussian basis. Defaults to step if not given
262 | """
263 | assert dmin < dmax
264 | assert dmax - dmin > step
265 |
266 | self.filter = np.arange(dmin, dmax + step, step)
267 | self.embedding_size = len(self.filter)
268 |
269 | if var is None:
270 | var = step
271 |
272 | self.var = var
273 |
274 | def expand(self, distances):
275 | """Apply Gaussian distance filter to a numpy distance array
276 |
277 | Args:
278 | distances (ArrayLike): A distance matrix of any shape
279 |
280 | Returns:
281 | Expanded distance matrix with the last dimension of length
282 | len(self.filter)
283 | """
284 | distances = np.array(distances)
285 |
286 | return np.exp(
287 | -((distances[..., np.newaxis] - self.filter) ** 2) / self.var ** 2
288 | )
289 |
290 |
291 | def collate_batch(dataset_list):
292 | """
293 | Collate a list of data and return a batch for predicting crystal
294 | properties.
295 |
296 | Parameters
297 | ----------
298 |
299 | dataset_list: list of tuples for each data point.
300 | (atom_fea, nbr_dist, nbr_idx, target)
301 |
302 | atom_fea: torch.Tensor shape (n_i, atom_fea_len)
303 | nbr_dist: torch.Tensor shape (n_i, M, nbr_dist_len)
304 | nbr_idx: torch.LongTensor shape (n_i, M)
305 | target: torch.Tensor shape (1, )
306 | cif_id: str or int
307 |
308 | Returns
309 | -------
310 | N = sum(n_i); N0 = sum(i)
311 |
312 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len)
313 | Atom features from atom type
314 | batch_nbr_dist: torch.Tensor shape (N, M, nbr_dist_len)
315 | Bond features of each atom's M neighbors
316 | batch_nbr_idx: torch.LongTensor shape (N, M)
317 | Indices of M neighbors of each atom
318 | crystal_atom_idx: list of torch.LongTensor of length N0
319 | Mapping from the crystal idx to atom idx
320 | target: torch.Tensor shape (N, 1)
321 | Target value for prediction
322 | batch_cif_ids: list
323 | """
324 | batch_atom_fea = []
325 | batch_nbr_dist = []
326 | batch_self_idx = []
327 | batch_nbr_idx = []
328 | crystal_atom_idx = []
329 | batch_mask_idx = []
330 | batch_targets = []
331 | batch_comps = []
332 | batch_cif_ids = []
333 | base_idx = 0
334 |
335 | for i, (inputs, target, comp, cif_id) in enumerate(dataset_list):
336 | atom_fea, nbr_dist, self_idx, nbr_idx, mask_idx = inputs
337 | n_atoms = atom_fea.shape[0] # number of atoms for this crystal
338 |
339 | batch_atom_fea.append(atom_fea)
340 | batch_nbr_dist.append(nbr_dist)
341 | batch_self_idx.append(self_idx + base_idx)
342 | batch_nbr_idx.append(nbr_idx + base_idx)
343 | batch_mask_idx.append(mask_idx + base_idx)
344 |
345 | crystal_atom_idx.extend([i] * n_atoms)
346 |
347 | batch_targets.append(target)
348 | batch_comps.append(comp)
349 | batch_cif_ids.append(cif_id)
350 |
351 | base_idx += n_atoms
352 |
353 | atom_fea = torch.cat(batch_atom_fea, dim=0)
354 | nbr_dist = torch.cat(batch_nbr_dist, dim=0)
355 | self_idx = torch.cat(batch_self_idx, dim=0)
356 | nbr_idx = torch.cat(batch_nbr_idx, dim=0)
357 | mask_idx = torch.cat(batch_mask_idx, dim=0)
358 |
359 | cry_idx = torch.LongTensor(crystal_atom_idx)
360 |
361 | return (
362 | (atom_fea, nbr_dist, self_idx, nbr_idx, mask_idx, cry_idx),
363 | tuple(torch.cat(b_target, dim=0) for b_target in zip(*batch_targets)),
364 | batch_comps,
365 | batch_cif_ids,
366 | )
367 |
368 |
369 | def get_structure(cols):
370 | """Return pymatgen structure from lattice and sites cols"""
371 | cell, sites = cols
372 | cell, elems, coords = parse_cgcnn(cell, sites)
373 | # NOTE getting primitive structure before constructing graph
374 | # significantly harms the performance of this model.
375 |
376 | crystal = Structure(lattice=cell, species=elems, coords=coords, to_unit_cell=True)
377 |
378 | # In place modification of structures that only contain a few sites
379 | # this is to allow us to mask ~15% of sites without having a
380 | # disproportionate impact on small unit cell structures.
381 | if crystal.num_sites < 7:
382 | crystal.make_supercell([2, 2, 2])
383 |
384 | return crystal
385 |
386 |
387 | def parse_cgcnn(cell, sites):
388 | """Parse str representation into lists"""
389 | cell = np.array(ast.literal_eval(cell), dtype=float)
390 | elems = []
391 | coords = []
392 | for site in ast.literal_eval(sites):
393 | ele, pos = site.split(" @ ")
394 | elems.append(ele)
395 | coords.append(pos.split(" "))
396 |
397 | coords = np.array(coords, dtype=float)
398 | return cell, elems, coords
399 |
--------------------------------------------------------------------------------
/roost/pretrain/ele_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from roost.cgcnn.model import DescriptorNetwork
4 | from roost.core import BaseModelClass
5 | from roost.segments import MeanPooling
6 |
7 |
8 | class CrystalGraphPreNet(BaseModelClass):
9 | """
10 | Create a crystal graph convolutional neural network for predicting total
11 | material properties.
12 |
13 | This model is based on: https://github.com/txie-93/cgcnn [MIT License].
14 | Changes to the code were made to allow for the removal of zero-padding
15 | and to benefit from the BaseModelClass functionality. The architectural
16 | choices of the model remain unchanged.
17 | """
18 |
19 | def __init__(
20 | self,
21 | robust,
22 | n_targets,
23 | elem_emb_len,
24 | nbr_fea_len,
25 | elem_fea_len=64,
26 | n_graph=4,
27 | **kwargs,
28 | ):
29 | """
30 | Initialize CrystalGraphConvNet.
31 |
32 | Parameters
33 | ----------
34 |
35 | orig_elem_fea_len: int
36 | Number of atom features in the input.
37 | nbr_fea_len: int
38 | Number of bond features.
39 | elem_fea_len: int
40 | Number of hidden atom features in the convolutional layers
41 | n_graph: int
42 | Number of convolutional layers
43 | h_fea_len: int
44 | Number of hidden features after pooling
45 | n_hidden: int
46 | Number of hidden layers after pooling
47 | """
48 | super().__init__(robust=robust, **kwargs)
49 |
50 | desc_dict = {
51 | "elem_emb_len": elem_emb_len,
52 | "nbr_fea_len": nbr_fea_len,
53 | "elem_fea_len": elem_fea_len,
54 | "n_graph": n_graph,
55 | }
56 |
57 | self.node_nn = DescriptorNetwork(**desc_dict)
58 |
59 | self.model_params.update(
60 | {
61 | "robust": robust,
62 | "n_targets": n_targets,
63 | }
64 | )
65 |
66 | self.node_linear = nn.Linear(elem_fea_len, n_targets[0])
67 |
68 | self.global_pool = MeanPooling()
69 | self.global_linear = nn.Linear(elem_fea_len, 1)
70 |
71 | self.model_params.update(desc_dict)
72 |
73 | def forward(self, atom_fea, nbr_fea, self_idx, nbr_idx, mask_idx, cry_idx):
74 | """
75 | Forward pass
76 |
77 | N: Total number of atoms in the batch
78 | M: Max number of neighbors
79 | N0: Total number of crystals in the batch
80 |
81 | Parameters
82 | ----------
83 |
84 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len)
85 | Atom features from atom type
86 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
87 | Bond features of each atom's M neighbors
88 | nbr_idx: torch.LongTensor shape (N, M)
89 | Indices of M neighbors of each atom
90 | crystal_atom_idx: list of torch.LongTensor of length N0
91 | Mapping from the crystal idx to atom idx
92 |
93 | Returns
94 | -------
95 |
96 | prediction: nn.Variable shape (N, )
97 | Atom hidden features after convolution
98 |
99 | """
100 | crys_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx)
101 |
102 | nodes = self.node_linear(crys_fea[mask_idx, :])
103 |
104 | glob = self.global_linear(self.global_pool(crys_fea, cry_idx))
105 |
106 | return [nodes, glob]
107 |
--------------------------------------------------------------------------------
/roost/pretrain/ener_data.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | from itertools import groupby
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from pymatgen.core.structure import Structure
9 | from torch.utils.data import Dataset
10 |
11 | from roost.core import Featurizer
12 |
13 |
14 | class CrystalGraphData(Dataset):
15 | def __init__(
16 | self,
17 | data_path,
18 | fea_path,
19 | task_dict,
20 | inputs=["lattice", "sites"],
21 | identifiers=["material_id", "composition"],
22 | radius=5,
23 | max_num_nbr=12,
24 | dmin=0,
25 | step=0.2,
26 | p_mask=0.15,
27 | p_zero=0.8,
28 | ):
29 | """CrystalGraphData returns neighbourhood graphs
30 |
31 | Args:
32 | data_path (str): The path to the dataset
33 | fea_path (str): The path to the element embedding
34 | task_dict ({target: task}): task dict for multi-task learning
35 | inputs (list, optional): df columns for lattice and sites.
36 | Defaults to ["lattice", "sites"].
37 | identifiers (list, optional): df columns for distinguishing data points.
38 | Defaults to ["material_id", "composition"].
39 | radius (int, optional): cut-off radius for neighbourhood. Defaults to 5.
40 | max_num_nbr (int, optional): maximum number of neighbours to consider. Defaults to 12.
41 | dmin (int, optional): minimum distance in gaussian basis. Defaults to 0.
42 | step (float, optional): increment size of gaussian basis. Defaults to 0.2.
43 | """
44 | assert len(identifiers) == 2, "Two identifiers are required"
45 | assert len(inputs) == 2, "One input column required are required"
46 |
47 | self.inputs = inputs
48 | self.task_dict = task_dict
49 | self.identifiers = identifiers
50 | self.radius = radius
51 | self.max_num_nbr = max_num_nbr
52 | self.p_mask = p_mask
53 | self.p_zero = p_zero
54 |
55 | self.graph = ["self", "nbr", "dist"]
56 |
57 | assert os.path.exists(fea_path), f"{fea_path} does not exist!"
58 | self.ari = Featurizer.from_json(fea_path)
59 | self.ohe = Featurizer.from_json("data/el-embeddings/onehot-embedding.json")
60 | self.elem_fea_dim = self.ari.embedding_size
61 |
62 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
63 | self.nbr_fea_dim = self.gdf.embedding_size
64 |
65 | assert os.path.exists(data_path), f"{data_path} does not exist!"
66 |
67 | # NOTE make sure to use dense datasets, here do not use the default na
68 | # as they can clash with "NaN" which is a valid material
69 | self.df = pd.read_csv(
70 | data_path, keep_default_na=False, na_values=[], comment="#"
71 | )[:1000]
72 |
73 | self.df["Structure_obj"] = self.df[inputs].apply(get_structure, axis=1)
74 |
75 | self._pre_check()
76 |
77 | self.n_targets = []
78 | for target in self.task_dict:
79 | if self.task_dict[target] == "mask":
80 | self.n_targets.append(self.ohe.embedding_size)
81 | elif self.task_dict[target] == "regression":
82 | self.n_targets.append(1)
83 | else:
84 | raise NotImplementedError("bad user")
85 |
86 | def __len__(self):
87 | return len(self.df)
88 |
89 | # @functools.lru_cache(maxsize=None) # Cache loaded structures
90 | def _get_nbr_data(self, crystal):
91 | """get neighbours for every site
92 |
93 | Args:
94 | crystal ([Structure]): pymatgen structure to get neighbours for
95 | """
96 | # # # neighbours
97 | self_idx, nbr_idx, _, nbr_dist = crystal.get_neighbor_list(
98 | self.radius,
99 | numerical_tol=1e-8,
100 | )
101 |
102 | if self.max_num_nbr is not None:
103 | _self_idx, _nbr_idx, _nbr_dist = [], [], []
104 |
105 | for _, g in groupby(zip(self_idx, nbr_idx, nbr_dist), key=lambda x: x[0]):
106 | s, n, d = zip(*sorted(g, key=lambda x: x[2]))
107 | _self_idx.extend(s[: self.max_num_nbr])
108 | _nbr_idx.extend(n[: self.max_num_nbr])
109 | _nbr_dist.extend(d[: self.max_num_nbr])
110 |
111 | self_idx = np.array(_self_idx)
112 | nbr_idx = np.array(_nbr_idx)
113 | nbr_dist = np.array(_nbr_dist)
114 |
115 | return self_idx, nbr_idx, nbr_dist
116 |
117 | def _pre_check(self):
118 | """Check that none of the structures have isolated atoms.
119 |
120 | Raises:
121 | ValueError: if isolated structures are present
122 | """
123 | print("checking all structures valid")
124 | all_iso = []
125 | some_iso = []
126 |
127 | # initialise empty columns of objects to insert the lists
128 | self.df[self.graph[0]] = [[] for _ in range(len(self.df))]
129 | self.df[self.graph[1]] = [[] for _ in range(len(self.df))]
130 | self.df[self.graph[2]] = [[] for _ in range(len(self.df))]
131 |
132 | for index, row in self.df.iterrows():
133 | # for cif_id, crystal in zip(self.df["material_id"], self.df["Structure_obj"]):
134 | image = 0
135 |
136 | cif_id = row["material_id"]
137 | crystal = row["Structure_obj"]
138 |
139 | while image == 0:
140 | self_idx, nbr_idx, nbr_dist = self._get_nbr_data(crystal)
141 |
142 | if np.any(self_idx == nbr_idx):
143 | # TODO only double along the shortest dimension
144 | crystal.make_supercell([2, 2, 2])
145 | else:
146 | image = 1
147 |
148 | if len(self_idx) == 0:
149 | all_iso.append(cif_id)
150 | elif len(nbr_idx) == 0:
151 | all_iso.append(cif_id)
152 | elif set(self_idx) != set(range(crystal.num_sites)):
153 | some_iso.append(cif_id)
154 |
155 | nbr_dist = self.gdf.expand(nbr_dist)
156 |
157 | nbr_dist = torch.Tensor(nbr_dist)
158 | self_idx = torch.LongTensor(self_idx)
159 | nbr_idx = torch.LongTensor(nbr_idx)
160 |
161 | self.df.at[index, self.graph[0]] = self_idx
162 | self.df.at[index, self.graph[1]] = nbr_idx
163 | self.df.at[index, self.graph[2]] = nbr_dist
164 |
165 | # TODO have the option for the pre-check to delete non-valid entries from the df?
166 |
167 | if (len(all_iso) > 0) or (len(some_iso) > 0):
168 | self.df = self.df.drop(self.df[self.df["material_id"].isin(all_iso)].index)
169 | self.df = self.df.drop(self.df[self.df["material_id"].isin(some_iso)].index)
170 |
171 | print(all_iso)
172 | print(some_iso)
173 | # raise ValueError("isolated structures contained in dataframe")
174 |
175 | # NOTE do not cache the pre-training structures as we want to see new sets of
176 | # masked structures each epoch as this effectively expands the training set
177 | # @functools.lru_cache(maxsize=None) # Cache loaded structures
178 | def __getitem__(self, idx):
179 | # NOTE sites must be given in fractional coordinates
180 | df_idx = self.df.iloc[idx]
181 | crystal = df_idx["Structure_obj"]
182 | cif_id, comp = df_idx[self.identifiers]
183 | self_idx, nbr_idx, nbr_dist = df_idx[self.graph]
184 |
185 | # atom features
186 | # TODO can this be vectorised with numpy?
187 |
188 | # handle disordered structures (multiple fractional elements per site)
189 | site_atoms = [atom.species.as_dict() for atom in crystal]
190 | atom_fea = np.vstack(
191 | [
192 | np.sum([self.ari.get_fea(el) * amt for el, amt in site.items()], axis=0)
193 | for site in site_atoms
194 | ]
195 | )
196 |
197 | atom_fea = torch.Tensor(atom_fea)
198 |
199 | targets = []
200 | for target in self.task_dict:
201 | if self.task_dict[target] == "regression":
202 | targets.append(
203 | torch.Tensor(
204 | [
205 | [
206 | df_idx[target],
207 | ],
208 | ]
209 | )
210 | )
211 | else:
212 | raise NotImplementedError("bad user")
213 |
214 | return ((atom_fea, nbr_dist, self_idx, nbr_idx), targets, comp, cif_id)
215 |
216 |
217 | class GaussianDistance:
218 | """
219 | Expands the distance by Gaussian basis.
220 |
221 | Unit: angstrom
222 | """
223 |
224 | def __init__(self, dmin, dmax, step, var=None):
225 | """
226 | Args:
227 | dmin (float): Minimum interatomic distance
228 | dmax (float): Maximum interatomic distance
229 | step (float): Step size for the Gaussian filter
230 | var (float, optional): Variance of Gaussian basis. Defaults to step if not given
231 | """
232 | assert dmin < dmax
233 | assert dmax - dmin > step
234 |
235 | self.filter = np.arange(dmin, dmax + step, step)
236 | self.embedding_size = len(self.filter)
237 |
238 | if var is None:
239 | var = step
240 |
241 | self.var = var
242 |
243 | def expand(self, distances):
244 | """Apply Gaussian distance filter to a numpy distance array
245 |
246 | Args:
247 | distances (ArrayLike): A distance matrix of any shape
248 |
249 | Returns:
250 | Expanded distance matrix with the last dimension of length
251 | len(self.filter)
252 | """
253 | distances = np.array(distances)
254 |
255 | return np.exp(
256 | -((distances[..., np.newaxis] - self.filter) ** 2) / self.var ** 2
257 | )
258 |
259 |
260 | def collate_batch(dataset_list):
261 | """
262 | Collate a list of data and return a batch for predicting crystal
263 | properties.
264 |
265 | Parameters
266 | ----------
267 |
268 | dataset_list: list of tuples for each data point.
269 | (atom_fea, nbr_dist, nbr_idx, target)
270 |
271 | atom_fea: torch.Tensor shape (n_i, atom_fea_len)
272 | nbr_dist: torch.Tensor shape (n_i, M, nbr_dist_len)
273 | nbr_idx: torch.LongTensor shape (n_i, M)
274 | target: torch.Tensor shape (1, )
275 | cif_id: str or int
276 |
277 | Returns
278 | -------
279 | N = sum(n_i); N0 = sum(i)
280 |
281 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len)
282 | Atom features from atom type
283 | batch_nbr_dist: torch.Tensor shape (N, M, nbr_dist_len)
284 | Bond features of each atom's M neighbors
285 | batch_nbr_idx: torch.LongTensor shape (N, M)
286 | Indices of M neighbors of each atom
287 | crystal_atom_idx: list of torch.LongTensor of length N0
288 | Mapping from the crystal idx to atom idx
289 | target: torch.Tensor shape (N, 1)
290 | Target value for prediction
291 | batch_cif_ids: list
292 | """
293 | batch_atom_fea = []
294 | batch_nbr_dist = []
295 | batch_self_idx = []
296 | batch_nbr_idx = []
297 | crystal_atom_idx = []
298 | batch_targets = []
299 | batch_comps = []
300 | batch_cif_ids = []
301 | base_idx = 0
302 |
303 | for i, (inputs, target, comp, cif_id) in enumerate(dataset_list):
304 | atom_fea, nbr_dist, self_idx, nbr_idx = inputs
305 | n_atoms = atom_fea.shape[0] # number of atoms for this crystal
306 |
307 | batch_atom_fea.append(atom_fea)
308 | batch_nbr_dist.append(nbr_dist)
309 | batch_self_idx.append(self_idx + base_idx)
310 | batch_nbr_idx.append(nbr_idx + base_idx)
311 |
312 | crystal_atom_idx.extend([i] * n_atoms)
313 |
314 | batch_targets.append(target)
315 | batch_comps.append(comp)
316 | batch_cif_ids.append(cif_id)
317 |
318 | base_idx += n_atoms
319 |
320 | atom_fea = torch.cat(batch_atom_fea, dim=0)
321 | nbr_dist = torch.cat(batch_nbr_dist, dim=0)
322 | self_idx = torch.cat(batch_self_idx, dim=0)
323 | nbr_idx = torch.cat(batch_nbr_idx, dim=0)
324 |
325 | cry_idx = torch.LongTensor(crystal_atom_idx)
326 |
327 | return (
328 | (atom_fea, nbr_dist, self_idx, nbr_idx, cry_idx),
329 | tuple(torch.cat(b_target, dim=0) for b_target in zip(*batch_targets)),
330 | batch_comps,
331 | batch_cif_ids,
332 | )
333 |
334 |
335 | def get_structure(cols):
336 | """Return pymatgen structure from lattice and sites cols"""
337 | cell, sites = cols
338 | cell, elems, coords = parse_cgcnn(cell, sites)
339 | # NOTE getting primitive structure before constructing graph
340 | # significantly harms the performance of this model.
341 |
342 | crystal = Structure(lattice=cell, species=elems, coords=coords, to_unit_cell=True)
343 |
344 | # In place modification of structures that only contain a few sites
345 | # this is to allow us to mask ~15% of sites without having a
346 | # disproportionate impact on small unit cell structures.
347 | if crystal.num_sites < 7:
348 | crystal.make_supercell([2, 2, 2])
349 |
350 | return crystal
351 |
352 |
353 | def parse_cgcnn(cell, sites):
354 | """Parse str representation into lists"""
355 | cell = np.array(ast.literal_eval(cell), dtype=float)
356 | elems = []
357 | coords = []
358 | for site in ast.literal_eval(sites):
359 | ele, pos = site.split(" @ ")
360 | elems.append(ele)
361 | coords.append(pos.split(" "))
362 |
363 | coords = np.array(coords, dtype=float)
364 | return cell, elems, coords
365 |
--------------------------------------------------------------------------------
/roost/pretrain/ener_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from roost.cgcnn.model import DescriptorNetwork
4 | from roost.core import BaseModelClass
5 | from roost.segments import MeanPooling
6 |
7 |
8 | class CrystalGraphPreNet(BaseModelClass):
9 | """
10 | Create a crystal graph convolutional neural network for predicting total
11 | material properties.
12 |
13 | This model is based on: https://github.com/txie-93/cgcnn [MIT License].
14 | Changes to the code were made to allow for the removal of zero-padding
15 | and to benefit from the BaseModelClass functionality. The architectural
16 | choices of the model remain unchanged.
17 | """
18 |
19 | def __init__(
20 | self,
21 | robust,
22 | n_targets,
23 | elem_emb_len,
24 | nbr_fea_len,
25 | elem_fea_len=64,
26 | n_graph=4,
27 | **kwargs,
28 | ):
29 | """
30 | Initialize CrystalGraphConvNet.
31 |
32 | Parameters
33 | ----------
34 |
35 | orig_elem_fea_len: int
36 | Number of atom features in the input.
37 | nbr_fea_len: int
38 | Number of bond features.
39 | elem_fea_len: int
40 | Number of hidden atom features in the convolutional layers
41 | n_graph: int
42 | Number of convolutional layers
43 | h_fea_len: int
44 | Number of hidden features after pooling
45 | n_hidden: int
46 | Number of hidden layers after pooling
47 | """
48 | super().__init__(robust=robust, **kwargs)
49 |
50 | desc_dict = {
51 | "elem_emb_len": elem_emb_len,
52 | "nbr_fea_len": nbr_fea_len,
53 | "elem_fea_len": elem_fea_len,
54 | "n_graph": n_graph,
55 | }
56 |
57 | self.node_nn = DescriptorNetwork(**desc_dict)
58 |
59 | self.model_params.update(
60 | {
61 | "robust": robust,
62 | "n_targets": n_targets,
63 | }
64 | )
65 |
66 | self.global_pool = MeanPooling()
67 | self.global_linear = nn.Linear(elem_fea_len, 1)
68 |
69 | self.model_params.update(desc_dict)
70 |
71 | def forward(self, atom_fea, nbr_fea, self_idx, nbr_idx, cry_idx):
72 | """
73 | Forward pass
74 |
75 | N: Total number of atoms in the batch
76 | M: Max number of neighbors
77 | N0: Total number of crystals in the batch
78 |
79 | Parameters
80 | ----------
81 |
82 | atom_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len)
83 | Atom features from atom type
84 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
85 | Bond features of each atom's M neighbors
86 | nbr_idx: torch.LongTensor shape (N, M)
87 | Indices of M neighbors of each atom
88 | crystal_atom_idx: list of torch.LongTensor of length N0
89 | Mapping from the crystal idx to atom idx
90 |
91 | Returns
92 | -------
93 |
94 | prediction: nn.Variable shape (N, )
95 | Atom hidden features after convolution
96 |
97 | """
98 | crys_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx)
99 |
100 | glob = self.global_linear(self.global_pool(crys_fea, cry_idx))
101 |
102 | return [glob]
103 |
--------------------------------------------------------------------------------
/roost/roost/data.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from pymatgen.core.composition import Composition
8 | from torch.utils.data import Dataset
9 |
10 | from roost.core import Featurizer
11 |
12 |
13 | class CompositionData(Dataset):
14 | """
15 | The CompositionData dataset is a wrapper for a dataset data points are
16 | automatically constructed from composition strings.
17 | """
18 |
19 | def __init__(
20 | self,
21 | data_path,
22 | fea_path,
23 | task_dict,
24 | inputs=["composition"],
25 | identifiers=["material_id", "composition"],
26 | # identifiers=["material_id", "composition"],
27 | ):
28 | """[summary]
29 |
30 | Args:
31 | data_path (str): [description]
32 | fea_path (str): [description]
33 | task_dict ({name: task}): list of tasks
34 | inputs (list, optional): column name for compositions.
35 | Defaults to ["composition"].
36 | identifiers (list, optional): column names for unique identifier
37 | and pretty name. Defaults to ["id", "composition"].
38 | """
39 |
40 | assert len(identifiers) == 2, "Two identifiers are required"
41 | assert len(inputs) == 1, "One input column required are required"
42 |
43 | self.inputs = inputs
44 | self.task_dict = task_dict
45 | self.identifiers = identifiers
46 |
47 | assert os.path.exists(data_path), f"{data_path} does not exist!"
48 | # NOTE make sure to use dense datasets,
49 | # NOTE do not use default_na as "NaN" is a valid material
50 | self.df = pd.read_csv(data_path, keep_default_na=False, na_values=[])
51 |
52 | assert os.path.exists(fea_path), f"{fea_path} does not exist!"
53 | self.elem_features = Featurizer.from_json(fea_path)
54 | self.elem_emb_len = self.elem_features.embedding_size
55 |
56 | self.n_targets = []
57 | for target, task in self.task_dict.items():
58 | if task == "regression":
59 | self.n_targets.append(1)
60 | elif task == "classification":
61 | n_classes = np.max(self.df[target].values) + 1
62 | self.n_targets.append(n_classes)
63 |
64 | def __len__(self):
65 | return len(self.df)
66 |
67 | @functools.lru_cache(maxsize=None) # Cache data for faster training
68 | def __getitem__(self, idx):
69 | """[summary]
70 |
71 | Args:
72 | idx (int): dataset index
73 |
74 | Raises:
75 | AssertionError: [description]
76 | ValueError: [description]
77 |
78 | Returns:
79 | atom_weights: torch.Tensor shape (M, 1)
80 | weights of atoms in the material
81 | atom_fea: torch.Tensor shape (M, n_fea)
82 | features of atoms in the material
83 | self_fea_idx: torch.Tensor shape (M*M, 1)
84 | list of self indices
85 | nbr_fea_idx: torch.Tensor shape (M*M, 1)
86 | list of neighbor indices
87 | target: torch.Tensor shape (1,)
88 | target value for material
89 | cry_id: torch.Tensor shape (1,)
90 | input id for the material
91 |
92 | """
93 | df_idx = self.df.iloc[idx]
94 | composition = df_idx[self.inputs][0]
95 | cry_ids = df_idx[self.identifiers].values
96 |
97 | comp_dict = Composition(composition).get_el_amt_dict()
98 | elements = list(comp_dict.keys())
99 |
100 | weights = list(comp_dict.values())
101 | weights = np.atleast_2d(weights).T / np.sum(weights)
102 |
103 | try:
104 | atom_fea = np.vstack(
105 | [self.elem_features.get_fea(element) for element in elements]
106 | )
107 | except AssertionError:
108 | raise AssertionError(
109 | f"cry-id {cry_ids[0]} [{composition}] contains element types not in embedding"
110 | )
111 | except ValueError:
112 | raise ValueError(
113 | f"cry-id {cry_ids[0]} [{composition}] composition cannot be parsed into elements"
114 | )
115 |
116 | nele = len(elements)
117 | self_fea_idx = []
118 | nbr_fea_idx = []
119 | for i, _ in enumerate(elements):
120 | self_fea_idx += [i] * nele
121 | nbr_fea_idx += list(range(nele))
122 |
123 | # convert all data to tensors
124 | atom_weights = torch.Tensor(weights)
125 | atom_fea = torch.Tensor(atom_fea)
126 | self_fea_idx = torch.LongTensor(self_fea_idx)
127 | nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
128 |
129 | targets = []
130 | for target in self.task_dict:
131 | if self.task_dict[target] == "regression":
132 | targets.append(torch.Tensor([df_idx[target]]))
133 | elif self.task_dict[target] == "classification":
134 | targets.append(torch.LongTensor([df_idx[target]]))
135 |
136 | return (
137 | (atom_weights, atom_fea, self_fea_idx, nbr_fea_idx),
138 | targets,
139 | *cry_ids,
140 | )
141 |
142 |
143 | def collate_batch(dataset_list):
144 | """
145 | Collate a list of data and return a batch for predicting crystal
146 | properties.
147 |
148 | Parameters
149 | ----------
150 |
151 | dataset_list: list of tuples for each data point.
152 | (atom_fea, nbr_fea, nbr_fea_idx, target)
153 |
154 | atom_fea: torch.Tensor shape (n_i, atom_fea_len)
155 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len)
156 | self_fea_idx: torch.LongTensor shape (n_i, M)
157 | nbr_fea_idx: torch.LongTensor shape (n_i, M)
158 | target: torch.Tensor shape (1, )
159 | cif_id: str or int
160 |
161 | Returns
162 | -------
163 | N = sum(n_i); N0 = sum(i)
164 |
165 | batch_atom_weights: torch.Tensor shape (N, 1)
166 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len)
167 | Atom features from atom type
168 | batch_self_fea_idx: torch.LongTensor shape (N, M)
169 | Indices of mapping atom to copies of itself
170 | batch_nbr_fea_idx: torch.LongTensor shape (N, M)
171 | Indices of M neighbors of each atom
172 | crystal_atom_idx: list of torch.LongTensor of length N0
173 | Mapping from the crystal idx to atom idx
174 | target: torch.Tensor shape (N, 1)
175 | Target value for prediction
176 | batch_comps: list
177 | batch_ids: list
178 | """
179 | # define the lists
180 | batch_atom_weights = []
181 | batch_atom_fea = []
182 | batch_self_fea_idx = []
183 | batch_nbr_fea_idx = []
184 | crystal_atom_idx = []
185 | batch_targets = []
186 | batch_cry_ids = []
187 |
188 | cry_base_idx = 0
189 | for i, (inputs, target, *cry_ids) in enumerate(dataset_list):
190 | atom_weights, atom_fea, self_fea_idx, nbr_fea_idx = inputs
191 |
192 | # number of atoms for this crystal
193 | n_i = atom_fea.shape[0]
194 |
195 | # batch the features together
196 | batch_atom_weights.append(atom_weights)
197 | batch_atom_fea.append(atom_fea)
198 |
199 | # mappings from bonds to atoms
200 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx)
201 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx)
202 |
203 | # mapping from atoms to crystals
204 | crystal_atom_idx.append(torch.tensor([i] * n_i))
205 |
206 | # batch the targets and ids
207 | batch_targets.append(target)
208 | batch_cry_ids.append(cry_ids)
209 |
210 | # increment the id counter
211 | cry_base_idx += n_i
212 |
213 | return (
214 | (
215 | torch.cat(batch_atom_weights, dim=0),
216 | torch.cat(batch_atom_fea, dim=0),
217 | torch.cat(batch_self_fea_idx, dim=0),
218 | torch.cat(batch_nbr_fea_idx, dim=0),
219 | torch.cat(crystal_atom_idx),
220 | ),
221 | tuple(torch.stack(b_target, dim=0) for b_target in zip(*batch_targets)),
222 | *zip(*batch_cry_ids),
223 | )
224 |
--------------------------------------------------------------------------------
/roost/roost/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from roost.core import BaseModelClass
6 | from roost.segments import ResidualNetwork, SimpleNetwork, WeightedAttentionPooling
7 |
8 |
9 | class Roost(BaseModelClass):
10 | """
11 | The Roost model is comprised of a fully connected network
12 | and message passing graph layers.
13 |
14 | The message passing layers are used to determine a descriptor set
15 | for the fully connected network. The graphs are used to represent
16 | the stoichiometry of inorganic materials in a trainable manner.
17 | This makes them systematically improvable with more data.
18 | """
19 |
20 | def __init__(
21 | self,
22 | robust,
23 | n_targets,
24 | elem_emb_len,
25 | elem_fea_len=64,
26 | n_graph=3,
27 | elem_heads=3,
28 | elem_gate=[256],
29 | elem_msg=[256],
30 | cry_heads=3,
31 | cry_gate=[256],
32 | cry_msg=[256],
33 | trunk_hidden=[1024, 512],
34 | out_hidden=[256, 128, 64],
35 | **kwargs
36 | ):
37 | if isinstance(out_hidden[0], list):
38 | raise ValueError("boo hiss bad user")
39 | # assert all([isinstance(x, list) for x in out_hidden]),
40 | # 'all elements of out_hidden must be ints or all lists'
41 | # assert len(out_hidden) == len(n_targets),
42 | # 'out_hidden-n_targets length mismatch'
43 |
44 | super().__init__(robust=robust, **kwargs)
45 |
46 | desc_dict = {
47 | "elem_emb_len": elem_emb_len,
48 | "elem_fea_len": elem_fea_len,
49 | "n_graph": n_graph,
50 | "elem_heads": elem_heads,
51 | "elem_gate": elem_gate,
52 | "elem_msg": elem_msg,
53 | "cry_heads": cry_heads,
54 | "cry_gate": cry_gate,
55 | "cry_msg": cry_msg,
56 | }
57 |
58 | self.material_nn = DescriptorNetwork(**desc_dict)
59 |
60 | self.model_params.update(
61 | {
62 | "robust": robust,
63 | "n_targets": n_targets,
64 | "out_hidden": out_hidden,
65 | "trunk_hidden": trunk_hidden,
66 | }
67 | )
68 |
69 | self.model_params.update(desc_dict)
70 |
71 | # define an output neural network
72 | if self.robust:
73 | n_targets = [2 * n for n in n_targets]
74 |
75 | self.trunk_nn = ResidualNetwork(elem_fea_len, out_hidden[0], trunk_hidden)
76 |
77 | self.output_nns = nn.ModuleList([
78 | ResidualNetwork(out_hidden[0], n, out_hidden[1:]) for n in n_targets
79 | ])
80 |
81 | def forward(self, elem_weights, elem_fea, self_fea_idx, nbr_fea_idx, cry_elem_idx):
82 | """
83 | Forward pass through the material_nn and output_nn
84 | """
85 | crys_fea = self.material_nn(
86 | elem_weights, elem_fea, self_fea_idx, nbr_fea_idx, cry_elem_idx
87 | )
88 |
89 | crys_fea = F.relu(self.trunk_nn(crys_fea))
90 |
91 | # apply neural network to map from learned features to target
92 | return (output_nn(crys_fea) for output_nn in self.output_nns)
93 |
94 | def __repr__(self):
95 | return self.__class__.__name__
96 |
97 |
98 | class DescriptorNetwork(nn.Module):
99 | """
100 | The Descriptor Network is the message passing section of the
101 | Roost Model.
102 | """
103 |
104 | def __init__(
105 | self,
106 | elem_emb_len,
107 | elem_fea_len=64,
108 | n_graph=3,
109 | elem_heads=3,
110 | elem_gate=[256],
111 | elem_msg=[256],
112 | cry_heads=3,
113 | cry_gate=[256],
114 | cry_msg=[256],
115 | ):
116 | """
117 | """
118 | super().__init__()
119 |
120 | # apply linear transform to the input to get a trainable embedding
121 | # NOTE -1 here so we can add the weights as a node feature
122 | self.embedding = nn.Linear(elem_emb_len, elem_fea_len - 1)
123 |
124 | # create a list of Message passing layers
125 | self.graphs = nn.ModuleList(
126 | [
127 | MessageLayer(
128 | elem_fea_len=elem_fea_len,
129 | elem_heads=elem_heads,
130 | elem_gate=elem_gate,
131 | elem_msg=elem_msg,
132 | )
133 | for i in range(n_graph)
134 | ]
135 | )
136 |
137 | # define a global pooling function for materials
138 | self.cry_pool = nn.ModuleList(
139 | [
140 | WeightedAttentionPooling(
141 | gate_nn=SimpleNetwork(elem_fea_len, 1, cry_gate),
142 | message_nn=SimpleNetwork(elem_fea_len, elem_fea_len, cry_msg),
143 | )
144 | for _ in range(cry_heads)
145 | ]
146 | )
147 |
148 | def forward(self, elem_weights, elem_fea, self_fea_idx, nbr_fea_idx, cry_elem_idx):
149 | """
150 | Forward pass
151 |
152 | Parameters
153 | ----------
154 | N: Total number of elements (nodes) in the batch
155 | M: Total number of pairs (edges) in the batch
156 | C: Total number of crystals (graphs) in the batch
157 |
158 | Inputs
159 | ----------
160 | elem_weights: Variable(torch.Tensor) shape (N)
161 | Fractional weight of each Element in its stoichiometry
162 | elem_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len)
163 | Element features of each of the N elems in the batch
164 | self_fea_idx: torch.Tensor shape (M,)
165 | Indices of the first element in each of the M pairs
166 | nbr_fea_idx: torch.Tensor shape (M,)
167 | Indices of the second element in each of the M pairs
168 | cry_elem_idx: list of torch.LongTensor of length C
169 | Mapping from the elem idx to crystal idx
170 |
171 | Returns
172 | -------
173 | cry_fea: nn.Variable shape (C,)
174 | Material representation after message passing
175 | """
176 |
177 | # embed the original features into a trainable embedding space
178 | elem_fea = self.embedding(elem_fea)
179 |
180 | # add weights as a node feature
181 | elem_fea = torch.cat([elem_fea, elem_weights], dim=1)
182 |
183 | # apply the message passing functions
184 | for graph_func in self.graphs:
185 | elem_fea = graph_func(elem_weights, elem_fea, self_fea_idx, nbr_fea_idx)
186 |
187 | # generate crystal features by pooling the elemental features
188 | head_fea = []
189 | for attnhead in self.cry_pool:
190 | head_fea.append(
191 | attnhead(elem_fea, index=cry_elem_idx, weights=elem_weights)
192 | )
193 |
194 | # head_fea = [
195 | # head(elem_fea, index=cry_elem_idx, weights=elem_weights)
196 | # for head in self.cry_pool
197 | # ]
198 |
199 | return torch.mean(torch.stack(head_fea), dim=0)
200 |
201 | def __repr__(self):
202 | return self.__class__.__name__
203 |
204 |
205 | class MessageLayer(nn.Module):
206 | """
207 | Massage Layers are used to propagate information between nodes in
208 | the stoichiometry graph.
209 | """
210 |
211 | def __init__(self, elem_fea_len, elem_heads, elem_gate, elem_msg):
212 | """
213 | """
214 | super().__init__()
215 |
216 | # Pooling and Output
217 | self.pooling = nn.ModuleList(
218 | [
219 | WeightedAttentionPooling(
220 | gate_nn=SimpleNetwork(2 * elem_fea_len, 1, elem_gate),
221 | message_nn=SimpleNetwork(2 * elem_fea_len, elem_fea_len, elem_msg),
222 | )
223 | for _ in range(elem_heads)
224 | ]
225 | )
226 |
227 | def forward(self, elem_weights, elem_in_fea, self_fea_idx, nbr_fea_idx):
228 | """
229 | Forward pass
230 |
231 | Parameters
232 | ----------
233 | N: Total number of elements (nodes) in the batch
234 | M: Total number of pairs (edges) in the batch
235 | C: Total number of crystals (graphs) in the batch
236 |
237 | Inputs
238 | ----------
239 | elem_weights: Variable(torch.Tensor) shape (N,)
240 | The fractional weights of elems in their materials
241 | elem_in_fea: Variable(torch.Tensor) shape (N, elem_fea_len)
242 | Element hidden features before message passing
243 | self_fea_idx: torch.Tensor shape (M,)
244 | Indices of the first element in each of the M pairs
245 | nbr_fea_idx: torch.Tensor shape (M,)
246 | Indices of the second element in each of the M pairs
247 |
248 | Returns
249 | -------
250 | elem_out_fea: nn.Variable shape (N, elem_fea_len)
251 | Element hidden features after message passing
252 | """
253 | # construct the total features for passing
254 | elem_nbr_weights = elem_weights[nbr_fea_idx, :]
255 | elem_nbr_fea = elem_in_fea[nbr_fea_idx, :]
256 | elem_self_fea = elem_in_fea[self_fea_idx, :]
257 | fea = torch.cat([elem_self_fea, elem_nbr_fea], dim=1)
258 |
259 | # sum selectivity over the neighbours to get elems
260 | head_fea = []
261 | for attnhead in self.pooling:
262 | head_fea.append(
263 | attnhead(fea, index=self_fea_idx, weights=elem_nbr_weights)
264 | )
265 |
266 | # average the attention heads
267 | fea = torch.mean(torch.stack(head_fea), dim=0)
268 |
269 | return fea + elem_in_fea
270 |
271 | def __repr__(self):
272 | return self.__class__.__name__
273 |
--------------------------------------------------------------------------------
/roost/segments.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_scatter import scatter_add, scatter_max, scatter_mean
4 |
5 |
6 | class MeanPooling(nn.Module):
7 | """Mean pooling"""
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x, index):
13 | return scatter_mean(x, index, dim=0)
14 |
15 | def __repr__(self):
16 | return self.__class__.__name__
17 |
18 |
19 | class SumPooling(nn.Module):
20 | """Sum pooling"""
21 |
22 | def __init__(self):
23 | super().__init__()
24 |
25 | def forward(self, x, index):
26 | return scatter_add(x, index, dim=0)
27 |
28 | def __repr__(self):
29 | return self.__class__.__name__
30 |
31 |
32 | class AttentionPooling(nn.Module):
33 | """
34 | softmax attention layer
35 | """
36 |
37 | def __init__(self, gate_nn, message_nn):
38 | """
39 | Args:
40 | gate_nn: Variable(nn.Module)
41 | message_nn
42 | """
43 | super().__init__()
44 | self.gate_nn = gate_nn
45 | self.message_nn = message_nn
46 |
47 | def forward(self, x, index):
48 | gate = self.gate_nn(x)
49 |
50 | gate = gate - scatter_max(gate, index, dim=0)[0][index]
51 | gate = gate.exp()
52 | gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
53 |
54 | x = self.message_nn(x)
55 | out = scatter_add(gate * x, index, dim=0)
56 |
57 | return out
58 |
59 | def __repr__(self):
60 | return self.__class__.__name__
61 |
62 |
63 | class WeightedAttentionPooling(nn.Module):
64 | """
65 | Weighted softmax attention layer
66 | """
67 |
68 | def __init__(self, gate_nn, message_nn):
69 | """
70 | Inputs
71 | ----------
72 | gate_nn: Variable(nn.Module)
73 | """
74 | super().__init__()
75 | self.gate_nn = gate_nn
76 | self.message_nn = message_nn
77 | self.pow = torch.nn.Parameter(torch.randn(1))
78 |
79 | def forward(self, x, index, weights):
80 | gate = self.gate_nn(x)
81 |
82 | gate = gate - scatter_max(gate, index, dim=0)[0][index]
83 | gate = (weights ** self.pow) * gate.exp()
84 | # gate = weights * gate.exp()
85 | # gate = gate.exp()
86 | gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
87 |
88 | x = self.message_nn(x)
89 | out = scatter_add(gate * x, index, dim=0)
90 |
91 | return out
92 |
93 | def __repr__(self):
94 | return self.__class__.__name__
95 |
96 |
97 | class SimpleNetwork(nn.Module):
98 | """
99 | Simple Feed Forward Neural Network
100 | """
101 |
102 | def __init__(
103 | self,
104 | input_dim,
105 | output_dim,
106 | hidden_layer_dims,
107 | activation=nn.LeakyReLU,
108 | batchnorm=False,
109 | ):
110 | """
111 | Inputs
112 | ----------
113 | input_dim: int
114 | output_dim: int
115 | hidden_layer_dims: list(int)
116 |
117 | """
118 | super().__init__()
119 |
120 | dims = [input_dim] + hidden_layer_dims
121 |
122 | self.fcs = nn.ModuleList(
123 | [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)]
124 | )
125 |
126 | if batchnorm:
127 | self.bns = nn.ModuleList(
128 | [nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1)]
129 | )
130 | else:
131 | self.bns = nn.ModuleList([nn.Identity() for i in range(len(dims) - 1)])
132 |
133 | self.acts = nn.ModuleList([activation() for _ in range(len(dims) - 1)])
134 |
135 | self.fc_out = nn.Linear(dims[-1], output_dim)
136 |
137 | def forward(self, x):
138 | for fc, bn, act in zip(self.fcs, self.bns, self.acts):
139 | x = act(bn(fc(x)))
140 |
141 | return self.fc_out(x)
142 |
143 | def __repr__(self):
144 | return self.__class__.__name__
145 |
146 | def reset_parameters(self):
147 | for fc in self.fcs:
148 | fc.reset_parameters()
149 |
150 | self.fc_out.reset_parameters()
151 |
152 |
153 | class ResidualNetwork(nn.Module):
154 | """
155 | Feed forward Residual Neural Network
156 | """
157 |
158 | def __init__(
159 | self,
160 | input_dim,
161 | output_dim,
162 | hidden_layer_dims,
163 | activation=nn.ReLU,
164 | batchnorm=False,
165 | return_features=False,
166 | ):
167 | """
168 | Inputs
169 | ----------
170 | input_dim: int
171 | output_dim: int
172 | hidden_layer_dims: list(int)
173 |
174 | """
175 | super().__init__()
176 |
177 | dims = [input_dim] + hidden_layer_dims
178 |
179 | self.fcs = nn.ModuleList(
180 | [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)]
181 | )
182 |
183 | if batchnorm:
184 | self.bns = nn.ModuleList(
185 | [nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1)]
186 | )
187 | else:
188 | self.bns = nn.ModuleList([nn.Identity() for i in range(len(dims) - 1)])
189 |
190 | self.res_fcs = nn.ModuleList(
191 | [
192 | nn.Linear(dims[i], dims[i + 1], bias=False)
193 | if (dims[i] != dims[i + 1])
194 | else nn.Identity()
195 | for i in range(len(dims) - 1)
196 | ]
197 | )
198 | self.acts = nn.ModuleList([activation() for _ in range(len(dims) - 1)])
199 |
200 | self.return_features = return_features
201 | if not self.return_features:
202 | self.fc_out = nn.Linear(dims[-1], output_dim)
203 |
204 | def forward(self, x):
205 | for fc, bn, res_fc, act in zip(self.fcs, self.bns, self.res_fcs, self.acts):
206 | x = act(bn(fc(x))) + res_fc(x)
207 |
208 | if self.return_features:
209 | return x
210 | else:
211 | return self.fc_out(x)
212 |
213 | def __repr__(self):
214 | return self.__class__.__name__
215 |
--------------------------------------------------------------------------------
/roost/wren/data.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import functools
3 | import json
4 | import os
5 | import re
6 | from itertools import groupby
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import torch
11 | from torch.utils.data import Dataset
12 |
13 | from roost.wren.utils import mult_dict, relab_dict
14 |
15 |
16 | class WyckoffData(Dataset):
17 | """
18 | The WrenData dataset is a wrapper for a dataset data points are
19 | automatically constructed from composition strings.
20 | """
21 |
22 | def __init__(
23 | self,
24 | data_path,
25 | sym_path,
26 | fea_path,
27 | task_dict,
28 | inputs=["wyckoff"],
29 | identifiers=["material_id", "composition", "wyckoff"],
30 | ):
31 | """[summary]
32 |
33 | Args:
34 | data_path ([type]): [description]
35 | sym_path ([type]): [description]
36 | fea_path ([type]): [description]
37 | task_dict ([type]): [description]
38 | inputs (list, optional): [description]. Defaults to ["composition"].
39 | identifiers (list, optional): [description]. Defaults to ["material_id", "composition"].
40 | """
41 | assert len(identifiers) >= 2, "Two identifiers are required"
42 | assert len(inputs) == 1, "One input column required"
43 |
44 | self.inputs = inputs
45 | self.task_dict = task_dict
46 | self.identifiers = identifiers
47 |
48 | assert os.path.exists(data_path), f"{data_path} does not exist!"
49 | # NOTE make sure to use dense datasets,
50 | # NOTE do not use default_na as "NaN" is a valid material composition
51 | self.df = pd.read_csv(data_path, keep_default_na=False, na_values=[])
52 |
53 | assert os.path.exists(fea_path), f"{fea_path} does not exist!"
54 |
55 | # TODO now using 2 level dicts so can't use featuriser, can this be standardised?
56 | # self.atom_features = Featurizer.from_json(fea_path)
57 | with open(fea_path) as f:
58 | self.atom_features = json.load(f)
59 |
60 | assert os.path.exists(sym_path), f"{sym_path} does not exist!"
61 | # self.sym_features = Featurizer.from_json(sym_path)
62 | with open(sym_path) as f:
63 | self.sym_features = json.load(f)
64 |
65 | # self.elem_fea_dim = self.atom_features.embedding_size
66 | # self.sym_fea_dim = self.sym_features.embedding_size
67 |
68 | self.elem_fea_dim = len(list(self.atom_features.values())[0])
69 | self.sym_fea_dim = len(list(list(self.sym_features.values())[0].values())[0])
70 |
71 | self.n_targets = []
72 | for target, task in self.task_dict.items():
73 | if task == "regression":
74 | self.n_targets.append(1)
75 | elif task == "classification":
76 | n_classes = np.max(self.df[target].values) + 1
77 | self.n_targets.append(n_classes)
78 |
79 | def __len__(self):
80 | return len(self.df)
81 |
82 | @functools.lru_cache(maxsize=None) # Cache loaded structures
83 | def __getitem__(self, idx):
84 | """[summary]
85 |
86 | Args:
87 | idx ([type]): [description]
88 |
89 | Raises:
90 | AssertionError: [description]
91 |
92 | Returns:
93 | atom_weights: torch.Tensor shape (M, 1)
94 | weights of atoms in the material
95 | atom_fea: torch.Tensor shape (M, n_fea)
96 | features of atoms in the material
97 | self_fea_idx: torch.Tensor shape (M*M, 1)
98 | list of self indices
99 | nbr_fea_idx: torch.Tensor shape (M*M, 1)
100 | list of neighbour indices
101 | target: torch.Tensor shape (1,)
102 | target value for material
103 | cry_id: torch.Tensor shape (1,)
104 | input id for the material
105 | """
106 | df_idx = self.df.iloc[idx]
107 | swyks = df_idx[self.inputs][0]
108 | cry_ids = df_idx[self.identifiers].values
109 |
110 | # print(cry_id, composition, swyks)
111 |
112 | spg_no, weights, elements, aug_wyks = parse_aflow(swyks)
113 | # spg_no, weights, elements, aug_wyks = parse_wren(swyks)
114 | weights = np.atleast_2d(weights).T / np.sum(weights)
115 |
116 | try:
117 | atom_fea = np.vstack([self.atom_features[el] for el in elements])
118 | sym_fea = np.vstack(
119 | [self.sym_features[spg_no][wyk] for wyks in aug_wyks for wyk in wyks]
120 | )
121 | except AssertionError:
122 | print(f"failed to process {cry_ids[0]}: {cry_ids[1]}-{swyks}")
123 | raise
124 |
125 | n_wyks = len(elements)
126 | self_fea_idx = []
127 | nbr_fea_idx = []
128 | for i in range(n_wyks):
129 | self_fea_idx += [i] * n_wyks
130 | nbr_fea_idx += list(range(n_wyks))
131 |
132 | self_aug_fea_idx = []
133 | nbr_aug_fea_idx = []
134 | n_aug = len(aug_wyks)
135 | for i in range(n_aug):
136 | self_aug_fea_idx += [x + i * n_wyks for x in self_fea_idx]
137 | nbr_aug_fea_idx += [x + i * n_wyks for x in nbr_fea_idx]
138 |
139 | # convert all data to tensors
140 | atom_weights = torch.Tensor(weights)
141 | atom_fea = torch.Tensor(atom_fea)
142 | sym_fea = torch.Tensor(sym_fea)
143 | self_fea_idx = torch.LongTensor(self_aug_fea_idx)
144 | nbr_fea_idx = torch.LongTensor(nbr_aug_fea_idx)
145 |
146 | targets = []
147 | for name in self.task_dict:
148 | if self.task_dict[name] == "regression":
149 | targets.append(torch.Tensor([float(self.df.iloc[idx][name])]))
150 | elif self.task_dict[name] == "classification":
151 | targets.append(torch.LongTensor([int(self.df.iloc[idx][name])]))
152 |
153 | return (
154 | (atom_weights, atom_fea, sym_fea, self_fea_idx, nbr_fea_idx),
155 | targets,
156 | *cry_ids,
157 | )
158 |
159 |
160 | def collate_batch(dataset_list):
161 | """Collate a list of data and return a batch for predicting
162 | crystal properties.
163 |
164 | N = sum(n_i); N0 = sum(i)
165 |
166 | Args:
167 | dataset_list ([tuple]): list of tuples for each data point.
168 | (atom_fea, nbr_fea, nbr_fea_idx, target)
169 |
170 | atom_fea: torch.Tensor shape (n_i, atom_fea_len)
171 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len)
172 | nbr_fea_idx: torch.LongTensor shape (n_i, M)
173 | target: torch.Tensor shape (1, )
174 | cif_id: str or int
175 |
176 | Returns:
177 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len)
178 | Atom features from atom type
179 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len)
180 | Bond features of each atom"s M neighbors
181 | batch_nbr_fea_idx: torch.LongTensor shape (N, M)
182 | Indices of M neighbors of each atom
183 | crystal_atom_idx: list of torch.LongTensor of length N0
184 | Mapping from the crystal idx to atom idx
185 | target: torch.Tensor shape (N, 1)
186 | Target value for prediction
187 | batch_cif_ids: list
188 | """
189 | # define the lists
190 | batch_atom_weights = []
191 | batch_atom_fea = []
192 | batch_sym_fea = []
193 | batch_self_fea_idx = []
194 | batch_nbr_fea_idx = []
195 | crystal_atom_idx = []
196 | aug_cry_idx = []
197 | batch_targets = []
198 | batch_cry_ids = []
199 |
200 | aug_count = 0
201 | cry_base_idx = 0
202 | for i, (inputs, target, *cry_ids) in enumerate(dataset_list):
203 | atom_weights, atom_fea, sym_fea, self_fea_idx, nbr_fea_idx = inputs
204 |
205 | # number of atoms for this crystal
206 | n_el = atom_fea.shape[0]
207 | n_i = sym_fea.shape[0]
208 | n_aug = int(float(n_i) / float(n_el))
209 |
210 | # batch the features together
211 | batch_atom_weights.append(atom_weights.repeat((n_aug, 1)))
212 | batch_atom_fea.append(atom_fea.repeat((n_aug, 1)))
213 | batch_sym_fea.append(sym_fea)
214 |
215 | # mappings from bonds to atoms
216 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx)
217 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx)
218 |
219 | # mapping from atoms to crystals
220 | # print(torch.tensor(range(i, i+n_aug)).size())
221 | crystal_atom_idx.append(
222 | torch.tensor(range(aug_count, aug_count + n_aug)).repeat_interleave(n_el)
223 | )
224 | aug_cry_idx.append(torch.tensor([i] * n_aug))
225 |
226 | # batch the targets and ids
227 | batch_targets.append(target)
228 | batch_cry_ids.append(cry_ids)
229 |
230 | # increment the id counter
231 | aug_count += n_aug
232 | cry_base_idx += n_i
233 |
234 | return (
235 | (
236 | torch.cat(batch_atom_weights, dim=0),
237 | torch.cat(batch_atom_fea, dim=0),
238 | torch.cat(batch_sym_fea, dim=0),
239 | torch.cat(batch_self_fea_idx, dim=0),
240 | torch.cat(batch_nbr_fea_idx, dim=0),
241 | torch.cat(crystal_atom_idx),
242 | torch.cat(aug_cry_idx),
243 | ),
244 | tuple(torch.stack(b_target, dim=0) for b_target in zip(*batch_targets)),
245 | *zip(*batch_cry_ids),
246 | )
247 |
248 |
249 | def parse_wren(swyk_list):
250 | """parse the wyckoff format
251 |
252 | Args:
253 | swyk_list ([type]): [description]
254 |
255 | Returns:
256 | mult_list, ele_list, aug_wyks
257 | """
258 | swyk_list = ast.literal_eval(swyk_list)
259 |
260 | mult_list = []
261 | ele_list = []
262 | wyk_list = []
263 |
264 | spg_no = swyk_list[0].split("-")[-1]
265 |
266 | for swyk in swyk_list:
267 | ele_mult, wyk = swyk.split(" @ ")
268 | ele, mult = ele_mult.split("-")
269 | let, _ = wyk.split("-")
270 |
271 | # ele, wyk = swyk.split(" @ ")
272 | # mult, let, _ = wyk.split("-")
273 |
274 | mult_list.append(float(mult))
275 | ele_list.append(ele)
276 | wyk_list.append(let)
277 |
278 | aug_wyks = []
279 | for trans in relab_dict[spg_no]:
280 | t = str.maketrans(trans)
281 | aug_wyks.append(tuple(",".join(wyk_list).translate(t).split(",")))
282 |
283 | aug_wyks = list(set(aug_wyks))
284 | # print(len(aug_wyks))
285 | # print(aug_wyks)
286 | # exit()
287 |
288 | return spg_no, mult_list, ele_list, aug_wyks
289 |
290 |
291 | def parse_aflow(aflow_label):
292 | """parse the wyckoff format
293 |
294 | Args:
295 | swyk_list ([type]): [description]
296 | relab_dict ([type]): [description]
297 |
298 | Returns:
299 | mult_list, ele_list, aug_wyks
300 | """
301 | proto, chemsys = aflow_label.split(":")
302 | elems = chemsys.split("-")
303 | _, _, spg_no, *wyks = proto.split("_")
304 |
305 | mult_list = []
306 | ele_list = []
307 | wyk_list = []
308 |
309 | subst = r"1\g<1>"
310 | for el, wyk in zip(elems, wyks):
311 | wyk = re.sub(r"((? str:
58 | """get aflow prototype label for pymatgen structure
59 |
60 | args:
61 | struct (Structure): pymatgen structure object
62 |
63 | returns:
64 | aflow prototype labels
65 | """
66 |
67 | poscar = Poscar(struct)
68 |
69 | cmd = f"{aflow_executable} --prototype --print=json cat"
70 |
71 | output = subprocess.run(
72 | cmd, input=poscar.get_string(), text=True, capture_output=True, shell=True
73 | )
74 |
75 | aflow_proto = json.loads(output.stdout)
76 |
77 | aflow_label = aflow_proto["aflow_label"]
78 |
79 | aflow_label = aflow_label.replace(
80 | "alpha", "A"
81 | ) # to be consistent with spglib and wren embeddings
82 |
83 | # check that multiplicities satisfy original composition
84 | symm = aflow_label.split("_")
85 | spg_no = symm[2]
86 | wyks = symm[3:]
87 | elems = poscar.site_symbols
88 | elem_dict = {}
89 | subst = r"1\g<1>"
90 | for el, wyk in zip(elems, wyks):
91 | wyk = re.sub(r"((? str:
108 | """get aflow prototype label for pymatgen structure
109 |
110 | args:
111 | struct (Structure): pymatgen structure object
112 |
113 | returns:
114 | aflow prototype labels
115 | """
116 | spga = SpacegroupAnalyzer(struct, symprec=0.1, angle_tolerance=5)
117 | aflow = get_aflow_label_from_spga(spga)
118 |
119 | # try again with refined structure if it initially fails
120 | # NOTE structures with magmoms fail unless all have same magmom
121 | if "Invalid" in aflow:
122 | spga = SpacegroupAnalyzer(
123 | spga.get_refined_structure(), symprec=1e-5, angle_tolerance=-1
124 | )
125 | aflow = get_aflow_label_from_spga(spga)
126 |
127 | return aflow
128 |
129 |
130 | def get_aflow_label_from_spga(spga):
131 | spg_no = spga.get_space_group_number()
132 | sym_struct = spga.get_symmetrized_structure()
133 |
134 | equivs = [
135 | (len(s), s[0].species_string, f"{wyk.translate(remove_digits)}")
136 | for s, wyk in zip(sym_struct.equivalent_sites, sym_struct.wyckoff_symbols)
137 | ]
138 | equivs = sorted(equivs, key=lambda x: (x[1], x[2]))
139 |
140 | # check that multiplicities satisfy original composition
141 | elem_dict = {}
142 | elem_wyks = []
143 | for el, g in groupby(equivs, key=lambda x: x[1]): # sort alphabetically by element
144 | g = list(g)
145 | elem_dict[el] = sum(float(mult_dict[str(spg_no)][e[2]]) for e in g)
146 | wyks = ""
147 | for wyk, w in groupby(
148 | g, key=lambda x: x[2]
149 | ): # sort alphabetically by wyckoff letter
150 | w = list(w)
151 | wyks += f"{len(w)}{wyk}"
152 | elem_wyks.append(wyks)
153 |
154 | # cannonicalise the possible wyckoff letter sequences
155 | elem_wyks = "_".join(elem_wyks)
156 | canonical = canonicalise_elem_wyks(elem_wyks, spg_no)
157 |
158 | # get pearson symbol
159 | cry_sys = spga.get_crystal_system()
160 | spg_sym = spga.get_space_group_symbol()
161 | centering = "C" if spg_sym[0] in ("A", "B", "C", "S") else spg_sym[0]
162 | n_conv = len(spga._space_group_data["std_types"])
163 | pearson = f"{cry_sys_dict[cry_sys]}{centering}{n_conv}"
164 |
165 | prototype_form = prototype_formula(spga._structure.composition)
166 |
167 | aflow_label = (
168 | f"{prototype_form}_{pearson}_{spg_no}_{canonical}:"
169 | f"{spga._structure.composition.chemical_system}"
170 | )
171 |
172 | eqi_comp = Composition(elem_dict)
173 | if not eqi_comp.reduced_formula == spga._structure.composition.reduced_formula:
174 | return f"Invalid WP Multiplicities - {aflow_label}"
175 |
176 | return aflow_label
177 |
178 |
179 | def canonicalise_elem_wyks(elem_wyks, spg_no):
180 | """
181 | Given an element ordering canonicalise the associated wyckoff positions
182 | based on the alphabetical weight of equivalent choices of origin.
183 | """
184 |
185 | isopointial = []
186 |
187 | for trans in relab_dict[str(spg_no)]:
188 | t = str.maketrans(trans)
189 | isopointial.append(elem_wyks.translate(t))
190 |
191 | isopointial = list(set(isopointial))
192 |
193 | scores = []
194 | sorted_iso = []
195 | for wyks in isopointial:
196 | score = 0
197 | sorted_el_wyks = []
198 | for el_wyks in wyks.split("_"):
199 | sep_el_wyks = ["".join(g) for _, g in groupby(el_wyks, str.isalpha)]
200 | sep_el_wyks = ["" if i == "1" else i for i in sep_el_wyks]
201 | sorted_el_wyks.append(
202 | "".join(
203 | [
204 | f"{n}{w}"
205 | for n, w in sorted(
206 | zip(sep_el_wyks[0::2], sep_el_wyks[1::2]),
207 | key=lambda x: x[1],
208 | )
209 | ]
210 | )
211 | )
212 | score += sum(0 if el == "A" else ord(el) - 96 for el in sep_el_wyks[1::2])
213 |
214 | scores.append(score)
215 | sorted_iso.append("_".join(sorted_el_wyks))
216 |
217 | canonical = sorted(zip(scores, sorted_iso), key=lambda x: (x[0], x[1]))[0][1]
218 |
219 | return canonical
220 |
221 |
222 | def prototype_formula(composition) -> str:
223 | """
224 | An anonymized formula. Unique species are arranged in alphabetical order
225 | and assigned ascending alphabets. This format is used in the aflow structure
226 | prototype labelling scheme.
227 | """
228 | reduced = composition.element_composition
229 | if all(x == int(x) for x in composition.values()):
230 | reduced /= gcd(*(int(i) for i in composition.values()))
231 |
232 | amts = [amt for _, amt in sorted(reduced.items(), key=lambda x: str(x[0]))]
233 |
234 | anon = ""
235 | for e, amt in zip(ascii_uppercase, amts):
236 | if amt == 1:
237 | amt_str = ""
238 | elif abs(amt % 1) < 1e-8:
239 | amt_str = str(int(amt))
240 | else:
241 | amt_str = str(amt)
242 | anon += f"{e}{amt_str}"
243 | return anon
244 |
245 |
246 | def count_wyks(aflow_label):
247 | num_wyk = 0
248 |
249 | aflow_label, _ = aflow_label.split(":")
250 | wyks = aflow_label.split("_")[3:]
251 |
252 | subst = r"1\g<1>"
253 | for wyk in wyks:
254 | wyk = re.sub(r"((?"
274 | for wyk in wyks:
275 | wyk = re.sub(r"((? 0.75
153 | assert ens_roc_auc > 0.9
154 |
155 |
156 | if __name__ == "__main__":
157 | test_single_roost_clf()
158 |
--------------------------------------------------------------------------------
/tests/test_single_roost_regression.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from sklearn.metrics import r2_score
6 | from sklearn.model_selection import train_test_split as split
7 |
8 | from roost.roost.data import CompositionData, collate_batch
9 | from roost.roost.model import Roost
10 | from roost.utils import results_multitask, train_ensemble
11 |
12 | torch.manual_seed(0) # ensure reproducible results
13 |
14 |
15 | def test_single_roost():
16 |
17 | data_path = "tests/data/roost-regression.csv"
18 | fea_path = "data/el-embeddings/matscholar-embedding.json"
19 | targets = ["Eg"]
20 | tasks = ["regression"]
21 | losses = ["L1"]
22 | robust = True
23 | model_name = "roost"
24 | elem_fea_len = 64
25 | n_graph = 3
26 | ensemble = 2
27 | run_id = 1
28 | data_seed = 42
29 | epochs = 25
30 | log = False
31 | sample = 1
32 | test_size = 0.2
33 | resume = False
34 | fine_tune = None
35 | transfer = None
36 | optim = "AdamW"
37 | learning_rate = 3e-4
38 | momentum = 0.9
39 | weight_decay = 1e-6
40 | batch_size = 128
41 | workers = 0
42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
43 |
44 | task_dict = {k: v for k, v in zip(targets, tasks)}
45 | loss_dict = {k: v for k, v in zip(targets, losses)}
46 |
47 | dataset = CompositionData(data_path=data_path, fea_path=fea_path, task_dict=task_dict)
48 | n_targets = dataset.n_targets
49 | elem_emb_len = dataset.elem_emb_len
50 |
51 | train_idx = list(range(len(dataset)))
52 |
53 | print(f"using {test_size} of training set as test set")
54 | train_idx, test_idx = split(train_idx, random_state=data_seed, test_size=test_size)
55 | test_set = torch.utils.data.Subset(dataset, test_idx)
56 |
57 | print("No validation set used, using test set for evaluation purposes")
58 | # NOTE that when using this option care must be taken not to
59 | # peak at the test-set. The only valid model to use is the one
60 | # obtained after the final epoch where the epoch count is
61 | # decided in advance of the experiment.
62 | val_set = test_set
63 |
64 | train_set = torch.utils.data.Subset(dataset, train_idx[0::sample])
65 |
66 | data_params = {
67 | "batch_size": batch_size,
68 | "num_workers": workers,
69 | "pin_memory": False,
70 | "shuffle": True,
71 | "collate_fn": collate_batch,
72 | }
73 |
74 | setup_params = {
75 | "optim": optim,
76 | "learning_rate": learning_rate,
77 | "weight_decay": weight_decay,
78 | "momentum": momentum,
79 | "device": device,
80 | }
81 |
82 | restart_params = {
83 | "resume": resume,
84 | "fine_tune": fine_tune,
85 | "transfer": transfer,
86 | }
87 |
88 | model_params = {
89 | "task_dict": task_dict,
90 | "robust": robust,
91 | "n_targets": n_targets,
92 | "elem_emb_len": elem_emb_len,
93 | "elem_fea_len": elem_fea_len,
94 | "n_graph": n_graph,
95 | "elem_heads": 3,
96 | "elem_gate": [256],
97 | "elem_msg": [256],
98 | "cry_heads": 3,
99 | "cry_gate": [256],
100 | "cry_msg": [256],
101 | "trunk_hidden": [1024, 512],
102 | "out_hidden": [256, 128, 64],
103 | }
104 |
105 | os.makedirs(f"models/{model_name}", exist_ok=True)
106 | os.makedirs(f"results/{model_name}", exist_ok=True)
107 |
108 | train_ensemble(
109 | model_class=Roost,
110 | model_name=model_name,
111 | run_id=run_id,
112 | ensemble_folds=ensemble,
113 | epochs=epochs,
114 | train_set=train_set,
115 | val_set=val_set,
116 | log=log,
117 | data_params=data_params,
118 | setup_params=setup_params,
119 | restart_params=restart_params,
120 | model_params=model_params,
121 | loss_dict=loss_dict,
122 | )
123 |
124 | data_params["batch_size"] = 64 * batch_size # faster model inference
125 | data_params["shuffle"] = False # need fixed data order due to ensembling
126 |
127 | results_dict = results_multitask(
128 | model_class=Roost,
129 | model_name=model_name,
130 | run_id=run_id,
131 | ensemble_folds=ensemble,
132 | test_set=test_set,
133 | data_params=data_params,
134 | robust=robust,
135 | task_dict=task_dict,
136 | device=device,
137 | eval_type="checkpoint",
138 | )
139 |
140 | pred = results_dict["Eg"]["pred"]
141 | target = results_dict["Eg"]["target"]
142 |
143 | y_ens = np.mean(pred, axis=0)
144 |
145 | mae = np.abs(target - y_ens).mean()
146 | mse = np.square(target - y_ens).mean()
147 | rmse = np.sqrt(mse)
148 | r2 = r2_score(target, y_ens)
149 |
150 | assert r2 > 0.7
151 | assert mae < 0.6
152 | assert rmse < 0.8
153 |
154 |
155 | if __name__ == "__main__":
156 | test_single_roost()
157 |
--------------------------------------------------------------------------------