├── .gitignore ├── LICENSE ├── README.md ├── configs ├── example │ ├── example_cvae_trainer.yaml │ ├── example_extract_reconstruct.yaml │ ├── example_ff_trainer.yaml │ └── example_generator.yaml └── unittest │ └── unittest_behler.yaml ├── data ├── raw │ └── data.db ├── representation │ ├── data_x_fps_3.csv │ ├── data_y_fps_3.csv │ ├── x_test.csv │ └── y_test.csv ├── saved_models │ ├── model_FF_saved.pt │ ├── model_saved.pt │ └── scaler.gz └── unittest │ └── dummy.txt ├── examples ├── example_cvae_trainer.py ├── example_extract.py ├── example_generator.py ├── example_reconstruction.py └── example_surrogate_trainer.py ├── misc.py ├── outputs ├── generated_samples │ ├── gen_samps_cvae_1.9.csv │ ├── gen_samps_cvae_2.2.csv │ ├── gen_samps_cvae_2.5.csv │ ├── gen_samps_cvae_2.8.csv │ ├── gen_samps_cvae_3.1.csv │ ├── gen_samps_cvae_3.4.csv │ ├── gen_samps_cvae_3.7.csv │ ├── gen_samps_x_1.9.csv │ ├── gen_samps_x_2.2.csv │ ├── gen_samps_x_2.5.csv │ ├── gen_samps_x_2.8.csv │ ├── gen_samps_x_3.1.csv │ ├── gen_samps_x_3.4.csv │ └── gen_samps_x_3.7.csv └── reconstructed │ ├── 0_reconstructed.cif │ ├── 1_reconstructed.cif │ ├── 2_reconstructed.cif │ └── 3_reconstructed.cif ├── quickrun.sh ├── requirements.txt ├── setup.py ├── structrepgen ├── __init__.py ├── descriptors │ ├── __init__.py │ ├── behler.py │ └── generic.py ├── extraction │ ├── __init__.py │ └── representation_extraction.py ├── generators │ ├── __init__.py │ └── generator.py ├── models │ ├── CVAE.py │ ├── __init__.py │ └── models.py ├── reconstruction │ ├── __init__.py │ ├── generic.py │ └── reconstruction.py └── utils │ ├── __init__.py │ ├── dotdict.py │ └── utils.py └── tests ├── original ├── db_to_data.py └── original_descriptor.py └── test_behler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # mac 132 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Fung Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structure Representation Generation 2 | 3 |
4 | 5 | **Atomic Structure Generation from Reconstructing Structural Fingerprints** 6 | 7 | [[arXiv](https://arxiv.org/abs/2207.13227)] 8 | 9 |
10 | 11 | ## Requirements 12 | 13 | Git clone this repo and in the main directory do 14 | 15 | ```python 16 | pip install -r requirements.txt 17 | pip install -e . 18 | ``` 19 | 20 | The package has been tested on 21 | 22 | - Python 3.6.10 / 3.9.12 23 | - PyTorch 1.10.2 / 1.12.0 24 | - Torch.cuda 11.3 25 | 26 | ## How to Use 27 | 28 | Most configurations of the system/code are done through `.yaml` files under `configs/`. 29 | 30 | To load a configuration: 31 | 32 | ```python 33 | # load our config yaml file 34 | stream = open('./configs/example/example_extract_reconstruct.yaml') 35 | CONFIG = yaml.safe_load(stream) 36 | stream.close() 37 | # dotdict for dot operations on Python dict 38 | # e.g., CONFIG.cutoff == CONFIG['cutoff'] 39 | CONFIG = dotdict(CONFIG) 40 | ``` 41 | 42 | ### Representation Extraction 43 | 44 | Extract representations for training generative model using selected descriptor. 45 | 46 | See [`examples/example_extract.py`](https://github.com/Fung-Lab/StructRepGen/blob/main/examples/example_extract.py) 47 | 48 | ### Generative Model 49 | 50 | Train a CVAE (conditional variational auto-encoder) as the generative model. 51 | 52 | See [`examples/example_cvae_trainer.py`](https://github.com/Fung-Lab/StructRepGen/blob/main/examples/example_cvae_trainer.py). 53 | 54 | ### Generation 55 | 56 | Generate representation from a given target value using the decoder part of the CVAE. 57 | 58 | See [`examples/example_generator.py`](https://github.com/Fung-Lab/StructRepGen/blob/main/examples/example_generator.py). 59 | 60 | ### Reconstruction 61 | 62 | Generate atomic structures from generated representation. 63 | 64 | See [`examples/example_reconstruction.py`](https://github.com/Fung-Lab/StructRepGen/blob/main/examples/example_reconstruction.py). 65 | 66 | A script to run all of the above examples is provided in `quickrun.sh`. To execute, do 67 | 68 | ```bash 69 | chmod u+x quickrun.sh 70 | ./quickrun.sh 71 | ``` 72 | ______________________________________________________________________ 73 | 74 | ## Citation 75 | 76 | ```bash 77 | @article{fung2022atomic, 78 | title={Atomic structure generation from reconstructing structural fingerprints}, 79 | author={Fung, Victor and Jia, Shuyi and Zhang, Jiaxin and Bi, Sirui and Yin, Junqi and Ganesh, P}, 80 | journal={arXiv preprint arXiv:2207.13227}, 81 | year={2022} 82 | } 83 | ``` -------------------------------------------------------------------------------- /configs/example/example_cvae_trainer.yaml: -------------------------------------------------------------------------------- 1 | gpu: False 2 | params: 3 | seed: 42 4 | split_ratio: 0.2 5 | input_dim: 380 6 | hidden_dim: 128 7 | latent_dim: 10 8 | hidden_layers: 3 9 | y_dim: 1 10 | batch_size: 64 11 | n_epochs: 1000 12 | lr: 0.001 13 | final_decay: 0.2 14 | weight_decay: 0.001 15 | verbosity: 50 16 | kl_weight: 2.5 17 | mc_kl_loss: False 18 | data_x_path: "./data/representation/data_x_fps_3.csv" 19 | data_y_path: "./data/representation/data_y_fps_3.csv" 20 | model_path: "./data/saved_models/model_saved.pt" 21 | scaler_path: "./data/saved_models/scaler.gz" -------------------------------------------------------------------------------- /configs/example/example_extract_reconstruct.yaml: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # Parameters and configurations for both extraction and reconstruction # 3 | ######################################################################## 4 | 5 | # PyTorch device 6 | gpu: False 7 | 8 | # chemical system 9 | cell: [[20.,0.,0.],[0.,20.,0.],[0.,0.,20.]] 10 | atoms: [78, 78, 78, 78, 78, 78, 78, 78, 78, 78] 11 | 12 | # descriptor configuration 13 | descriptor: "behler" 14 | cutoff: 20.0 15 | mode: "features" 16 | g5: True 17 | average: True 18 | g2_params: 19 | eta: [0.01, 0.06, 0.1, 0.2, 0.4, 0.7, 1.0, 2.0, 3.5, 5.0] 20 | Rs: [0, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10] 21 | g5_params: 22 | lambdas: [-1, 1] 23 | zeta: [1, 2, 4, 16, 64] 24 | eta: [0.06, 0.1, 0.2, 0.4, 1.0] 25 | Rs: [0] 26 | 27 | # raw data file path 28 | data: "./data/raw/data.db" 29 | 30 | # extracted representation save file path 31 | x_fname: "./data/representation/x_test.csv" 32 | y_fname: "./data/representation/y_test.csv" 33 | 34 | # ff_net surrogate model path 35 | ff_model_path: "./data/saved_models/model_FF_saved.pt" 36 | 37 | # generated sample file path 38 | structure_file_path: "./outputs/generated_samples/gen_samps_x_2.2.csv" 39 | 40 | # reconstruction save file path 41 | reconstructed_file_path: "./outputs/reconstructed/" -------------------------------------------------------------------------------- /configs/example/example_ff_trainer.yaml: -------------------------------------------------------------------------------- 1 | gpu: False 2 | params: 3 | seed: 42 4 | split_ratio: 0.2 5 | input_dim: 380 6 | hidden_dim: 128 7 | latent_dim: 10 8 | hidden_layers: 3 9 | y_dim: 1 10 | batch_size: 64 11 | n_epochs: 1000 12 | lr: 0.001 13 | final_decay: 0.2 14 | weight_decay: 0.001 15 | verbosity: 50 16 | kl_weight: 2.5 17 | mc_kl_loss: False 18 | data_x_path: "./data/representation/data_x_fps_3.csv" 19 | data_y_path: "./data/representation/data_y_fps_3.csv" 20 | model_path: "./data/saved_models/model_saved.pt" 21 | scaler_path: "./data/saved_models/scaler.gz" 22 | ff_path: "./data/saved_models/model_FF_saved.pt" -------------------------------------------------------------------------------- /configs/example/example_generator.yaml: -------------------------------------------------------------------------------- 1 | gpu: False 2 | params: 3 | input_dim: 380 4 | latent_dim: 10 5 | num_z: 1000 6 | targets: [1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7] 7 | delta: 0.1 8 | model_path: "./data/saved_models/model_saved.pt" 9 | scaler_path: "./data/saved_models/scaler.gz" 10 | ff_model_path: "./data/saved_models/model_FF_saved.pt" 11 | save_path: "./outputs/generated_samples/" -------------------------------------------------------------------------------- /configs/unittest/unittest_behler.yaml: -------------------------------------------------------------------------------- 1 | descriptor: "behler" 2 | gpu: False 3 | cutoff: 20.0 4 | mode: "features" 5 | g5: True 6 | average: True 7 | g2_params: 8 | eta: [0.01, 0.06, 0.1, 0.2, 0.4, 0.7, 1.0, 2.0, 3.5, 5.0] 9 | Rs: [0, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10] 10 | g5_params: 11 | lambdas: [-1, 1] 12 | zeta: [1, 2, 4, 16, 64] 13 | eta: [0.06, 0.1, 0.2, 0.4, 1.0] 14 | Rs: [0] 15 | data: "./data/raw/data.db" 16 | x_fname: "./data/unittest/unittest_behler_x.csv" 17 | y_fname: "./data/unittest/unittest_behler_y.csv" -------------------------------------------------------------------------------- /data/raw/data.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/raw/data.db -------------------------------------------------------------------------------- /data/representation/data_y_fps_3.csv: -------------------------------------------------------------------------------- 1 | 3.629251785500000160e+00 2 | 2.193435603500000219e+00 3 | 2.389190833499999833e+00 4 | 2.590185373500000221e+00 5 | 3.469791547500000295e+00 6 | 3.639065514499999932e+00 7 | 3.031497834500000099e+00 8 | 3.278771844500000032e+00 9 | 3.400559700500000115e+00 10 | 2.569046109499999453e+00 11 | 2.572176612499999404e+00 12 | 1.977554506499999531e+00 13 | 2.726371039500000037e+00 14 | 3.378239648500000136e+00 15 | 1.969304244499999523e+00 16 | 2.482163833500000027e+00 17 | 3.469053469499999487e+00 18 | 2.161180478499999502e+00 19 | 2.427411639499999829e+00 20 | 3.511905419500000125e+00 21 | 2.346946486499999818e+00 22 | 2.863972432500000220e+00 23 | 2.195669181500000011e+00 24 | 3.263429192499999854e+00 25 | 3.634469329499999901e+00 26 | 2.869667623499999820e+00 27 | 3.505030466500000053e+00 28 | 3.201020360500000272e+00 29 | 3.223487879499999931e+00 30 | 2.754639036499999971e+00 31 | 3.446267526499999789e+00 32 | 3.330566205499999821e+00 33 | 3.447257610500000347e+00 34 | 2.853754735499999917e+00 35 | 2.458017109500000075e+00 36 | 3.407018959500000221e+00 37 | 2.034932231500000022e+00 38 | 2.938345457500000091e+00 39 | 3.432077122499999966e+00 40 | 3.260302140499999890e+00 41 | 2.664259684500000169e+00 42 | 2.845696269500000319e+00 43 | 3.068155024500000216e+00 44 | 3.001935866500000216e+00 45 | 1.939045304500000011e+00 46 | 2.059636720499999907e+00 47 | 2.773203739499999987e+00 48 | 2.077227622500000148e+00 49 | 2.370573837499999392e+00 50 | 2.916555953500000076e+00 51 | 3.327417127499999960e+00 52 | 2.841348202499999864e+00 53 | 3.011263771499999908e+00 54 | 2.900439509499999957e+00 55 | 3.488860160499999807e+00 56 | 1.978455072500000078e+00 57 | 2.740035258499999848e+00 58 | 2.553865472500000067e+00 59 | 3.658810885499999888e+00 60 | 2.449129288500000001e+00 61 | 3.219105184500000050e+00 62 | 2.891397128499999969e+00 63 | 3.598062906500000047e+00 64 | 2.725029893499999467e+00 65 | 1.908753786499999938e+00 66 | 2.170524436500000043e+00 67 | 3.667390165500000077e+00 68 | 3.153143658499999891e+00 69 | 3.381735946500000090e+00 70 | 1.981137564500000003e+00 71 | 2.292985748500000032e+00 72 | 2.857798630500000048e+00 73 | 3.370462267500000220e+00 74 | 3.482201046500000174e+00 75 | 3.188101787499999951e+00 76 | 3.583713149499999417e+00 77 | 3.568160926500000052e+00 78 | 3.579448063500000110e+00 79 | 2.934886240500000021e+00 80 | 3.577668529499999472e+00 81 | 3.521518180500000206e+00 82 | 2.892856812499999819e+00 83 | 3.379746976500000333e+00 84 | 3.548381890499999969e+00 85 | 2.467273743499999394e+00 86 | 3.540359935500000166e+00 87 | 2.768394223500000084e+00 88 | 3.678645528500000150e+00 89 | 2.943328566499999965e+00 90 | 3.560178632499999996e+00 91 | 3.078999631499999889e+00 92 | 3.315472485499999955e+00 93 | 2.514673767500000157e+00 94 | 2.053671849500000146e+00 95 | 3.397123429500000125e+00 96 | 1.927875391499999536e+00 97 | 3.254700554499999843e+00 98 | 2.855225439499999851e+00 99 | 3.289567590499999916e+00 100 | 3.080570113499999874e+00 101 | 3.149139033500000018e+00 102 | 3.095625262500000030e+00 103 | 3.606596311500000152e+00 104 | 3.309686274500000192e+00 105 | 3.532479607500000007e+00 106 | 3.335945054499999785e+00 107 | 2.680776439500000219e+00 108 | 3.413721247499999834e+00 109 | 3.230590477500000279e+00 110 | 3.081152116500000204e+00 111 | 2.475544047500000122e+00 112 | 3.687523089499999962e+00 113 | 3.425417975499999379e+00 114 | 3.668337117500000133e+00 115 | 2.644321825500000056e+00 116 | 2.321161881500000135e+00 117 | 2.670165391500000318e+00 118 | 2.509617504499999541e+00 119 | 3.429835049499999400e+00 120 | 2.686018230499999770e+00 121 | 2.274859405499999987e+00 122 | 2.875813564500000030e+00 123 | 2.345672777499999917e+00 124 | 3.149382091500000147e+00 125 | 3.563936150500000011e+00 126 | 3.606921896499999836e+00 127 | 2.239281146500000208e+00 128 | 2.885017470499999792e+00 129 | 2.843746938499999821e+00 130 | 2.151624935500000113e+00 131 | 3.695640772500000004e+00 132 | 2.604716774499999499e+00 133 | 3.533703536500000020e+00 134 | 3.618628567499999171e+00 135 | 2.269472805499999968e+00 136 | 3.094148476500000022e+00 137 | 3.344097396500000041e+00 138 | 3.370727902500000095e+00 139 | 3.695081813499999868e+00 140 | 2.247777477500000121e+00 141 | 2.612466452500000091e+00 142 | 2.596694727500000077e+00 143 | 2.866961518500000139e+00 144 | 3.140304035499999813e+00 145 | 2.724920338500000039e+00 146 | 3.096131675500000124e+00 147 | 2.899930516500000000e+00 148 | 3.690274181499999973e+00 149 | 3.036688120500000032e+00 150 | 3.590003534500000093e+00 151 | 1.916972713499999648e+00 152 | 3.694017793500000035e+00 153 | 3.278344651499999429e+00 154 | 3.311863587499999539e+00 155 | 3.440875531499999695e+00 156 | 2.298078631500000135e+00 157 | 2.154455966500000041e+00 158 | 3.668343183499999327e+00 159 | 3.189560615500000029e+00 160 | 2.954388168499999967e+00 161 | 3.158917279499999786e+00 162 | 3.049848474500000073e+00 163 | 3.121720442500000026e+00 164 | 3.559216489499999803e+00 165 | 2.394893494500000219e+00 166 | 2.870004901500000205e+00 167 | 3.222248899499999819e+00 168 | 3.021361172499999803e+00 169 | 2.997914102499999789e+00 170 | 2.152280122500000115e+00 171 | 3.691511492500000102e+00 172 | 3.534808753500000122e+00 173 | 2.441883251500000185e+00 174 | 3.158438142500000101e+00 175 | 2.714237651499999515e+00 176 | 2.361495079500000038e+00 177 | 3.492186637499999247e+00 178 | 2.444909780500000185e+00 179 | 2.245195300499999824e+00 180 | 3.564829119500000143e+00 181 | 3.422282724500000040e+00 182 | 2.413833691499999823e+00 183 | 3.553694700500000359e+00 184 | 3.133653180500000079e+00 185 | 2.688062402499999504e+00 186 | 3.029967856499999890e+00 187 | 3.699558817499999819e+00 188 | 3.645472044499999953e+00 189 | 2.924037525500000179e+00 190 | 2.707218354500000146e+00 191 | 3.467364557499999389e+00 192 | 3.181638873499999853e+00 193 | 1.991422050499999541e+00 194 | 3.667853380499999982e+00 195 | 3.095820271500000054e+00 196 | 3.568928516500000203e+00 197 | 3.188936304499999874e+00 198 | 3.092066244499999783e+00 199 | 3.420078662499999922e+00 200 | 2.782251731500000158e+00 201 | 3.340617010499999928e+00 202 | 3.609050915500000123e+00 203 | 3.643656063499999931e+00 204 | 3.336802336499999910e+00 205 | 3.369463562500000009e+00 206 | 3.638709875500000024e+00 207 | 2.737987797500000209e+00 208 | 3.662035683500000083e+00 209 | 2.442745851499999787e+00 210 | 2.794192333499999847e+00 211 | 2.130696768500000005e+00 212 | 3.035902088499999874e+00 213 | 2.206589475499999509e+00 214 | 3.474328342500000222e+00 215 | 2.627364993500000079e+00 216 | 3.423192286499999959e+00 217 | 3.219988001499999974e+00 218 | 3.267569387499999145e+00 219 | 2.679220242500000015e+00 220 | 3.675037752500000199e+00 221 | 3.459667313499999786e+00 222 | 2.230246310500000106e+00 223 | 2.850722157499999465e+00 224 | 2.407494965500000195e+00 225 | 2.920090459499999902e+00 226 | 2.943086850499999851e+00 227 | 3.323967807499999871e+00 228 | 3.363145897499999926e+00 229 | 2.684926472499999939e+00 230 | 2.657415704499999975e+00 231 | 3.622536860499999900e+00 232 | 3.540368124500000047e+00 233 | 3.589065706500000008e+00 234 | 2.596920302500000055e+00 235 | 2.730771360499999467e+00 236 | 2.734785122499999943e+00 237 | 3.377389813499999782e+00 238 | 3.045484734500000013e+00 239 | 2.671878994500000104e+00 240 | 2.050986682500000047e+00 241 | 2.817160023500000054e+00 242 | 2.723591580499999942e+00 243 | 3.557056275499999920e+00 244 | 2.685716286499999939e+00 245 | 3.568842027500000125e+00 246 | 2.932958204500000221e+00 247 | 3.069034059499999856e+00 248 | 3.514755606500000518e+00 249 | 3.262418033499999925e+00 250 | 3.652738589499999300e+00 251 | 3.119569415499999998e+00 252 | 3.452809048499999811e+00 253 | 2.740936844500000191e+00 254 | 2.451905307500000131e+00 255 | 3.313156561500000041e+00 256 | 3.511915449500000008e+00 257 | 2.700493722500000082e+00 258 | 3.348573206499999788e+00 259 | 3.468992130500000215e+00 260 | 3.476668341500000370e+00 261 | 3.567517962500000195e+00 262 | 3.528156868499999543e+00 263 | 3.183894331499999897e+00 264 | 3.392475934500000179e+00 265 | 3.289337511499999867e+00 266 | 2.498368390500000036e+00 267 | 2.767821931499999888e+00 268 | 3.603900129499999938e+00 269 | 3.590931797499999245e+00 270 | 3.045904796499999900e+00 271 | 3.689263541499999910e+00 272 | 3.146069863500000174e+00 273 | 3.100778360500000108e+00 274 | 3.672868856500000057e+00 275 | 3.624809072500000173e+00 276 | 3.689044692500000000e+00 277 | 2.818979973500000291e+00 278 | 3.092128980500000068e+00 279 | 3.274181661500000118e+00 280 | 2.443876192500000322e+00 281 | 3.035938488499999810e+00 282 | 2.764125669499999383e+00 283 | 3.678849937499999889e+00 284 | 2.334595027499999809e+00 285 | 3.328708817499999473e+00 286 | 2.041349412499999794e+00 287 | 2.593014625499999504e+00 288 | 3.466235386500000182e+00 289 | 3.377096930500000038e+00 290 | 3.163263784499999787e+00 291 | 3.148148751499999953e+00 292 | 2.798697951500000336e+00 293 | 2.881607269499999902e+00 294 | 2.075533408500000121e+00 295 | 3.648939681499999921e+00 296 | 3.516611950500000194e+00 297 | 3.361316212499999789e+00 298 | 3.014654499500000195e+00 299 | 3.411344701500000021e+00 300 | 2.914182840500000093e+00 301 | -------------------------------------------------------------------------------- /data/representation/x_test.csv: -------------------------------------------------------------------------------- 1 | 7.6347027,8.037969,8.186526,8.296676,8.366806,8.395889,8.383496,8.329815,8.102355,7.7313023,7.3277287,6.86488,6.306292,5.726998,4.776494,6.401983,7.101475,7.6568522,8.025179,8.177042,8.100437,7.809145,7.3074493,6.369568,5.276666,4.031698,2.7925496,1.7492355,3.390356,5.433555,6.4142046,7.230476,7.786035,8.012206,7.8819127,7.482154,6.9597116,5.6499586,4.4195595,2.9891787,1.7316858,0.8515215,1.5577528,3.7955854,5.158061,6.4083543,7.2968764,7.635884,7.3648577,6.7645187,6.361255,4.677028,3.2875094,1.7794862,0.7073132,0.19914488,0.38233596,2.0855284,3.6734178,5.387841,6.648571,7.0102725,6.4451065,5.6690145,5.7577953,3.6979256,2.4004283,0.92405725,0.17827508,0.01618633,0.052093655,0.9478716,2.4330766,4.526375,6.13364,6.301463,5.3087854,4.5954037,5.3286743,2.934912,1.9198909,0.44170582,0.027728438,0.00045367455,0.0073398,0.44834715,1.6864524,3.9551892,5.8717976,5.7726793,4.440757,3.9239402,5.032765,2.4922295,1.6833789,0.22568327,0.0046401196,1.39647345e-05,1.366814e-05,0.0401704,0.53390944,2.7223392,5.5251403,4.636115,2.8632991,2.8454342,4.2715855,1.7808043,1.2348522,0.028469397,1.5511014e-05,1.7215465e-10,1.7988977e-09,0.0014022979,0.10472332,1.6557811,5.301574,3.6102402,1.9564735,2.1535573,3.6110535,1.2962196,0.85115135,0.0016333024,3.967595e-09,9.399013e-18,2.55697e-13,6.2094536e-05,0.024239682,1.0463419,5.1334095,2.9081228,1.6165028,1.7476071,3.4143996,1.0256757,0.6231437,0.00010555335,1.0957327e-12,5.3302025e-25,9.048065,4.6781635,1.0532091,0.06843075,2.6120613e-05,11.02688,5.3850875,1.065321,0.05888957,1.845249e-05,5.834724,3.0927684,0.7209282,0.048420068,1.9752504e-05,7.423365,3.5952344,0.6997687,0.038108453,1.2066323e-05,3.4506154,1.8724595,0.4534073,0.031738166,1.4180921e-05,4.2822556,1.9768807,0.370257,0.019514933,6.053242e-06,1.1266446,0.65117013,0.17467695,0.01370768,7.650941e-06,1.4605906,0.43000588,0.028385328,0.0008188701,1.8175793e-07,0.5204019,0.31753644,0.09240823,0.007836901,4.8e-06,0.48970345,0.14299357,0.00666044,1.904607e-05,2.0390837e-12,6.160422,6.6661687,6.884791,7.0766244,7.239064,7.369865,7.467193,7.529668,7.5469923,7.4208264,7.158287,6.622072,5.9512777,5.2437177,2.516549,3.700965,4.3307323,4.9500437,5.5276737,6.0315166,6.4313974,6.7019396,6.6555557,5.0780435,3.4668381,2.1184268,1.1587516,0.5673503,1.5057518,2.6438906,3.3143363,4.0107822,4.690665,5.3066983,5.8117385,6.1640244,5.7459326,3.73569,2.0396943,0.93513787,0.3595433,0.11567887,0.60331887,1.5165292,2.144766,2.8380828,3.5416698,4.1985183,4.7550955,5.1627855,4.151659,1.9443294,0.6678895,0.16489673,0.028637297,0.0034390688,0.15375443,0.7934002,1.3708824,2.0203269,2.6254656,3.144864,3.607701,4.0079374,2.582511,0.77285445,0.12180669,0.009218421,0.00032254448,5.1418156e-06,0.022954568,0.38450772,0.9380872,1.6432308,2.1547751,2.3755608,2.654021,3.1652706,1.7264359,0.2890365,0.013413128,0.00015911311,4.721615e-07,3.477529e-10,0.003478007,0.19299978,0.6826151,1.4758364,1.9996482,1.9062647,2.0228968,2.6757653,1.3618618,0.12095639,0.0015862945,2.8853976e-06,7.16319e-10,2.4157986e-14,6.1845753e-06,0.01988394,0.2447432,1.1207368,1.9127657,1.1005502,0.7205751,1.2920241,0.7925133,0.007606483,1.4006755e-06,4.7752136e-12,3.0039723e-19,3.4845793e-28,3.6193123e-10,0.00061199337,0.053782914,0.7538048,1.9004161,0.6558406,0.17283502,0.5879705,0.40321013,0.00012828893,3.83648e-11,1.06804914e-20,2.7603e-33,0.0,2.1907079e-14,1.669212e-05,0.011074061,0.5105281,1.8898233,0.41115925,0.043087196,0.3161983,0.21218473,2.19662e-06,1.0794713e-15,2.49584e-29,0.0,0.0,0.7785383,0.28510672,0.044930883,0.0028594765,1.3748747e-06,4.4133234,1.4070488,0.17132547,0.009021766,4.237469e-06,0.18193093,0.06960676,0.011042307,0.0006995932,3.3681093e-07,3.8167162,1.1915488,0.13743687,0.0060121035,3.199405e-06,0.019804496,0.007352831,0.00088740233,4.3055396e-05,2.0213296e-08,2.9670405,0.89224577,0.09169072,0.0024531926,1.1060774e-06,2.2397329e-07,3.9957833e-08,1.0153539e-09,2.955518e-11,9.588455e-16,0.61931694,0.2371586,0.018524287,0.00013715237,2.043432e-09,1.7251224e-24,2.8203624e-25,3.0480218e-27,3.5601168e-31,1.64047e-40,0.012272776,0.004336454,0.00032502704,1.4801496e-06,2.8672374e-13 2 | -------------------------------------------------------------------------------- /data/representation/y_test.csv: -------------------------------------------------------------------------------- 1 | 1.9583510175 2 | -------------------------------------------------------------------------------- /data/saved_models/model_FF_saved.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/saved_models/model_FF_saved.pt -------------------------------------------------------------------------------- /data/saved_models/model_saved.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/saved_models/model_saved.pt -------------------------------------------------------------------------------- /data/saved_models/scaler.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/saved_models/scaler.gz -------------------------------------------------------------------------------- /data/unittest/dummy.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/data/unittest/dummy.txt -------------------------------------------------------------------------------- /examples/example_cvae_trainer.py: -------------------------------------------------------------------------------- 1 | import torch, yaml 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader, TensorDataset 6 | from sklearn.preprocessing import MinMaxScaler 7 | import joblib, time, copy 8 | import pandas as pd 9 | import numpy as np 10 | from sklearn.model_selection import train_test_split 11 | 12 | from structrepgen.models.CVAE import * 13 | from structrepgen.utils.dotdict import dotdict 14 | from structrepgen.utils.utils import torch_device_select 15 | 16 | ''' 17 | Example of training a CVAE based on the representation R extracted using Behler descriptors 18 | ''' 19 | 20 | class Trainer(): 21 | def __init__(self, CONFIG) -> None: 22 | self.CONFIG = CONFIG 23 | 24 | # check GPU availability & set device 25 | self.device = torch_device_select(self.CONFIG.gpu) 26 | 27 | # initialize 28 | self.create_data() 29 | self.initialize() 30 | 31 | def create_data(self): 32 | p = self.CONFIG.params 33 | 34 | data_x = pd.read_csv(self.CONFIG.data_x_path, header=None).values 35 | data_y = pd.read_csv(self.CONFIG.data_y_path, header=None).values 36 | 37 | # scale 38 | scaler = MinMaxScaler() 39 | data_x = scaler.fit_transform(data_x) 40 | joblib.dump(scaler, self.CONFIG.scaler_path) 41 | 42 | # train/test split and create torch dataloader 43 | xtrain, xtest, ytrain, ytest = train_test_split(data_x, data_y, test_size=self.CONFIG.split_ratio, random_state=p.seed) 44 | self.x_train = torch.tensor(xtrain, dtype=torch.float) 45 | self.y_train = torch.tensor(ytrain, dtype=torch.float) 46 | self.x_test = torch.tensor(xtest, dtype=torch.float) 47 | self.y_test = torch.tensor(ytest, dtype=torch.float) 48 | 49 | self.train_loader = DataLoader( 50 | TensorDataset(self.x_train, self.y_train), 51 | batch_size=p.batch_size, shuffle=True, drop_last=False 52 | ) 53 | 54 | self.test_loader = DataLoader( 55 | TensorDataset(self.x_test, self.y_test), 56 | batch_size=p.batch_size, shuffle=False, drop_last=False 57 | ) 58 | 59 | def initialize(self): 60 | p = self.CONFIG.params 61 | 62 | # create model 63 | self.model = CVAE(p.input_dim, p.hidden_dim, p.latent_dim, p.hidden_layers, p.y_dim) 64 | self.model.to(self.device) 65 | print(self.model) 66 | 67 | # set up optimizer 68 | gamma = (p.final_decay)**(1./p.n_epochs) 69 | self.optimizer = optim.Adam(self.model.parameters(), lr=p.lr, weight_decay=p.weight_decay) 70 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=gamma) 71 | 72 | def train(self): 73 | p = self.CONFIG.params 74 | self.model.train() 75 | 76 | # loss of the peoch 77 | rcl_loss = 0. 78 | kld_loss = 0. 79 | 80 | for i, (x, y) in enumerate(self.train_loader): 81 | x = x.to(self.device) 82 | y = y.to(self.device) 83 | 84 | self.optimizer.zero_grad() 85 | 86 | # forward 87 | reconstructed_x, z_mu, z_var = self.model(x, y) 88 | 89 | rcl, kld = calculate_loss(x, reconstructed_x, z_mu, z_var, p.kl_weight, p.mc_kl_loss) 90 | 91 | # backward 92 | combined_loss = rcl + kld 93 | combined_loss.backward() 94 | rcl_loss += rcl.item() 95 | kld_loss += kld.item() 96 | 97 | # update the weights 98 | self.optimizer.step() 99 | 100 | return rcl_loss, kld_loss 101 | 102 | def test(self): 103 | p = self.CONFIG.params 104 | 105 | self.model.eval() 106 | 107 | # loss of the evaluation 108 | rcl_loss = 0. 109 | kld_loss = 0. 110 | 111 | with torch.no_grad(): 112 | for i, (x, y) in enumerate(self.test_loader): 113 | x = x.to(self.device) 114 | y = y.to(self.device) 115 | 116 | # forward pass 117 | reconstructed_x, z_mu, z_var = self.model(x, y) 118 | 119 | # loss 120 | rcl, kld = calculate_loss(x, reconstructed_x, z_mu, z_var, p.kl_weight, p.mc_kl_loss) 121 | rcl_loss += rcl.item() 122 | kld_loss += kld.item() 123 | 124 | return rcl_loss, kld_loss 125 | 126 | def run(self): 127 | p = self.CONFIG.params 128 | best_test_loss = float('inf') 129 | best_train_loss = float('inf') 130 | best_epoch = 0 131 | 132 | for e in range(p.n_epochs): 133 | tic = time.time() 134 | 135 | rcl_train_loss, kld_train_loss = self.train() 136 | rcl_test_loss, kld_test_loss = self.test() 137 | 138 | rcl_train_loss /= len(self.x_train) 139 | kld_train_loss /= len(self.x_train) 140 | train_loss = rcl_train_loss + kld_train_loss 141 | rcl_test_loss /= len(self.x_test) 142 | kld_test_loss /= len(self.x_test) 143 | test_loss = rcl_test_loss + kld_test_loss 144 | 145 | self.scheduler.step() 146 | lr = self.scheduler.optimizer.param_groups[0]["lr"] 147 | 148 | if best_test_loss > test_loss: 149 | best_epoch = e 150 | best_test_loss = test_loss 151 | best_train_loss = train_loss 152 | model_best = copy.deepcopy(self.model) 153 | 154 | elapsed_time = time.time() - tic 155 | epoch_out = f'Epoch {e:04d}, Train RCL: {rcl_train_loss:.3f}, Train KLD: {kld_train_loss:.3f}, Train: {train_loss:.3f}, Test RLC: {rcl_test_loss:.3f}, Test KLD: {kld_test_loss:.3f}, Test: {test_loss:.3f}, LR: {lr:.5f}, Time/Epoch (s): {elapsed_time:.3f}' 156 | if e % p.verbosity == 0: 157 | print(epoch_out) 158 | 159 | torch.save(model_best, self.CONFIG.model_path) 160 | return best_epoch, best_train_loss, best_test_loss 161 | 162 | if __name__ == "__main__": 163 | # load parameters from yaml file 164 | stream = open('./configs/example/example_cvae_trainer.yaml') 165 | CONFIG = yaml.safe_load(stream) 166 | stream.close() 167 | CONFIG = dotdict(CONFIG) 168 | 169 | trainer = Trainer(CONFIG) 170 | trainer.run() -------------------------------------------------------------------------------- /examples/example_extract.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Use Behler ACSF descriptor to extract features. 3 | 4 | To customize parameters, make a yaml file in configs/ 5 | In this example, we will use configs/example_extract_reconstruct.yaml 6 | 7 | Output csv files are saved under data/representation/ 8 | ''' 9 | 10 | import yaml 11 | from structrepgen.extraction.representation_extraction import * 12 | from structrepgen.utils.dotdict import dotdict 13 | 14 | # load our config yaml file 15 | stream = open('./configs/example/example_extract_reconstruct.yaml') 16 | CONFIG = yaml.safe_load(stream) 17 | stream.close() 18 | # dotdict for dot operations on Python dict e.g., CONFIG.cutoff 19 | CONFIG = dotdict(CONFIG) 20 | 21 | # create extractor instance 22 | extractor = RepresentationExtraction(CONFIG) 23 | # extract 24 | extractor.extract() -------------------------------------------------------------------------------- /examples/example_generator.py: -------------------------------------------------------------------------------- 1 | import yaml, os, unittest, torch 2 | import numpy as np 3 | from structrepgen.generators.generator import Generator 4 | from structrepgen.utils.dotdict import dotdict 5 | 6 | stream = open('./configs/example/example_generator.yaml') 7 | CONFIG = yaml.safe_load(stream) 8 | stream.close() 9 | CONFIG = dotdict(CONFIG) 10 | 11 | gen = Generator(CONFIG) 12 | 13 | gen.generate() 14 | gen.range_check() -------------------------------------------------------------------------------- /examples/example_reconstruction.py: -------------------------------------------------------------------------------- 1 | import yaml, os, unittest, torch 2 | import numpy as np 3 | from structrepgen.extraction.representation_extraction import * 4 | from structrepgen.utils.dotdict import dotdict 5 | from structrepgen.reconstruction.reconstruction import Reconstruction 6 | 7 | stream = open('./configs/example/example_extract_reconstruct.yaml') 8 | CONFIG = yaml.safe_load(stream) 9 | stream.close() 10 | CONFIG = dotdict(CONFIG) 11 | 12 | constructor = Reconstruction(CONFIG) 13 | 14 | constructor.main() -------------------------------------------------------------------------------- /examples/example_surrogate_trainer.py: -------------------------------------------------------------------------------- 1 | import torch, yaml 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader, TensorDataset 6 | from sklearn.preprocessing import MinMaxScaler 7 | import joblib, time, copy 8 | import pandas as pd 9 | import numpy as np 10 | from sklearn.model_selection import train_test_split 11 | 12 | from structrepgen.models.models import ff_net 13 | from structrepgen.utils.dotdict import dotdict 14 | from structrepgen.utils.utils import torch_device_select 15 | 16 | ''' 17 | Example of training a surrogate model based on the representation R extracted using Behler descriptors 18 | The model is trained on MSE loss for y 19 | 20 | Model used: 21 | fully connected feed-forward neural network 22 | ''' 23 | 24 | class Trainer(): 25 | def __init__(self, CONFIG) -> None: 26 | self.CONFIG = CONFIG 27 | 28 | # check GPU availability & set device 29 | self.device = torch_device_select(self.CONFIG.gpu) 30 | 31 | # initialize 32 | self.create_data() 33 | self.initialize() 34 | 35 | def create_data(self): 36 | p = self.CONFIG.params 37 | 38 | data_x = pd.read_csv(self.CONFIG.data_x_path, header=None).values 39 | data_y = pd.read_csv(self.CONFIG.data_y_path, header=None).values 40 | 41 | # train/test split and create torch dataloader 42 | xtrain, xtest, ytrain, ytest = train_test_split(data_x, data_y, test_size=self.CONFIG.split_ratio, random_state=p.seed) 43 | self.x_train = torch.tensor(xtrain, dtype=torch.float) 44 | self.y_train = torch.tensor(ytrain, dtype=torch.float) 45 | self.x_test = torch.tensor(xtest, dtype=torch.float) 46 | self.y_test = torch.tensor(ytest, dtype=torch.float) 47 | 48 | self.train_loader = DataLoader( 49 | TensorDataset(self.x_train, self.y_train), 50 | batch_size=p.batch_size, shuffle=True, drop_last=False 51 | ) 52 | 53 | self.test_loader = DataLoader( 54 | TensorDataset(self.x_test, self.y_test), 55 | batch_size=p.batch_size, shuffle=False, drop_last=False 56 | ) 57 | 58 | def initialize(self): 59 | p = self.CONFIG.params 60 | 61 | # create model 62 | self.model = ff_net(p.input_dim, p.hidden_dim, p.hidden_layers) 63 | self.model.to(self.device) 64 | print(self.model) 65 | 66 | # set up optimizer 67 | gamma = (p.final_decay)**(1./p.n_epochs) 68 | self.optimizer = optim.Adam(self.model.parameters(), lr=p.lr, weight_decay=p.weight_decay) 69 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=gamma) 70 | 71 | def train(self): 72 | p = self.CONFIG.params 73 | self.model.train() 74 | 75 | # loss of the epoch 76 | total_loss = 0. 77 | 78 | for i, (x, y) in enumerate(self.train_loader): 79 | x = x.to(self.device) 80 | y = y.to(self.device) 81 | 82 | self.optimizer.zero_grad() 83 | 84 | # forward 85 | y_pred = self.model(x) 86 | 87 | loss = F.mse_loss(y_pred, y, size_average=False) 88 | 89 | # backward 90 | loss.backward() 91 | total_loss += loss.item() 92 | 93 | # update the weights 94 | self.optimizer.step() 95 | 96 | return total_loss 97 | 98 | def test(self): 99 | p = self.CONFIG.params 100 | 101 | self.model.eval() 102 | 103 | # loss of the evaluation 104 | total_loss = 0. 105 | 106 | with torch.no_grad(): 107 | for i, (x, y) in enumerate(self.test_loader): 108 | x = x.to(self.device) 109 | y = y.to(self.device) 110 | 111 | # forward pass 112 | y_pred = self.model(x) 113 | 114 | # loss 115 | loss = F.mse_loss(y_pred, y, size_average=False) 116 | total_loss += loss.item() 117 | 118 | return total_loss 119 | 120 | def run(self): 121 | p = self.CONFIG.params 122 | best_test_loss = float('inf') 123 | best_train_loss = float('inf') 124 | best_epoch = 0 125 | 126 | for e in range(p.n_epochs): 127 | tic = time.time() 128 | 129 | train_loss = self.train() 130 | test_loss = self.test() 131 | 132 | train_loss /= len(self.x_train) 133 | test_loss /= len(self.x_test) 134 | 135 | self.scheduler.step() 136 | lr = self.scheduler.optimizer.param_groups[0]["lr"] 137 | 138 | if best_test_loss > test_loss: 139 | best_epoch = e 140 | best_test_loss = test_loss 141 | best_train_loss = train_loss 142 | model_best = copy.deepcopy(self.model) 143 | 144 | elapsed_time = time.time() - tic 145 | epoch_out = f'Epoch {e:04d}, Train: {train_loss:.4f}, Test: {test_loss:.4f}, LR: {lr:.5f}, Time/Epoch (s): {elapsed_time:.3f}' 146 | if e % p.verbosity == 0: 147 | print(epoch_out) 148 | 149 | torch.save(model_best, self.CONFIG.ff_path) 150 | return best_epoch, best_train_loss, best_test_loss 151 | 152 | if __name__ == "__main__": 153 | # load parameters from yaml file 154 | stream = open('./configs/example/example_ff_trainer.yaml') 155 | CONFIG = yaml.safe_load(stream) 156 | stream.close() 157 | CONFIG = dotdict(CONFIG) 158 | 159 | trainer = Trainer(CONFIG) 160 | trainer.run() -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | # this is a placeholder -------------------------------------------------------------------------------- /outputs/reconstructed/0_reconstructed.cif: -------------------------------------------------------------------------------- 1 | data_image0 2 | _chemical_formula_structural Pt10 3 | _chemical_formula_sum "Pt10" 4 | _cell_length_a 20 5 | _cell_length_b 20 6 | _cell_length_c 20 7 | _cell_angle_alpha 90 8 | _cell_angle_beta 90 9 | _cell_angle_gamma 90 10 | 11 | _space_group_name_H-M_alt "P 1" 12 | _space_group_IT_number 1 13 | 14 | loop_ 15 | _space_group_symop_operation_xyz 16 | 'x, y, z' 17 | 18 | loop_ 19 | _atom_site_type_symbol 20 | _atom_site_label 21 | _atom_site_symmetry_multiplicity 22 | _atom_site_fract_x 23 | _atom_site_fract_y 24 | _atom_site_fract_z 25 | _atom_site_occupancy 26 | Pt Pt1 1.0 0.59185 0.23819 0.75476 1.0000 27 | Pt Pt2 1.0 0.58611 0.54290 0.61489 1.0000 28 | Pt Pt3 1.0 0.69703 0.27786 0.68786 1.0000 29 | Pt Pt4 1.0 0.50226 0.36483 0.79940 1.0000 30 | Pt Pt5 1.0 0.66070 0.55487 0.53070 1.0000 31 | Pt Pt6 1.0 0.55922 0.23397 0.62426 1.0000 32 | Pt Pt7 1.0 0.63122 0.31871 0.58849 1.0000 33 | Pt Pt8 1.0 0.55219 0.24226 0.85988 1.0000 34 | Pt Pt9 1.0 0.55663 0.33984 0.69276 1.0000 35 | Pt Pt10 1.0 0.57114 0.42699 0.59601 1.0000 36 | -------------------------------------------------------------------------------- /outputs/reconstructed/1_reconstructed.cif: -------------------------------------------------------------------------------- 1 | data_image0 2 | _chemical_formula_structural Pt10 3 | _chemical_formula_sum "Pt10" 4 | _cell_length_a 20 5 | _cell_length_b 20 6 | _cell_length_c 20 7 | _cell_angle_alpha 90 8 | _cell_angle_beta 90 9 | _cell_angle_gamma 90 10 | 11 | _space_group_name_H-M_alt "P 1" 12 | _space_group_IT_number 1 13 | 14 | loop_ 15 | _space_group_symop_operation_xyz 16 | 'x, y, z' 17 | 18 | loop_ 19 | _atom_site_type_symbol 20 | _atom_site_label 21 | _atom_site_symmetry_multiplicity 22 | _atom_site_fract_x 23 | _atom_site_fract_y 24 | _atom_site_fract_z 25 | _atom_site_occupancy 26 | Pt Pt1 1.0 0.54359 0.34959 0.50534 1.0000 27 | Pt Pt2 1.0 0.44811 0.48147 0.44699 1.0000 28 | Pt Pt3 1.0 0.46119 0.44006 0.55937 1.0000 29 | Pt Pt4 1.0 0.35195 0.31757 0.37187 1.0000 30 | Pt Pt5 1.0 0.41042 0.60342 0.48576 1.0000 31 | Pt Pt6 1.0 0.59037 0.44773 0.53806 1.0000 32 | Pt Pt7 1.0 0.42613 0.41147 0.35061 1.0000 33 | Pt Pt8 1.0 0.39104 0.37135 0.47558 1.0000 34 | Pt Pt9 1.0 0.38792 0.24888 0.48067 1.0000 35 | Pt Pt10 1.0 0.35573 0.52945 0.55215 1.0000 36 | -------------------------------------------------------------------------------- /outputs/reconstructed/2_reconstructed.cif: -------------------------------------------------------------------------------- 1 | data_image0 2 | _chemical_formula_structural Pt10 3 | _chemical_formula_sum "Pt10" 4 | _cell_length_a 20 5 | _cell_length_b 20 6 | _cell_length_c 20 7 | _cell_angle_alpha 90 8 | _cell_angle_beta 90 9 | _cell_angle_gamma 90 10 | 11 | _space_group_name_H-M_alt "P 1" 12 | _space_group_IT_number 1 13 | 14 | loop_ 15 | _space_group_symop_operation_xyz 16 | 'x, y, z' 17 | 18 | loop_ 19 | _atom_site_type_symbol 20 | _atom_site_label 21 | _atom_site_symmetry_multiplicity 22 | _atom_site_fract_x 23 | _atom_site_fract_y 24 | _atom_site_fract_z 25 | _atom_site_occupancy 26 | Pt Pt1 1.0 0.63372 0.44299 0.29739 1.0000 27 | Pt Pt2 1.0 0.58648 0.41319 0.55833 1.0000 28 | Pt Pt3 1.0 0.63758 0.43002 0.45702 1.0000 29 | Pt Pt4 1.0 0.57271 0.54867 0.28353 1.0000 30 | Pt Pt5 1.0 0.56319 0.71164 0.17605 1.0000 31 | Pt Pt6 1.0 0.52529 0.60120 0.18550 1.0000 32 | Pt Pt7 1.0 0.60091 0.45877 0.19391 1.0000 33 | Pt Pt8 1.0 0.44374 0.56779 0.28553 1.0000 34 | Pt Pt9 1.0 0.51231 0.65642 0.29554 1.0000 35 | Pt Pt10 1.0 0.48298 0.37207 0.58142 1.0000 36 | -------------------------------------------------------------------------------- /outputs/reconstructed/3_reconstructed.cif: -------------------------------------------------------------------------------- 1 | data_image0 2 | _chemical_formula_structural Pt10 3 | _chemical_formula_sum "Pt10" 4 | _cell_length_a 20 5 | _cell_length_b 20 6 | _cell_length_c 20 7 | _cell_angle_alpha 90 8 | _cell_angle_beta 90 9 | _cell_angle_gamma 90 10 | 11 | _space_group_name_H-M_alt "P 1" 12 | _space_group_IT_number 1 13 | 14 | loop_ 15 | _space_group_symop_operation_xyz 16 | 'x, y, z' 17 | 18 | loop_ 19 | _atom_site_type_symbol 20 | _atom_site_label 21 | _atom_site_symmetry_multiplicity 22 | _atom_site_fract_x 23 | _atom_site_fract_y 24 | _atom_site_fract_z 25 | _atom_site_occupancy 26 | Pt Pt1 1.0 0.78766 0.62397 0.60775 1.0000 27 | Pt Pt2 1.0 0.67944 0.56004 0.64893 1.0000 28 | Pt Pt3 1.0 0.51837 0.54830 0.84224 1.0000 29 | Pt Pt4 1.0 0.53434 0.62492 0.62552 1.0000 30 | Pt Pt5 1.0 0.59688 0.46453 0.81786 1.0000 31 | Pt Pt6 1.0 0.70441 0.67983 0.68149 1.0000 32 | Pt Pt7 1.0 0.56416 0.56277 0.72494 1.0000 33 | Pt Pt8 1.0 0.55373 0.68267 0.73955 1.0000 34 | Pt Pt9 1.0 0.75654 0.57270 0.74149 1.0000 35 | Pt Pt10 1.0 0.60483 0.56585 0.54790 1.0000 36 | -------------------------------------------------------------------------------- /quickrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "=== BEHLER (TEST) ===" 4 | python tests/test_behler.py 5 | 6 | echo "=== EXTRACTION ===" 7 | python examples/example_extract.py 8 | 9 | echo "=== CVAE TRAINER ===" 10 | python examples/example_cvae_trainer.py 11 | 12 | echo "=== SURROGATE TRAINER ===" 13 | python examples/example_surrogate_trainer.py 14 | 15 | echo "=== GENERATION ===" 16 | python examples/example_generator.py 17 | 18 | echo "=== RECONSTRUCTION ===" 19 | python examples/example_reconstruction.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ase==3.22.1 2 | scikit-optimize==0.9.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='structrepgen', 5 | version='0.0.1', 6 | description="Atomic Structure Generation from Reconstructing Structural Fingerprints", 7 | url="https://github.com/Fung-Lab/StructRepGen", 8 | packages=find_packages(), 9 | include_package_data=True, 10 | ) -------------------------------------------------------------------------------- /structrepgen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/__init__.py -------------------------------------------------------------------------------- /structrepgen/descriptors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/descriptors/__init__.py -------------------------------------------------------------------------------- /structrepgen/descriptors/behler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools, torch, warnings 3 | from itertools import combinations, combinations_with_replacement 4 | from structrepgen.utils.utils import torch_device_select 5 | 6 | class Behler: 7 | ''' 8 | Behler-Parinello Atomic-centered symmetry functions (ACSF) for representation extraction 9 | See 10 | 1. Behler, J. Chem. Phys. (2011) 11 | ''' 12 | def __init__(self, CONFIG) -> None: 13 | self.CONFIG = CONFIG 14 | 15 | # check GPU availability & set device 16 | self.device = torch_device_select(self.CONFIG.gpu) 17 | 18 | def get_features(self, distances, rij, atomic_numbers): 19 | ''' 20 | wrapper method for self.get_ACSF_features() 21 | ''' 22 | return self.get_ACSF_features(distances, rij, atomic_numbers) 23 | 24 | def get_ACSF_features(self, distances, rij, atomic_numbers): 25 | ''' 26 | Get Atomic-centered symmetry functions (ACSF) features (representation R) 27 | 28 | Parameters: 29 | distances (): 30 | rij (): 31 | atomic_numbers (): 32 | 33 | Returns: 34 | #TODO 35 | ''' 36 | 37 | n_atoms = distances.shape[0] 38 | type_set1 = self.create_type_set(atomic_numbers, 1) 39 | type_set2 = self.create_type_set(atomic_numbers, 2) 40 | 41 | atomic_numbers = torch.tensor(atomic_numbers, dtype=int) 42 | atomic_numbers = atomic_numbers.view(1, -1).expand(n_atoms, -1) 43 | atomic_numbers = atomic_numbers.reshape(n_atoms, -1).numpy() 44 | 45 | # calculate size of all_features tensor 46 | g2p = self.CONFIG.g2_params 47 | g5p = self.CONFIG.g5_params 48 | G2_out_size = len(g2p.Rs) * len(g2p.eta) * len(type_set1) 49 | G5_out_size = len(g5p.Rs) * len(g5p.eta) * len(g5p.zeta) * len(g5p.lambdas) * len(type_set2) if self.CONFIG.g5 else 0 50 | all_features = torch.zeros((n_atoms, G2_out_size+G5_out_size), device=self.device, dtype=torch.float) 51 | 52 | for i in range(n_atoms): 53 | Rc = self.CONFIG.cutoff 54 | mask = (distances[i, :] <= Rc) & (distances[i, :] != 0.0) 55 | Dij = distances[i, mask] 56 | Rij = rij[i, mask] 57 | IDs = atomic_numbers[i, mask.cpu().numpy()] 58 | jks = np.array(list(combinations(range(len(IDs)), 2))) 59 | 60 | # get G2 61 | G2_i = self.get_G2(Dij, IDs, type_set1) 62 | 63 | all_features[i, :G2_out_size] = G2_i 64 | # get G5 65 | if self.CONFIG.g5: 66 | G5_i = self.get_G5(Rij, IDs, jks, type_set2) 67 | all_features[i, G2_out_size:] = G5_i 68 | 69 | if self.CONFIG.average: 70 | length = all_features.shape[1] 71 | out = torch.zeros(length * 2, device=self.device, dtype=torch.float) 72 | out[:length] = torch.max(all_features, dim=0)[0] 73 | out[length:] = torch.min(all_features, dim=0)[0] 74 | return out 75 | 76 | return all_features 77 | 78 | def get_G2(self, Dij, IDs, type_set): 79 | ''' 80 | Get G2 symmetry function (radial): 81 | G_{i}^{2}=\sum_{j} e^{-\eta\left(R_{i j}-R_{s}\right)^{2}} \cdot f_{c}\left(R_{i j}\right) 82 | See Behler, J. Chem. Phys. (2011) Eqn (6) 83 | 84 | Parameters: 85 | Dij (): 86 | IDs (): 87 | type_set (): 88 | 89 | Returns: 90 | TODO 91 | ''' 92 | Rc = self.CONFIG.cutoff 93 | Rs = torch.tensor(self.CONFIG.g2_params.Rs, device=self.device) 94 | eta = torch.tensor(self.CONFIG.g2_params.eta, device=self.device) 95 | 96 | n1, n2, m, l = len(Rs), len(eta), len(Dij), len(type_set) 97 | 98 | d20 = (Dij - Rs.view(-1, 1)) ** 2 99 | term = torch.exp(torch.einsum('i,jk->ijk', -eta, d20)) 100 | results = torch.einsum('ijk,k->ijk', term, self.cosine_cutoff(Dij, Rc)) 101 | results = results.view(-1, m) 102 | 103 | G2 = torch.zeros([n1*n2*l], device=self.device) 104 | 105 | for id, j_type in enumerate(type_set): 106 | ids = self.select_rows(IDs, j_type) 107 | G2[id::l] = torch.sum(results[:, ids], axis=1) 108 | 109 | return G2 110 | 111 | def get_G5(self, Rij, IDs, jks, type_set): 112 | ''' 113 | Get G5 symmetry function (angular) 114 | See Behler, J. Chem. Phys. (2011) Eqn (9) 115 | 116 | Parameters: 117 | TODO 118 | 119 | Returns: 120 | TODO 121 | ''' 122 | 123 | Rc = self.CONFIG.cutoff 124 | Rs = torch.tensor(self.CONFIG.g5_params.Rs, device=self.device) 125 | eta = torch.tensor(self.CONFIG.g5_params.eta, device=self.device) 126 | lambdas = torch.tensor(self.CONFIG.g5_params.lambdas, device=self.device) 127 | zeta = torch.tensor(self.CONFIG.g5_params.zeta, device=self.device) 128 | 129 | n1, n2, n3, n4, l = len(Rs), len(eta), len(lambdas), len(zeta), len(type_set) 130 | jk = len(jks) # m1 131 | if jk == 0: # For dimer 132 | return torch.zeros([n1*n2*n3*n4*l], device=self.device) 133 | 134 | rij = Rij[jks[:, 0]] # [m1, 3] 135 | rik = Rij[jks[:, 1]] # [m1, 3] 136 | R2ij0 = torch.sum(rij**2., axis=1) 137 | R2ik0 = torch.sum(rik**2., axis=1) 138 | R1ij0 = torch.sqrt(R2ij0) # m1 139 | R1ik0 = torch.sqrt(R2ik0) # m1 140 | R2ij = R2ij0 - Rs.view(-1, 1)**2 # n1*m1 141 | R2ik = R2ik0 - Rs.view(-1, 1)**2 # n1*m1 142 | 143 | R1ij = R1ij0 - Rs.view(-1, 1) # n1*m1 144 | R1ik = R1ik0 - Rs.view(-1, 1) # n1*m1 145 | 146 | powers = 2. ** (1.-zeta) # n4 147 | cos_ijk = torch.sum(rij*rik, axis=1)/R1ij0/R1ik0 # m1 array 148 | term1 = 1. + torch.einsum('i,j->ij', lambdas, cos_ijk) # n3*m1 149 | 150 | zetas1 = zeta.repeat_interleave(n3*jk).reshape([n4, n3, jk]) # n4*n3*m1 151 | term2 = torch.pow(term1, zetas1) # n4*n3*m1 152 | term3 = torch.exp(torch.einsum('i,jk->ijk', -eta, (R2ij+R2ik))) # n2*n1*m1 153 | # * Cosine(R1jk0, Rc) # m1 154 | term4 = self.cosine_cutoff(R1ij0, Rc) * self.cosine_cutoff(R1ik0, Rc) 155 | term5 = torch.einsum('ijk,lmk->ijlmk', term2, term3) # n4*n3*n2*n1*m1 156 | term6 = torch.einsum('ijkml,l->ijkml', term5, term4) # n4*n3*n2*n1*m1 157 | results = torch.einsum('i,ijkml->ijkml', powers, term6) # n4*n3*n2*n1*m1 158 | results = results.reshape([n1*n2*n3*n4, jk]) 159 | 160 | G5 = torch.zeros([n1*n2*n3*n4*l], device=self.device) 161 | jk_ids = IDs[jks] 162 | for id, jk_type in enumerate(type_set): 163 | ids = self.select_rows(jk_ids, jk_type) 164 | G5[id::l] = torch.sum(results[:, ids], axis=1) 165 | 166 | return G5 167 | 168 | def create_type_set(self, number_set, order): 169 | ''' 170 | TODO 171 | ''' 172 | types = list(set(number_set)) 173 | return np.array(list(combinations_with_replacement(types, order))) 174 | 175 | def select_rows(self, data, row_pattern): 176 | ''' 177 | TODO 178 | ''' 179 | if len(row_pattern) == 1: 180 | ids = (data == row_pattern) 181 | elif len(row_pattern) == 2: 182 | a, b = row_pattern 183 | if a == b: 184 | ids = [id for id, d in enumerate(data) if d[0] == a and d[1] == a] 185 | else: 186 | ids = [id for id, d in enumerate(data) if (d[0] == a and d[1] == b) or (d[0] == b and d[1] == a)] 187 | return ids 188 | 189 | def cosine_cutoff(self, Rij, Rc): 190 | ''' 191 | Cosine cutoff function 192 | See Behler, J. Chem. Phys. (2011) Eqn (4) 193 | 194 | Parameters: 195 | Rij (torch.Tensor): distance between atom i and j 196 | Rc (float): cutoff radius 197 | 198 | Returns: 199 | out (torch.Tensor): cosine cutoff 200 | ''' 201 | 202 | out = 0.5 * (torch.cos(np.pi * Rij / Rc) + 1.) 203 | out[out > Rc] = 0. 204 | return out -------------------------------------------------------------------------------- /structrepgen/descriptors/generic.py: -------------------------------------------------------------------------------- 1 | import torch, itertools 2 | import numpy as np 3 | 4 | def get_distances(positions, pbc_offsets, device): 5 | ''' 6 | Get atomic distances 7 | 8 | Parameters: 9 | positions (numpy.ndarray/torch.Tensor): positions attribute of ase.Atoms 10 | pbc_offsets (numpy.ndarray/torch.Tensor): periodic boundary condition offsets 11 | 12 | Returns: 13 | TODO 14 | ''' 15 | 16 | if isinstance(positions, np.ndarray): 17 | positions = torch.tensor(positions, device=device, dtype=torch.float) 18 | 19 | n_atoms = len(positions) 20 | n_cells = len(pbc_offsets) 21 | 22 | pos1 = positions.view(-1, 1, 1, 3).expand(-1, n_atoms, n_cells, 3) 23 | pos2 = positions.view(1, -1, 1, 3).expand(n_atoms, -1, n_cells, 3) 24 | pbc_offsets = pbc_offsets.view(-1, n_cells, 3).expand(pos2.shape[0], n_cells, 3) 25 | pos2 = pos2 + pbc_offsets 26 | 27 | # calculate the distance between target atom and the periodic images of the other atom 28 | atom_distance_sqr = torch.linalg.norm(pos1 - pos2, dim=-1) 29 | # get the minimum distance 30 | atom_distance_sqr_min, min_indices = torch.min(atom_distance_sqr, dim=-1) 31 | 32 | atom_rij = pos1 - pos2 33 | min_indices = min_indices[..., None, None].expand(-1, -1, 1, atom_rij.size(3)) 34 | atom_rij = torch.gather(atom_rij, dim=2, index=min_indices).squeeze() 35 | 36 | return atom_distance_sqr_min, atom_rij 37 | 38 | def get_pbc_offsets(cell, offset_num, device): 39 | ''' 40 | Get periodic boundary condition (PBC) offsets 41 | 42 | Parameters: 43 | cell (np.ndarray/torch.Tensor): unit cell vectors of ase.cell.Cell 44 | offset_num: 45 | 46 | Returns: 47 | TODO 48 | ''' 49 | if isinstance(cell, np.ndarray): 50 | cell = torch.tensor(np.array(cell), device=device, dtype=torch.float) 51 | 52 | unit_cell = [] 53 | offset_range = np.arange(-offset_num, offset_num + 1) 54 | 55 | for prod in itertools.product(offset_range, offset_range, offset_range): 56 | unit_cell.append(list(prod)) 57 | 58 | unit_cell = torch.tensor(unit_cell, dtype=torch.float, device=device) 59 | 60 | return torch.mm(unit_cell, cell.to(device)) 61 | 62 | # Obtain unit cell offsets for distance calculation 63 | class PBC_offsets(): 64 | def __init__(self, cell, device, supercell_max=4): 65 | # set up pbc offsets for minimum distance in pbc 66 | self.pbc_offsets = [] 67 | 68 | for offset_num in range(0, supercell_max): 69 | unit_cell = [] 70 | offset_range = np.arange(-offset_num, offset_num+1) 71 | 72 | for prod in itertools.product(offset_range, offset_range, offset_range): 73 | unit_cell.append(list(prod)) 74 | 75 | unit_cell = torch.tensor(unit_cell, dtype=torch.float, device=device) 76 | self.pbc_offsets.append(torch.mm(unit_cell, cell.to(device))) 77 | 78 | def get_offset(self, offset_num): 79 | return self.pbc_offsets[offset_num] -------------------------------------------------------------------------------- /structrepgen/extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/extraction/__init__.py -------------------------------------------------------------------------------- /structrepgen/extraction/representation_extraction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import ase, ase.db, csv, yaml, torch 4 | 5 | from structrepgen.descriptors.behler import Behler 6 | from structrepgen.descriptors.generic import * 7 | 8 | class RepresentationExtraction: 9 | def __init__(self, CONFIG) -> None: 10 | ''' 11 | Extract representation R from raw data. 12 | 13 | Parameters: 14 | CONFIG (dict): descriptor configurations for extraction 15 | 16 | Returns: 17 | NIL 18 | ''' 19 | self.CONFIG = CONFIG 20 | 21 | def extract(self, fname=None): 22 | descriptor = self.CONFIG.descriptor 23 | 24 | if descriptor == 'behler': 25 | self.behler(fname) 26 | else: 27 | raise Exception("Descriptor currently not supported.") 28 | 29 | def behler(self, fname=None): 30 | behler = Behler(self.CONFIG) 31 | 32 | data_x=[] 33 | data_y=[] 34 | 35 | db = ase.db.connect(self.CONFIG.data) 36 | for row in db.select(): 37 | atoms = row.toatoms(add_additional_information=False) 38 | positions = atoms.get_positions() 39 | offsets = get_pbc_offsets(np.array(atoms.get_cell()), 0, behler.device) 40 | distances, rij = get_distances(positions, offsets, behler.device) 41 | 42 | features = behler.get_features(distances, rij, atoms.get_atomic_numbers()) 43 | features = features.cpu().numpy() 44 | 45 | data_x.append(list(features)) 46 | data_y.append([row.get('target')]) 47 | break # should be removed in production 48 | 49 | with open(self.CONFIG.x_fname, 'w', newline='') as f: 50 | wr = csv.writer(f) 51 | wr.writerows(data_x) 52 | 53 | with open(self.CONFIG.y_fname, 'w', newline='') as f: 54 | wr = csv.writer(f) 55 | wr.writerows(data_y) -------------------------------------------------------------------------------- /structrepgen/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/generators/__init__.py -------------------------------------------------------------------------------- /structrepgen/generators/generator.py: -------------------------------------------------------------------------------- 1 | import torch, joblib, os 2 | import numpy as np 3 | 4 | from structrepgen.models.models import * 5 | from structrepgen.models.CVAE import * 6 | from structrepgen.utils.utils import torch_device_select 7 | 8 | class Generator: 9 | ''' 10 | Generator class for generating representation R from decoder 11 | 12 | # TODO 13 | - make this generic such that different decoders can be used 14 | 15 | Parameters: 16 | CONFIG (dict): model configurations for generation 17 | 18 | Returns: 19 | NIL 20 | ''' 21 | def __init__(self, CONFIG) -> None: 22 | self.CONFIG = CONFIG 23 | 24 | # check GPU availability & set device 25 | self.device = torch_device_select(self.CONFIG.gpu) 26 | 27 | # load models 28 | if self.device == 'cpu': 29 | self.model = torch.load(CONFIG.model_path, map_location=torch.device(self.device)) 30 | self.model_ff = torch.load(CONFIG.ff_model_path, map_location=torch.device(self.device)) 31 | else: 32 | self.model = torch.load(CONFIG.model_path) 33 | self.model_ff = torch.load(CONFIG.ff_model_path) 34 | 35 | self.model.to(self.device) 36 | self.model_ff.to(self.device) 37 | self.model.eval() 38 | self.model_ff.eval() 39 | 40 | self.rev_x = torch.zeros((len(self.CONFIG.targets), self.CONFIG.num_z, self.CONFIG.params.input_dim)).to(self.device) 41 | 42 | def generate(self): 43 | ''' 44 | Generate representation R using decoder by randomly sampling the latent space 45 | Output csv saved in folder defined in yaml file 46 | 47 | Parameters: 48 | NIL 49 | 50 | Returns: 51 | NIL 52 | ''' 53 | p = self.CONFIG.params 54 | for count, y0 in enumerate(self.CONFIG.targets): 55 | for i in range(self.CONFIG.num_z): 56 | z = torch.randn(1, p.latent_dim).to(self.device) 57 | y = torch.tensor([[y0]]).to(self.device) 58 | 59 | z = torch.cat((z, y), dim=1) 60 | 61 | reconstructed_x = self.model.decoder(z) 62 | self.rev_x[count, i, :] = reconstructed_x 63 | 64 | fname = self.CONFIG.save_path + 'gen_samps_cvae_' + str(y0) + '.csv' 65 | np.savetxt(fname, self.rev_x[count, :, :].cpu().data.numpy(), fmt='%.6f', delimiter=',') 66 | 67 | def range_check(self): 68 | ''' 69 | Check percentage of generated structures that have y value within +- self.CONFIG.delta of target y value 70 | ''' 71 | scaler = joblib.load(self.CONFIG.scaler_path) 72 | rev_x_scaled = scaler.inverse_transform(self.rev_x.reshape(-1, self.rev_x.shape[2]).cpu().data.numpy()) 73 | rev_x = torch.tensor(rev_x_scaled).to(self.device).reshape(self.rev_x.shape) 74 | 75 | ratios = 0 76 | avg_diff = 0 77 | 78 | for count, y0 in enumerate(self.CONFIG.targets): 79 | y1 = self.model_ff(rev_x[count, :,:]) 80 | 81 | indices = torch.where((y1 > y0 - self.CONFIG.delta) & (y1 < y0 + self.CONFIG.delta))[0] 82 | ratio = len(indices)/len(y1) * 100 83 | ratios += ratio 84 | rev_x_out = rev_x[count, indices, :] 85 | 86 | minn, maxx = min(self.CONFIG.targets), max(self.CONFIG.targets) 87 | y_indices = torch.where((y1 > minn) & (y1 < maxx))[0] 88 | average = torch.mean(y1[y_indices]).item() 89 | avg_diff += abs(y0-average) 90 | 91 | out = f'Target: {y0:.2f}, Average value: {average:.2f}, Percent of samples within range: {ratio:.2f}' 92 | print(out) 93 | 94 | fname = self.CONFIG.save_path + 'gen_samps_x_' + str(y0) + '.csv' 95 | np.savetxt(fname, rev_x_out.cpu().data.numpy(), fmt='%.6f', delimiter=',') 96 | 97 | ratios = ratios/len(self.CONFIG.targets) 98 | avg_diff = avg_diff/len(self.CONFIG.targets) 99 | avg_out = f'Average difference: {avg_diff:.2f}, Average percent: {ratios:.2f}' 100 | print(avg_out) 101 | 102 | return avg_diff, ratios -------------------------------------------------------------------------------- /structrepgen/models/CVAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Encoder(nn.Module): 7 | ''' This the encoder part of VAE 8 | ''' 9 | def __init__(self, input_dim, hidden_dim, latent_dim, hidden_layers, y_dim): 10 | ''' 11 | Args: 12 | input_dim: A integer indicating the size of input (in case of MNIST 28 * 28). 13 | hidden_dim: A integer indicating the size of hidden dimension. 14 | latent_dim: A integer indicating the latent size. 15 | n_classes: A integer indicating the number of classes. (dimension of one-hot representation of labels) 16 | ''' 17 | super().__init__() 18 | 19 | self.hidden = nn.Sequential() 20 | intermediate_dimensions = np.linspace( 21 | hidden_dim, latent_dim, hidden_layers+1, dtype=int)[0:-1] 22 | intermediate_dimensions = np.concatenate( 23 | ([input_dim + y_dim], intermediate_dimensions)) 24 | for i, (in_size, out_size) in enumerate(zip(intermediate_dimensions[:-1], intermediate_dimensions[1:])): 25 | self.hidden.add_module(name='Linear_'+str(i), 26 | module=nn.Linear(in_size, out_size)) 27 | #self.hidden.add_module(name='BN_'+str(i), module=nn.BatchNorm1d(out_size)) 28 | self.hidden.add_module(name='Act_'+str(i), module=nn.ELU()) 29 | #self.hidden.add_module(name='Drop_'+str(i), module=nn.Dropout(p=0.05, inplace=False)) 30 | 31 | self.mu = nn.Linear(intermediate_dimensions[-1], latent_dim) 32 | self.var = nn.Linear(intermediate_dimensions[-1], latent_dim) 33 | 34 | def forward(self, x): 35 | # x is of shape [batch_size, input_dim + n_classes] 36 | 37 | x = self.hidden(x) 38 | # hidden is of shape [batch_size, hidden_dim] 39 | 40 | # latent parameters 41 | mean = self.mu(x) 42 | # mean is of shape [batch_size, latent_dim] 43 | log_var = self.var(x) 44 | # log_var is of shape [batch_size, latent_dim] 45 | 46 | return mean, log_var 47 | 48 | 49 | class Decoder(nn.Module): 50 | ''' This the decoder part of VAE 51 | ''' 52 | 53 | def __init__(self, latent_dim, hidden_dim, output_dim, hidden_layers, y_dim): 54 | ''' 55 | Args: 56 | latent_dim: A integer indicating the latent size. 57 | hidden_dim: A integer indicating the size of hidden dimension. 58 | output_dim: A integer indicating the size of output. 59 | n_classes: A integer indicating the number of classes. (dimension of one-hot representation of labels) 60 | ''' 61 | super().__init__() 62 | 63 | self.hidden = nn.Sequential() 64 | intermediate_dimensions = np.linspace( 65 | latent_dim + y_dim, hidden_dim, hidden_layers+1, dtype=int) 66 | for i, (in_size, out_size) in enumerate(zip(intermediate_dimensions[:-1], intermediate_dimensions[1:])): 67 | self.hidden.add_module(name='Linear_'+str(i), 68 | module=nn.Linear(in_size, out_size)) 69 | #self.hidden.add_module(name='BN_'+str(i), module=nn.BatchNorm1d(out_size)) 70 | self.hidden.add_module(name='Act_'+str(i), module=nn.ELU()) 71 | #self.hidden.add_module(name='Drop_'+str(i), module=nn.Dropout(p=0.05, inplace=False)) 72 | 73 | self.hidden_to_out = nn.Linear(hidden_dim, output_dim) 74 | 75 | def forward(self, x): 76 | # x is of shape [batch_size, latent_dim + num_classes] 77 | x = self.hidden(x) 78 | 79 | generated_x = self.hidden_to_out(x) 80 | 81 | return generated_x 82 | 83 | 84 | class CVAE(nn.Module): 85 | ''' This the VAE, which takes a encoder and decoder. 86 | ''' 87 | 88 | def __init__(self, input_dim, hidden_dim, latent_dim, hidden_layers, y_dim): 89 | ''' 90 | Args: 91 | input_dim: A integer indicating the size of input. 92 | hidden_dim: A integer indicating the size of hidden dimension. 93 | latent_dim: A integer indicating the latent size. 94 | n_classes: A integer indicating the number of classes. (dimension of one-hot representation of labels) 95 | ''' 96 | super().__init__() 97 | 98 | self.encoder = Encoder(input_dim, hidden_dim, 99 | latent_dim, hidden_layers, y_dim) 100 | self.decoder = Decoder(latent_dim, hidden_dim, 101 | input_dim, hidden_layers, y_dim) 102 | 103 | def forward(self, x, y): 104 | 105 | x = torch.cat((x, y), dim=1) 106 | 107 | # encode 108 | z_mu, z_var = self.encoder(x) 109 | 110 | # sample from the distribution having latent parameters z_mu, z_var 111 | # reparameterize 112 | std = torch.exp(z_var / 2) 113 | eps = torch.randn_like(std) 114 | x_sample = eps.mul(std).add_(z_mu) 115 | 116 | z = torch.cat((x_sample, y), dim=1) 117 | 118 | # decode 119 | generated_x = self.decoder(z) 120 | 121 | return generated_x, z_mu, z_var 122 | 123 | def kl_divergence(z, mu, std): 124 | # -------------------------- 125 | # Monte carlo KL divergence 126 | # -------------------------- 127 | # https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed 128 | # 1. define the first two probabilities (in this case Normal for both) 129 | p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) 130 | q = torch.distributions.Normal(mu, std) 131 | 132 | # 2. get the probabilities from the equation 133 | log_qzx = q.log_prob(z) 134 | log_pz = p.log_prob(z) 135 | 136 | # kl 137 | kl = (log_qzx - log_pz) 138 | 139 | # sum over last dim to go from single dim distribution to multi-dim 140 | kl = kl.sum(-1) 141 | return kl 142 | 143 | 144 | def calculate_loss(x, reconstructed_x, mu, log_var, weight, mc_kl_loss): 145 | # reconstruction loss 146 | rcl = F.mse_loss(reconstructed_x, x, size_average=False) 147 | # kl divergence loss 148 | 149 | if mc_kl_loss == True: 150 | std = torch.exp(log_var / 2) 151 | q = torch.distributions.Normal(mu, std) 152 | z = q.rsample() 153 | kld = kl_divergence(z, mu, std).sum() 154 | else: 155 | kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) 156 | 157 | return rcl, kld * weight -------------------------------------------------------------------------------- /structrepgen/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/models/__init__.py -------------------------------------------------------------------------------- /structrepgen/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class ff_net(nn.Module): 7 | def __init__(self, input_dim, hidden_dim, layers): 8 | super(ff_net, self).__init__() 9 | self.lin1 = torch.nn.Linear(input_dim, hidden_dim) 10 | self.lin_list = torch.nn.ModuleList( 11 | [torch.nn.Linear(hidden_dim, hidden_dim) for i in range(layers)] 12 | ) 13 | 14 | self.lin2 = torch.nn.Linear(hidden_dim, 1) 15 | 16 | def forward(self, x): 17 | out = torch.nn.functional.relu(self.lin1(x)) 18 | for layer in self.lin_list: 19 | out = torch.nn.functional.relu(layer(out)) 20 | out = self.lin2(out) 21 | return out -------------------------------------------------------------------------------- /structrepgen/reconstruction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/reconstruction/__init__.py -------------------------------------------------------------------------------- /structrepgen/reconstruction/generic.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/reconstruction/generic.py -------------------------------------------------------------------------------- /structrepgen/reconstruction/reconstruction.py: -------------------------------------------------------------------------------- 1 | import time, ase, torch, os, pickle 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from ase import io, Atoms 5 | from skopt.space import Space 6 | from skopt.sampler import Lhs, Halton, Grid 7 | from structrepgen.models.models import * 8 | from structrepgen.descriptors.behler import Behler 9 | from structrepgen.descriptors.generic import * 10 | from structrepgen.models.CVAE import * 11 | from structrepgen.utils.utils import torch_device_select 12 | 13 | class Reconstruction: 14 | ''' 15 | Reconstruction 16 | ''' 17 | def __init__(self, CONFIG) -> None: 18 | self.CONFIG = CONFIG 19 | 20 | # select gpu or cpu for pytorch 21 | self.device = torch_device_select(self.CONFIG.gpu) 22 | 23 | # initialize descriptor 24 | if self.CONFIG.descriptor == 'behler': 25 | self.descriptor = Behler(CONFIG) 26 | else: 27 | raise ValueError('Unrecognised descriptor method: {}.'.format(self.CONFIG.descriptor)) 28 | 29 | self.model_ff = torch.load(self.CONFIG.ff_model_path) 30 | self.model_ff.to(self.device) 31 | 32 | torch.backends.cudnn.benchmark = True 33 | 34 | def main(self): 35 | ''' 36 | 37 | ''' 38 | tic = time.time() 39 | 40 | data_list = self.create_datalist() 41 | self.model_ff.eval() 42 | 43 | for i in range(len(data_list)): 44 | best_positions, _ = self.basin_hopping( 45 | data_list[i], 46 | total_trials=2, 47 | max_hops=10, 48 | lr=0.2, 49 | displacement_factor=2, 50 | max_loss=0.0001, 51 | verbose=True 52 | ) 53 | 54 | optimized_structure = Atoms( 55 | numbers=data_list[i]['atomic_numbers'], 56 | positions=best_positions.detach().cpu().numpy(), 57 | cell=data_list[i]['cell'].cpu().numpy(), 58 | pbc=(True, True, True) 59 | ) 60 | 61 | # save reconstructed cifs 62 | filename = self.CONFIG.reconstructed_file_path + str(i) + '_reconstructed.cif' 63 | ase.io.write(filename, optimized_structure) 64 | 65 | pbc_offsets = get_pbc_offsets(data_list[i]['cell'], 0, self.device) 66 | distances, rij = get_distances(best_positions, pbc_offsets, self.device) 67 | features = self.descriptor.get_features(distances, rij, data_list[i]['atomic_numbers']) 68 | 69 | # inference on FF model 70 | y0 = self.model_ff(data_list[i]['representation']) 71 | y1 = self.model_ff(features) 72 | print(y0, y1) 73 | 74 | elapsed_time = time.time() - tic 75 | print('Total elapsed time: {:.3f}'.format(elapsed_time)) 76 | 77 | def basin_hopping( 78 | self, 79 | data, 80 | total_trials = 20, 81 | max_hops = 500, 82 | lr = 0.05, 83 | displacement_factor = 2, 84 | max_loss = 0.01, 85 | write = False, 86 | verbose = False 87 | ): 88 | ''' 89 | 90 | ''' 91 | # setting for self.optimize() 92 | max_iter = 140 93 | 94 | offset_count = 0 95 | offsets = PBC_offsets(data['cell'], self.device, supercell_max=1) 96 | 97 | best_global_loss = float('inf') 98 | best_global_positions = None 99 | converged = False 100 | 101 | for _ in range(total_trials): 102 | # initialize random structure 103 | generated_pos = self.initialize( 104 | len(data['atomic_numbers']), 105 | data['cell'], 106 | data['atomic_numbers'], 107 | sampling_method='random' 108 | ) 109 | 110 | tic = time.time() 111 | best_local_loss = float('inf') 112 | best_local_positions = None 113 | 114 | for hop in range(max_hops): 115 | # TODO: parse loss function, optimizer and scheduler as args 116 | loss_func = F.l1_loss 117 | optimizer = torch.optim.Adam([generated_pos], lr=lr) 118 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 119 | optimizer, 120 | mode='min', 121 | factor=0.8, 122 | patience=15, 123 | min_lr=0.0001, 124 | threshold=0.0001 125 | ) 126 | 127 | best_positions, _, loss = self.optimize( 128 | generated_pos, 129 | data['cell'], 130 | offsets.get_offset(offset_count), 131 | data['atomic_numbers'], 132 | data['representation'], 133 | loss_func, 134 | optimizer, 135 | scheduler, 136 | max_iter, 137 | verbosity=999, 138 | early_stopping=100, 139 | threshold=0.0001, 140 | convergence=0.00005 141 | ) 142 | 143 | elapsed_time = time.time() - tic 144 | tic = time.time() 145 | 146 | if verbose: 147 | print("BH Hop: {:04d}, Loss: {:.5f}, Time (s): {:.4f}".format(hop, loss.item(), elapsed_time)) 148 | 149 | if loss.item() < best_local_loss: 150 | # TODO: implement metropolis criterion 151 | best_local_positions = best_positions.detach().clone() 152 | best_local_loss = loss.item() 153 | 154 | if loss.item() < best_global_loss: 155 | best_global_positions = best_local_positions 156 | best_global_loss = loss.item() 157 | 158 | if best_local_loss < max_loss: 159 | print("Convergence criterion met.") 160 | converged = True 161 | break 162 | 163 | # apply random shift in positions 164 | # TODO: change hardcoded value of 0.5 165 | displacement = (np.random.random_sample(best_positions.shape) - 0.5) * displacement_factor 166 | displacement = torch.tensor(displacement, device=self.device, dtype=torch.float) 167 | 168 | generated_pos = best_local_positions.detach().clone() + displacement 169 | generated_pos.requires_grad_() 170 | 171 | if converged: 172 | break 173 | 174 | if verbose: 175 | print('Ending Basin Hopping, fine optimization of best structure. Best global loss: {:5.3f}'.format(best_global_loss)) 176 | 177 | best_global_positions.requires_grad_() 178 | optimizer = torch.optim.Adam([best_global_positions], lr=lr*0.2) 179 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 180 | optimizer, 181 | mode='min', 182 | factor=0.8, 183 | patience=20, 184 | min_lr=0.0001, 185 | threshold=0.0001 186 | ) 187 | 188 | best_global_positions, _, loss = self.optimize( 189 | best_global_positions, 190 | data['cell'], 191 | offsets.get_offset(offset_count), 192 | data['atomic_numbers'], 193 | data['representation'], 194 | loss_func, 195 | optimizer, 196 | scheduler, 197 | 900, 198 | verbosity=(10 if verbose == True else 999), 199 | early_stopping=20, 200 | threshold=0.00001, 201 | convergence=0.00005 202 | ) 203 | 204 | if verbose: 205 | print('Best Basin Hopping loss: {:.5f}'.format(loss.item())) 206 | return best_global_positions, loss 207 | 208 | def optimize( 209 | self, 210 | positions, 211 | cell, 212 | pbc_offsets, 213 | atomic_numbers, 214 | target_representation, 215 | loss_function, 216 | optimizer, 217 | scheduler, 218 | max_iterations = 100, 219 | verbosity = 5, 220 | early_stopping = 1, 221 | threshold = 0.01, 222 | convergence = 0.001 223 | ): 224 | ''' 225 | 226 | ''' 227 | tic = time.time() 228 | 229 | 230 | all_positions = [positions.detach().clone()] 231 | count, loss_history = 0, [] 232 | best_positions = None 233 | best_loss = float('inf') 234 | min_dist = 1.5 235 | 236 | while count < max_iterations: 237 | optimizer.zero_grad() 238 | lr = scheduler.optimizer.param_groups[0]['lr'] 239 | 240 | # get representation / reconstruction loss 241 | distances, rij = get_distances(positions, pbc_offsets, self.device) 242 | features = self.descriptor.get_features(distances, rij, atomic_numbers) 243 | 244 | features = torch.unsqueeze(features, 0) 245 | 246 | loss = loss_function(features, target_representation) 247 | loss.backward(retain_graph=False) 248 | loss_history.append(loss) 249 | 250 | optimizer.step() 251 | scheduler.step(loss) 252 | 253 | #print loss and time taken 254 | if (count + 1) % verbosity == 0: 255 | elapsed_time = time.time() - tic 256 | tic = time.time() 257 | print("System Size: {:04d}, Step: {:04d}, Loss: {:.5f}, LR: {:.5f}, Time/step (s): {:.4f}".format(len(positions), count+1, loss.item(), lr, elapsed_time/verbosity)) 258 | 259 | # save best structure 260 | if loss < best_loss: 261 | best_positions = positions.detach().clone() 262 | best_loss = loss 263 | 264 | count += 1 265 | 266 | if best_loss < convergence: 267 | break 268 | 269 | if early_stopping > 1 and count > early_stopping: 270 | if abs((loss.item() - loss_history[-early_stopping]) / loss_history[-early_stopping]) < threshold: 271 | break 272 | 273 | all_positions.append(positions.detach().clone()) 274 | 275 | return best_positions, all_positions, best_loss 276 | 277 | def initialize(self, structure_len, cell, atomic_numbers, sampling_method='random', write=False): 278 | ''' 279 | Initialize structure 280 | 281 | Parameters: 282 | 283 | Returns: 284 | 285 | ''' 286 | space = Space([(0., 1.)] * 3) 287 | if sampling_method == 'random': 288 | positions = np.array(space.rvs(structure_len)) 289 | else: 290 | raise ValueError("Unrecognised sampling method.") 291 | 292 | generated_structure = Atoms( 293 | numbers=atomic_numbers, 294 | scaled_positions=positions, 295 | cell=cell.cpu().numpy(), 296 | pbc=(True, True, True) 297 | ) 298 | 299 | positions = generated_structure.positions 300 | if write == True: 301 | # TODO 302 | print("Implement write") 303 | pass 304 | 305 | return torch.tensor(positions, requires_grad=True, device=self.device, dtype=torch.float) 306 | 307 | def create_datalist(self): 308 | ''' 309 | Load crystal structure(s) data from directory 310 | 311 | Parameters: 312 | NIL 313 | 314 | Returns: 315 | TODO 316 | 317 | ''' 318 | representation = np.loadtxt(self.CONFIG.structure_file_path, dtype='float', delimiter=',') 319 | cell = self.CONFIG.cell 320 | data_list = [] 321 | 322 | for i in range(len(representation)): 323 | data = {} 324 | 325 | data['atomic_numbers'] = self.CONFIG.atoms 326 | data['cell'] = torch.tensor(np.array(cell), dtype=torch.float) 327 | data['representation'] = torch.unsqueeze(torch.tensor(representation[i], dtype=torch.float, device=self.device), 0) 328 | data_list.append(data) 329 | 330 | return data_list -------------------------------------------------------------------------------- /structrepgen/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fung-Lab/StructRepGen/c000acddbb450e2a4c0a81d09a6dfb07fc675a7d/structrepgen/utils/__init__.py -------------------------------------------------------------------------------- /structrepgen/utils/dotdict.py: -------------------------------------------------------------------------------- 1 | # class dotdict(dict): 2 | # """dot.notation access to dictionary attributes""" 3 | # __getattr__ = dict.get 4 | # __setattr__ = dict.__setitem__ 5 | # __delattr__ = dict.__delitem__ 6 | 7 | 8 | class dotdict(dict): 9 | """dot.notation access to dictionary attributes""" 10 | def __getattr__(*args): 11 | val = dict.get(*args) 12 | return dotdict(val) if type(val) is dict else val 13 | __setattr__ = dict.__setitem__ 14 | __delattr__ = dict.__delitem__ -------------------------------------------------------------------------------- /structrepgen/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch, warnings 2 | 3 | def torch_device_select(gpu): 4 | # check GPU availability & return device type 5 | if torch.cuda.is_available() and not gpu: 6 | warnings.warn("GPU is available but not used.") 7 | return 'cpu' 8 | elif not torch.cuda.is_available() and gpu: 9 | warnings.warn("GPU is not available but set to used. Using CPU.") 10 | return 'cpu' 11 | elif torch.cuda.is_available() and gpu: 12 | return 'cuda' 13 | else: 14 | return 'cpu' -------------------------------------------------------------------------------- /tests/original/db_to_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ase, csv 3 | import ase.db 4 | from .original_descriptor import * 5 | 6 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | device = 'cpu' 8 | 9 | def db_to_data(): 10 | processing_arguments = {} 11 | g2_params = {'eta': torch.tensor([0.01, 0.06, 0.1, 0.2, 0.4, 0.7, 1.0, 2.0, 3.5, 5.0], device=device), 'Rs': torch.tensor([ 12 | 0, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10], device=device)} 13 | g5_params = {'lambda': torch.tensor([-1, 1], device=device), 'zeta': torch.tensor([1, 2, 4, 16, 64], device=device), 14 | 'eta': torch.tensor([0.06, 0.1, 0.2, 0.4, 1.0], device=device), 'Rs': torch.tensor([0], device=device)} 15 | processing_arguments['cutoff'] = 20 16 | processing_arguments['mode'] = 'features' 17 | processing_arguments['g2_params'] = g2_params 18 | processing_arguments['g5_params'] = g5_params 19 | processing_arguments['g5'] = True 20 | processing_arguments['average'] = True 21 | 22 | db = ase.db.connect('./data/raw/data.db') 23 | data_x=[] 24 | data_y=[] 25 | 26 | for row in db.select(): 27 | ase_structure = row.toatoms(add_additional_information=False) 28 | positions = torch.tensor( 29 | ase_structure.get_positions(), device=device, dtype=torch.float) 30 | distances, rij = get_distances(positions, get_pbc_offsets(torch.tensor( 31 | np.array(ase_structure.get_cell()), device=device, dtype=torch.float), 0)) 32 | features = get_ACSF_features( 33 | distances, rij, ase_structure.get_atomic_numbers(), processing_arguments) 34 | features = torch.squeeze(features).cpu().numpy() 35 | 36 | data_x.append(list(features)) 37 | data_y.append([row.get('target')]) 38 | break 39 | 40 | with open("./data/unittest/unittest_original_x.csv", 'w', newline='') as f: 41 | wr = csv.writer(f) 42 | wr.writerows(data_x) 43 | 44 | with open("./data/unittest/unittest_original_y.csv", 'w', newline='') as f: 45 | wr = csv.writer(f) 46 | wr.writerows(data_y) 47 | 48 | if __name__ == "__main__": 49 | db_to_data() -------------------------------------------------------------------------------- /tests/original/original_descriptor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | import math 4 | from itertools import combinations, combinations_with_replacement 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | device = 'cpu' 11 | 12 | ###obtains distance with pytorch tensors 13 | def get_distances(positions, pbc_offsets): 14 | 15 | num_atoms = len(positions) 16 | num_cells = len(pbc_offsets) 17 | 18 | pos1 = positions.view(-1, 1, 1, 3).expand(-1, num_atoms, 1, 3) 19 | pos1 = pos1.expand(-1, num_atoms, num_cells, 3) 20 | pos2 = positions.view(1, -1, 1, 3).expand(num_atoms, -1, 1, 3) 21 | pos2 = pos2.expand(num_atoms, -1, num_cells, 3) 22 | 23 | pbc_offsets = pbc_offsets.view(-1, num_cells, 3).expand(pos2.shape[0], num_cells, 3) 24 | pos2 = pos2 + pbc_offsets 25 | ###calculates distance between target atom and the periodic images of the other atom, then gets the minimum distance 26 | atom_distance_sqr = torch.linalg.norm(pos1 - pos2, dim=-1) 27 | atom_distance_sqr_min, min_indices = torch.min(atom_distance_sqr, dim=-1) 28 | #atom_distance_sqr_min = torch.amin(atom_distance_sqr, dim=-1) 29 | 30 | atom_rij = pos1 - pos2 31 | min_indices = min_indices[..., None, None].expand(-1, -1, 1, atom_rij.size(3)) 32 | atom_rij = torch.gather(atom_rij, dim=2, index=min_indices).squeeze() 33 | 34 | return atom_distance_sqr_min, atom_rij 35 | 36 | 37 | ###Obtain unit cell offsets for distance calculation 38 | class PBC_offsets(): 39 | def __init__(self,cell, supercell_max=4): 40 | #set up pbc offsets for minimum distance in pbc 41 | self.pbc_offsets=[] 42 | 43 | for offset_num in range(0, supercell_max): 44 | unit_cell=[] 45 | offset_range = np.arange(-offset_num, offset_num+1) 46 | 47 | for prod in itertools.product(offset_range, offset_range, offset_range): 48 | unit_cell.append(list(prod)) 49 | 50 | unit_cell = torch.tensor(unit_cell, dtype=torch.float, device=device) 51 | self.pbc_offsets.append(torch.mm(unit_cell, cell.to(device))) 52 | 53 | def get_offset(self, offset_num): 54 | return self.pbc_offsets[offset_num] 55 | ###functional form 56 | def get_pbc_offsets(cell, offset_num): 57 | 58 | unit_cell=[] 59 | offset_range = np.arange(-offset_num, offset_num+1) 60 | 61 | ##put this out of loop 62 | for prod in itertools.product(offset_range, offset_range, offset_range): 63 | unit_cell.append(list(prod)) 64 | 65 | unit_cell = torch.tensor(unit_cell, dtype=torch.float, device=device) 66 | 67 | return torch.mm(unit_cell, cell.to(device)) 68 | 69 | 70 | ###ACSF evaluation function 71 | def get_ACSF_features(distances, rij, atomic_numbers, parameters): 72 | 73 | num_atoms = distances.shape[0] 74 | 75 | #offsets=get_pbc_offsets(cell, 2) 76 | #distances, rij = get_distances(positions, offsets) 77 | #print(distances.shape, rij.shape) 78 | 79 | type_set1 = create_type_set(atomic_numbers, 1) 80 | type_set2 = create_type_set(atomic_numbers, 2) 81 | 82 | atomic_numbers = torch.tensor(atomic_numbers, dtype=int) 83 | atomic_numbers = atomic_numbers.view(1, -1).expand(num_atoms, -1) 84 | atomic_numbers = atomic_numbers.reshape(num_atoms, -1).numpy() 85 | 86 | for i in range(0, num_atoms): 87 | 88 | mask = (distances[i, :] <= parameters['cutoff']) & (distances[i, :] != 0.0) 89 | Dij = distances[i, mask] 90 | Rij = rij[i, mask] 91 | IDs = atomic_numbers[i, mask.cpu().numpy()] 92 | jks = np.array(list(combinations(range(len(IDs)), 2))) 93 | 94 | G2i = calculate_G2_pt(Dij, IDs, type_set1, parameters['cutoff'], parameters['g2_params']['Rs'], parameters['g2_params']['eta']) 95 | if parameters['g5'] == True: 96 | G5i = calculate_G5_pt(Rij, IDs, jks, type_set2, parameters['cutoff'], parameters['g5_params']['Rs'], parameters['g5_params']['eta'], parameters['g5_params']['zeta'], parameters['g5_params']['lambda']) 97 | G_comb=torch.cat((G2i, G5i), dim=0).view(1,-1) 98 | else: 99 | G_comb=G2i.view(1,-1) 100 | 101 | if i == 0: 102 | all_G=G_comb 103 | else: 104 | all_G=torch.cat((all_G, G_comb), dim=0) 105 | 106 | if parameters['average']==True: 107 | all_G_max = torch.max(all_G, dim=0)[0] 108 | all_G_min = torch.min(all_G, dim=0)[0] 109 | all_G = torch.cat((all_G_max, all_G_min)).unsqueeze(0) 110 | 111 | return all_G 112 | 113 | 114 | ### ACSF sub-functions 115 | ###G2 symmetry function (radial) 116 | def calculate_G2_pt(Dij, IDs, type_set, Rc, Rs, etas): 117 | 118 | n1, n2, m, l = len(Rs), len(etas), len(Dij), len(type_set) 119 | d20 = (Dij - Rs.view(-1,1)) ** 2 # n1*m 120 | term = torch.exp(torch.einsum('i,jk->ijk', -etas, d20)) # n2*n1*m 121 | results = torch.einsum('ijk,k->ijk', term, cosine_cutoff_pt(Dij, Rc)) # n2*n1*m 122 | results = results.reshape([n1*n2, m]) 123 | 124 | G2 = torch.zeros([n1*n2*l], device=device) 125 | for id, j_type in enumerate(type_set): 126 | ids = select_rows(IDs, j_type) 127 | G2[id::l] = torch.sum(results[:, ids], axis=1) 128 | 129 | return G2 130 | 131 | 132 | ###G5 symmetry function (angular) 133 | def calculate_G5_pt(Rij, IDs, jks, type_set, Rc, Rs, etas, zetas, lambdas): 134 | 135 | n1, n2, n3, n4, l = len(Rs), len(etas), len(lambdas), len(zetas), len(type_set) 136 | jk = len(jks) # m1 137 | if jk == 0: # For dimer 138 | return torch.zeros([n1*n2*n3*n4*l], device=device) 139 | 140 | rij = Rij[jks[:,0]] # [m1, 3] 141 | rik = Rij[jks[:,1]] # [m1, 3] 142 | R2ij0 = torch.sum(rij**2., axis=1) 143 | R2ik0 = torch.sum(rik**2., axis=1) 144 | R1ij0 = torch.sqrt(R2ij0) # m1 145 | R1ik0 = torch.sqrt(R2ik0) # m1 146 | R2ij = R2ij0 - Rs.view(-1,1)**2 # n1*m1 147 | R2ik = R2ik0 - Rs.view(-1,1)**2 # n1*m1 148 | 149 | R1ij = R1ij0 - Rs.view(-1,1) # n1*m1 150 | R1ik = R1ik0 - Rs.view(-1,1) # n1*m1 151 | 152 | powers = 2. ** (1.-zetas) #n4 153 | cos_ijk = torch.sum(rij*rik, axis=1)/R1ij0/R1ik0 # m1 array 154 | term1 = 1. + torch.einsum('i,j->ij', lambdas, cos_ijk) # n3*m1 155 | 156 | zetas1 = zetas.repeat_interleave(n3*jk).reshape([n4, n3, jk]) # n4*n3*m1 157 | term2 = torch.pow(term1, zetas1) # n4*n3*m1 158 | term3 = torch.exp(torch.einsum('i,jk->ijk', -etas, (R2ij+R2ik))) # n2*n1*m1 159 | term4 = cosine_cutoff_pt(R1ij0, Rc) * cosine_cutoff_pt(R1ik0, Rc) #* Cosine(R1jk0, Rc) # m1 160 | term5 = torch.einsum('ijk,lmk->ijlmk', term2, term3) #n4*n3*n2*n1*m1 161 | term6 = torch.einsum('ijkml,l->ijkml', term5, term4) #n4*n3*n2*n1*m1 162 | results = torch.einsum('i,ijkml->ijkml', powers, term6) #n4*n3*n2*n1*m1 163 | results = results.reshape([n1*n2*n3*n4, jk]) 164 | 165 | G5 = torch.zeros([n1*n2*n3*n4*l], device=device) 166 | jk_ids = IDs[jks] 167 | for id, jk_type in enumerate(type_set): 168 | ids = select_rows(jk_ids, jk_type) 169 | G5[id::l] = torch.sum(results[:, ids], axis=1) 170 | 171 | return G5 172 | 173 | 174 | ### 175 | def create_type_set(number_set, order): 176 | types = list(set(number_set)) 177 | return np.array(list(combinations_with_replacement(types, order))) 178 | 179 | 180 | ###Aggregation function 181 | #slow 182 | def select_rows(data, row_pattern): 183 | if len(row_pattern) == 1: 184 | ids = (data==row_pattern) 185 | elif len(row_pattern) == 2: 186 | a, b = row_pattern 187 | if a==b: 188 | ids = [id for id, d in enumerate(data) if d[0]==a and d[1]==a] 189 | else: 190 | ids = [id for id, d in enumerate(data) if (d[0] == a and d[1]==b) or (d[0] == b and d[1]==a)] 191 | return ids 192 | 193 | 194 | ###Cosine cutoff function 195 | def cosine_cutoff_pt(Rij, Rc): 196 | mask = (Rij > Rc) 197 | result = 0.5 * (torch.cos(torch.tensor(np.pi, device=device) * Rij / Rc) + 1.) 198 | result[mask] = 0 199 | return result -------------------------------------------------------------------------------- /tests/test_behler.py: -------------------------------------------------------------------------------- 1 | import yaml, os, unittest, torch 2 | import numpy as np 3 | from structrepgen.extraction.representation_extraction import * 4 | from structrepgen.utils.dotdict import dotdict 5 | from original.db_to_data import db_to_data 6 | 7 | class TestBehler(unittest.TestCase): 8 | 9 | @classmethod 10 | def setUpClass(self): 11 | # original implementation 12 | db_to_data() 13 | 14 | # SRG 15 | stream = open('./configs/unittest/unittest_behler.yaml') 16 | CONFIG = yaml.safe_load(stream) 17 | stream.close() 18 | self.CONFIG = dotdict(CONFIG) 19 | extractor = RepresentationExtraction(self.CONFIG) 20 | extractor.extract() 21 | 22 | self.original_x = './data/unittest/unittest_original_x.csv' 23 | self.original_y = './data/unittest/unittest_original_y.csv' 24 | self.srg_x = self.CONFIG.x_fname 25 | self.srg_y = self.CONFIG.y_fname 26 | 27 | def test_file_existence(self): 28 | ''' 29 | Test if output files are successfully generated 30 | ''' 31 | 32 | for file in [self.original_x, self.original_y, self.srg_x, self.srg_y]: 33 | self.assertTrue(os.path.exists(file)) 34 | 35 | def test_x(self): 36 | ''' 37 | Test if output x files are the same 38 | ''' 39 | 40 | original_x_tensor = torch.from_numpy(np.genfromtxt(self.original_x, delimiter=',')) 41 | srg_x_tensor = torch.from_numpy(np.genfromtxt(self.srg_x, delimiter=',')) 42 | 43 | same = torch.all(torch.eq(original_x_tensor, srg_x_tensor)) 44 | self.assertTrue(same.item()) 45 | 46 | def test_y(self): 47 | ''' 48 | Test if output y files are the same 49 | ''' 50 | 51 | original_y_tensor = torch.from_numpy(np.genfromtxt(self.original_y, delimiter=',')) 52 | srg_y_tensor = torch.from_numpy(np.genfromtxt(self.srg_y, delimiter=',')) 53 | 54 | same = torch.all(torch.eq(original_y_tensor, srg_y_tensor)) 55 | self.assertTrue(same.item()) 56 | 57 | @classmethod 58 | def tearDownClass(self): 59 | for file in [self.original_x, self.original_y, self.srg_x, self.srg_y]: 60 | os.remove(file) 61 | 62 | if __name__ == "__main__": 63 | unittest.main(verbosity=2) --------------------------------------------------------------------------------