├── .gitignore
├── LICENSE
├── README.md
├── configs
├── example
│ ├── example_cvae_trainer.yaml
│ ├── example_extract_reconstruct.yaml
│ ├── example_ff_trainer.yaml
│ └── example_generator.yaml
└── unittest
│ └── unittest_behler.yaml
├── data
├── raw
│ └── data.db
├── representation
│ ├── data_x_fps_3.csv
│ ├── data_y_fps_3.csv
│ ├── x_test.csv
│ └── y_test.csv
├── saved_models
│ ├── model_FF_saved.pt
│ ├── model_saved.pt
│ └── scaler.gz
└── unittest
│ └── dummy.txt
├── examples
├── example_cvae_trainer.py
├── example_extract.py
├── example_generator.py
├── example_reconstruction.py
└── example_surrogate_trainer.py
├── misc.py
├── outputs
├── generated_samples
│ ├── gen_samps_cvae_1.9.csv
│ ├── gen_samps_cvae_2.2.csv
│ ├── gen_samps_cvae_2.5.csv
│ ├── gen_samps_cvae_2.8.csv
│ ├── gen_samps_cvae_3.1.csv
│ ├── gen_samps_cvae_3.4.csv
│ ├── gen_samps_cvae_3.7.csv
│ ├── gen_samps_x_1.9.csv
│ ├── gen_samps_x_2.2.csv
│ ├── gen_samps_x_2.5.csv
│ ├── gen_samps_x_2.8.csv
│ ├── gen_samps_x_3.1.csv
│ ├── gen_samps_x_3.4.csv
│ └── gen_samps_x_3.7.csv
└── reconstructed
│ ├── 0_reconstructed.cif
│ ├── 1_reconstructed.cif
│ ├── 2_reconstructed.cif
│ └── 3_reconstructed.cif
├── quickrun.sh
├── requirements.txt
├── setup.py
├── structrepgen
├── __init__.py
├── descriptors
│ ├── __init__.py
│ ├── behler.py
│ └── generic.py
├── extraction
│ ├── __init__.py
│ └── representation_extraction.py
├── generators
│ ├── __init__.py
│ └── generator.py
├── models
│ ├── CVAE.py
│ ├── __init__.py
│ └── models.py
├── reconstruction
│ ├── __init__.py
│ ├── generic.py
│ └── reconstruction.py
└── utils
│ ├── __init__.py
│ ├── dotdict.py
│ └── utils.py
└── tests
├── original
├── db_to_data.py
└── original_descriptor.py
└── test_behler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # mac
132 | .DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Fung Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Structure Representation Generation
2 |
3 |
4 |
5 | **Atomic Structure Generation from Reconstructing Structural Fingerprints**
6 |
7 | [[arXiv](https://arxiv.org/abs/2207.13227)]
8 |
9 |
10 |
11 | ## Requirements
12 |
13 | Git clone this repo and in the main directory do
14 |
15 | ```python
16 | pip install -r requirements.txt
17 | pip install -e .
18 | ```
19 |
20 | The package has been tested on
21 |
22 | - Python 3.6.10 / 3.9.12
23 | - PyTorch 1.10.2 / 1.12.0
24 | - Torch.cuda 11.3
25 |
26 | ## How to Use
27 |
28 | Most configurations of the system/code are done through `.yaml` files under `configs/`.
29 |
30 | To load a configuration:
31 |
32 | ```python
33 | # load our config yaml file
34 | stream = open('./configs/example/example_extract_reconstruct.yaml')
35 | CONFIG = yaml.safe_load(stream)
36 | stream.close()
37 | # dotdict for dot operations on Python dict
38 | # e.g., CONFIG.cutoff == CONFIG['cutoff']
39 | CONFIG = dotdict(CONFIG)
40 | ```
41 |
42 | ### Representation Extraction
43 |
44 | Extract representations for training generative model using selected descriptor.
45 |
46 | See [`examples/example_extract.py`](https://github.com/Fung-Lab/StructRepGen/blob/main/examples/example_extract.py)
47 |
48 | ### Generative Model
49 |
50 | Train a CVAE (conditional variational auto-encoder) as the generative model.
51 |
52 | See [`examples/example_cvae_trainer.py`](https://github.com/Fung-Lab/StructRepGen/blob/main/examples/example_cvae_trainer.py).
53 |
54 | ### Generation
55 |
56 | Generate representation from a given target value using the decoder part of the CVAE.
57 |
58 | See [`examples/example_generator.py`](https://github.com/Fung-Lab/StructRepGen/blob/main/examples/example_generator.py).
59 |
60 | ### Reconstruction
61 |
62 | Generate atomic structures from generated representation.
63 |
64 | See [`examples/example_reconstruction.py`](https://github.com/Fung-Lab/StructRepGen/blob/main/examples/example_reconstruction.py).
65 |
66 | A script to run all of the above examples is provided in `quickrun.sh`. To execute, do
67 |
68 | ```bash
69 | chmod u+x quickrun.sh
70 | ./quickrun.sh
71 | ```
72 | ______________________________________________________________________
73 |
74 | ## Citation
75 |
76 | ```bash
77 | @article{fung2022atomic,
78 | title={Atomic structure generation from reconstructing structural fingerprints},
79 | author={Fung, Victor and Jia, Shuyi and Zhang, Jiaxin and Bi, Sirui and Yin, Junqi and Ganesh, P},
80 | journal={arXiv preprint arXiv:2207.13227},
81 | year={2022}
82 | }
83 | ```
--------------------------------------------------------------------------------
/configs/example/example_cvae_trainer.yaml:
--------------------------------------------------------------------------------
1 | gpu: False
2 | params:
3 | seed: 42
4 | split_ratio: 0.2
5 | input_dim: 380
6 | hidden_dim: 128
7 | latent_dim: 10
8 | hidden_layers: 3
9 | y_dim: 1
10 | batch_size: 64
11 | n_epochs: 1000
12 | lr: 0.001
13 | final_decay: 0.2
14 | weight_decay: 0.001
15 | verbosity: 50
16 | kl_weight: 2.5
17 | mc_kl_loss: False
18 | data_x_path: "./data/representation/data_x_fps_3.csv"
19 | data_y_path: "./data/representation/data_y_fps_3.csv"
20 | model_path: "./data/saved_models/model_saved.pt"
21 | scaler_path: "./data/saved_models/scaler.gz"
--------------------------------------------------------------------------------
/configs/example/example_extract_reconstruct.yaml:
--------------------------------------------------------------------------------
1 | ########################################################################
2 | # Parameters and configurations for both extraction and reconstruction #
3 | ########################################################################
4 |
5 | # PyTorch device
6 | gpu: False
7 |
8 | # chemical system
9 | cell: [[20.,0.,0.],[0.,20.,0.],[0.,0.,20.]]
10 | atoms: [78, 78, 78, 78, 78, 78, 78, 78, 78, 78]
11 |
12 | # descriptor configuration
13 | descriptor: "behler"
14 | cutoff: 20.0
15 | mode: "features"
16 | g5: True
17 | average: True
18 | g2_params:
19 | eta: [0.01, 0.06, 0.1, 0.2, 0.4, 0.7, 1.0, 2.0, 3.5, 5.0]
20 | Rs: [0, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10]
21 | g5_params:
22 | lambdas: [-1, 1]
23 | zeta: [1, 2, 4, 16, 64]
24 | eta: [0.06, 0.1, 0.2, 0.4, 1.0]
25 | Rs: [0]
26 |
27 | # raw data file path
28 | data: "./data/raw/data.db"
29 |
30 | # extracted representation save file path
31 | x_fname: "./data/representation/x_test.csv"
32 | y_fname: "./data/representation/y_test.csv"
33 |
34 | # ff_net surrogate model path
35 | ff_model_path: "./data/saved_models/model_FF_saved.pt"
36 |
37 | # generated sample file path
38 | structure_file_path: "./outputs/generated_samples/gen_samps_x_2.2.csv"
39 |
40 | # reconstruction save file path
41 | reconstructed_file_path: "./outputs/reconstructed/"
--------------------------------------------------------------------------------
/configs/example/example_ff_trainer.yaml:
--------------------------------------------------------------------------------
1 | gpu: False
2 | params:
3 | seed: 42
4 | split_ratio: 0.2
5 | input_dim: 380
6 | hidden_dim: 128
7 | latent_dim: 10
8 | hidden_layers: 3
9 | y_dim: 1
10 | batch_size: 64
11 | n_epochs: 1000
12 | lr: 0.001
13 | final_decay: 0.2
14 | weight_decay: 0.001
15 | verbosity: 50
16 | kl_weight: 2.5
17 | mc_kl_loss: False
18 | data_x_path: "./data/representation/data_x_fps_3.csv"
19 | data_y_path: "./data/representation/data_y_fps_3.csv"
20 | model_path: "./data/saved_models/model_saved.pt"
21 | scaler_path: "./data/saved_models/scaler.gz"
22 | ff_path: "./data/saved_models/model_FF_saved.pt"
--------------------------------------------------------------------------------
/configs/example/example_generator.yaml:
--------------------------------------------------------------------------------
1 | gpu: False
2 | params:
3 | input_dim: 380
4 | latent_dim: 10
5 | num_z: 1000
6 | targets: [1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7]
7 | delta: 0.1
8 | model_path: "./data/saved_models/model_saved.pt"
9 | scaler_path: "./data/saved_models/scaler.gz"
10 | ff_model_path: "./data/saved_models/model_FF_saved.pt"
11 | save_path: "./outputs/generated_samples/"
--------------------------------------------------------------------------------
/configs/unittest/unittest_behler.yaml:
--------------------------------------------------------------------------------
1 | descriptor: "behler"
2 | gpu: False
3 | cutoff: 20.0
4 | mode: "features"
5 | g5: True
6 | average: True
7 | g2_params:
8 | eta: [0.01, 0.06, 0.1, 0.2, 0.4, 0.7, 1.0, 2.0, 3.5, 5.0]
9 | Rs: [0, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10]
10 | g5_params:
11 | lambdas: [-1, 1]
12 | zeta: [1, 2, 4, 16, 64]
13 | eta: [0.06, 0.1, 0.2, 0.4, 1.0]
14 | Rs: [0]
15 | data: "./data/raw/data.db"
16 | x_fname: "./data/unittest/unittest_behler_x.csv"
17 | y_fname: "./data/unittest/unittest_behler_y.csv"
--------------------------------------------------------------------------------
/data/raw/data.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/raw/data.db
--------------------------------------------------------------------------------
/data/representation/data_y_fps_3.csv:
--------------------------------------------------------------------------------
1 | 3.629251785500000160e+00
2 | 2.193435603500000219e+00
3 | 2.389190833499999833e+00
4 | 2.590185373500000221e+00
5 | 3.469791547500000295e+00
6 | 3.639065514499999932e+00
7 | 3.031497834500000099e+00
8 | 3.278771844500000032e+00
9 | 3.400559700500000115e+00
10 | 2.569046109499999453e+00
11 | 2.572176612499999404e+00
12 | 1.977554506499999531e+00
13 | 2.726371039500000037e+00
14 | 3.378239648500000136e+00
15 | 1.969304244499999523e+00
16 | 2.482163833500000027e+00
17 | 3.469053469499999487e+00
18 | 2.161180478499999502e+00
19 | 2.427411639499999829e+00
20 | 3.511905419500000125e+00
21 | 2.346946486499999818e+00
22 | 2.863972432500000220e+00
23 | 2.195669181500000011e+00
24 | 3.263429192499999854e+00
25 | 3.634469329499999901e+00
26 | 2.869667623499999820e+00
27 | 3.505030466500000053e+00
28 | 3.201020360500000272e+00
29 | 3.223487879499999931e+00
30 | 2.754639036499999971e+00
31 | 3.446267526499999789e+00
32 | 3.330566205499999821e+00
33 | 3.447257610500000347e+00
34 | 2.853754735499999917e+00
35 | 2.458017109500000075e+00
36 | 3.407018959500000221e+00
37 | 2.034932231500000022e+00
38 | 2.938345457500000091e+00
39 | 3.432077122499999966e+00
40 | 3.260302140499999890e+00
41 | 2.664259684500000169e+00
42 | 2.845696269500000319e+00
43 | 3.068155024500000216e+00
44 | 3.001935866500000216e+00
45 | 1.939045304500000011e+00
46 | 2.059636720499999907e+00
47 | 2.773203739499999987e+00
48 | 2.077227622500000148e+00
49 | 2.370573837499999392e+00
50 | 2.916555953500000076e+00
51 | 3.327417127499999960e+00
52 | 2.841348202499999864e+00
53 | 3.011263771499999908e+00
54 | 2.900439509499999957e+00
55 | 3.488860160499999807e+00
56 | 1.978455072500000078e+00
57 | 2.740035258499999848e+00
58 | 2.553865472500000067e+00
59 | 3.658810885499999888e+00
60 | 2.449129288500000001e+00
61 | 3.219105184500000050e+00
62 | 2.891397128499999969e+00
63 | 3.598062906500000047e+00
64 | 2.725029893499999467e+00
65 | 1.908753786499999938e+00
66 | 2.170524436500000043e+00
67 | 3.667390165500000077e+00
68 | 3.153143658499999891e+00
69 | 3.381735946500000090e+00
70 | 1.981137564500000003e+00
71 | 2.292985748500000032e+00
72 | 2.857798630500000048e+00
73 | 3.370462267500000220e+00
74 | 3.482201046500000174e+00
75 | 3.188101787499999951e+00
76 | 3.583713149499999417e+00
77 | 3.568160926500000052e+00
78 | 3.579448063500000110e+00
79 | 2.934886240500000021e+00
80 | 3.577668529499999472e+00
81 | 3.521518180500000206e+00
82 | 2.892856812499999819e+00
83 | 3.379746976500000333e+00
84 | 3.548381890499999969e+00
85 | 2.467273743499999394e+00
86 | 3.540359935500000166e+00
87 | 2.768394223500000084e+00
88 | 3.678645528500000150e+00
89 | 2.943328566499999965e+00
90 | 3.560178632499999996e+00
91 | 3.078999631499999889e+00
92 | 3.315472485499999955e+00
93 | 2.514673767500000157e+00
94 | 2.053671849500000146e+00
95 | 3.397123429500000125e+00
96 | 1.927875391499999536e+00
97 | 3.254700554499999843e+00
98 | 2.855225439499999851e+00
99 | 3.289567590499999916e+00
100 | 3.080570113499999874e+00
101 | 3.149139033500000018e+00
102 | 3.095625262500000030e+00
103 | 3.606596311500000152e+00
104 | 3.309686274500000192e+00
105 | 3.532479607500000007e+00
106 | 3.335945054499999785e+00
107 | 2.680776439500000219e+00
108 | 3.413721247499999834e+00
109 | 3.230590477500000279e+00
110 | 3.081152116500000204e+00
111 | 2.475544047500000122e+00
112 | 3.687523089499999962e+00
113 | 3.425417975499999379e+00
114 | 3.668337117500000133e+00
115 | 2.644321825500000056e+00
116 | 2.321161881500000135e+00
117 | 2.670165391500000318e+00
118 | 2.509617504499999541e+00
119 | 3.429835049499999400e+00
120 | 2.686018230499999770e+00
121 | 2.274859405499999987e+00
122 | 2.875813564500000030e+00
123 | 2.345672777499999917e+00
124 | 3.149382091500000147e+00
125 | 3.563936150500000011e+00
126 | 3.606921896499999836e+00
127 | 2.239281146500000208e+00
128 | 2.885017470499999792e+00
129 | 2.843746938499999821e+00
130 | 2.151624935500000113e+00
131 | 3.695640772500000004e+00
132 | 2.604716774499999499e+00
133 | 3.533703536500000020e+00
134 | 3.618628567499999171e+00
135 | 2.269472805499999968e+00
136 | 3.094148476500000022e+00
137 | 3.344097396500000041e+00
138 | 3.370727902500000095e+00
139 | 3.695081813499999868e+00
140 | 2.247777477500000121e+00
141 | 2.612466452500000091e+00
142 | 2.596694727500000077e+00
143 | 2.866961518500000139e+00
144 | 3.140304035499999813e+00
145 | 2.724920338500000039e+00
146 | 3.096131675500000124e+00
147 | 2.899930516500000000e+00
148 | 3.690274181499999973e+00
149 | 3.036688120500000032e+00
150 | 3.590003534500000093e+00
151 | 1.916972713499999648e+00
152 | 3.694017793500000035e+00
153 | 3.278344651499999429e+00
154 | 3.311863587499999539e+00
155 | 3.440875531499999695e+00
156 | 2.298078631500000135e+00
157 | 2.154455966500000041e+00
158 | 3.668343183499999327e+00
159 | 3.189560615500000029e+00
160 | 2.954388168499999967e+00
161 | 3.158917279499999786e+00
162 | 3.049848474500000073e+00
163 | 3.121720442500000026e+00
164 | 3.559216489499999803e+00
165 | 2.394893494500000219e+00
166 | 2.870004901500000205e+00
167 | 3.222248899499999819e+00
168 | 3.021361172499999803e+00
169 | 2.997914102499999789e+00
170 | 2.152280122500000115e+00
171 | 3.691511492500000102e+00
172 | 3.534808753500000122e+00
173 | 2.441883251500000185e+00
174 | 3.158438142500000101e+00
175 | 2.714237651499999515e+00
176 | 2.361495079500000038e+00
177 | 3.492186637499999247e+00
178 | 2.444909780500000185e+00
179 | 2.245195300499999824e+00
180 | 3.564829119500000143e+00
181 | 3.422282724500000040e+00
182 | 2.413833691499999823e+00
183 | 3.553694700500000359e+00
184 | 3.133653180500000079e+00
185 | 2.688062402499999504e+00
186 | 3.029967856499999890e+00
187 | 3.699558817499999819e+00
188 | 3.645472044499999953e+00
189 | 2.924037525500000179e+00
190 | 2.707218354500000146e+00
191 | 3.467364557499999389e+00
192 | 3.181638873499999853e+00
193 | 1.991422050499999541e+00
194 | 3.667853380499999982e+00
195 | 3.095820271500000054e+00
196 | 3.568928516500000203e+00
197 | 3.188936304499999874e+00
198 | 3.092066244499999783e+00
199 | 3.420078662499999922e+00
200 | 2.782251731500000158e+00
201 | 3.340617010499999928e+00
202 | 3.609050915500000123e+00
203 | 3.643656063499999931e+00
204 | 3.336802336499999910e+00
205 | 3.369463562500000009e+00
206 | 3.638709875500000024e+00
207 | 2.737987797500000209e+00
208 | 3.662035683500000083e+00
209 | 2.442745851499999787e+00
210 | 2.794192333499999847e+00
211 | 2.130696768500000005e+00
212 | 3.035902088499999874e+00
213 | 2.206589475499999509e+00
214 | 3.474328342500000222e+00
215 | 2.627364993500000079e+00
216 | 3.423192286499999959e+00
217 | 3.219988001499999974e+00
218 | 3.267569387499999145e+00
219 | 2.679220242500000015e+00
220 | 3.675037752500000199e+00
221 | 3.459667313499999786e+00
222 | 2.230246310500000106e+00
223 | 2.850722157499999465e+00
224 | 2.407494965500000195e+00
225 | 2.920090459499999902e+00
226 | 2.943086850499999851e+00
227 | 3.323967807499999871e+00
228 | 3.363145897499999926e+00
229 | 2.684926472499999939e+00
230 | 2.657415704499999975e+00
231 | 3.622536860499999900e+00
232 | 3.540368124500000047e+00
233 | 3.589065706500000008e+00
234 | 2.596920302500000055e+00
235 | 2.730771360499999467e+00
236 | 2.734785122499999943e+00
237 | 3.377389813499999782e+00
238 | 3.045484734500000013e+00
239 | 2.671878994500000104e+00
240 | 2.050986682500000047e+00
241 | 2.817160023500000054e+00
242 | 2.723591580499999942e+00
243 | 3.557056275499999920e+00
244 | 2.685716286499999939e+00
245 | 3.568842027500000125e+00
246 | 2.932958204500000221e+00
247 | 3.069034059499999856e+00
248 | 3.514755606500000518e+00
249 | 3.262418033499999925e+00
250 | 3.652738589499999300e+00
251 | 3.119569415499999998e+00
252 | 3.452809048499999811e+00
253 | 2.740936844500000191e+00
254 | 2.451905307500000131e+00
255 | 3.313156561500000041e+00
256 | 3.511915449500000008e+00
257 | 2.700493722500000082e+00
258 | 3.348573206499999788e+00
259 | 3.468992130500000215e+00
260 | 3.476668341500000370e+00
261 | 3.567517962500000195e+00
262 | 3.528156868499999543e+00
263 | 3.183894331499999897e+00
264 | 3.392475934500000179e+00
265 | 3.289337511499999867e+00
266 | 2.498368390500000036e+00
267 | 2.767821931499999888e+00
268 | 3.603900129499999938e+00
269 | 3.590931797499999245e+00
270 | 3.045904796499999900e+00
271 | 3.689263541499999910e+00
272 | 3.146069863500000174e+00
273 | 3.100778360500000108e+00
274 | 3.672868856500000057e+00
275 | 3.624809072500000173e+00
276 | 3.689044692500000000e+00
277 | 2.818979973500000291e+00
278 | 3.092128980500000068e+00
279 | 3.274181661500000118e+00
280 | 2.443876192500000322e+00
281 | 3.035938488499999810e+00
282 | 2.764125669499999383e+00
283 | 3.678849937499999889e+00
284 | 2.334595027499999809e+00
285 | 3.328708817499999473e+00
286 | 2.041349412499999794e+00
287 | 2.593014625499999504e+00
288 | 3.466235386500000182e+00
289 | 3.377096930500000038e+00
290 | 3.163263784499999787e+00
291 | 3.148148751499999953e+00
292 | 2.798697951500000336e+00
293 | 2.881607269499999902e+00
294 | 2.075533408500000121e+00
295 | 3.648939681499999921e+00
296 | 3.516611950500000194e+00
297 | 3.361316212499999789e+00
298 | 3.014654499500000195e+00
299 | 3.411344701500000021e+00
300 | 2.914182840500000093e+00
301 |
--------------------------------------------------------------------------------
/data/representation/x_test.csv:
--------------------------------------------------------------------------------
1 | 7.6347027,8.037969,8.186526,8.296676,8.366806,8.395889,8.383496,8.329815,8.102355,7.7313023,7.3277287,6.86488,6.306292,5.726998,4.776494,6.401983,7.101475,7.6568522,8.025179,8.177042,8.100437,7.809145,7.3074493,6.369568,5.276666,4.031698,2.7925496,1.7492355,3.390356,5.433555,6.4142046,7.230476,7.786035,8.012206,7.8819127,7.482154,6.9597116,5.6499586,4.4195595,2.9891787,1.7316858,0.8515215,1.5577528,3.7955854,5.158061,6.4083543,7.2968764,7.635884,7.3648577,6.7645187,6.361255,4.677028,3.2875094,1.7794862,0.7073132,0.19914488,0.38233596,2.0855284,3.6734178,5.387841,6.648571,7.0102725,6.4451065,5.6690145,5.7577953,3.6979256,2.4004283,0.92405725,0.17827508,0.01618633,0.052093655,0.9478716,2.4330766,4.526375,6.13364,6.301463,5.3087854,4.5954037,5.3286743,2.934912,1.9198909,0.44170582,0.027728438,0.00045367455,0.0073398,0.44834715,1.6864524,3.9551892,5.8717976,5.7726793,4.440757,3.9239402,5.032765,2.4922295,1.6833789,0.22568327,0.0046401196,1.39647345e-05,1.366814e-05,0.0401704,0.53390944,2.7223392,5.5251403,4.636115,2.8632991,2.8454342,4.2715855,1.7808043,1.2348522,0.028469397,1.5511014e-05,1.7215465e-10,1.7988977e-09,0.0014022979,0.10472332,1.6557811,5.301574,3.6102402,1.9564735,2.1535573,3.6110535,1.2962196,0.85115135,0.0016333024,3.967595e-09,9.399013e-18,2.55697e-13,6.2094536e-05,0.024239682,1.0463419,5.1334095,2.9081228,1.6165028,1.7476071,3.4143996,1.0256757,0.6231437,0.00010555335,1.0957327e-12,5.3302025e-25,9.048065,4.6781635,1.0532091,0.06843075,2.6120613e-05,11.02688,5.3850875,1.065321,0.05888957,1.845249e-05,5.834724,3.0927684,0.7209282,0.048420068,1.9752504e-05,7.423365,3.5952344,0.6997687,0.038108453,1.2066323e-05,3.4506154,1.8724595,0.4534073,0.031738166,1.4180921e-05,4.2822556,1.9768807,0.370257,0.019514933,6.053242e-06,1.1266446,0.65117013,0.17467695,0.01370768,7.650941e-06,1.4605906,0.43000588,0.028385328,0.0008188701,1.8175793e-07,0.5204019,0.31753644,0.09240823,0.007836901,4.8e-06,0.48970345,0.14299357,0.00666044,1.904607e-05,2.0390837e-12,6.160422,6.6661687,6.884791,7.0766244,7.239064,7.369865,7.467193,7.529668,7.5469923,7.4208264,7.158287,6.622072,5.9512777,5.2437177,2.516549,3.700965,4.3307323,4.9500437,5.5276737,6.0315166,6.4313974,6.7019396,6.6555557,5.0780435,3.4668381,2.1184268,1.1587516,0.5673503,1.5057518,2.6438906,3.3143363,4.0107822,4.690665,5.3066983,5.8117385,6.1640244,5.7459326,3.73569,2.0396943,0.93513787,0.3595433,0.11567887,0.60331887,1.5165292,2.144766,2.8380828,3.5416698,4.1985183,4.7550955,5.1627855,4.151659,1.9443294,0.6678895,0.16489673,0.028637297,0.0034390688,0.15375443,0.7934002,1.3708824,2.0203269,2.6254656,3.144864,3.607701,4.0079374,2.582511,0.77285445,0.12180669,0.009218421,0.00032254448,5.1418156e-06,0.022954568,0.38450772,0.9380872,1.6432308,2.1547751,2.3755608,2.654021,3.1652706,1.7264359,0.2890365,0.013413128,0.00015911311,4.721615e-07,3.477529e-10,0.003478007,0.19299978,0.6826151,1.4758364,1.9996482,1.9062647,2.0228968,2.6757653,1.3618618,0.12095639,0.0015862945,2.8853976e-06,7.16319e-10,2.4157986e-14,6.1845753e-06,0.01988394,0.2447432,1.1207368,1.9127657,1.1005502,0.7205751,1.2920241,0.7925133,0.007606483,1.4006755e-06,4.7752136e-12,3.0039723e-19,3.4845793e-28,3.6193123e-10,0.00061199337,0.053782914,0.7538048,1.9004161,0.6558406,0.17283502,0.5879705,0.40321013,0.00012828893,3.83648e-11,1.06804914e-20,2.7603e-33,0.0,2.1907079e-14,1.669212e-05,0.011074061,0.5105281,1.8898233,0.41115925,0.043087196,0.3161983,0.21218473,2.19662e-06,1.0794713e-15,2.49584e-29,0.0,0.0,0.7785383,0.28510672,0.044930883,0.0028594765,1.3748747e-06,4.4133234,1.4070488,0.17132547,0.009021766,4.237469e-06,0.18193093,0.06960676,0.011042307,0.0006995932,3.3681093e-07,3.8167162,1.1915488,0.13743687,0.0060121035,3.199405e-06,0.019804496,0.007352831,0.00088740233,4.3055396e-05,2.0213296e-08,2.9670405,0.89224577,0.09169072,0.0024531926,1.1060774e-06,2.2397329e-07,3.9957833e-08,1.0153539e-09,2.955518e-11,9.588455e-16,0.61931694,0.2371586,0.018524287,0.00013715237,2.043432e-09,1.7251224e-24,2.8203624e-25,3.0480218e-27,3.5601168e-31,1.64047e-40,0.012272776,0.004336454,0.00032502704,1.4801496e-06,2.8672374e-13
2 |
--------------------------------------------------------------------------------
/data/representation/y_test.csv:
--------------------------------------------------------------------------------
1 | 1.9583510175
2 |
--------------------------------------------------------------------------------
/data/saved_models/model_FF_saved.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/saved_models/model_FF_saved.pt
--------------------------------------------------------------------------------
/data/saved_models/model_saved.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/saved_models/model_saved.pt
--------------------------------------------------------------------------------
/data/saved_models/scaler.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/saved_models/scaler.gz
--------------------------------------------------------------------------------
/data/unittest/dummy.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/unittest/dummy.txt
--------------------------------------------------------------------------------
/examples/example_cvae_trainer.py:
--------------------------------------------------------------------------------
1 | import torch, yaml
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | from torch.utils.data import DataLoader, TensorDataset
6 | from sklearn.preprocessing import MinMaxScaler
7 | import joblib, time, copy
8 | import pandas as pd
9 | import numpy as np
10 | from sklearn.model_selection import train_test_split
11 |
12 | from structrepgen.models.CVAE import *
13 | from structrepgen.utils.dotdict import dotdict
14 | from structrepgen.utils.utils import torch_device_select
15 |
16 | '''
17 | Example of training a CVAE based on the representation R extracted using Behler descriptors
18 | '''
19 |
20 | class Trainer():
21 | def __init__(self, CONFIG) -> None:
22 | self.CONFIG = CONFIG
23 |
24 | # check GPU availability & set device
25 | self.device = torch_device_select(self.CONFIG.gpu)
26 |
27 | # initialize
28 | self.create_data()
29 | self.initialize()
30 |
31 | def create_data(self):
32 | p = self.CONFIG.params
33 |
34 | data_x = pd.read_csv(self.CONFIG.data_x_path, header=None).values
35 | data_y = pd.read_csv(self.CONFIG.data_y_path, header=None).values
36 |
37 | # scale
38 | scaler = MinMaxScaler()
39 | data_x = scaler.fit_transform(data_x)
40 | joblib.dump(scaler, self.CONFIG.scaler_path)
41 |
42 | # train/test split and create torch dataloader
43 | xtrain, xtest, ytrain, ytest = train_test_split(data_x, data_y, test_size=self.CONFIG.split_ratio, random_state=p.seed)
44 | self.x_train = torch.tensor(xtrain, dtype=torch.float)
45 | self.y_train = torch.tensor(ytrain, dtype=torch.float)
46 | self.x_test = torch.tensor(xtest, dtype=torch.float)
47 | self.y_test = torch.tensor(ytest, dtype=torch.float)
48 |
49 | self.train_loader = DataLoader(
50 | TensorDataset(self.x_train, self.y_train),
51 | batch_size=p.batch_size, shuffle=True, drop_last=False
52 | )
53 |
54 | self.test_loader = DataLoader(
55 | TensorDataset(self.x_test, self.y_test),
56 | batch_size=p.batch_size, shuffle=False, drop_last=False
57 | )
58 |
59 | def initialize(self):
60 | p = self.CONFIG.params
61 |
62 | # create model
63 | self.model = CVAE(p.input_dim, p.hidden_dim, p.latent_dim, p.hidden_layers, p.y_dim)
64 | self.model.to(self.device)
65 | print(self.model)
66 |
67 | # set up optimizer
68 | gamma = (p.final_decay)**(1./p.n_epochs)
69 | self.optimizer = optim.Adam(self.model.parameters(), lr=p.lr, weight_decay=p.weight_decay)
70 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=gamma)
71 |
72 | def train(self):
73 | p = self.CONFIG.params
74 | self.model.train()
75 |
76 | # loss of the peoch
77 | rcl_loss = 0.
78 | kld_loss = 0.
79 |
80 | for i, (x, y) in enumerate(self.train_loader):
81 | x = x.to(self.device)
82 | y = y.to(self.device)
83 |
84 | self.optimizer.zero_grad()
85 |
86 | # forward
87 | reconstructed_x, z_mu, z_var = self.model(x, y)
88 |
89 | rcl, kld = calculate_loss(x, reconstructed_x, z_mu, z_var, p.kl_weight, p.mc_kl_loss)
90 |
91 | # backward
92 | combined_loss = rcl + kld
93 | combined_loss.backward()
94 | rcl_loss += rcl.item()
95 | kld_loss += kld.item()
96 |
97 | # update the weights
98 | self.optimizer.step()
99 |
100 | return rcl_loss, kld_loss
101 |
102 | def test(self):
103 | p = self.CONFIG.params
104 |
105 | self.model.eval()
106 |
107 | # loss of the evaluation
108 | rcl_loss = 0.
109 | kld_loss = 0.
110 |
111 | with torch.no_grad():
112 | for i, (x, y) in enumerate(self.test_loader):
113 | x = x.to(self.device)
114 | y = y.to(self.device)
115 |
116 | # forward pass
117 | reconstructed_x, z_mu, z_var = self.model(x, y)
118 |
119 | # loss
120 | rcl, kld = calculate_loss(x, reconstructed_x, z_mu, z_var, p.kl_weight, p.mc_kl_loss)
121 | rcl_loss += rcl.item()
122 | kld_loss += kld.item()
123 |
124 | return rcl_loss, kld_loss
125 |
126 | def run(self):
127 | p = self.CONFIG.params
128 | best_test_loss = float('inf')
129 | best_train_loss = float('inf')
130 | best_epoch = 0
131 |
132 | for e in range(p.n_epochs):
133 | tic = time.time()
134 |
135 | rcl_train_loss, kld_train_loss = self.train()
136 | rcl_test_loss, kld_test_loss = self.test()
137 |
138 | rcl_train_loss /= len(self.x_train)
139 | kld_train_loss /= len(self.x_train)
140 | train_loss = rcl_train_loss + kld_train_loss
141 | rcl_test_loss /= len(self.x_test)
142 | kld_test_loss /= len(self.x_test)
143 | test_loss = rcl_test_loss + kld_test_loss
144 |
145 | self.scheduler.step()
146 | lr = self.scheduler.optimizer.param_groups[0]["lr"]
147 |
148 | if best_test_loss > test_loss:
149 | best_epoch = e
150 | best_test_loss = test_loss
151 | best_train_loss = train_loss
152 | model_best = copy.deepcopy(self.model)
153 |
154 | elapsed_time = time.time() - tic
155 | epoch_out = f'Epoch {e:04d}, Train RCL: {rcl_train_loss:.3f}, Train KLD: {kld_train_loss:.3f}, Train: {train_loss:.3f}, Test RLC: {rcl_test_loss:.3f}, Test KLD: {kld_test_loss:.3f}, Test: {test_loss:.3f}, LR: {lr:.5f}, Time/Epoch (s): {elapsed_time:.3f}'
156 | if e % p.verbosity == 0:
157 | print(epoch_out)
158 |
159 | torch.save(model_best, self.CONFIG.model_path)
160 | return best_epoch, best_train_loss, best_test_loss
161 |
162 | if __name__ == "__main__":
163 | # load parameters from yaml file
164 | stream = open('./configs/example/example_cvae_trainer.yaml')
165 | CONFIG = yaml.safe_load(stream)
166 | stream.close()
167 | CONFIG = dotdict(CONFIG)
168 |
169 | trainer = Trainer(CONFIG)
170 | trainer.run()
--------------------------------------------------------------------------------
/examples/example_extract.py:
--------------------------------------------------------------------------------
1 | '''
2 | Use Behler ACSF descriptor to extract features.
3 |
4 | To customize parameters, make a yaml file in configs/
5 | In this example, we will use configs/example_extract_reconstruct.yaml
6 |
7 | Output csv files are saved under data/representation/
8 | '''
9 |
10 | import yaml
11 | from structrepgen.extraction.representation_extraction import *
12 | from structrepgen.utils.dotdict import dotdict
13 |
14 | # load our config yaml file
15 | stream = open('./configs/example/example_extract_reconstruct.yaml')
16 | CONFIG = yaml.safe_load(stream)
17 | stream.close()
18 | # dotdict for dot operations on Python dict e.g., CONFIG.cutoff
19 | CONFIG = dotdict(CONFIG)
20 |
21 | # create extractor instance
22 | extractor = RepresentationExtraction(CONFIG)
23 | # extract
24 | extractor.extract()
--------------------------------------------------------------------------------
/examples/example_generator.py:
--------------------------------------------------------------------------------
1 | import yaml, os, unittest, torch
2 | import numpy as np
3 | from structrepgen.generators.generator import Generator
4 | from structrepgen.utils.dotdict import dotdict
5 |
6 | stream = open('./configs/example/example_generator.yaml')
7 | CONFIG = yaml.safe_load(stream)
8 | stream.close()
9 | CONFIG = dotdict(CONFIG)
10 |
11 | gen = Generator(CONFIG)
12 |
13 | gen.generate()
14 | gen.range_check()
--------------------------------------------------------------------------------
/examples/example_reconstruction.py:
--------------------------------------------------------------------------------
1 | import yaml, os, unittest, torch
2 | import numpy as np
3 | from structrepgen.extraction.representation_extraction import *
4 | from structrepgen.utils.dotdict import dotdict
5 | from structrepgen.reconstruction.reconstruction import Reconstruction
6 |
7 | stream = open('./configs/example/example_extract_reconstruct.yaml')
8 | CONFIG = yaml.safe_load(stream)
9 | stream.close()
10 | CONFIG = dotdict(CONFIG)
11 |
12 | constructor = Reconstruction(CONFIG)
13 |
14 | constructor.main()
--------------------------------------------------------------------------------
/examples/example_surrogate_trainer.py:
--------------------------------------------------------------------------------
1 | import torch, yaml
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | from torch.utils.data import DataLoader, TensorDataset
6 | from sklearn.preprocessing import MinMaxScaler
7 | import joblib, time, copy
8 | import pandas as pd
9 | import numpy as np
10 | from sklearn.model_selection import train_test_split
11 |
12 | from structrepgen.models.models import ff_net
13 | from structrepgen.utils.dotdict import dotdict
14 | from structrepgen.utils.utils import torch_device_select
15 |
16 | '''
17 | Example of training a surrogate model based on the representation R extracted using Behler descriptors
18 | The model is trained on MSE loss for y
19 |
20 | Model used:
21 | fully connected feed-forward neural network
22 | '''
23 |
24 | class Trainer():
25 | def __init__(self, CONFIG) -> None:
26 | self.CONFIG = CONFIG
27 |
28 | # check GPU availability & set device
29 | self.device = torch_device_select(self.CONFIG.gpu)
30 |
31 | # initialize
32 | self.create_data()
33 | self.initialize()
34 |
35 | def create_data(self):
36 | p = self.CONFIG.params
37 |
38 | data_x = pd.read_csv(self.CONFIG.data_x_path, header=None).values
39 | data_y = pd.read_csv(self.CONFIG.data_y_path, header=None).values
40 |
41 | # train/test split and create torch dataloader
42 | xtrain, xtest, ytrain, ytest = train_test_split(data_x, data_y, test_size=self.CONFIG.split_ratio, random_state=p.seed)
43 | self.x_train = torch.tensor(xtrain, dtype=torch.float)
44 | self.y_train = torch.tensor(ytrain, dtype=torch.float)
45 | self.x_test = torch.tensor(xtest, dtype=torch.float)
46 | self.y_test = torch.tensor(ytest, dtype=torch.float)
47 |
48 | self.train_loader = DataLoader(
49 | TensorDataset(self.x_train, self.y_train),
50 | batch_size=p.batch_size, shuffle=True, drop_last=False
51 | )
52 |
53 | self.test_loader = DataLoader(
54 | TensorDataset(self.x_test, self.y_test),
55 | batch_size=p.batch_size, shuffle=False, drop_last=False
56 | )
57 |
58 | def initialize(self):
59 | p = self.CONFIG.params
60 |
61 | # create model
62 | self.model = ff_net(p.input_dim, p.hidden_dim, p.hidden_layers)
63 | self.model.to(self.device)
64 | print(self.model)
65 |
66 | # set up optimizer
67 | gamma = (p.final_decay)**(1./p.n_epochs)
68 | self.optimizer = optim.Adam(self.model.parameters(), lr=p.lr, weight_decay=p.weight_decay)
69 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=gamma)
70 |
71 | def train(self):
72 | p = self.CONFIG.params
73 | self.model.train()
74 |
75 | # loss of the epoch
76 | total_loss = 0.
77 |
78 | for i, (x, y) in enumerate(self.train_loader):
79 | x = x.to(self.device)
80 | y = y.to(self.device)
81 |
82 | self.optimizer.zero_grad()
83 |
84 | # forward
85 | y_pred = self.model(x)
86 |
87 | loss = F.mse_loss(y_pred, y, size_average=False)
88 |
89 | # backward
90 | loss.backward()
91 | total_loss += loss.item()
92 |
93 | # update the weights
94 | self.optimizer.step()
95 |
96 | return total_loss
97 |
98 | def test(self):
99 | p = self.CONFIG.params
100 |
101 | self.model.eval()
102 |
103 | # loss of the evaluation
104 | total_loss = 0.
105 |
106 | with torch.no_grad():
107 | for i, (x, y) in enumerate(self.test_loader):
108 | x = x.to(self.device)
109 | y = y.to(self.device)
110 |
111 | # forward pass
112 | y_pred = self.model(x)
113 |
114 | # loss
115 | loss = F.mse_loss(y_pred, y, size_average=False)
116 | total_loss += loss.item()
117 |
118 | return total_loss
119 |
120 | def run(self):
121 | p = self.CONFIG.params
122 | best_test_loss = float('inf')
123 | best_train_loss = float('inf')
124 | best_epoch = 0
125 |
126 | for e in range(p.n_epochs):
127 | tic = time.time()
128 |
129 | train_loss = self.train()
130 | test_loss = self.test()
131 |
132 | train_loss /= len(self.x_train)
133 | test_loss /= len(self.x_test)
134 |
135 | self.scheduler.step()
136 | lr = self.scheduler.optimizer.param_groups[0]["lr"]
137 |
138 | if best_test_loss > test_loss:
139 | best_epoch = e
140 | best_test_loss = test_loss
141 | best_train_loss = train_loss
142 | model_best = copy.deepcopy(self.model)
143 |
144 | elapsed_time = time.time() - tic
145 | epoch_out = f'Epoch {e:04d}, Train: {train_loss:.4f}, Test: {test_loss:.4f}, LR: {lr:.5f}, Time/Epoch (s): {elapsed_time:.3f}'
146 | if e % p.verbosity == 0:
147 | print(epoch_out)
148 |
149 | torch.save(model_best, self.CONFIG.ff_path)
150 | return best_epoch, best_train_loss, best_test_loss
151 |
152 | if __name__ == "__main__":
153 | # load parameters from yaml file
154 | stream = open('./configs/example/example_ff_trainer.yaml')
155 | CONFIG = yaml.safe_load(stream)
156 | stream.close()
157 | CONFIG = dotdict(CONFIG)
158 |
159 | trainer = Trainer(CONFIG)
160 | trainer.run()
--------------------------------------------------------------------------------
/misc.py:
--------------------------------------------------------------------------------
1 | # this is a placeholder
--------------------------------------------------------------------------------
/outputs/reconstructed/0_reconstructed.cif:
--------------------------------------------------------------------------------
1 | data_image0
2 | _chemical_formula_structural Pt10
3 | _chemical_formula_sum "Pt10"
4 | _cell_length_a 20
5 | _cell_length_b 20
6 | _cell_length_c 20
7 | _cell_angle_alpha 90
8 | _cell_angle_beta 90
9 | _cell_angle_gamma 90
10 |
11 | _space_group_name_H-M_alt "P 1"
12 | _space_group_IT_number 1
13 |
14 | loop_
15 | _space_group_symop_operation_xyz
16 | 'x, y, z'
17 |
18 | loop_
19 | _atom_site_type_symbol
20 | _atom_site_label
21 | _atom_site_symmetry_multiplicity
22 | _atom_site_fract_x
23 | _atom_site_fract_y
24 | _atom_site_fract_z
25 | _atom_site_occupancy
26 | Pt Pt1 1.0 0.59185 0.23819 0.75476 1.0000
27 | Pt Pt2 1.0 0.58611 0.54290 0.61489 1.0000
28 | Pt Pt3 1.0 0.69703 0.27786 0.68786 1.0000
29 | Pt Pt4 1.0 0.50226 0.36483 0.79940 1.0000
30 | Pt Pt5 1.0 0.66070 0.55487 0.53070 1.0000
31 | Pt Pt6 1.0 0.55922 0.23397 0.62426 1.0000
32 | Pt Pt7 1.0 0.63122 0.31871 0.58849 1.0000
33 | Pt Pt8 1.0 0.55219 0.24226 0.85988 1.0000
34 | Pt Pt9 1.0 0.55663 0.33984 0.69276 1.0000
35 | Pt Pt10 1.0 0.57114 0.42699 0.59601 1.0000
36 |
--------------------------------------------------------------------------------
/outputs/reconstructed/1_reconstructed.cif:
--------------------------------------------------------------------------------
1 | data_image0
2 | _chemical_formula_structural Pt10
3 | _chemical_formula_sum "Pt10"
4 | _cell_length_a 20
5 | _cell_length_b 20
6 | _cell_length_c 20
7 | _cell_angle_alpha 90
8 | _cell_angle_beta 90
9 | _cell_angle_gamma 90
10 |
11 | _space_group_name_H-M_alt "P 1"
12 | _space_group_IT_number 1
13 |
14 | loop_
15 | _space_group_symop_operation_xyz
16 | 'x, y, z'
17 |
18 | loop_
19 | _atom_site_type_symbol
20 | _atom_site_label
21 | _atom_site_symmetry_multiplicity
22 | _atom_site_fract_x
23 | _atom_site_fract_y
24 | _atom_site_fract_z
25 | _atom_site_occupancy
26 | Pt Pt1 1.0 0.54359 0.34959 0.50534 1.0000
27 | Pt Pt2 1.0 0.44811 0.48147 0.44699 1.0000
28 | Pt Pt3 1.0 0.46119 0.44006 0.55937 1.0000
29 | Pt Pt4 1.0 0.35195 0.31757 0.37187 1.0000
30 | Pt Pt5 1.0 0.41042 0.60342 0.48576 1.0000
31 | Pt Pt6 1.0 0.59037 0.44773 0.53806 1.0000
32 | Pt Pt7 1.0 0.42613 0.41147 0.35061 1.0000
33 | Pt Pt8 1.0 0.39104 0.37135 0.47558 1.0000
34 | Pt Pt9 1.0 0.38792 0.24888 0.48067 1.0000
35 | Pt Pt10 1.0 0.35573 0.52945 0.55215 1.0000
36 |
--------------------------------------------------------------------------------
/outputs/reconstructed/2_reconstructed.cif:
--------------------------------------------------------------------------------
1 | data_image0
2 | _chemical_formula_structural Pt10
3 | _chemical_formula_sum "Pt10"
4 | _cell_length_a 20
5 | _cell_length_b 20
6 | _cell_length_c 20
7 | _cell_angle_alpha 90
8 | _cell_angle_beta 90
9 | _cell_angle_gamma 90
10 |
11 | _space_group_name_H-M_alt "P 1"
12 | _space_group_IT_number 1
13 |
14 | loop_
15 | _space_group_symop_operation_xyz
16 | 'x, y, z'
17 |
18 | loop_
19 | _atom_site_type_symbol
20 | _atom_site_label
21 | _atom_site_symmetry_multiplicity
22 | _atom_site_fract_x
23 | _atom_site_fract_y
24 | _atom_site_fract_z
25 | _atom_site_occupancy
26 | Pt Pt1 1.0 0.63372 0.44299 0.29739 1.0000
27 | Pt Pt2 1.0 0.58648 0.41319 0.55833 1.0000
28 | Pt Pt3 1.0 0.63758 0.43002 0.45702 1.0000
29 | Pt Pt4 1.0 0.57271 0.54867 0.28353 1.0000
30 | Pt Pt5 1.0 0.56319 0.71164 0.17605 1.0000
31 | Pt Pt6 1.0 0.52529 0.60120 0.18550 1.0000
32 | Pt Pt7 1.0 0.60091 0.45877 0.19391 1.0000
33 | Pt Pt8 1.0 0.44374 0.56779 0.28553 1.0000
34 | Pt Pt9 1.0 0.51231 0.65642 0.29554 1.0000
35 | Pt Pt10 1.0 0.48298 0.37207 0.58142 1.0000
36 |
--------------------------------------------------------------------------------
/outputs/reconstructed/3_reconstructed.cif:
--------------------------------------------------------------------------------
1 | data_image0
2 | _chemical_formula_structural Pt10
3 | _chemical_formula_sum "Pt10"
4 | _cell_length_a 20
5 | _cell_length_b 20
6 | _cell_length_c 20
7 | _cell_angle_alpha 90
8 | _cell_angle_beta 90
9 | _cell_angle_gamma 90
10 |
11 | _space_group_name_H-M_alt "P 1"
12 | _space_group_IT_number 1
13 |
14 | loop_
15 | _space_group_symop_operation_xyz
16 | 'x, y, z'
17 |
18 | loop_
19 | _atom_site_type_symbol
20 | _atom_site_label
21 | _atom_site_symmetry_multiplicity
22 | _atom_site_fract_x
23 | _atom_site_fract_y
24 | _atom_site_fract_z
25 | _atom_site_occupancy
26 | Pt Pt1 1.0 0.78766 0.62397 0.60775 1.0000
27 | Pt Pt2 1.0 0.67944 0.56004 0.64893 1.0000
28 | Pt Pt3 1.0 0.51837 0.54830 0.84224 1.0000
29 | Pt Pt4 1.0 0.53434 0.62492 0.62552 1.0000
30 | Pt Pt5 1.0 0.59688 0.46453 0.81786 1.0000
31 | Pt Pt6 1.0 0.70441 0.67983 0.68149 1.0000
32 | Pt Pt7 1.0 0.56416 0.56277 0.72494 1.0000
33 | Pt Pt8 1.0 0.55373 0.68267 0.73955 1.0000
34 | Pt Pt9 1.0 0.75654 0.57270 0.74149 1.0000
35 | Pt Pt10 1.0 0.60483 0.56585 0.54790 1.0000
36 |
--------------------------------------------------------------------------------
/quickrun.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | echo "=== BEHLER (TEST) ==="
4 | python tests/test_behler.py
5 |
6 | echo "=== EXTRACTION ==="
7 | python examples/example_extract.py
8 |
9 | echo "=== CVAE TRAINER ==="
10 | python examples/example_cvae_trainer.py
11 |
12 | echo "=== SURROGATE TRAINER ==="
13 | python examples/example_surrogate_trainer.py
14 |
15 | echo "=== GENERATION ==="
16 | python examples/example_generator.py
17 |
18 | echo "=== RECONSTRUCTION ==="
19 | python examples/example_reconstruction.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ase==3.22.1
2 | scikit-optimize==0.9.0
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='structrepgen',
5 | version='0.0.1',
6 | description="Atomic Structure Generation from Reconstructing Structural Fingerprints",
7 | url="https://github.com/Fung-Lab/StructRepGen",
8 | packages=find_packages(),
9 | include_package_data=True,
10 | )
--------------------------------------------------------------------------------
/structrepgen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/__init__.py
--------------------------------------------------------------------------------
/structrepgen/descriptors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/descriptors/__init__.py
--------------------------------------------------------------------------------
/structrepgen/descriptors/behler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools, torch, warnings
3 | from itertools import combinations, combinations_with_replacement
4 | from structrepgen.utils.utils import torch_device_select
5 |
6 | class Behler:
7 | '''
8 | Behler-Parinello Atomic-centered symmetry functions (ACSF) for representation extraction
9 | See
10 | 1. Behler, J. Chem. Phys. (2011)
11 | '''
12 | def __init__(self, CONFIG) -> None:
13 | self.CONFIG = CONFIG
14 |
15 | # check GPU availability & set device
16 | self.device = torch_device_select(self.CONFIG.gpu)
17 |
18 | def get_features(self, distances, rij, atomic_numbers):
19 | '''
20 | wrapper method for self.get_ACSF_features()
21 | '''
22 | return self.get_ACSF_features(distances, rij, atomic_numbers)
23 |
24 | def get_ACSF_features(self, distances, rij, atomic_numbers):
25 | '''
26 | Get Atomic-centered symmetry functions (ACSF) features (representation R)
27 |
28 | Parameters:
29 | distances ():
30 | rij ():
31 | atomic_numbers ():
32 |
33 | Returns:
34 | #TODO
35 | '''
36 |
37 | n_atoms = distances.shape[0]
38 | type_set1 = self.create_type_set(atomic_numbers, 1)
39 | type_set2 = self.create_type_set(atomic_numbers, 2)
40 |
41 | atomic_numbers = torch.tensor(atomic_numbers, dtype=int)
42 | atomic_numbers = atomic_numbers.view(1, -1).expand(n_atoms, -1)
43 | atomic_numbers = atomic_numbers.reshape(n_atoms, -1).numpy()
44 |
45 | # calculate size of all_features tensor
46 | g2p = self.CONFIG.g2_params
47 | g5p = self.CONFIG.g5_params
48 | G2_out_size = len(g2p.Rs) * len(g2p.eta) * len(type_set1)
49 | G5_out_size = len(g5p.Rs) * len(g5p.eta) * len(g5p.zeta) * len(g5p.lambdas) * len(type_set2) if self.CONFIG.g5 else 0
50 | all_features = torch.zeros((n_atoms, G2_out_size+G5_out_size), device=self.device, dtype=torch.float)
51 |
52 | for i in range(n_atoms):
53 | Rc = self.CONFIG.cutoff
54 | mask = (distances[i, :] <= Rc) & (distances[i, :] != 0.0)
55 | Dij = distances[i, mask]
56 | Rij = rij[i, mask]
57 | IDs = atomic_numbers[i, mask.cpu().numpy()]
58 | jks = np.array(list(combinations(range(len(IDs)), 2)))
59 |
60 | # get G2
61 | G2_i = self.get_G2(Dij, IDs, type_set1)
62 |
63 | all_features[i, :G2_out_size] = G2_i
64 | # get G5
65 | if self.CONFIG.g5:
66 | G5_i = self.get_G5(Rij, IDs, jks, type_set2)
67 | all_features[i, G2_out_size:] = G5_i
68 |
69 | if self.CONFIG.average:
70 | length = all_features.shape[1]
71 | out = torch.zeros(length * 2, device=self.device, dtype=torch.float)
72 | out[:length] = torch.max(all_features, dim=0)[0]
73 | out[length:] = torch.min(all_features, dim=0)[0]
74 | return out
75 |
76 | return all_features
77 |
78 | def get_G2(self, Dij, IDs, type_set):
79 | '''
80 | Get G2 symmetry function (radial):
81 | G_{i}^{2}=\sum_{j} e^{-\eta\left(R_{i j}-R_{s}\right)^{2}} \cdot f_{c}\left(R_{i j}\right)
82 | See Behler, J. Chem. Phys. (2011) Eqn (6)
83 |
84 | Parameters:
85 | Dij ():
86 | IDs ():
87 | type_set ():
88 |
89 | Returns:
90 | TODO
91 | '''
92 | Rc = self.CONFIG.cutoff
93 | Rs = torch.tensor(self.CONFIG.g2_params.Rs, device=self.device)
94 | eta = torch.tensor(self.CONFIG.g2_params.eta, device=self.device)
95 |
96 | n1, n2, m, l = len(Rs), len(eta), len(Dij), len(type_set)
97 |
98 | d20 = (Dij - Rs.view(-1, 1)) ** 2
99 | term = torch.exp(torch.einsum('i,jk->ijk', -eta, d20))
100 | results = torch.einsum('ijk,k->ijk', term, self.cosine_cutoff(Dij, Rc))
101 | results = results.view(-1, m)
102 |
103 | G2 = torch.zeros([n1*n2*l], device=self.device)
104 |
105 | for id, j_type in enumerate(type_set):
106 | ids = self.select_rows(IDs, j_type)
107 | G2[id::l] = torch.sum(results[:, ids], axis=1)
108 |
109 | return G2
110 |
111 | def get_G5(self, Rij, IDs, jks, type_set):
112 | '''
113 | Get G5 symmetry function (angular)
114 | See Behler, J. Chem. Phys. (2011) Eqn (9)
115 |
116 | Parameters:
117 | TODO
118 |
119 | Returns:
120 | TODO
121 | '''
122 |
123 | Rc = self.CONFIG.cutoff
124 | Rs = torch.tensor(self.CONFIG.g5_params.Rs, device=self.device)
125 | eta = torch.tensor(self.CONFIG.g5_params.eta, device=self.device)
126 | lambdas = torch.tensor(self.CONFIG.g5_params.lambdas, device=self.device)
127 | zeta = torch.tensor(self.CONFIG.g5_params.zeta, device=self.device)
128 |
129 | n1, n2, n3, n4, l = len(Rs), len(eta), len(lambdas), len(zeta), len(type_set)
130 | jk = len(jks) # m1
131 | if jk == 0: # For dimer
132 | return torch.zeros([n1*n2*n3*n4*l], device=self.device)
133 |
134 | rij = Rij[jks[:, 0]] # [m1, 3]
135 | rik = Rij[jks[:, 1]] # [m1, 3]
136 | R2ij0 = torch.sum(rij**2., axis=1)
137 | R2ik0 = torch.sum(rik**2., axis=1)
138 | R1ij0 = torch.sqrt(R2ij0) # m1
139 | R1ik0 = torch.sqrt(R2ik0) # m1
140 | R2ij = R2ij0 - Rs.view(-1, 1)**2 # n1*m1
141 | R2ik = R2ik0 - Rs.view(-1, 1)**2 # n1*m1
142 |
143 | R1ij = R1ij0 - Rs.view(-1, 1) # n1*m1
144 | R1ik = R1ik0 - Rs.view(-1, 1) # n1*m1
145 |
146 | powers = 2. ** (1.-zeta) # n4
147 | cos_ijk = torch.sum(rij*rik, axis=1)/R1ij0/R1ik0 # m1 array
148 | term1 = 1. + torch.einsum('i,j->ij', lambdas, cos_ijk) # n3*m1
149 |
150 | zetas1 = zeta.repeat_interleave(n3*jk).reshape([n4, n3, jk]) # n4*n3*m1
151 | term2 = torch.pow(term1, zetas1) # n4*n3*m1
152 | term3 = torch.exp(torch.einsum('i,jk->ijk', -eta, (R2ij+R2ik))) # n2*n1*m1
153 | # * Cosine(R1jk0, Rc) # m1
154 | term4 = self.cosine_cutoff(R1ij0, Rc) * self.cosine_cutoff(R1ik0, Rc)
155 | term5 = torch.einsum('ijk,lmk->ijlmk', term2, term3) # n4*n3*n2*n1*m1
156 | term6 = torch.einsum('ijkml,l->ijkml', term5, term4) # n4*n3*n2*n1*m1
157 | results = torch.einsum('i,ijkml->ijkml', powers, term6) # n4*n3*n2*n1*m1
158 | results = results.reshape([n1*n2*n3*n4, jk])
159 |
160 | G5 = torch.zeros([n1*n2*n3*n4*l], device=self.device)
161 | jk_ids = IDs[jks]
162 | for id, jk_type in enumerate(type_set):
163 | ids = self.select_rows(jk_ids, jk_type)
164 | G5[id::l] = torch.sum(results[:, ids], axis=1)
165 |
166 | return G5
167 |
168 | def create_type_set(self, number_set, order):
169 | '''
170 | TODO
171 | '''
172 | types = list(set(number_set))
173 | return np.array(list(combinations_with_replacement(types, order)))
174 |
175 | def select_rows(self, data, row_pattern):
176 | '''
177 | TODO
178 | '''
179 | if len(row_pattern) == 1:
180 | ids = (data == row_pattern)
181 | elif len(row_pattern) == 2:
182 | a, b = row_pattern
183 | if a == b:
184 | ids = [id for id, d in enumerate(data) if d[0] == a and d[1] == a]
185 | else:
186 | ids = [id for id, d in enumerate(data) if (d[0] == a and d[1] == b) or (d[0] == b and d[1] == a)]
187 | return ids
188 |
189 | def cosine_cutoff(self, Rij, Rc):
190 | '''
191 | Cosine cutoff function
192 | See Behler, J. Chem. Phys. (2011) Eqn (4)
193 |
194 | Parameters:
195 | Rij (torch.Tensor): distance between atom i and j
196 | Rc (float): cutoff radius
197 |
198 | Returns:
199 | out (torch.Tensor): cosine cutoff
200 | '''
201 |
202 | out = 0.5 * (torch.cos(np.pi * Rij / Rc) + 1.)
203 | out[out > Rc] = 0.
204 | return out
--------------------------------------------------------------------------------
/structrepgen/descriptors/generic.py:
--------------------------------------------------------------------------------
1 | import torch, itertools
2 | import numpy as np
3 |
4 | def get_distances(positions, pbc_offsets, device):
5 | '''
6 | Get atomic distances
7 |
8 | Parameters:
9 | positions (numpy.ndarray/torch.Tensor): positions attribute of ase.Atoms
10 | pbc_offsets (numpy.ndarray/torch.Tensor): periodic boundary condition offsets
11 |
12 | Returns:
13 | TODO
14 | '''
15 |
16 | if isinstance(positions, np.ndarray):
17 | positions = torch.tensor(positions, device=device, dtype=torch.float)
18 |
19 | n_atoms = len(positions)
20 | n_cells = len(pbc_offsets)
21 |
22 | pos1 = positions.view(-1, 1, 1, 3).expand(-1, n_atoms, n_cells, 3)
23 | pos2 = positions.view(1, -1, 1, 3).expand(n_atoms, -1, n_cells, 3)
24 | pbc_offsets = pbc_offsets.view(-1, n_cells, 3).expand(pos2.shape[0], n_cells, 3)
25 | pos2 = pos2 + pbc_offsets
26 |
27 | # calculate the distance between target atom and the periodic images of the other atom
28 | atom_distance_sqr = torch.linalg.norm(pos1 - pos2, dim=-1)
29 | # get the minimum distance
30 | atom_distance_sqr_min, min_indices = torch.min(atom_distance_sqr, dim=-1)
31 |
32 | atom_rij = pos1 - pos2
33 | min_indices = min_indices[..., None, None].expand(-1, -1, 1, atom_rij.size(3))
34 | atom_rij = torch.gather(atom_rij, dim=2, index=min_indices).squeeze()
35 |
36 | return atom_distance_sqr_min, atom_rij
37 |
38 | def get_pbc_offsets(cell, offset_num, device):
39 | '''
40 | Get periodic boundary condition (PBC) offsets
41 |
42 | Parameters:
43 | cell (np.ndarray/torch.Tensor): unit cell vectors of ase.cell.Cell
44 | offset_num:
45 |
46 | Returns:
47 | TODO
48 | '''
49 | if isinstance(cell, np.ndarray):
50 | cell = torch.tensor(np.array(cell), device=device, dtype=torch.float)
51 |
52 | unit_cell = []
53 | offset_range = np.arange(-offset_num, offset_num + 1)
54 |
55 | for prod in itertools.product(offset_range, offset_range, offset_range):
56 | unit_cell.append(list(prod))
57 |
58 | unit_cell = torch.tensor(unit_cell, dtype=torch.float, device=device)
59 |
60 | return torch.mm(unit_cell, cell.to(device))
61 |
62 | # Obtain unit cell offsets for distance calculation
63 | class PBC_offsets():
64 | def __init__(self, cell, device, supercell_max=4):
65 | # set up pbc offsets for minimum distance in pbc
66 | self.pbc_offsets = []
67 |
68 | for offset_num in range(0, supercell_max):
69 | unit_cell = []
70 | offset_range = np.arange(-offset_num, offset_num+1)
71 |
72 | for prod in itertools.product(offset_range, offset_range, offset_range):
73 | unit_cell.append(list(prod))
74 |
75 | unit_cell = torch.tensor(unit_cell, dtype=torch.float, device=device)
76 | self.pbc_offsets.append(torch.mm(unit_cell, cell.to(device)))
77 |
78 | def get_offset(self, offset_num):
79 | return self.pbc_offsets[offset_num]
--------------------------------------------------------------------------------
/structrepgen/extraction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/extraction/__init__.py
--------------------------------------------------------------------------------
/structrepgen/extraction/representation_extraction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import ase, ase.db, csv, yaml, torch
4 |
5 | from structrepgen.descriptors.behler import Behler
6 | from structrepgen.descriptors.generic import *
7 |
8 | class RepresentationExtraction:
9 | def __init__(self, CONFIG) -> None:
10 | '''
11 | Extract representation R from raw data.
12 |
13 | Parameters:
14 | CONFIG (dict): descriptor configurations for extraction
15 |
16 | Returns:
17 | NIL
18 | '''
19 | self.CONFIG = CONFIG
20 |
21 | def extract(self, fname=None):
22 | descriptor = self.CONFIG.descriptor
23 |
24 | if descriptor == 'behler':
25 | self.behler(fname)
26 | else:
27 | raise Exception("Descriptor currently not supported.")
28 |
29 | def behler(self, fname=None):
30 | behler = Behler(self.CONFIG)
31 |
32 | data_x=[]
33 | data_y=[]
34 |
35 | db = ase.db.connect(self.CONFIG.data)
36 | for row in db.select():
37 | atoms = row.toatoms(add_additional_information=False)
38 | positions = atoms.get_positions()
39 | offsets = get_pbc_offsets(np.array(atoms.get_cell()), 0, behler.device)
40 | distances, rij = get_distances(positions, offsets, behler.device)
41 |
42 | features = behler.get_features(distances, rij, atoms.get_atomic_numbers())
43 | features = features.cpu().numpy()
44 |
45 | data_x.append(list(features))
46 | data_y.append([row.get('target')])
47 | break # should be removed in production
48 |
49 | with open(self.CONFIG.x_fname, 'w', newline='') as f:
50 | wr = csv.writer(f)
51 | wr.writerows(data_x)
52 |
53 | with open(self.CONFIG.y_fname, 'w', newline='') as f:
54 | wr = csv.writer(f)
55 | wr.writerows(data_y)
--------------------------------------------------------------------------------
/structrepgen/generators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/generators/__init__.py
--------------------------------------------------------------------------------
/structrepgen/generators/generator.py:
--------------------------------------------------------------------------------
1 | import torch, joblib, os
2 | import numpy as np
3 |
4 | from structrepgen.models.models import *
5 | from structrepgen.models.CVAE import *
6 | from structrepgen.utils.utils import torch_device_select
7 |
8 | class Generator:
9 | '''
10 | Generator class for generating representation R from decoder
11 |
12 | # TODO
13 | - make this generic such that different decoders can be used
14 |
15 | Parameters:
16 | CONFIG (dict): model configurations for generation
17 |
18 | Returns:
19 | NIL
20 | '''
21 | def __init__(self, CONFIG) -> None:
22 | self.CONFIG = CONFIG
23 |
24 | # check GPU availability & set device
25 | self.device = torch_device_select(self.CONFIG.gpu)
26 |
27 | # load models
28 | if self.device == 'cpu':
29 | self.model = torch.load(CONFIG.model_path, map_location=torch.device(self.device))
30 | self.model_ff = torch.load(CONFIG.ff_model_path, map_location=torch.device(self.device))
31 | else:
32 | self.model = torch.load(CONFIG.model_path)
33 | self.model_ff = torch.load(CONFIG.ff_model_path)
34 |
35 | self.model.to(self.device)
36 | self.model_ff.to(self.device)
37 | self.model.eval()
38 | self.model_ff.eval()
39 |
40 | self.rev_x = torch.zeros((len(self.CONFIG.targets), self.CONFIG.num_z, self.CONFIG.params.input_dim)).to(self.device)
41 |
42 | def generate(self):
43 | '''
44 | Generate representation R using decoder by randomly sampling the latent space
45 | Output csv saved in folder defined in yaml file
46 |
47 | Parameters:
48 | NIL
49 |
50 | Returns:
51 | NIL
52 | '''
53 | p = self.CONFIG.params
54 | for count, y0 in enumerate(self.CONFIG.targets):
55 | for i in range(self.CONFIG.num_z):
56 | z = torch.randn(1, p.latent_dim).to(self.device)
57 | y = torch.tensor([[y0]]).to(self.device)
58 |
59 | z = torch.cat((z, y), dim=1)
60 |
61 | reconstructed_x = self.model.decoder(z)
62 | self.rev_x[count, i, :] = reconstructed_x
63 |
64 | fname = self.CONFIG.save_path + 'gen_samps_cvae_' + str(y0) + '.csv'
65 | np.savetxt(fname, self.rev_x[count, :, :].cpu().data.numpy(), fmt='%.6f', delimiter=',')
66 |
67 | def range_check(self):
68 | '''
69 | Check percentage of generated structures that have y value within +- self.CONFIG.delta of target y value
70 | '''
71 | scaler = joblib.load(self.CONFIG.scaler_path)
72 | rev_x_scaled = scaler.inverse_transform(self.rev_x.reshape(-1, self.rev_x.shape[2]).cpu().data.numpy())
73 | rev_x = torch.tensor(rev_x_scaled).to(self.device).reshape(self.rev_x.shape)
74 |
75 | ratios = 0
76 | avg_diff = 0
77 |
78 | for count, y0 in enumerate(self.CONFIG.targets):
79 | y1 = self.model_ff(rev_x[count, :,:])
80 |
81 | indices = torch.where((y1 > y0 - self.CONFIG.delta) & (y1 < y0 + self.CONFIG.delta))[0]
82 | ratio = len(indices)/len(y1) * 100
83 | ratios += ratio
84 | rev_x_out = rev_x[count, indices, :]
85 |
86 | minn, maxx = min(self.CONFIG.targets), max(self.CONFIG.targets)
87 | y_indices = torch.where((y1 > minn) & (y1 < maxx))[0]
88 | average = torch.mean(y1[y_indices]).item()
89 | avg_diff += abs(y0-average)
90 |
91 | out = f'Target: {y0:.2f}, Average value: {average:.2f}, Percent of samples within range: {ratio:.2f}'
92 | print(out)
93 |
94 | fname = self.CONFIG.save_path + 'gen_samps_x_' + str(y0) + '.csv'
95 | np.savetxt(fname, rev_x_out.cpu().data.numpy(), fmt='%.6f', delimiter=',')
96 |
97 | ratios = ratios/len(self.CONFIG.targets)
98 | avg_diff = avg_diff/len(self.CONFIG.targets)
99 | avg_out = f'Average difference: {avg_diff:.2f}, Average percent: {ratios:.2f}'
100 | print(avg_out)
101 |
102 | return avg_diff, ratios
--------------------------------------------------------------------------------
/structrepgen/models/CVAE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class Encoder(nn.Module):
7 | ''' This the encoder part of VAE
8 | '''
9 | def __init__(self, input_dim, hidden_dim, latent_dim, hidden_layers, y_dim):
10 | '''
11 | Args:
12 | input_dim: A integer indicating the size of input (in case of MNIST 28 * 28).
13 | hidden_dim: A integer indicating the size of hidden dimension.
14 | latent_dim: A integer indicating the latent size.
15 | n_classes: A integer indicating the number of classes. (dimension of one-hot representation of labels)
16 | '''
17 | super().__init__()
18 |
19 | self.hidden = nn.Sequential()
20 | intermediate_dimensions = np.linspace(
21 | hidden_dim, latent_dim, hidden_layers+1, dtype=int)[0:-1]
22 | intermediate_dimensions = np.concatenate(
23 | ([input_dim + y_dim], intermediate_dimensions))
24 | for i, (in_size, out_size) in enumerate(zip(intermediate_dimensions[:-1], intermediate_dimensions[1:])):
25 | self.hidden.add_module(name='Linear_'+str(i),
26 | module=nn.Linear(in_size, out_size))
27 | #self.hidden.add_module(name='BN_'+str(i), module=nn.BatchNorm1d(out_size))
28 | self.hidden.add_module(name='Act_'+str(i), module=nn.ELU())
29 | #self.hidden.add_module(name='Drop_'+str(i), module=nn.Dropout(p=0.05, inplace=False))
30 |
31 | self.mu = nn.Linear(intermediate_dimensions[-1], latent_dim)
32 | self.var = nn.Linear(intermediate_dimensions[-1], latent_dim)
33 |
34 | def forward(self, x):
35 | # x is of shape [batch_size, input_dim + n_classes]
36 |
37 | x = self.hidden(x)
38 | # hidden is of shape [batch_size, hidden_dim]
39 |
40 | # latent parameters
41 | mean = self.mu(x)
42 | # mean is of shape [batch_size, latent_dim]
43 | log_var = self.var(x)
44 | # log_var is of shape [batch_size, latent_dim]
45 |
46 | return mean, log_var
47 |
48 |
49 | class Decoder(nn.Module):
50 | ''' This the decoder part of VAE
51 | '''
52 |
53 | def __init__(self, latent_dim, hidden_dim, output_dim, hidden_layers, y_dim):
54 | '''
55 | Args:
56 | latent_dim: A integer indicating the latent size.
57 | hidden_dim: A integer indicating the size of hidden dimension.
58 | output_dim: A integer indicating the size of output.
59 | n_classes: A integer indicating the number of classes. (dimension of one-hot representation of labels)
60 | '''
61 | super().__init__()
62 |
63 | self.hidden = nn.Sequential()
64 | intermediate_dimensions = np.linspace(
65 | latent_dim + y_dim, hidden_dim, hidden_layers+1, dtype=int)
66 | for i, (in_size, out_size) in enumerate(zip(intermediate_dimensions[:-1], intermediate_dimensions[1:])):
67 | self.hidden.add_module(name='Linear_'+str(i),
68 | module=nn.Linear(in_size, out_size))
69 | #self.hidden.add_module(name='BN_'+str(i), module=nn.BatchNorm1d(out_size))
70 | self.hidden.add_module(name='Act_'+str(i), module=nn.ELU())
71 | #self.hidden.add_module(name='Drop_'+str(i), module=nn.Dropout(p=0.05, inplace=False))
72 |
73 | self.hidden_to_out = nn.Linear(hidden_dim, output_dim)
74 |
75 | def forward(self, x):
76 | # x is of shape [batch_size, latent_dim + num_classes]
77 | x = self.hidden(x)
78 |
79 | generated_x = self.hidden_to_out(x)
80 |
81 | return generated_x
82 |
83 |
84 | class CVAE(nn.Module):
85 | ''' This the VAE, which takes a encoder and decoder.
86 | '''
87 |
88 | def __init__(self, input_dim, hidden_dim, latent_dim, hidden_layers, y_dim):
89 | '''
90 | Args:
91 | input_dim: A integer indicating the size of input.
92 | hidden_dim: A integer indicating the size of hidden dimension.
93 | latent_dim: A integer indicating the latent size.
94 | n_classes: A integer indicating the number of classes. (dimension of one-hot representation of labels)
95 | '''
96 | super().__init__()
97 |
98 | self.encoder = Encoder(input_dim, hidden_dim,
99 | latent_dim, hidden_layers, y_dim)
100 | self.decoder = Decoder(latent_dim, hidden_dim,
101 | input_dim, hidden_layers, y_dim)
102 |
103 | def forward(self, x, y):
104 |
105 | x = torch.cat((x, y), dim=1)
106 |
107 | # encode
108 | z_mu, z_var = self.encoder(x)
109 |
110 | # sample from the distribution having latent parameters z_mu, z_var
111 | # reparameterize
112 | std = torch.exp(z_var / 2)
113 | eps = torch.randn_like(std)
114 | x_sample = eps.mul(std).add_(z_mu)
115 |
116 | z = torch.cat((x_sample, y), dim=1)
117 |
118 | # decode
119 | generated_x = self.decoder(z)
120 |
121 | return generated_x, z_mu, z_var
122 |
123 | def kl_divergence(z, mu, std):
124 | # --------------------------
125 | # Monte carlo KL divergence
126 | # --------------------------
127 | # https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed
128 | # 1. define the first two probabilities (in this case Normal for both)
129 | p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
130 | q = torch.distributions.Normal(mu, std)
131 |
132 | # 2. get the probabilities from the equation
133 | log_qzx = q.log_prob(z)
134 | log_pz = p.log_prob(z)
135 |
136 | # kl
137 | kl = (log_qzx - log_pz)
138 |
139 | # sum over last dim to go from single dim distribution to multi-dim
140 | kl = kl.sum(-1)
141 | return kl
142 |
143 |
144 | def calculate_loss(x, reconstructed_x, mu, log_var, weight, mc_kl_loss):
145 | # reconstruction loss
146 | rcl = F.mse_loss(reconstructed_x, x, size_average=False)
147 | # kl divergence loss
148 |
149 | if mc_kl_loss == True:
150 | std = torch.exp(log_var / 2)
151 | q = torch.distributions.Normal(mu, std)
152 | z = q.rsample()
153 | kld = kl_divergence(z, mu, std).sum()
154 | else:
155 | kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
156 |
157 | return rcl, kld * weight
--------------------------------------------------------------------------------
/structrepgen/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/models/__init__.py
--------------------------------------------------------------------------------
/structrepgen/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class ff_net(nn.Module):
7 | def __init__(self, input_dim, hidden_dim, layers):
8 | super(ff_net, self).__init__()
9 | self.lin1 = torch.nn.Linear(input_dim, hidden_dim)
10 | self.lin_list = torch.nn.ModuleList(
11 | [torch.nn.Linear(hidden_dim, hidden_dim) for i in range(layers)]
12 | )
13 |
14 | self.lin2 = torch.nn.Linear(hidden_dim, 1)
15 |
16 | def forward(self, x):
17 | out = torch.nn.functional.relu(self.lin1(x))
18 | for layer in self.lin_list:
19 | out = torch.nn.functional.relu(layer(out))
20 | out = self.lin2(out)
21 | return out
--------------------------------------------------------------------------------
/structrepgen/reconstruction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/reconstruction/__init__.py
--------------------------------------------------------------------------------
/structrepgen/reconstruction/generic.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/reconstruction/generic.py
--------------------------------------------------------------------------------
/structrepgen/reconstruction/reconstruction.py:
--------------------------------------------------------------------------------
1 | import time, ase, torch, os, pickle
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from ase import io, Atoms
5 | from skopt.space import Space
6 | from skopt.sampler import Lhs, Halton, Grid
7 | from structrepgen.models.models import *
8 | from structrepgen.descriptors.behler import Behler
9 | from structrepgen.descriptors.generic import *
10 | from structrepgen.models.CVAE import *
11 | from structrepgen.utils.utils import torch_device_select
12 |
13 | class Reconstruction:
14 | '''
15 | Reconstruction
16 | '''
17 | def __init__(self, CONFIG) -> None:
18 | self.CONFIG = CONFIG
19 |
20 | # select gpu or cpu for pytorch
21 | self.device = torch_device_select(self.CONFIG.gpu)
22 |
23 | # initialize descriptor
24 | if self.CONFIG.descriptor == 'behler':
25 | self.descriptor = Behler(CONFIG)
26 | else:
27 | raise ValueError('Unrecognised descriptor method: {}.'.format(self.CONFIG.descriptor))
28 |
29 | self.model_ff = torch.load(self.CONFIG.ff_model_path)
30 | self.model_ff.to(self.device)
31 |
32 | torch.backends.cudnn.benchmark = True
33 |
34 | def main(self):
35 | '''
36 |
37 | '''
38 | tic = time.time()
39 |
40 | data_list = self.create_datalist()
41 | self.model_ff.eval()
42 |
43 | for i in range(len(data_list)):
44 | best_positions, _ = self.basin_hopping(
45 | data_list[i],
46 | total_trials=2,
47 | max_hops=10,
48 | lr=0.2,
49 | displacement_factor=2,
50 | max_loss=0.0001,
51 | verbose=True
52 | )
53 |
54 | optimized_structure = Atoms(
55 | numbers=data_list[i]['atomic_numbers'],
56 | positions=best_positions.detach().cpu().numpy(),
57 | cell=data_list[i]['cell'].cpu().numpy(),
58 | pbc=(True, True, True)
59 | )
60 |
61 | # save reconstructed cifs
62 | filename = self.CONFIG.reconstructed_file_path + str(i) + '_reconstructed.cif'
63 | ase.io.write(filename, optimized_structure)
64 |
65 | pbc_offsets = get_pbc_offsets(data_list[i]['cell'], 0, self.device)
66 | distances, rij = get_distances(best_positions, pbc_offsets, self.device)
67 | features = self.descriptor.get_features(distances, rij, data_list[i]['atomic_numbers'])
68 |
69 | # inference on FF model
70 | y0 = self.model_ff(data_list[i]['representation'])
71 | y1 = self.model_ff(features)
72 | print(y0, y1)
73 |
74 | elapsed_time = time.time() - tic
75 | print('Total elapsed time: {:.3f}'.format(elapsed_time))
76 |
77 | def basin_hopping(
78 | self,
79 | data,
80 | total_trials = 20,
81 | max_hops = 500,
82 | lr = 0.05,
83 | displacement_factor = 2,
84 | max_loss = 0.01,
85 | write = False,
86 | verbose = False
87 | ):
88 | '''
89 |
90 | '''
91 | # setting for self.optimize()
92 | max_iter = 140
93 |
94 | offset_count = 0
95 | offsets = PBC_offsets(data['cell'], self.device, supercell_max=1)
96 |
97 | best_global_loss = float('inf')
98 | best_global_positions = None
99 | converged = False
100 |
101 | for _ in range(total_trials):
102 | # initialize random structure
103 | generated_pos = self.initialize(
104 | len(data['atomic_numbers']),
105 | data['cell'],
106 | data['atomic_numbers'],
107 | sampling_method='random'
108 | )
109 |
110 | tic = time.time()
111 | best_local_loss = float('inf')
112 | best_local_positions = None
113 |
114 | for hop in range(max_hops):
115 | # TODO: parse loss function, optimizer and scheduler as args
116 | loss_func = F.l1_loss
117 | optimizer = torch.optim.Adam([generated_pos], lr=lr)
118 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
119 | optimizer,
120 | mode='min',
121 | factor=0.8,
122 | patience=15,
123 | min_lr=0.0001,
124 | threshold=0.0001
125 | )
126 |
127 | best_positions, _, loss = self.optimize(
128 | generated_pos,
129 | data['cell'],
130 | offsets.get_offset(offset_count),
131 | data['atomic_numbers'],
132 | data['representation'],
133 | loss_func,
134 | optimizer,
135 | scheduler,
136 | max_iter,
137 | verbosity=999,
138 | early_stopping=100,
139 | threshold=0.0001,
140 | convergence=0.00005
141 | )
142 |
143 | elapsed_time = time.time() - tic
144 | tic = time.time()
145 |
146 | if verbose:
147 | print("BH Hop: {:04d}, Loss: {:.5f}, Time (s): {:.4f}".format(hop, loss.item(), elapsed_time))
148 |
149 | if loss.item() < best_local_loss:
150 | # TODO: implement metropolis criterion
151 | best_local_positions = best_positions.detach().clone()
152 | best_local_loss = loss.item()
153 |
154 | if loss.item() < best_global_loss:
155 | best_global_positions = best_local_positions
156 | best_global_loss = loss.item()
157 |
158 | if best_local_loss < max_loss:
159 | print("Convergence criterion met.")
160 | converged = True
161 | break
162 |
163 | # apply random shift in positions
164 | # TODO: change hardcoded value of 0.5
165 | displacement = (np.random.random_sample(best_positions.shape) - 0.5) * displacement_factor
166 | displacement = torch.tensor(displacement, device=self.device, dtype=torch.float)
167 |
168 | generated_pos = best_local_positions.detach().clone() + displacement
169 | generated_pos.requires_grad_()
170 |
171 | if converged:
172 | break
173 |
174 | if verbose:
175 | print('Ending Basin Hopping, fine optimization of best structure. Best global loss: {:5.3f}'.format(best_global_loss))
176 |
177 | best_global_positions.requires_grad_()
178 | optimizer = torch.optim.Adam([best_global_positions], lr=lr*0.2)
179 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
180 | optimizer,
181 | mode='min',
182 | factor=0.8,
183 | patience=20,
184 | min_lr=0.0001,
185 | threshold=0.0001
186 | )
187 |
188 | best_global_positions, _, loss = self.optimize(
189 | best_global_positions,
190 | data['cell'],
191 | offsets.get_offset(offset_count),
192 | data['atomic_numbers'],
193 | data['representation'],
194 | loss_func,
195 | optimizer,
196 | scheduler,
197 | 900,
198 | verbosity=(10 if verbose == True else 999),
199 | early_stopping=20,
200 | threshold=0.00001,
201 | convergence=0.00005
202 | )
203 |
204 | if verbose:
205 | print('Best Basin Hopping loss: {:.5f}'.format(loss.item()))
206 | return best_global_positions, loss
207 |
208 | def optimize(
209 | self,
210 | positions,
211 | cell,
212 | pbc_offsets,
213 | atomic_numbers,
214 | target_representation,
215 | loss_function,
216 | optimizer,
217 | scheduler,
218 | max_iterations = 100,
219 | verbosity = 5,
220 | early_stopping = 1,
221 | threshold = 0.01,
222 | convergence = 0.001
223 | ):
224 | '''
225 |
226 | '''
227 | tic = time.time()
228 |
229 |
230 | all_positions = [positions.detach().clone()]
231 | count, loss_history = 0, []
232 | best_positions = None
233 | best_loss = float('inf')
234 | min_dist = 1.5
235 |
236 | while count < max_iterations:
237 | optimizer.zero_grad()
238 | lr = scheduler.optimizer.param_groups[0]['lr']
239 |
240 | # get representation / reconstruction loss
241 | distances, rij = get_distances(positions, pbc_offsets, self.device)
242 | features = self.descriptor.get_features(distances, rij, atomic_numbers)
243 |
244 | features = torch.unsqueeze(features, 0)
245 |
246 | loss = loss_function(features, target_representation)
247 | loss.backward(retain_graph=False)
248 | loss_history.append(loss)
249 |
250 | optimizer.step()
251 | scheduler.step(loss)
252 |
253 | #print loss and time taken
254 | if (count + 1) % verbosity == 0:
255 | elapsed_time = time.time() - tic
256 | tic = time.time()
257 | print("System Size: {:04d}, Step: {:04d}, Loss: {:.5f}, LR: {:.5f}, Time/step (s): {:.4f}".format(len(positions), count+1, loss.item(), lr, elapsed_time/verbosity))
258 |
259 | # save best structure
260 | if loss < best_loss:
261 | best_positions = positions.detach().clone()
262 | best_loss = loss
263 |
264 | count += 1
265 |
266 | if best_loss < convergence:
267 | break
268 |
269 | if early_stopping > 1 and count > early_stopping:
270 | if abs((loss.item() - loss_history[-early_stopping]) / loss_history[-early_stopping]) < threshold:
271 | break
272 |
273 | all_positions.append(positions.detach().clone())
274 |
275 | return best_positions, all_positions, best_loss
276 |
277 | def initialize(self, structure_len, cell, atomic_numbers, sampling_method='random', write=False):
278 | '''
279 | Initialize structure
280 |
281 | Parameters:
282 |
283 | Returns:
284 |
285 | '''
286 | space = Space([(0., 1.)] * 3)
287 | if sampling_method == 'random':
288 | positions = np.array(space.rvs(structure_len))
289 | else:
290 | raise ValueError("Unrecognised sampling method.")
291 |
292 | generated_structure = Atoms(
293 | numbers=atomic_numbers,
294 | scaled_positions=positions,
295 | cell=cell.cpu().numpy(),
296 | pbc=(True, True, True)
297 | )
298 |
299 | positions = generated_structure.positions
300 | if write == True:
301 | # TODO
302 | print("Implement write")
303 | pass
304 |
305 | return torch.tensor(positions, requires_grad=True, device=self.device, dtype=torch.float)
306 |
307 | def create_datalist(self):
308 | '''
309 | Load crystal structure(s) data from directory
310 |
311 | Parameters:
312 | NIL
313 |
314 | Returns:
315 | TODO
316 |
317 | '''
318 | representation = np.loadtxt(self.CONFIG.structure_file_path, dtype='float', delimiter=',')
319 | cell = self.CONFIG.cell
320 | data_list = []
321 |
322 | for i in range(len(representation)):
323 | data = {}
324 |
325 | data['atomic_numbers'] = self.CONFIG.atoms
326 | data['cell'] = torch.tensor(np.array(cell), dtype=torch.float)
327 | data['representation'] = torch.unsqueeze(torch.tensor(representation[i], dtype=torch.float, device=self.device), 0)
328 | data_list.append(data)
329 |
330 | return data_list
--------------------------------------------------------------------------------
/structrepgen/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/utils/__init__.py
--------------------------------------------------------------------------------
/structrepgen/utils/dotdict.py:
--------------------------------------------------------------------------------
1 | # class dotdict(dict):
2 | # """dot.notation access to dictionary attributes"""
3 | # __getattr__ = dict.get
4 | # __setattr__ = dict.__setitem__
5 | # __delattr__ = dict.__delitem__
6 |
7 |
8 | class dotdict(dict):
9 | """dot.notation access to dictionary attributes"""
10 | def __getattr__(*args):
11 | val = dict.get(*args)
12 | return dotdict(val) if type(val) is dict else val
13 | __setattr__ = dict.__setitem__
14 | __delattr__ = dict.__delitem__
--------------------------------------------------------------------------------
/structrepgen/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch, warnings
2 |
3 | def torch_device_select(gpu):
4 | # check GPU availability & return device type
5 | if torch.cuda.is_available() and not gpu:
6 | warnings.warn("GPU is available but not used.")
7 | return 'cpu'
8 | elif not torch.cuda.is_available() and gpu:
9 | warnings.warn("GPU is not available but set to used. Using CPU.")
10 | return 'cpu'
11 | elif torch.cuda.is_available() and gpu:
12 | return 'cuda'
13 | else:
14 | return 'cpu'
--------------------------------------------------------------------------------
/tests/original/db_to_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import ase, csv
3 | import ase.db
4 | from .original_descriptor import *
5 |
6 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 | device = 'cpu'
8 |
9 | def db_to_data():
10 | processing_arguments = {}
11 | g2_params = {'eta': torch.tensor([0.01, 0.06, 0.1, 0.2, 0.4, 0.7, 1.0, 2.0, 3.5, 5.0], device=device), 'Rs': torch.tensor([
12 | 0, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10], device=device)}
13 | g5_params = {'lambda': torch.tensor([-1, 1], device=device), 'zeta': torch.tensor([1, 2, 4, 16, 64], device=device),
14 | 'eta': torch.tensor([0.06, 0.1, 0.2, 0.4, 1.0], device=device), 'Rs': torch.tensor([0], device=device)}
15 | processing_arguments['cutoff'] = 20
16 | processing_arguments['mode'] = 'features'
17 | processing_arguments['g2_params'] = g2_params
18 | processing_arguments['g5_params'] = g5_params
19 | processing_arguments['g5'] = True
20 | processing_arguments['average'] = True
21 |
22 | db = ase.db.connect('./data/raw/data.db')
23 | data_x=[]
24 | data_y=[]
25 |
26 | for row in db.select():
27 | ase_structure = row.toatoms(add_additional_information=False)
28 | positions = torch.tensor(
29 | ase_structure.get_positions(), device=device, dtype=torch.float)
30 | distances, rij = get_distances(positions, get_pbc_offsets(torch.tensor(
31 | np.array(ase_structure.get_cell()), device=device, dtype=torch.float), 0))
32 | features = get_ACSF_features(
33 | distances, rij, ase_structure.get_atomic_numbers(), processing_arguments)
34 | features = torch.squeeze(features).cpu().numpy()
35 |
36 | data_x.append(list(features))
37 | data_y.append([row.get('target')])
38 | break
39 |
40 | with open("./data/unittest/unittest_original_x.csv", 'w', newline='') as f:
41 | wr = csv.writer(f)
42 | wr.writerows(data_x)
43 |
44 | with open("./data/unittest/unittest_original_y.csv", 'w', newline='') as f:
45 | wr = csv.writer(f)
46 | wr.writerows(data_y)
47 |
48 | if __name__ == "__main__":
49 | db_to_data()
--------------------------------------------------------------------------------
/tests/original/original_descriptor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 | import math
4 | from itertools import combinations, combinations_with_replacement
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 | device = 'cpu'
11 |
12 | ###obtains distance with pytorch tensors
13 | def get_distances(positions, pbc_offsets):
14 |
15 | num_atoms = len(positions)
16 | num_cells = len(pbc_offsets)
17 |
18 | pos1 = positions.view(-1, 1, 1, 3).expand(-1, num_atoms, 1, 3)
19 | pos1 = pos1.expand(-1, num_atoms, num_cells, 3)
20 | pos2 = positions.view(1, -1, 1, 3).expand(num_atoms, -1, 1, 3)
21 | pos2 = pos2.expand(num_atoms, -1, num_cells, 3)
22 |
23 | pbc_offsets = pbc_offsets.view(-1, num_cells, 3).expand(pos2.shape[0], num_cells, 3)
24 | pos2 = pos2 + pbc_offsets
25 | ###calculates distance between target atom and the periodic images of the other atom, then gets the minimum distance
26 | atom_distance_sqr = torch.linalg.norm(pos1 - pos2, dim=-1)
27 | atom_distance_sqr_min, min_indices = torch.min(atom_distance_sqr, dim=-1)
28 | #atom_distance_sqr_min = torch.amin(atom_distance_sqr, dim=-1)
29 |
30 | atom_rij = pos1 - pos2
31 | min_indices = min_indices[..., None, None].expand(-1, -1, 1, atom_rij.size(3))
32 | atom_rij = torch.gather(atom_rij, dim=2, index=min_indices).squeeze()
33 |
34 | return atom_distance_sqr_min, atom_rij
35 |
36 |
37 | ###Obtain unit cell offsets for distance calculation
38 | class PBC_offsets():
39 | def __init__(self,cell, supercell_max=4):
40 | #set up pbc offsets for minimum distance in pbc
41 | self.pbc_offsets=[]
42 |
43 | for offset_num in range(0, supercell_max):
44 | unit_cell=[]
45 | offset_range = np.arange(-offset_num, offset_num+1)
46 |
47 | for prod in itertools.product(offset_range, offset_range, offset_range):
48 | unit_cell.append(list(prod))
49 |
50 | unit_cell = torch.tensor(unit_cell, dtype=torch.float, device=device)
51 | self.pbc_offsets.append(torch.mm(unit_cell, cell.to(device)))
52 |
53 | def get_offset(self, offset_num):
54 | return self.pbc_offsets[offset_num]
55 | ###functional form
56 | def get_pbc_offsets(cell, offset_num):
57 |
58 | unit_cell=[]
59 | offset_range = np.arange(-offset_num, offset_num+1)
60 |
61 | ##put this out of loop
62 | for prod in itertools.product(offset_range, offset_range, offset_range):
63 | unit_cell.append(list(prod))
64 |
65 | unit_cell = torch.tensor(unit_cell, dtype=torch.float, device=device)
66 |
67 | return torch.mm(unit_cell, cell.to(device))
68 |
69 |
70 | ###ACSF evaluation function
71 | def get_ACSF_features(distances, rij, atomic_numbers, parameters):
72 |
73 | num_atoms = distances.shape[0]
74 |
75 | #offsets=get_pbc_offsets(cell, 2)
76 | #distances, rij = get_distances(positions, offsets)
77 | #print(distances.shape, rij.shape)
78 |
79 | type_set1 = create_type_set(atomic_numbers, 1)
80 | type_set2 = create_type_set(atomic_numbers, 2)
81 |
82 | atomic_numbers = torch.tensor(atomic_numbers, dtype=int)
83 | atomic_numbers = atomic_numbers.view(1, -1).expand(num_atoms, -1)
84 | atomic_numbers = atomic_numbers.reshape(num_atoms, -1).numpy()
85 |
86 | for i in range(0, num_atoms):
87 |
88 | mask = (distances[i, :] <= parameters['cutoff']) & (distances[i, :] != 0.0)
89 | Dij = distances[i, mask]
90 | Rij = rij[i, mask]
91 | IDs = atomic_numbers[i, mask.cpu().numpy()]
92 | jks = np.array(list(combinations(range(len(IDs)), 2)))
93 |
94 | G2i = calculate_G2_pt(Dij, IDs, type_set1, parameters['cutoff'], parameters['g2_params']['Rs'], parameters['g2_params']['eta'])
95 | if parameters['g5'] == True:
96 | G5i = calculate_G5_pt(Rij, IDs, jks, type_set2, parameters['cutoff'], parameters['g5_params']['Rs'], parameters['g5_params']['eta'], parameters['g5_params']['zeta'], parameters['g5_params']['lambda'])
97 | G_comb=torch.cat((G2i, G5i), dim=0).view(1,-1)
98 | else:
99 | G_comb=G2i.view(1,-1)
100 |
101 | if i == 0:
102 | all_G=G_comb
103 | else:
104 | all_G=torch.cat((all_G, G_comb), dim=0)
105 |
106 | if parameters['average']==True:
107 | all_G_max = torch.max(all_G, dim=0)[0]
108 | all_G_min = torch.min(all_G, dim=0)[0]
109 | all_G = torch.cat((all_G_max, all_G_min)).unsqueeze(0)
110 |
111 | return all_G
112 |
113 |
114 | ### ACSF sub-functions
115 | ###G2 symmetry function (radial)
116 | def calculate_G2_pt(Dij, IDs, type_set, Rc, Rs, etas):
117 |
118 | n1, n2, m, l = len(Rs), len(etas), len(Dij), len(type_set)
119 | d20 = (Dij - Rs.view(-1,1)) ** 2 # n1*m
120 | term = torch.exp(torch.einsum('i,jk->ijk', -etas, d20)) # n2*n1*m
121 | results = torch.einsum('ijk,k->ijk', term, cosine_cutoff_pt(Dij, Rc)) # n2*n1*m
122 | results = results.reshape([n1*n2, m])
123 |
124 | G2 = torch.zeros([n1*n2*l], device=device)
125 | for id, j_type in enumerate(type_set):
126 | ids = select_rows(IDs, j_type)
127 | G2[id::l] = torch.sum(results[:, ids], axis=1)
128 |
129 | return G2
130 |
131 |
132 | ###G5 symmetry function (angular)
133 | def calculate_G5_pt(Rij, IDs, jks, type_set, Rc, Rs, etas, zetas, lambdas):
134 |
135 | n1, n2, n3, n4, l = len(Rs), len(etas), len(lambdas), len(zetas), len(type_set)
136 | jk = len(jks) # m1
137 | if jk == 0: # For dimer
138 | return torch.zeros([n1*n2*n3*n4*l], device=device)
139 |
140 | rij = Rij[jks[:,0]] # [m1, 3]
141 | rik = Rij[jks[:,1]] # [m1, 3]
142 | R2ij0 = torch.sum(rij**2., axis=1)
143 | R2ik0 = torch.sum(rik**2., axis=1)
144 | R1ij0 = torch.sqrt(R2ij0) # m1
145 | R1ik0 = torch.sqrt(R2ik0) # m1
146 | R2ij = R2ij0 - Rs.view(-1,1)**2 # n1*m1
147 | R2ik = R2ik0 - Rs.view(-1,1)**2 # n1*m1
148 |
149 | R1ij = R1ij0 - Rs.view(-1,1) # n1*m1
150 | R1ik = R1ik0 - Rs.view(-1,1) # n1*m1
151 |
152 | powers = 2. ** (1.-zetas) #n4
153 | cos_ijk = torch.sum(rij*rik, axis=1)/R1ij0/R1ik0 # m1 array
154 | term1 = 1. + torch.einsum('i,j->ij', lambdas, cos_ijk) # n3*m1
155 |
156 | zetas1 = zetas.repeat_interleave(n3*jk).reshape([n4, n3, jk]) # n4*n3*m1
157 | term2 = torch.pow(term1, zetas1) # n4*n3*m1
158 | term3 = torch.exp(torch.einsum('i,jk->ijk', -etas, (R2ij+R2ik))) # n2*n1*m1
159 | term4 = cosine_cutoff_pt(R1ij0, Rc) * cosine_cutoff_pt(R1ik0, Rc) #* Cosine(R1jk0, Rc) # m1
160 | term5 = torch.einsum('ijk,lmk->ijlmk', term2, term3) #n4*n3*n2*n1*m1
161 | term6 = torch.einsum('ijkml,l->ijkml', term5, term4) #n4*n3*n2*n1*m1
162 | results = torch.einsum('i,ijkml->ijkml', powers, term6) #n4*n3*n2*n1*m1
163 | results = results.reshape([n1*n2*n3*n4, jk])
164 |
165 | G5 = torch.zeros([n1*n2*n3*n4*l], device=device)
166 | jk_ids = IDs[jks]
167 | for id, jk_type in enumerate(type_set):
168 | ids = select_rows(jk_ids, jk_type)
169 | G5[id::l] = torch.sum(results[:, ids], axis=1)
170 |
171 | return G5
172 |
173 |
174 | ###
175 | def create_type_set(number_set, order):
176 | types = list(set(number_set))
177 | return np.array(list(combinations_with_replacement(types, order)))
178 |
179 |
180 | ###Aggregation function
181 | #slow
182 | def select_rows(data, row_pattern):
183 | if len(row_pattern) == 1:
184 | ids = (data==row_pattern)
185 | elif len(row_pattern) == 2:
186 | a, b = row_pattern
187 | if a==b:
188 | ids = [id for id, d in enumerate(data) if d[0]==a and d[1]==a]
189 | else:
190 | ids = [id for id, d in enumerate(data) if (d[0] == a and d[1]==b) or (d[0] == b and d[1]==a)]
191 | return ids
192 |
193 |
194 | ###Cosine cutoff function
195 | def cosine_cutoff_pt(Rij, Rc):
196 | mask = (Rij > Rc)
197 | result = 0.5 * (torch.cos(torch.tensor(np.pi, device=device) * Rij / Rc) + 1.)
198 | result[mask] = 0
199 | return result
--------------------------------------------------------------------------------
/tests/test_behler.py:
--------------------------------------------------------------------------------
1 | import yaml, os, unittest, torch
2 | import numpy as np
3 | from structrepgen.extraction.representation_extraction import *
4 | from structrepgen.utils.dotdict import dotdict
5 | from original.db_to_data import db_to_data
6 |
7 | class TestBehler(unittest.TestCase):
8 |
9 | @classmethod
10 | def setUpClass(self):
11 | # original implementation
12 | db_to_data()
13 |
14 | # SRG
15 | stream = open('./configs/unittest/unittest_behler.yaml')
16 | CONFIG = yaml.safe_load(stream)
17 | stream.close()
18 | self.CONFIG = dotdict(CONFIG)
19 | extractor = RepresentationExtraction(self.CONFIG)
20 | extractor.extract()
21 |
22 | self.original_x = './data/unittest/unittest_original_x.csv'
23 | self.original_y = './data/unittest/unittest_original_y.csv'
24 | self.srg_x = self.CONFIG.x_fname
25 | self.srg_y = self.CONFIG.y_fname
26 |
27 | def test_file_existence(self):
28 | '''
29 | Test if output files are successfully generated
30 | '''
31 |
32 | for file in [self.original_x, self.original_y, self.srg_x, self.srg_y]:
33 | self.assertTrue(os.path.exists(file))
34 |
35 | def test_x(self):
36 | '''
37 | Test if output x files are the same
38 | '''
39 |
40 | original_x_tensor = torch.from_numpy(np.genfromtxt(self.original_x, delimiter=','))
41 | srg_x_tensor = torch.from_numpy(np.genfromtxt(self.srg_x, delimiter=','))
42 |
43 | same = torch.all(torch.eq(original_x_tensor, srg_x_tensor))
44 | self.assertTrue(same.item())
45 |
46 | def test_y(self):
47 | '''
48 | Test if output y files are the same
49 | '''
50 |
51 | original_y_tensor = torch.from_numpy(np.genfromtxt(self.original_y, delimiter=','))
52 | srg_y_tensor = torch.from_numpy(np.genfromtxt(self.srg_y, delimiter=','))
53 |
54 | same = torch.all(torch.eq(original_y_tensor, srg_y_tensor))
55 | self.assertTrue(same.item())
56 |
57 | @classmethod
58 | def tearDownClass(self):
59 | for file in [self.original_x, self.original_y, self.srg_x, self.srg_y]:
60 | os.remove(file)
61 |
62 | if __name__ == "__main__":
63 | unittest.main(verbosity=2)
--------------------------------------------------------------------------------