├── .gitignore ├── LICENSE ├── README.md ├── conditioning_specifications ├── 1_polarizability.json ├── 2_fingerprint.json ├── 3_gap_comp.json ├── 4_comp_relenergy.json └── 5_gap_relenergy.json ├── display_molecules.py ├── filter_generated.py ├── gschnet_cond_script.py ├── images └── concept_results_scheme.png ├── nn_classes.py ├── preprocess_dataset.py ├── published_data ├── README.md └── data_base_help.py ├── qm9_data.py ├── splits ├── 1_polarizability_split.npz ├── 2_fingerprint_split.npz ├── 3_gap_comp_split.npz ├── 4_comp_relenergy_split.npz ├── 5_gap_relenergy_split.npz └── qm9_invalid.txt ├── utility_classes.py └── utility_functions.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | *.DS_Store 7 | 8 | # test data 9 | src/sacred_scripts/data/* 10 | src/sacred_scripts/models/* 11 | src/sacred_scripts/experiments/* 12 | src/scripts/data/ 13 | src/scripts/training/ 14 | 15 | docs/tutorials/*.db 16 | docs/tutorials/*.xyz 17 | docs/tutorials/qm9tut 18 | 19 | # C extensions 20 | *.so 21 | 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # dotenv 98 | .env 99 | 100 | # virtualenv 101 | .venv 102 | venv/ 103 | ENV/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Niklas Gebauer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ⚠️ **_Disclaimer: This repository is deprecated and only meant for reproduction of the published results of cG-SchNet on QM9. If you want to use custom data sets or build on top of our model, please refer to the [up-to-date implementation](https://github.com/atomistic-machine-learning/schnetpack-gschnet)._** 2 | 3 | # cG-SchNet - A Conditional Generative Neural Network for 3d Molecules 4 | 5 | DOI 6 | 7 | Implementation of cG-SchNet - a conditional generative neural network for 3d molecular structures - accompanying the paper [_"Inverse design of 3d molecular structures with conditional generative neural networks"_](https://www.nature.com/articles/s41467-022-28526-y). 8 | If you are using cG-SchNet in your research, please cite the corresponding publication: 9 | 10 | N.W.A. Gebauer, M. Gastegger, S.S.P. Hessmann, K.-R. Müller, and K.T. Schütt. 11 | Inverse design of 3d molecular structures with conditional generative neural networks. 12 | Nature Communications 13, 973 (2022). https://doi.org/10.1038/s41467-022-28526-y 13 | 14 | 15 | @Article{gebauer2022inverse, 16 | author={Gebauer, Niklas W. A. and Gastegger, Michael and Hessmann, Stefaan S. P. and M{\"u}ller, Klaus-Robert and Sch{\"u}tt, Kristof T.}, 17 | title={Inverse design of 3d molecular structures with conditional generative neural networks}, 18 | journal={Nature Communications}, 19 | year={2022}, 20 | volume={13}, 21 | number={1}, 22 | pages={973}, 23 | issn={2041-1723}, 24 | doi={10.1038/s41467-022-28526-y}, 25 | url={https://doi.org/10.1038/s41467-022-28526-y} 26 | } 27 | 28 | 29 | The code provided in this repository allows to train cG-SchNet on the tasks presented in the paper using the QM9 data set which consists of approximately 130k small molecules with up to nine heavy atoms from fluorine, oxygen, nitrogen, and carbon. 30 | 31 | We also provide links to the molecules generated with cG-SchNet for the paper as well as two pretrained models used therein. For details, please refer to the folder [_published_data_](https://github.com/atomistic-machine-learning/cG-SchNet/tree/main/published_data). 32 | 33 | ### Requirements 34 | - schnetpack 0.3 35 | - pytorch >= 1.2 36 | - python >= 3.7 37 | - ASE >= 3.17.0 38 | - Open Babel 2.41 39 | - rdkit >= 2019.03.4.0 40 | 41 | In the following, we describe the setup using Anaconda. 42 | 43 | The following commands will create a new conda environment called _"cgschnet"_ and install all dependencies (tested on Ubuntu 18.04 and 20.04): 44 | 45 | conda create -n cgschnet python=3.7 pytorch=1.5.1 torchvision cudatoolkit=10.2 ase=3.19.0 openbabel=2.4.1 rdkit=2019.09.2.0 -c pytorch -c openbabel -c defaults -c conda-forge 46 | conda activate cgschnet 47 | pip install 'schnetpack==0.3' 48 | 49 | Replace _"cudatoolkit=10.2"_ with _"cpuonly"_ if you do not want to utilize a GPU for training/generation. However, we strongly recommend to use a GPU if available. 50 | 51 | To observe the training progress, install tensorboard: 52 | 53 | pip install tensorboard 54 | 55 | # Getting started 56 | Clone the repository into your folder of choice: 57 | 58 | git clone https://github.com/atomistic-machine-learning/cG-SchNet.git 59 | 60 | 61 | ### Training a model 62 | A model conditioned on the composition of molecules and their relative atomic energy with the same settings as described in the paper can be trained by running gschnet_cond_script.py with the following parameters: 63 | 64 | python ./cG-SchNet/gschnet_cond_script.py train gschnet ./data/ ./models/cgschnet/ --conditioning_json_path ./cG-SchNet/conditioning_specifications/4_comp_relenergy.json --split_path ./cG-SchNet/splits/4_comp_relenergy_split.npz --cuda 65 | 66 | The training data (QM9) is automatically downloaded and preprocessed if not present in ./data/ and the model will be stored in ./models/cgschnet/. 67 | We recommend to train on a GPU but you can remove _--cuda_ from the call to use the CPU instead. If your GPU has less than 16GB VRAM, you need to decrease the number of features (e.g. _--features 64_) or the depth of the network (e.g. _--interactions 6_). 68 | The conditioning network is specified in a json-file. We provide the settings used in experiments throughout the paper in ./cG-SchNet/conditioning_specifications. The exact data splits utilized in our trainings are provided in ./cG-SchNet/splits/. The conditioning network and corresponding split are numbered like the models in Supplementary Table S1, where we list the hyper-parameter settings. Simply replace the paths behind the _--conditioning_json_path_ and _--split_path_ arguments in order to train the desired model. 69 | If you want to train on a new data split, remove _--split_path_ from the call and instead use _--split x y_ to randomly create a new split with x training molecules, y validation molecules, and the remaining molecules in the test set (e.g. with x=50000 and y=5000). 70 | 71 | The training of the full model takes about 40 hours on a A100 GPU. For testing purposes, you can leave the training running for only a couple of epochs. An epoch should take 10-20 min, depending on your hardware. To observe the training progress, use TensorBoard: 72 | 73 | tensorboard --logdir=./models 74 | 75 | 76 | The logs will appear after the first epoch has completed. 77 | 78 | ### Generating molecules 79 | Running the script with the following arguments will generate 1000 molecules using the composition C7O2H10 and relative atomic energy of -0.1 as conditions using the trained model at ./model/cgschnet/ and store them in ./model/cgschnet/generated/generated.mol_dict: 80 | 81 | python ./cG-SchNet/gschnet_cond_script.py generate gschnet ./models/cgschnet/ 1000 --conditioning "composition 10 7 0 2 0; n_atoms 19; relative_atomic_energy -0.1" --cuda 82 | 83 | Remove _--cuda_ from the call if you want to run on the CPU. Add _--show_gen_ to display the molecules with ASE after generation. If you are running into problems due to small VRAM, decrease the size of mini-batches during generation (e.g. _--chunk_size 500_, default is 1000). 84 | 85 | Note that the conditions for sampling are provided as a string using the _--conditioning_ argument. You have to provide the name of the property (as given in the field _in\_key\_property_ in the conditioning specification json without trailing underscore) followed by the target value(s). Each property the model was conditioned on, i.e. as listed in one of the layers in the conditioning specification json, has to be provided. Otherwise, the target value will be set to zero which often is an inappropriate value and thus may lead to invalid molecules. The composition of molecules is given as the number of atoms of type h, c, n, o, f (in that particular order). When conditioning on fingerprints, the value can either be an index (which is used to load the fingerprint of the corresponding molecule in the ```./data/qm9gen.db``` data base) or a SMILES string. 86 | 87 | ### Filtering and analysis of generated molecules 88 | After generation, the generated molecules can be filtered for invalid and duplicate structures by running filter_generated.py: 89 | 90 | python ./cG-SchNet/filter_generated.py ./models/cgschnet/generated/generated.mol_dict --train_data_path ./data/qm9gen.db --model_path ./models/cgschnet 91 | 92 | The script will print its progress and the gathered results. 93 | The script checks the valency constraints (e.g. every hydrogen atom should have exactly one bond), the connectedness (i.e. all atoms in a molecule should be connected to each other via a path over bonds), and removes duplicates*. The remaining valid structures are stored in an sqlite database with ASE (at ./models/cgschnet/generated/generated_molecules.db) along with an .npz-file that records certain statistics (e.g. the number of rings of certain sizes, the number of single, double, and triple bonds, the index of the matching training/test data molecule etc. for each molecule). 94 | 95 | In order to match the generated structures to training/test data, the QM9 data set is required. If it hasn't been downloaded before, e.g. because you are using a [pretrained model](https://github.com/atomistic-machine-learning/cG-SchNet/blob/main/published_data/README.md#pretrained-models) for generation, you can initialize the download by starting to train a dummy model for zero epochs and deleting it afterwards as follows: 96 | 97 | python ./cG-SchNet/gschnet_cond_script.py train gschnet ./data/ ./models/_dummy/ --split 1 1 --max_epochs 0 98 | rm -r ./models/_dummy 99 | 100 | *_Please note that, as described in the paper, we use molecular fingerprints and canonical smiles representations to identify duplicates which means that different spatial conformers corresponding to the same canonical smiles string are tagged as duplicates and removed in the process. Add '--filters valence disconnected' to the call in order to not remove but keep identified duplicates in the created database._ 101 | 102 | ### Displaying generated molecules 103 | After filtering, all generated molecules stored in the sqlite database can be displayed with ASE as follows: 104 | 105 | python ./cG-SchNet/display_molecules.py --data_path ./models/cgschnet/generated/generated_molecules.db 106 | 107 | # How does it work? 108 | 109 | cG-SchNet is an autoregressive neural network. It builds 3d molecules by placing one atom after another in 3d space. To this end, the joint distribution of all atoms is factorized into single steps, where the position and type of the new atom depends on the preceding atoms (Figure a). The model also processes conditions, i.e. values of target properties, which enable it to learn a conditional distribution of molecular structures. This distribution allows targeted sampling of molecules that are highly likely to exhibit specified conditions (see e.g. the distribution of the polarizability of molecules generated with cG-SchNet using five different target values in Figure b). The type and absolute position of new atoms are sampled successively, where the probability of the positions is apporximated from predicted pairwise distances to preceding atoms. In order to improve the accuracy of the approximation and steer the generation process, the network uses two auxiliary tokens, the focus and the origin. The new atom always has to be a neighbor of the focus and the origin marks the supposed center of mass of the final structure. A scheme explaining the generation procedure can be seen in Figure c. It uses 2d positional distributions for visualization purposes. For more details, please refer to the publication [_here_](https://www.nature.com/articles/s41467-022-28526-y). 110 | 111 | ![generated molecules](./images/concept_results_scheme.png) 112 | -------------------------------------------------------------------------------- /conditioning_specifications/1_polarizability.json: -------------------------------------------------------------------------------- 1 | { 2 | "stack": { 3 | "in_key": "representation", 4 | "layers": { 5 | "isotropic_polarizability": { 6 | "in_key_property": "_isotropic_polarizability", 7 | "n_in": 15, 8 | "n_layers": 3, 9 | "n_neurons": 64, 10 | "n_out": 64, 11 | "start": 5, 12 | "stop": 16 13 | } 14 | }, 15 | "mode": "stack", 16 | "n_global_cond_features": 128, 17 | "n_layers": 5, 18 | "n_neurons": 128, 19 | "out_key": "representation" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /conditioning_specifications/2_fingerprint.json: -------------------------------------------------------------------------------- 1 | { 2 | "stack": { 3 | "in_key": "representation", 4 | "layers": { 5 | "fingerprint": { 6 | "n_in": 1024, 7 | "n_out": 128, 8 | "n_layers": 3, 9 | "in_key_fingerprint": "_fingerprint" 10 | } 11 | }, 12 | "mode": "stack", 13 | "n_global_cond_features": 128, 14 | "n_layers": 5, 15 | "n_neurons": 128, 16 | "out_key": "representation" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /conditioning_specifications/3_gap_comp.json: -------------------------------------------------------------------------------- 1 | { 2 | "stack": { 3 | "in_key": "representation", 4 | "layers": { 5 | "composition": { 6 | "embedding": { 7 | "embedding_dim": 16, 8 | "num_embeddings": 20 9 | }, 10 | "in_key_composition": "_composition", 11 | "n_layers": 3, 12 | "n_neurons": 64, 13 | "n_out": 64, 14 | "n_types": 5, 15 | "skip_h": false, 16 | "type_weighting": "relative" 17 | }, 18 | "gap": { 19 | "in_key_property": "_gap", 20 | "n_in": 5, 21 | "n_layers": 3, 22 | "n_neurons": 64, 23 | "n_out": 64, 24 | "start": 2, 25 | "stop": 11 26 | }, 27 | "n_atoms": { 28 | "in_key_property": "_n_atoms", 29 | "n_in": 5, 30 | "n_layers": 3, 31 | "n_neurons": 64, 32 | "n_out": 64, 33 | "start": 0, 34 | "stop": 35 35 | } 36 | }, 37 | "mode": "stack", 38 | "n_global_cond_features": 128, 39 | "n_layers": 5, 40 | "n_neurons": 128, 41 | "out_key": "representation" 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /conditioning_specifications/4_comp_relenergy.json: -------------------------------------------------------------------------------- 1 | { 2 | "stack": { 3 | "in_key": "representation", 4 | "layers": { 5 | "composition": { 6 | "embedding": { 7 | "embedding_dim": 16, 8 | "num_embeddings": 20 9 | }, 10 | "in_key_composition": "_composition", 11 | "n_layers": 3, 12 | "n_neurons": 64, 13 | "n_out": 64, 14 | "n_types": 5, 15 | "skip_h": false, 16 | "type_weighting": "relative" 17 | }, 18 | "relative_atomic_energy": { 19 | "in_key_property": "_relative_atomic_energy", 20 | "n_in": 5, 21 | "n_layers": 3, 22 | "n_neurons": 64, 23 | "n_out": 64, 24 | "start": -0.2, 25 | "stop": 0.2 26 | }, 27 | "n_atoms": { 28 | "in_key_property": "_n_atoms", 29 | "n_in": 5, 30 | "n_layers": 3, 31 | "n_neurons": 64, 32 | "n_out": 64, 33 | "start": 0, 34 | "stop": 35 35 | } 36 | }, 37 | "mode": "stack", 38 | "n_global_cond_features": 128, 39 | "n_layers": 5, 40 | "n_neurons": 128, 41 | "out_key": "representation" 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /conditioning_specifications/5_gap_relenergy.json: -------------------------------------------------------------------------------- 1 | { 2 | "stack": { 3 | "in_key": "representation", 4 | "layers": { 5 | "gap": { 6 | "in_key_property": "_gap", 7 | "n_in": 5, 8 | "n_layers": 3, 9 | "n_neurons": 64, 10 | "n_out": 64, 11 | "start": 2, 12 | "stop": 11 13 | }, 14 | "relative_atomic_energy": { 15 | "in_key_property": "_relative_atomic_energy", 16 | "n_in": 5, 17 | "n_layers": 3, 18 | "n_neurons": 64, 19 | "n_out": 64, 20 | "start": -0.2, 21 | "stop": 0.2 22 | } 23 | }, 24 | "mode": "stack", 25 | "n_global_cond_features": 128, 26 | "n_layers": 5, 27 | "n_neurons": 128, 28 | "out_key": "representation" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /display_molecules.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import subprocess 5 | import numpy as np 6 | import tempfile 7 | 8 | from ase.db import connect 9 | from ase.io import write 10 | from utility_classes import IndexProvider 11 | 12 | 13 | def get_parser(): 14 | """ Setup parser for command line arguments """ 15 | main_parser = argparse.ArgumentParser() 16 | main_parser.add_argument('--data_path', type=str, default=None, 17 | help='Path to database with filtered, generated molecules ' 18 | '(.db format, needs to be provided if generated ' 19 | 'molecules shall be displayed, default: %(default)s)') 20 | main_parser.add_argument('--train_data_path', type=str, 21 | help='Path to training data base (.db format, needs to be ' 22 | 'provided if molecules from the training data set ' 23 | 'shall be displayed, e.g. when using --train or ' 24 | '--test, default: %(default)s)', 25 | default=None) 26 | main_parser.add_argument('--select', type=str, nargs='*', 27 | help='Selection strings that specify which molecules ' 28 | 'shall be shown, if None all molecules from ' 29 | 'data_path and/or train_data_path are shown, ' 30 | 'providing multiple strings' 31 | ' will open multiple windows (one per string), ' 32 | '(default: %(default)s). The selection string has ' 33 | 'the general format "Property,OperatorTarget" (e.g. ' 34 | '"C,>8"to filter for all molecules with more than ' 35 | 'eight carbon atoms where "C" is the statistic ' 36 | 'counting the number of carbon atoms in a molecule, ' 37 | '">" is the operator, and "8" is the target value). ' 38 | 'Multiple conditions can be combined to form one ' 39 | 'selection string using "&" (e.g "C,>8&R5,>0" to ' 40 | 'get all molecules with more than 8 carbon atoms ' 41 | 'and at least 1 ring of size 5). Prepending ' 42 | '"training" to the selection string will filter and ' 43 | 'display molecules from the training data base ' 44 | 'instead of generated molecules (e.g. "training C,>8"' 45 | '). An overview of the available properties for ' 46 | 'molecules generated with G-SchNet trained on QM9 can' 47 | ' be found in the README.md.', 48 | default=None) 49 | main_parser.add_argument('--print_indices', 50 | help='For each provided selection print out the indices ' 51 | 'of molecules that match the respective selection ' 52 | 'string', 53 | action='store_true') 54 | main_parser.add_argument('--export_to_dir', type=str, 55 | help='Optionally, provide a path to an directory to which ' 56 | 'indices of molecules matching the corresponding ' 57 | 'query shall be written (one .npy-file (numpy) per ' 58 | 'selection string, if None is provided, the ' 59 | 'indices will not be exported, default: %(default)s)', 60 | default=None) 61 | main_parser.add_argument('--train', 62 | help='Display all generated molecules that match ' 63 | 'structures used during training and the ' 64 | 'corresponding molecules from the training data set.', 65 | action='store_true') 66 | main_parser.add_argument('--test', 67 | help='Display all generated molecules that match ' 68 | 'held out test data structures and the ' 69 | 'corresponding molecules from the training data set.', 70 | action='store_true') 71 | main_parser.add_argument('--novel', 72 | help='Display all generated molecules that match neither ' 73 | 'structures used during training nor those held out ' 74 | 'as test data.', 75 | action='store_true') 76 | main_parser.add_argument('--block', 77 | help='Make the call to ASE GUI blocking (such that the ' 78 | 'script stops until the GUI window is closed).', 79 | action='store_true') 80 | 81 | return main_parser 82 | 83 | 84 | def view_ase(mols, name, block=False): 85 | ''' 86 | Display a list of molecules using the ASE GUI. 87 | 88 | Args: 89 | mols (list of ase.Atoms): molecules as ase.Atoms objects 90 | name (str): the name that shall be displayed in the windows top bar 91 | block (bool, optional): whether the call to ase gui shall block or not block 92 | the script (default: False) 93 | ''' 94 | dir = tempfile.mkdtemp('', 'generated_molecules_') # make temporary directory 95 | filename = os.path.join(dir, name) # path of temporary file 96 | format = 'traj' # use trajectory format for temporary file 97 | command = sys.executable + ' -m ase gui -b' # command to execute ase gui viewer 98 | write(filename, mols, format=format) # write molecules to temporary file 99 | # show molecules in ase gui and remove temporary file and directory afterwards 100 | if block: 101 | subprocess.call(command.split() + [filename]) 102 | os.remove(filename) 103 | os.rmdir(dir) 104 | else: 105 | subprocess.Popen(command.split() + [filename]) 106 | subprocess.Popen(['sleep 60; rm "{0}"'.format(filename)], shell=True) 107 | subprocess.Popen(['sleep 65; rmdir "{0}"'.format(dir)], shell=True) 108 | 109 | 110 | def print_indices(idcs, name='', per_line=10): 111 | ''' 112 | Prints provided indices in a clean formatting. 113 | 114 | Args: 115 | idcs (list of int): indices that shall be printed 116 | name (str, optional): the selection string that was used to obtain the indices 117 | per_line (int, optional): the number of indices that are printed per line ( 118 | default: 10) 119 | ''' 120 | biggest_count = len(str(len(idcs))) 121 | format_count = f'>{biggest_count}d' 122 | biggest_original_idx = len(str(max(idcs))) 123 | new_line = '\n' 124 | format = f'>{biggest_original_idx}d' 125 | str_idcs = [f'{i+1:{format_count}}: {j:{format}} ' + 126 | (new_line if (i+1) % per_line == 0 else '') 127 | for i, j in enumerate(idcs)] 128 | print(f'\nAll {len(idcs)} indices for selection {name}:') 129 | print(''.join(str_idcs)) 130 | 131 | 132 | if __name__ == '__main__': 133 | parser = get_parser() 134 | args = parser.parse_args() 135 | 136 | # make sure that at least one path was provided 137 | if args.data_path is None and args.train_data_path is None: 138 | print(f'\nPlease specify --data_path to display generated molecules or ' 139 | f'--train_data_path to display training molecules (or both)!') 140 | sys.exit(0) 141 | 142 | # sort queries into those concerning generated structures and those concerning 143 | # training data molecules 144 | train_selections = [] 145 | gen_selections = [] 146 | if args.select is not None: 147 | for selection in args.select: 148 | if selection.startswith('training'): 149 | # put queries concerning training structures aside for later 150 | train_selections += [selection] 151 | else: 152 | gen_selections += [selection] 153 | 154 | # make sure that the required paths were provided 155 | if args.train or args.test: 156 | if args.data_path is None: 157 | print('\nYou need to specify --data_path (and optionally ' 158 | '--train_data_path) if using --train or --test!') 159 | sys.exit(0) 160 | if args.novel: 161 | if args.data_path is None: 162 | print('\nYou need to specify --data_path if you want to display novel ' 163 | 'molecules!') 164 | sys.exit(0) 165 | if len(gen_selections) > 0: 166 | if args.data_path is None: 167 | print(f'\nYou need to specify --data_path to process the selections ' 168 | f'{gen_selections}!') 169 | sys.exit(0) 170 | if len(train_selections) > 0: 171 | if args.train_data_path is None: 172 | print(f'\nYou need to specify --train_data_path to process the selections ' 173 | f'{train_selections}!') 174 | sys.exit(0) 175 | 176 | # check if statistics files are needed 177 | need_gen_stats = (len(gen_selections) > 0) or args.train or args.test or args.novel 178 | need_train_stats = (len(train_selections) > 0) or args.train or args.test 179 | 180 | # check if there is a database with generated molecules at the provided path 181 | # and load accompanying statistics file 182 | if args.data_path is not None: 183 | if not os.path.isfile(args.data_path): 184 | print(f'\nThe specified data path ({args.data_path}) is not a file! Please ' 185 | f'specify a different data path.') 186 | raise FileNotFoundError 187 | elif need_gen_stats: 188 | stats_path = os.path.splitext(args.data_path)[0] + f'_statistics.npz' 189 | if not os.path.isfile(stats_path): 190 | print(f'\nCannot find statistics file belonging to {args.data_path} (' 191 | f'expected it at {stats_path}. Please make sure that the file ' 192 | f'exists.') 193 | raise FileNotFoundError 194 | else: 195 | stats_dict = np.load(stats_path) 196 | index_provider = IndexProvider(stats_dict['stats'], 197 | stats_dict['stat_heads']) 198 | 199 | # check if there is a database with training molecules at the provided path 200 | # and load accompanying statistics file 201 | if args.train_data_path is not None: 202 | if not os.path.isfile(args.train_data_path): 203 | print(f'\nThe specified training data path ({args.train_data_path}) is ' 204 | f'not a file! Please specify --train_data_path correctly.') 205 | raise FileNotFoundError 206 | elif need_train_stats: 207 | stats_path = os.path.splitext(args.train_data_path)[0] + f'_statistics.npz' 208 | if not os.path.isfile(stats_path) and len(train_selections) > 0: 209 | print(f'\nCannot find statistics file belonging to ' 210 | f'{args.train_data_path} (expected it at {stats_path}. Please ' 211 | f'make sure that the file exists.') 212 | raise FileNotFoundError 213 | else: 214 | train_stats_dict = np.load(stats_path) 215 | train_index_provider = IndexProvider(train_stats_dict['stats'], 216 | train_stats_dict['stat_heads']) 217 | 218 | # create folder(s) for export of indices if necessary 219 | if args.export_to_dir is not None: 220 | if not os.path.isdir(args.export_to_dir): 221 | print(f'\nDirectory {args.export_to_dir} does not exist, creating ' 222 | f'it to store indices of molecules matching the queries!') 223 | os.makedirs(args.export_to_dir) 224 | else: 225 | print(f'\nWill store indices of molecules matching the queries at ' 226 | f'{args.export_to_dir}!') 227 | 228 | # display all generated molecules if desired 229 | if (len(gen_selections) == 0) and not (args.train or args.test or args.novel) and\ 230 | args.data_path is not None: 231 | with connect(args.data_path) as con: 232 | _ats = [con.get(int(idx) + 1).toatoms() for idx in range(con.count())] 233 | view_ase(_ats, 'all generated molecules', args.block) 234 | 235 | # display generated molecules matching selection strings 236 | if len(gen_selections) > 0: 237 | for selection in gen_selections: 238 | # display queries concerning generated molecules 239 | idcs = index_provider.get_selected(selection) 240 | if len(idcs) == 0: 241 | print(f'\nNo molecules match selection {selection}!') 242 | continue 243 | with connect(args.data_path) as con: 244 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs] 245 | if args.print_indices: 246 | print_indices(idcs, selection) 247 | view_ase(_ats, f'generated molecules ({selection})', args.block) 248 | if args.export_to_dir is not None: 249 | np.save(os.path.join(args.export_to_dir, selection), idcs) 250 | 251 | # display all training molecules if desired 252 | if (len(train_selections) == 0) and not (args.train or args.test) and \ 253 | args.train_data_path is not None: 254 | with connect(args.train_data_path) as con: 255 | _ats = [con.get(int(idx) + 1).toatoms() for idx in range(con.count())] 256 | view_ase(_ats, 'all molecules in the training data set', args.block) 257 | 258 | # display training molecules matching selection strings 259 | if len(train_selections) > 0: 260 | # display training molecules that match the selection strings 261 | for selection in train_selections: 262 | _selection = selection.split()[1] 263 | stats_queries = [] 264 | db_queries = [] 265 | # sort into queries handled by looking into the statistics or the db 266 | for _sel_str in _selection.split('&'): 267 | prop = _sel_str.split(',')[0] 268 | if prop in train_stats_dict['stat_heads']: 269 | stats_queries += [_sel_str] 270 | elif len(prop.split('+')) > 0: 271 | found = True 272 | for p in prop.split('+'): 273 | if p not in train_stats_dict['stat_heads']: 274 | found = False 275 | break 276 | if found: 277 | stats_queries += [_sel_str] 278 | else: 279 | db_queries += [_sel_str] 280 | else: 281 | db_queries += [_sel_str] 282 | # process queries concerning the statistics 283 | if len(stats_queries) > 0: 284 | idcs = train_index_provider.get_selected('&'.join(stats_queries)) 285 | else: 286 | idcs = range(connect(args.train_data_path).count()) 287 | # process queries concerning the db entries 288 | if len(db_queries) > 0: 289 | with connect(args.train_data_path) as con: 290 | for query in db_queries: 291 | head, condition = query.split(',') 292 | if head not in con.get(1).data: 293 | print(f'Entry {head} not found for molecules in the ' 294 | f'database, skipping query {query}.') 295 | continue 296 | else: 297 | op = train_index_provider.rel_re.search(condition).group(0) 298 | op = train_index_provider.op_dict[op] # extract operator 299 | num = float(train_index_provider.num_re.search( 300 | condition).group(0)) # extract numerical value 301 | remaining_idcs = [] 302 | for idx in idcs: 303 | if op(con.get(int(idx)+1).data[head], num): 304 | remaining_idcs += [idx] 305 | idcs = remaining_idcs 306 | # extract molecules matching the query from db and display them 307 | if len(idcs) == 0: 308 | print(f'\nNo training molecules match selection {_selection}!') 309 | continue 310 | with connect(args.train_data_path) as con: 311 | _ats = [con.get(int(idx)+1).toatoms() for idx in idcs] 312 | if args.print_indices: 313 | print_indices(idcs, selection) 314 | view_ase(_ats, f'training data set molecules ({_selection})', args.block) 315 | if args.export_to_dir is not None: 316 | np.save(os.path.join(args.export_to_dir, selection), idcs) 317 | 318 | # display generated molecules that match structures used for training 319 | if args.train: 320 | idcs = index_provider.get_selected('known,>=1&known,<=2') 321 | if len(idcs) == 0: 322 | print(f'\nNo generated molecules found that match structures used ' 323 | f'during training!') 324 | else: 325 | with connect(args.data_path) as con: 326 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs] 327 | if args.print_indices: 328 | print_indices(idcs, 'generated train') 329 | view_ase(_ats, f'generated molecules (matching train structures)', 330 | args.block) 331 | if args.export_to_dir is not None: 332 | np.save(os.path.join(args.export_to_dir, 'generated train'), idcs) 333 | # display corresponding training structures 334 | if args.train_data_path is not None: 335 | _row_idx = list(stats_dict['stat_heads']).index('equals') 336 | t_idcs = stats_dict['stats'][_row_idx, idcs].astype(int) 337 | with connect(args.train_data_path) as con: 338 | _ats = [con.get(int(idx) + 1).toatoms() for idx in t_idcs] 339 | if args.print_indices: 340 | print_indices(t_idcs, 'reference train') 341 | view_ase(_ats, f'training molecules (train structures)', args.block) 342 | if args.export_to_dir is not None: 343 | np.save(os.path.join(args.export_to_dir, 'reference train'), t_idcs) 344 | 345 | # display generated molecules that match held out test structures 346 | if args.test: 347 | idcs = index_provider.get_selected('known,==3') 348 | if len(idcs) == 0: 349 | print(f'\nNo generated molecules found that match held out test ' 350 | f'structures!') 351 | else: 352 | with connect(args.data_path) as con: 353 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs] 354 | if args.print_indices: 355 | print_indices(idcs, 'generated test') 356 | view_ase(_ats, f'generated molecules (matching test structures)', 357 | args.block) 358 | if args.export_to_dir is not None: 359 | np.save(os.path.join(args.export_to_dir, 'generated test'), idcs) 360 | # display corresponding training structures 361 | if args.train_data_path is not None: 362 | _row_idx = list(stats_dict['stat_heads']).index('equals') 363 | t_idcs = stats_dict['stats'][_row_idx, idcs].astype(int) 364 | with connect(args.train_data_path) as con: 365 | _ats = [con.get(int(idx) + 1).toatoms() for idx in t_idcs] 366 | if args.print_indices: 367 | print_indices(t_idcs, 'reference test') 368 | view_ase(_ats, f'training molecules (test structures)', args.block) 369 | if args.export_to_dir is not None: 370 | np.save(os.path.join(args.export_to_dir, 'reference test'), t_idcs) 371 | 372 | # display generated molecules that are novel (i.e. that do not match held out 373 | # test structures or structures used during training) 374 | if args.novel: 375 | idcs = index_provider.get_selected('known,==0') 376 | if len(idcs) == 0: 377 | print(f'\nNo novel molecules found!') 378 | else: 379 | with connect(args.data_path) as con: 380 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs] 381 | if args.print_indices: 382 | print_indices(idcs, 'novel') 383 | view_ase(_ats, f'generated molecules (novel)', args.block) 384 | if args.export_to_dir is not None: 385 | np.save(os.path.join(args.export_to_dir, 'generated novel'), idcs) 386 | -------------------------------------------------------------------------------- /gschnet_cond_script.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import time 6 | import json 7 | from shutil import copyfile, rmtree 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.optim import Adam 13 | from torch.utils.data.sampler import RandomSampler 14 | from ase import Atoms 15 | from ase.db import connect 16 | import ase.visualize as asv 17 | 18 | import schnetpack as spk 19 | from schnetpack.utils import count_params, to_json, read_from_json 20 | from schnetpack import Properties 21 | from schnetpack.datasets import DownloadableAtomsData 22 | 23 | from nn_classes import AtomwiseWithProcessing, EmbeddingMultiplication,\ 24 | RepresentationConditioning, NormalizeAndAggregate, KLDivergence,\ 25 | FingerprintEmbedding, AtomCompositionEmbedding, PropertyEmbedding 26 | from utility_functions import boolean_string, collate_atoms, generate_molecules, \ 27 | update_dict, get_dict_count, get_composition 28 | 29 | # add your own dataset classes here: 30 | from qm9_data import QM9gen 31 | dataset_name_to_class_mapping = {'qm9': QM9gen} 32 | 33 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) 34 | 35 | 36 | def get_parser(): 37 | """ Setup parser for command line arguments """ 38 | main_parser = argparse.ArgumentParser() 39 | 40 | ## command-specific 41 | cmd_parser = argparse.ArgumentParser(add_help=False) 42 | cmd_parser.add_argument('--cuda', help='Set flag to use GPU(s)', 43 | action='store_true') 44 | cmd_parser.add_argument('--parallel', 45 | help='Run data-parallel on all available GPUs ' 46 | '(specify with environment variable' 47 | + ' CUDA_VISIBLE_DEVICES)', 48 | action='store_true') 49 | cmd_parser.add_argument('--batch_size', type=int, 50 | help='Mini-batch size for training and prediction ' 51 | '(default: %(default)s)', 52 | default=5) 53 | cmd_parser.add_argument('--draw_random_samples', type=int, default=0, 54 | help='Only draw x generation steps per molecule ' 55 | 'in each batch (if x=0, all generation ' 56 | 'steps are included for each molecule,' 57 | 'default: %(default)s)') 58 | cmd_parser.add_argument('--checkpoint', type=int, default=-1, 59 | help='The checkpoint of the model that is going ' 60 | 'to be loaded for evaluation or generation ' 61 | '(set to -1 to load the best model ' 62 | 'according to validation error, ' 63 | 'default: %(default)s)') 64 | cmd_parser.add_argument('--precompute_distances', type=boolean_string, 65 | default='true', 66 | help='Store precomputed distances in the database ' 67 | 'during pre-processing (caution, has no effect if ' 68 | 'the dataset has already been downloaded, ' 69 | 'pre-processed, and stored before, ' 70 | 'default: %(default)s)') 71 | 72 | ## training 73 | train_parser = argparse.ArgumentParser(add_help=False, 74 | parents=[cmd_parser]) 75 | train_parser.add_argument('datapath', 76 | help='Path / destination of dataset '\ 77 | 'directory') 78 | train_parser.add_argument('modelpath', 79 | help='Destination for models and logs') 80 | train_parser.add_argument('--dataset_name', type=str, default='qm9', 81 | help=f'Name of the dataset used (choose from ' 82 | f'{list(dataset_name_to_class_mapping.keys())}, ' 83 | f'default: %(default)s)'), 84 | train_parser.add_argument('--subset_path', type=str, 85 | help='A path to a npy file containing indices ' 86 | 'of a subset of the data set at datapath ' 87 | '(default: %(default)s)', 88 | default=None) 89 | train_parser.add_argument('--seed', type=int, default=None, 90 | help='Set random seed for torch and numpy.') 91 | train_parser.add_argument('--overwrite', 92 | help='Remove previous model directory.', 93 | action='store_true') 94 | train_parser.add_argument('--pretrained_path', 95 | help='Start training from the pre-trained model at the ' 96 | 'provided path (reset optimizer parameters such as ' 97 | 'best loss and learning rate and create new split)', 98 | default=None) 99 | train_parser.add_argument('--split_path', 100 | help='Path/destination of npz with data splits', 101 | default=None) 102 | train_parser.add_argument('--split', 103 | help='Split into [train] [validation] and use ' 104 | 'remaining for testing', 105 | type=int, nargs=2, default=[None, None]) 106 | train_parser.add_argument('--max_epochs', type=int, 107 | help='Maximum number of training epochs ' 108 | '(default: %(default)s)', 109 | default=500) 110 | train_parser.add_argument('--lr', type=float, 111 | help='Initial learning rate ' 112 | '(default: %(default)s)', 113 | default=1e-4) 114 | train_parser.add_argument('--lr_patience', type=int, 115 | help='Epochs without improvement before reducing' 116 | ' the learning rate (default: %(default)s)', 117 | default=10) 118 | train_parser.add_argument('--lr_decay', type=float, 119 | help='Learning rate decay ' 120 | '(default: %(default)s)', 121 | default=0.5) 122 | train_parser.add_argument('--lr_min', type=float, 123 | help='Minimal learning rate ' 124 | '(default: %(default)s)', 125 | default=1e-6) 126 | train_parser.add_argument('--logger', 127 | help='Choose logger for training process ' 128 | '(default: %(default)s)', 129 | choices=['csv', 'tensorboard'], 130 | default='tensorboard') 131 | train_parser.add_argument('--log_every_n_epochs', type=int, 132 | help='Log metrics every given number of epochs ' 133 | '(default: %(default)s)', 134 | default=1) 135 | train_parser.add_argument('--checkpoint_every_n_epochs', type=int, 136 | help='Create checkpoint every given number of ' 137 | 'epochs' 138 | '(default: %(default)s)', 139 | default=25) 140 | train_parser.add_argument('--label_width_factor', type=float, 141 | help='A factor that is multiplied with the ' 142 | 'range between two distance bins in order ' 143 | 'to determine the width of the Gaussians ' 144 | 'used to obtain labels from distances ' 145 | '(set to 0. to use one-hot ' 146 | 'encodings of distances as labels, ' 147 | 'default: %(default)s)', 148 | default=0.1) 149 | train_parser.add_argument('--conditioning_json_path', type=str, 150 | help='Path to .json-file with specification of layers ' 151 | 'used for conditioning of the model, default: ' 152 | '%(default)s)', 153 | default=None) 154 | train_parser.add_argument('--use_embeddings_for_type_predictions', 155 | help='Copy extracted features and multiply them with ' 156 | 'embeddings of all possible types to obtain scores.', 157 | action='store_true') 158 | train_parser.add_argument('--share_embeddings', 159 | help='Share embedding layers in SchNet part and in ' 160 | 'pre-processing before predicting distances and ' 161 | 'types.', 162 | action='store_true') 163 | 164 | ## evaluation 165 | eval_parser = argparse.ArgumentParser(add_help=False, parents=[cmd_parser]) 166 | eval_parser.add_argument('datapath', help='Path of dataset directory') 167 | eval_parser.add_argument('modelpath', help='Path of stored model') 168 | eval_parser.add_argument('--split', 169 | help='Evaluate trained model on given split', 170 | choices=['train', 'validation', 'test'], 171 | default=['test'], nargs='+') 172 | 173 | ## molecule generation 174 | gen_parser = argparse.ArgumentParser(add_help=False, parents=[cmd_parser]) 175 | gen_parser.add_argument('modelpath', help='Path of stored model') 176 | gen_parser.add_argument('amount_gen', type=int, 177 | help='The amount of generated molecules') 178 | gen_parser.add_argument('--show_gen', 179 | help='Whether to open plots of generated ' 180 | 'molecules for visual evaluation', 181 | action='store_true') 182 | gen_parser.add_argument('--chunk_size', type=int, 183 | help='The size of mini batches during generation ' 184 | '(default: %(default)s)', 185 | default=1000) 186 | gen_parser.add_argument('--max_length', type=int, 187 | help='The maximum number of atoms per molecule ' 188 | '(default: %(default)s)', 189 | default=35) 190 | gen_parser.add_argument('--folder_name', type=str, 191 | help='The name of the folder in which generated ' 192 | 'molecules are stored (please note that the folder ' 193 | 'is always inside the model directory and always ' 194 | 'called "generated", but custom extensions may be ' 195 | 'provided here, e.g. "--folder_name _10" will place ' 196 | 'the generated molecules in a folder called ' 197 | '"generated_10", default: %(default)s)', 198 | default='') 199 | gen_parser.add_argument('--file_name', type=str, 200 | help='The name of the file in which generated ' 201 | 'molecules are stored (please note that ' 202 | 'increasing numbers are appended to the file name ' 203 | 'if it already exists and that the extension ' 204 | '.mol_dict is automatically added to the chosen ' 205 | 'file name, default: %(default)s)', 206 | default='generated') 207 | gen_parser.add_argument('--store_unfinished', 208 | help='Store molecules which have not been ' 209 | 'finished after sampling max_length atoms', 210 | action='store_true') 211 | gen_parser.add_argument('--store_process', 212 | help='Store information needed to track the generation ' 213 | 'process (i.e. current focus, predicted distributions,' 214 | ' sampled type etc.) in the .mol_dict file', 215 | action='store_true') 216 | gen_parser.add_argument('--print_file', 217 | help='Use to limit the printing if results are ' 218 | 'written to a file instead of the console (' 219 | 'e.g. if running on a cluster)', 220 | action='store_true') 221 | gen_parser.add_argument('--temperature', type=float, 222 | help='The temperature T to use for sampling ' 223 | '(default: %(default)s)', 224 | default=0.1) 225 | gen_parser.add_argument('--conditioning', type=str, default=None, 226 | help='Additional input for conditioning of molecule ' 227 | 'generation. Write "property1 value1; property2 ' 228 | 'value2; ..." to specify the additional information ' 229 | '(where multiple values can be provided per property, ' 230 | 'e.g. the atomic composition, default: None)') 231 | 232 | # model-specific parsers 233 | model_parser = argparse.ArgumentParser(add_help=False) 234 | model_parser.add_argument('--aggregation_mode', type=str, default='sum', 235 | choices=['sum', 'avg'], 236 | help=' (default: %(default)s)') 237 | 238 | ####### G-SchNet ####### 239 | gschnet_parser = argparse.ArgumentParser(add_help=False, 240 | parents=[model_parser]) 241 | gschnet_parser.add_argument('--features', type=int, 242 | help='Size of atom-wise representation ' 243 | '(default: %(default)s)', 244 | default=128) 245 | gschnet_parser.add_argument('--interactions', type=int, 246 | help='Number of regular SchNet interaction ' 247 | 'blocks (default: %(default)s)', 248 | default=9) 249 | gschnet_parser.add_argument('--dense_layers', type=int, 250 | help='Number of layers in the (atom-wise) dense ' 251 | 'output networks to predict next type and ' 252 | 'distances (default: %(default)s)', 253 | default=5) 254 | gschnet_parser.add_argument('--cutoff', type=float, default=10., 255 | help='Cutoff radius of local environment ' 256 | '(default: %(default)s)') 257 | gschnet_parser.add_argument('--num_gaussians', type=int, default=25, 258 | help='Number of Gaussians to expand distances ' 259 | '(default: %(default)s)') 260 | gschnet_parser.add_argument('--max_distance', type=float, default=15., 261 | help='Maximum distance covered by the discrete ' 262 | 'distributions over distances learned by ' 263 | 'the model ' 264 | '(default: %(default)s)') 265 | gschnet_parser.add_argument('--num_distance_bins', type=int, default=300, 266 | help='Number of bins used in the discrete ' 267 | 'distributions over distances learned by ' 268 | 'the model(default: %(default)s)') 269 | 270 | ## setup subparser structure 271 | cmd_subparsers = main_parser.add_subparsers(dest='mode', 272 | help='Command-specific ' 273 | 'arguments') 274 | cmd_subparsers.required = True 275 | subparser_train = cmd_subparsers.add_parser('train', help='Training help') 276 | subparser_eval = cmd_subparsers.add_parser('eval', help='Eval help') 277 | subparser_gen = cmd_subparsers.add_parser('generate', help='Generate help') 278 | 279 | train_subparsers = subparser_train.add_subparsers(dest='model', 280 | help='Model-specific ' 281 | 'arguments') 282 | train_subparsers.required = True 283 | train_subparsers.add_parser('gschnet', help='G-SchNet help', 284 | parents=[train_parser, gschnet_parser]) 285 | 286 | eval_subparsers = subparser_eval.add_subparsers(dest='model', 287 | help='Model-specific ' 288 | 'arguments') 289 | eval_subparsers.required = True 290 | eval_subparsers.add_parser('gschnet', help='G-SchNet help', 291 | parents=[eval_parser, gschnet_parser]) 292 | 293 | gen_subparsers = subparser_gen.add_subparsers(dest='model', 294 | help='Model-specific ' 295 | 'arguments') 296 | gen_subparsers.required = True 297 | gen_subparsers.add_parser('gschnet', help='G-SchNet help', 298 | parents=[gen_parser, gschnet_parser]) 299 | 300 | return main_parser 301 | 302 | 303 | def get_model(args, conditioning_specification, parallelize=False): 304 | # get information about the atom types available in the data set 305 | dataclass = dataset_name_to_class_mapping[args.dataset_name] 306 | num_types = len(dataclass.available_atom_types) 307 | max_type = max(dataclass.available_atom_types) 308 | 309 | # get SchNet layers for feature extraction 310 | representation =\ 311 | spk.representation.SchNet(n_atom_basis=args.features, 312 | n_filters=args.features, 313 | n_interactions=args.interactions, 314 | cutoff=args.cutoff, 315 | n_gaussians=args.num_gaussians, 316 | max_z=max_type+3) 317 | if args.share_embeddings: 318 | emb_layers = representation.embedding 319 | else: 320 | emb_layers = nn.Embedding(max_type+3, args.features, padding_idx=0) 321 | 322 | # build layers for conditioning according to conditioning specification dictionary 323 | representation_conditioning_blocks = [] 324 | n_features = args.features 325 | if conditioning_specification is not None: 326 | for block in conditioning_specification: 327 | layers = conditioning_specification[block]['layers'] 328 | layers_list = [] 329 | for nn_name in layers: 330 | if 'p' in layers[nn_name]: 331 | layers[nn_name].pop('p') 332 | layer_class = get_conditioning_nn(nn_name, dataclass) 333 | layer_arguments = layers[nn_name] 334 | layers_list += [layer_class(**layer_arguments)] 335 | conditioning_specification[block]['layers'] = layers_list 336 | representation_conditioning_blocks += \ 337 | [RepresentationConditioning(**conditioning_specification[block])] 338 | n_features += representation_conditioning_blocks[-1].n_additional_features 339 | 340 | # get output layers for prediction of next atom type 341 | if args.use_embeddings_for_type_predictions: 342 | preprocess_type = \ 343 | EmbeddingMultiplication(emb_layers, 344 | in_key_types='_all_types', 345 | in_key_representation='representation', 346 | out_key='preprocessed_representation') 347 | _n_out = 1 348 | else: 349 | preprocess_type = None 350 | _n_out = num_types + 1 # number of possible types + stop token 351 | postprocess_type = NormalizeAndAggregate(normalize=True, 352 | normalization_axis=-1, 353 | normalization_mode='logsoftmax', 354 | aggregate=True, 355 | aggregation_axis=-2, 356 | aggregation_mode='sum', 357 | keepdim=False, 358 | mask='_type_mask', 359 | squeeze=True) 360 | out_module_type = \ 361 | AtomwiseWithProcessing(n_in=n_features, 362 | n_out=_n_out, 363 | n_layers=args.dense_layers, 364 | preprocess_layers=preprocess_type, 365 | postprocess_layers=postprocess_type, 366 | out_key='type_predictions') 367 | 368 | # get output layers for predictions of distances 369 | preprocess_dist = \ 370 | EmbeddingMultiplication(emb_layers, 371 | in_key_types='_next_types', 372 | in_key_representation='representation', 373 | out_key='preprocessed_representation') 374 | out_module_dist = \ 375 | AtomwiseWithProcessing(n_in=n_features, 376 | n_out=args.num_distance_bins, 377 | n_layers=args.dense_layers, 378 | preprocess_layers=preprocess_dist, 379 | out_key='distance_predictions') 380 | 381 | # combine layers into an atomistic model 382 | model = spk.atomistic.AtomisticModel(representation, 383 | representation_conditioning_blocks + 384 | [out_module_type, out_module_dist]) 385 | 386 | if parallelize: 387 | model = nn.DataParallel(model) 388 | 389 | logging.info("The model you built has: %d parameters" % 390 | count_params(model)) 391 | 392 | return model 393 | 394 | 395 | def train(args, model, train_loader, val_loader, device): 396 | 397 | # setup hooks and logging 398 | hooks = [ 399 | spk.hooks.MaxEpochHook(args.max_epochs) 400 | ] 401 | 402 | # filter for trainable parameters 403 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 404 | # setup optimizer 405 | optimizer = Adam(trainable_params, lr=args.lr) 406 | schedule = spk.hooks.ReduceLROnPlateauHook(optimizer, 407 | patience=args.lr_patience, 408 | factor=args.lr_decay, 409 | min_lr=args.lr_min, 410 | window_length=1, 411 | stop_after_min=True) 412 | hooks.append(schedule) 413 | 414 | # set up metrics to log KL divergence on distributions of types and distances 415 | metrics = [KLDivergence(target='_type_labels', 416 | model_output='type_predictions', 417 | name='KLD_types'), 418 | KLDivergence(target='_labels', 419 | model_output='distance_predictions', 420 | mask='_dist_mask', 421 | name='KLD_dists')] 422 | 423 | if args.logger == 'csv': 424 | logger =\ 425 | spk.hooks.CSVHook(os.path.join(args.modelpath, 'log'), 426 | metrics, 427 | every_n_epochs=args.log_every_n_epochs) 428 | hooks.append(logger) 429 | elif args.logger == 'tensorboard': 430 | logger =\ 431 | spk.hooks.TensorboardHook(os.path.join(args.modelpath, 'log'), 432 | metrics, 433 | every_n_epochs=args.log_every_n_epochs) 434 | hooks.append(logger) 435 | 436 | norm_layer = nn.LogSoftmax(-1).to(device) 437 | loss_layer = nn.KLDivLoss(reduction='none').to(device) 438 | 439 | # setup loss function 440 | def loss(batch, result): 441 | # loss for type predictions (KLD) 442 | out_type = norm_layer(result['type_predictions']) 443 | loss_type = loss_layer(out_type, batch['_type_labels']) 444 | loss_type = torch.sum(loss_type, -1) 445 | loss_type = torch.mean(loss_type) 446 | 447 | # loss for distance predictions (KLD) 448 | mask_dist = batch['_dist_mask'] 449 | N = torch.sum(mask_dist) 450 | out_dist = norm_layer(result['distance_predictions']) 451 | loss_dist = loss_layer(out_dist, batch['_labels']) 452 | loss_dist = torch.sum(loss_dist, -1) 453 | loss_dist = torch.sum(loss_dist * mask_dist) / torch.max(N, torch.ones_like(N)) 454 | 455 | return loss_type + loss_dist 456 | 457 | # initialize trainer 458 | trainer = spk.train.Trainer(args.modelpath, 459 | model, 460 | loss, 461 | optimizer, 462 | train_loader, 463 | val_loader, 464 | hooks=hooks, 465 | checkpoint_interval=args.checkpoint_every_n_epochs, 466 | keep_n_checkpoints=3) 467 | 468 | # reset optimizer and hooks if starting from pre-trained model (e.g. for 469 | # fine-tuning) 470 | if args.pretrained_path is not None: 471 | logging.info('starting from pre-trained model...') 472 | # reset epoch and step 473 | trainer.epoch = 0 474 | trainer.step = 0 475 | trainer.best_loss = float('inf') 476 | # reset optimizer 477 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 478 | optimizer = Adam(trainable_params, lr=args.lr) 479 | trainer.optimizer = optimizer 480 | # reset scheduler 481 | schedule =\ 482 | spk.hooks.ReduceLROnPlateauHook(optimizer, 483 | patience=args.lr_patience, 484 | factor=args.lr_decay, 485 | min_lr=args.lr_min, 486 | window_length=1, 487 | stop_after_min=True) 488 | trainer.hooks[1] = schedule 489 | # remove checkpoints of pre-trained model 490 | rmtree(os.path.join(args.modelpath, 'checkpoints')) 491 | os.makedirs(os.path.join(args.modelpath, 'checkpoints')) 492 | # store first checkpoint 493 | trainer.store_checkpoint() 494 | 495 | # start training 496 | trainer.train(device) 497 | 498 | 499 | def evaluate(args, model, train_loader, val_loader, test_loader, device): 500 | header = ['Subset', 'distances KLD', 'types KLD'] 501 | 502 | metrics = [KLDivergence(target='_labels', 503 | model_output='distance_predictions', 504 | mask='_dist_mask'), 505 | KLDivergence(target='_type_labels', 506 | model_output='type_predictions')] 507 | 508 | results = [] 509 | if 'train' in args.split: 510 | results.append(['training'] + 511 | ['%.5f' % i for i in 512 | evaluate_dataset(metrics, model, 513 | train_loader, device)]) 514 | 515 | if 'validation' in args.split: 516 | results.append(['validation'] + 517 | ['%.5f' % i for i in 518 | evaluate_dataset(metrics, model, 519 | val_loader, device)]) 520 | 521 | if 'test' in args.split: 522 | results.append(['test'] + ['%.5f' % i for i in evaluate_dataset( 523 | metrics, model, test_loader, device)]) 524 | 525 | header = ','.join(header) 526 | results = np.array(results) 527 | 528 | np.savetxt(os.path.join(args.modelpath, 'evaluation.csv'), results, 529 | header=header, fmt='%s', delimiter=',') 530 | 531 | 532 | def evaluate_dataset(metrics, model, loader, device): 533 | for metric in metrics: 534 | metric.reset() 535 | 536 | for batch in loader: 537 | batch = { 538 | k: v.to(device) 539 | for k, v in batch.items() 540 | } 541 | result = model(batch) 542 | 543 | for metric in metrics: 544 | metric.add_batch(batch, result) 545 | 546 | results = [ 547 | metric.aggregate() for metric in metrics 548 | ] 549 | return results 550 | 551 | 552 | def generate(args, train_args, model, device, conditioning_layer_list): 553 | # generate molecules (in chunks) and print progress 554 | 555 | dataclass = dataset_name_to_class_mapping[train_args.dataset_name] 556 | types = sorted(dataclass.available_atom_types) # retrieve available atom types 557 | all_types = types + [types[-1] + 1] # add stop token to list (largest type + 1) 558 | start_token = types[-1] + 2 # define start token (largest type + 2) 559 | amount = args.amount_gen 560 | chunk_size = args.chunk_size 561 | if chunk_size >= amount: 562 | chunk_size = amount 563 | 564 | # set parameters for printing progress 565 | if int(amount / 10.) < chunk_size: 566 | step = chunk_size 567 | else: 568 | step = int(amount / 10.) 569 | increase = lambda x, y: y + step if x >= y else y 570 | thresh = step 571 | if args.print_file: 572 | progress = lambda x, y: print(f'Generated {x}.', flush=True) \ 573 | if x >= y else print('', end='', flush=True) 574 | else: 575 | progress = lambda x, y: print(f'\x1b[2K\rSuccessfully generated' 576 | f' {x}', end='', flush=True) 577 | 578 | # extract conditioning information 579 | conditioning = {} 580 | if args.conditioning is not None: 581 | conds = args.conditioning.split('; ') 582 | for cond_list in conds: 583 | cond_list = cond_list.split(' ') 584 | cond_name = cond_list[0] 585 | if cond_name not in conditioning_layer_list: 586 | logging.info(f'The provided model was not trained to condition on ' 587 | f'{cond_name}! The condition will be ignored during ' 588 | f'generation.') 589 | continue 590 | cond_vals = cond_list[1:] 591 | if cond_name == 'fingerprint': 592 | if not os.path.isfile('./data/qm9gen.db'): 593 | logging.error(f'could not find database with fingerprints at ./data/qm9gen.db!') 594 | logging.error(f'stopping generation!') 595 | return 596 | with connect('./data/qm9gen.db') as conn: 597 | if cond_vals[0].isdigit(): 598 | if 'fingerprint_format' not in conn.metadata: 599 | logging.error(f'fingerprints not found in database!') 600 | logging.error(f'please re-download data when training a model with fingerprints as condions!') 601 | logging.error(f'stopping generation!') 602 | return 603 | fp = np.array(conn.get(int(cond_vals[0])+1).data['fingerprint'], 604 | dtype=conn.metadata['fingerprint_format']) 605 | else: 606 | import pybel 607 | fp = \ 608 | np.array(pybel.readstring('smi', cond_vals[0]).calcfp().fp, 609 | dtype=conn.metadata['fingerprint_format']) 610 | from collections import Counter 611 | _ct = Counter(cond_vals[0].lower()) 612 | conditioning['_composition'] = \ 613 | np.array([_ct['c'], _ct['n'], _ct['o'], _ct['f']]) 614 | conditioning['_fingerprint'] = \ 615 | np.unpackbits(fp.view(np.uint8), bitorder='little')[None, ...] 616 | else: 617 | if not cond_vals[0].isdigit(): 618 | conditioning['_' + cond_name] = np.array([cond_vals], dtype=float) 619 | else: 620 | conditioning['_' + cond_name] = np.array([cond_vals], dtype=int) 621 | cond_mask = np.ones((1, len(conditioning_layer_list))) 622 | for i, condition_layer in enumerate(conditioning_layer_list): 623 | # remove not provided conditional information from mask and provide dummy input 624 | if f'_{condition_layer}' not in conditioning: 625 | conditioning[f'_{condition_layer}'] = np.array([[0]], dtype=float) 626 | cond_mask[0, i] = 0 627 | conditioning[f'_cond_mask'] = cond_mask 628 | 629 | # generate 630 | generated = {} 631 | left = args.amount_gen 632 | done = 0 633 | start_time = time.time() 634 | with torch.no_grad(): 635 | while left > 0: 636 | if left - chunk_size < 0: 637 | batch = left 638 | else: 639 | batch = chunk_size 640 | update_dict(generated, 641 | generate_molecules( 642 | batch, 643 | model, 644 | all_types=all_types, 645 | start_token=start_token, 646 | max_length=args.max_length, 647 | save_unfinished=args.store_unfinished, 648 | device=device, 649 | max_dist=train_args.max_distance, 650 | n_bins=train_args.num_distance_bins, 651 | radial_limits=dataclass.radial_limits, 652 | t=args.temperature, 653 | conditioning=conditioning, 654 | store_process=args.store_process 655 | ) 656 | ) 657 | left -= batch 658 | done += batch 659 | n = np.sum(get_dict_count(generated, args.max_length)) 660 | progress(n, thresh) 661 | thresh = increase(n, thresh) 662 | print('') 663 | end_time = time.time() - start_time 664 | m, s = divmod(end_time, 60) 665 | h, m = divmod(m, 60) 666 | h, m, s = int(h), int(m), int(s) 667 | print(f'Time consumed: {h:d}:{m:02d}:{s:02d}') 668 | 669 | # sort keys in resulting dictionary 670 | generated = dict(sorted(generated.items())) 671 | 672 | # show generated molecules and print some statistics if desired 673 | if args.show_gen: 674 | ats = [] 675 | n_total_atoms = 0 676 | n_molecules = 0 677 | for key in generated: 678 | n = 0 679 | for i in range(len(generated[key][Properties.Z])): 680 | at = Atoms(generated[key][Properties.Z][i], 681 | positions=generated[key][Properties.R][i]) 682 | ats += [at] 683 | n += 1 684 | n_molecules += 1 685 | n_total_atoms += n * key 686 | asv.view(ats) 687 | print(f'Total number of atoms placed: {n_total_atoms} ' 688 | f'(avg {n_total_atoms / n_molecules:.2f})', flush=True) 689 | 690 | return generated 691 | 692 | 693 | def prepare_conditioning(conditioning_specification): 694 | if conditioning_specification is None: 695 | return [], {}, [] 696 | load_additionally = [] 697 | layer_list = [] 698 | conditioning_extractors = {} 699 | for block in conditioning_specification: 700 | layers = conditioning_specification[block]['layers'] 701 | replace_dict_argument_recursively('args.features', args.features, layers) 702 | for layer_name in layers: 703 | layer_list += [layer_name] 704 | if 'p' in layers[layer_name]: 705 | p = layers[layer_name]['p'] 706 | else: 707 | p = 1.0 708 | if layer_name == 'composition': 709 | all_types = [6, 7, 8, 9] 710 | if 'skip_h' in layers['composition']: 711 | if not layers['composition']['skip_h']: 712 | all_types = [1] + all_types 713 | conditioning_extractors.update( 714 | {'_composition': lambda x, prob=p: 715 | (torch.FloatTensor(get_composition(x, all_types)), prob)} 716 | ) 717 | else: 718 | load_additionally += [layer_name] 719 | conditioning_extractors.update( 720 | {'_' + layer_name: 721 | lambda x, prob=p, n=layer_name: (x.pop(n), prob)}) 722 | 723 | return load_additionally, conditioning_extractors, layer_list 724 | 725 | 726 | def get_conditioning_nn(name, dataclass): 727 | layers = None 728 | if name == 'fingerprint': 729 | layers = FingerprintEmbedding 730 | elif name == 'composition': 731 | layers = lambda embedding, **kwargs: \ 732 | AtomCompositionEmbedding(nn.Embedding(**embedding), **kwargs) 733 | elif name in dataclass.properties: 734 | layers = PropertyEmbedding 735 | return layers 736 | 737 | 738 | def replace_dict_argument_recursively(argument, value, d): 739 | for key in d: 740 | if isinstance(d[key], dict): 741 | replace_dict_argument_recursively(argument, value, d[key]) 742 | elif d[key] == argument: 743 | d[key] = value 744 | 745 | 746 | def main(args): 747 | # set device (cpu or gpu) 748 | device = torch.device('cuda' if args.cuda else 'cpu') 749 | 750 | # store (or load) arguments 751 | argparse_dict = vars(args) 752 | jsonpath = os.path.join(args.modelpath, 'args.json') 753 | 754 | if args.mode == 'train': 755 | # overwrite existing model if desired 756 | if args.overwrite and os.path.exists(args.modelpath): 757 | rmtree(args.modelpath) 758 | logging.info('existing model will be overwritten...') 759 | 760 | # create model directory if it does not exist 761 | if not os.path.exists(args.modelpath): 762 | os.makedirs(args.modelpath) 763 | 764 | # get latest checkpoint of pre-trained model if a path was provided 765 | if args.pretrained_path is not None: 766 | model_chkpt_path = os.path.join(args.modelpath, 'checkpoints') 767 | pretrained_chkpt_path = os.path.join(args.pretrained_path, 'checkpoints') 768 | if os.path.exists(model_chkpt_path) \ 769 | and len(os.listdir(model_chkpt_path)) > 0: 770 | logging.info(f'found existing checkpoints in model directory ' 771 | f'({model_chkpt_path}), please use --overwrite or choose ' 772 | f'empty model directory to start from a pre-trained ' 773 | f'model...') 774 | logging.warning(f'will ignore pre-trained model and start from latest ' 775 | f'checkpoint at {model_chkpt_path}...') 776 | args.pretrained_path = None 777 | else: 778 | logging.info(f'fetching latest checkpoint from pre-trained model at ' 779 | f'{pretrained_chkpt_path}...') 780 | if not os.path.exists(pretrained_chkpt_path): 781 | logging.warning(f'did not find checkpoints of pre-trained model, ' 782 | f'will train from scratch...') 783 | args.pretrained_path = None 784 | else: 785 | chkpt_files = [f for f in os.listdir(pretrained_chkpt_path) 786 | if f.startswith("checkpoint")] 787 | if len(chkpt_files) == 0: 788 | logging.warning(f'did not find checkpoints of pre-trained ' 789 | f'model, will train from scratch...') 790 | args.pretrained_path = None 791 | else: 792 | epoch = max([int(f.split(".")[0].split("-")[-1]) 793 | for f in chkpt_files]) 794 | chkpt = os.path.join(pretrained_chkpt_path, 795 | "checkpoint-" + str(epoch) + ".pth.tar") 796 | if not os.path.exists(model_chkpt_path): 797 | os.makedirs(model_chkpt_path) 798 | copyfile(chkpt, os.path.join(model_chkpt_path, 799 | f'checkpoint-{epoch}.pth.tar')) 800 | 801 | # store arguments for training in model directory 802 | to_json(jsonpath, argparse_dict) 803 | train_args = args 804 | 805 | # set seed 806 | spk.utils.set_random_seed(args.seed) 807 | else: 808 | # load arguments used for training from model directory 809 | train_args = read_from_json(jsonpath) 810 | 811 | # read in conditioning layers from file and set arguments accordingly 812 | conditioning_specification = None 813 | conditioning_json_path = None 814 | conditioning_specification_path = \ 815 | os.path.join(args.modelpath, 'conditioning_specification.json') 816 | # look for existing specification in model directory 817 | if os.path.isfile(conditioning_specification_path): 818 | conditioning_json_path = conditioning_specification_path 819 | # else look for specification at path provided with the training arguments 820 | elif args.mode == 'train' and args.conditioning_json_path is not None: 821 | if os.path.isfile(args.conditioning_json_path): 822 | conditioning_json_path = args.conditioning_json_path 823 | else: 824 | logging.error(f'The provided conditioning specification file ' 825 | f'({args.conditioning_json_path}) does not exist!') 826 | raise FileNotFoundError 827 | # read file 828 | if conditioning_json_path is not None: 829 | with open(conditioning_json_path) as handle: 830 | conditioning_specification = json.loads(handle.read()) 831 | # process information 832 | load_additionally, conditioning_extractors, cond_layer_list = \ 833 | prepare_conditioning(conditioning_specification) 834 | # store specification of conditioning layers in model folder 835 | if conditioning_specification is not None and \ 836 | conditioning_json_path != conditioning_specification_path: 837 | to_json(conditioning_specification_path, conditioning_specification) 838 | 839 | # load data for training/evaluation 840 | if args.mode in ['train', 'eval']: 841 | # find correct data class 842 | assert train_args.dataset_name in dataset_name_to_class_mapping, \ 843 | f'Could not find data class for dataset {train_args.dataset}. Please ' \ 844 | f'specify a correct dataset name!' 845 | dataclass = dataset_name_to_class_mapping[train_args.dataset_name] 846 | 847 | # load the dataset 848 | logging.info(f'{train_args.dataset_name} will be loaded...') 849 | subset = None 850 | if train_args.subset_path is not None: 851 | logging.info(f'Using subset from {train_args.subset_path}') 852 | subset = np.load(train_args.subset_path) 853 | subset = [int(i) for i in subset] 854 | if issubclass(dataclass, DownloadableAtomsData): 855 | data = dataclass(args.datapath, 856 | subset=subset, 857 | precompute_distances=args.precompute_distances, 858 | download=True if args.mode == 'train' else False, 859 | load_additionally=load_additionally) 860 | else: 861 | data = dataclass(args.datapath, 862 | subset=subset, 863 | precompute_distances=args.precompute_distances, 864 | load_additionally=load_additionally) 865 | 866 | # splits the dataset in test, val, train sets 867 | split_path = os.path.join(args.modelpath, 'split.npz') 868 | if args.mode == 'train': 869 | if args.split_path is not None: 870 | copyfile(args.split_path, split_path) 871 | 872 | logging.info('create splits...') 873 | data_train, data_val, data_test = data.create_splits(*train_args.split, 874 | split_file=split_path) 875 | 876 | logging.info('load data...') 877 | types = sorted(dataclass.available_atom_types) 878 | max_type = types[-1] 879 | # set up collate function according to args 880 | collate = lambda x: \ 881 | collate_atoms(x, 882 | all_types=types + [max_type+1], 883 | start_token=max_type+2, 884 | draw_samples=args.draw_random_samples, 885 | label_width_scaling=train_args.label_width_factor, 886 | max_dist=train_args.max_distance, 887 | n_bins=train_args.num_distance_bins, 888 | conditioning_extractors=conditioning_extractors) 889 | 890 | train_loader = spk.data.AtomsLoader(data_train, batch_size=args.batch_size, 891 | sampler=RandomSampler(data_train), 892 | num_workers=4, pin_memory=True, 893 | collate_fn=collate) 894 | val_loader = spk.data.AtomsLoader(data_val, batch_size=args.batch_size, 895 | num_workers=2, pin_memory=True, 896 | collate_fn=collate) 897 | 898 | # construct the model 899 | if args.mode == 'train' or args.checkpoint >= 0: 900 | model = get_model(train_args, conditioning_specification, 901 | parallelize=args.parallel).to(device) 902 | logging.info(f'running on {device}') 903 | 904 | # load model or checkpoint for evaluation or generation 905 | if args.mode in ['eval', 'generate']: 906 | if args.checkpoint < 0: # load best model 907 | logging.info(f'restoring best model') 908 | model = torch.load(os.path.join(args.modelpath, 'best_model')).to(device) 909 | else: 910 | logging.info(f'restoring checkpoint {args.checkpoint}') 911 | chkpt = os.path.join(args.modelpath, 'checkpoints', 912 | 'checkpoint-' + str(args.checkpoint) + '.pth.tar') 913 | state_dict = torch.load(chkpt) 914 | model.load_state_dict(state_dict['model'], strict=True) 915 | 916 | # execute training, evaluation, or generation 917 | if args.mode == 'train': 918 | logging.info("training...") 919 | train(args, model, train_loader, val_loader, device) 920 | logging.info("...training done!") 921 | 922 | elif args.mode == 'eval': 923 | logging.info("evaluating...") 924 | test_loader = spk.data.AtomsLoader(data_test, 925 | batch_size=args.batch_size, 926 | num_workers=2, 927 | pin_memory=True, 928 | collate_fn=collate) 929 | with torch.no_grad(): 930 | evaluate(args, model, train_loader, val_loader, test_loader, device) 931 | logging.info("... done!") 932 | 933 | elif args.mode == 'generate': 934 | logging.info(f'generating {args.amount_gen} molecules...') 935 | generated = generate(args, train_args, model, device, cond_layer_list) 936 | gen_path = os.path.join(args.modelpath, f'generated{args.folder_name}/') 937 | if not os.path.exists(gen_path): 938 | os.makedirs(gen_path) 939 | # get untaken filename and store results 940 | file_name = os.path.join(gen_path, args.file_name) 941 | if os.path.isfile(file_name + '.mol_dict'): 942 | expand = 0 943 | while True: 944 | expand += 1 945 | new_file_name = file_name + '_' + str(expand) 946 | if os.path.isfile(new_file_name + '.mol_dict'): 947 | continue 948 | else: 949 | file_name = new_file_name 950 | break 951 | with open(file_name + '.mol_dict', 'wb') as f: 952 | pickle.dump(generated, f) 953 | logging.info('...done!') 954 | else: 955 | logging.info(f'Unknown mode: {args.mode}') 956 | 957 | 958 | if __name__ == '__main__': 959 | parser = get_parser() 960 | args = parser.parse_args() 961 | main(args) 962 | -------------------------------------------------------------------------------- /images/concept_results_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/images/concept_results_scheme.png -------------------------------------------------------------------------------- /nn_classes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from collections import Iterable 6 | 7 | import schnetpack as spk 8 | from schnetpack.nn import MLP 9 | from schnetpack.metrics import Metric 10 | 11 | 12 | ### OUTPUT MODULE ### 13 | class AtomwiseWithProcessing(nn.Module): 14 | r""" 15 | Atom-wise dense layers that allow to use additional pre- and post-processing layers. 16 | 17 | Args: 18 | n_in (int, optional): input dimension of representation (default: 128) 19 | n_out (int, optional): output dimension (default: 1) 20 | n_layers (int, optional): number of atom-wise dense layers in output network 21 | (default: 5) 22 | n_neurons (list of int or int or None, optional): number of neurons in each 23 | layer of the output network. If a single int is provided, all layers will 24 | have that number of neurons, if `None`, interpolate linearly between n_in 25 | and n_out (default: None). 26 | activation (function, optional): activation function for hidden layers 27 | (default: spk.nn.activations.shifted_softplus). 28 | preprocess_layers (nn.Module, optional): a torch.nn.Module or list of Modules 29 | for preprocessing the representation given by the first part of the network 30 | (default: None). 31 | postprocess_layers (nn.Module, optional): a torch.nn.Module or list of Modules 32 | for postprocessing the output given by the second part of the network 33 | (default: None). 34 | in_key (str, optional): keyword to access the representation in the inputs 35 | dictionary, it is automatically inferred from the preprocessing layers, if 36 | at least one is given (default: 'representation'). 37 | out_key (str, optional): a string as key to the output dictionary (if set to 38 | 'None', the output will not be wrapped into a dictionary, default: 'y') 39 | 40 | Returns: 41 | result: dictionary with predictions stored in result[out_key] 42 | """ 43 | 44 | def __init__(self, n_in=128, n_out=1, n_layers=5, n_neurons=None, 45 | activation=spk.nn.activations.shifted_softplus, 46 | preprocess_layers=None, postprocess_layers=None, 47 | in_key='representation', out_key='y'): 48 | 49 | super(AtomwiseWithProcessing, self).__init__() 50 | 51 | self.n_in = n_in 52 | self.n_out = n_out 53 | self.n_layers = n_layers 54 | self.in_key = in_key 55 | self.out_key = out_key 56 | 57 | if isinstance(preprocess_layers, Iterable): 58 | self.preprocess_layers = nn.ModuleList(preprocess_layers) 59 | self.in_key = self.preprocess_layers[-1].out_key 60 | elif preprocess_layers is not None: 61 | self.preprocess_layers = preprocess_layers 62 | self.in_key = self.preprocess_layers.out_key 63 | else: 64 | self.preprocess_layers = None 65 | 66 | if isinstance(postprocess_layers, Iterable): 67 | self.postprocess_layers = nn.ModuleList(postprocess_layers) 68 | else: 69 | self.postprocess_layers = postprocess_layers 70 | 71 | if n_neurons is None: 72 | # linearly interpolate between n_in and n_out 73 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1]) 74 | self.out_net = MLP(n_in, n_out, n_neurons, n_layers, activation) 75 | 76 | self.derivative = None # don't compute derivative w.r.t. inputs 77 | 78 | def forward(self, inputs): 79 | """ 80 | Compute layer output and apply pre-/postprocessing if specified. 81 | 82 | Args: 83 | inputs (dict of torch.Tensor): batch of input values. 84 | Returns: 85 | torch.Tensor: layer output. 86 | """ 87 | # apply pre-processing layers 88 | if self.preprocess_layers is not None: 89 | if isinstance(self.preprocess_layers, Iterable): 90 | for pre_layer in self.preprocess_layers: 91 | inputs = pre_layer(inputs) 92 | else: 93 | inputs = self.preprocess_layers(inputs) 94 | 95 | # get (pre-processed) representation 96 | if isinstance(inputs[self.in_key], tuple): 97 | repr = inputs[self.in_key][0] 98 | else: 99 | repr = inputs[self.in_key] 100 | 101 | # apply output network 102 | result = self.out_net(repr) 103 | 104 | # apply post-processing layers 105 | if self.postprocess_layers is not None: 106 | if isinstance(self.postprocess_layers, Iterable): 107 | for post_layer in self.postprocess_layers: 108 | result = post_layer(inputs, result) 109 | else: 110 | result = self.postprocess_layers(inputs, result) 111 | 112 | # use provided key to store result 113 | if self.out_key is not None: 114 | result = {self.out_key: result} 115 | 116 | return result 117 | 118 | 119 | class RepresentationConditioning(nn.Module): 120 | r""" 121 | Layer that allows to alter the extracted feature representations in order to 122 | condition generation. Takes multiple networks that provide conditioning 123 | information as vectors, stacks these vectors and processes them in a fully 124 | connected MLP to get a global conditioning vector that is incorporated into 125 | the extracted feature representation. 126 | 127 | Args: 128 | layers (nn.Module): a torch.nn.Module or list of Modules that each provide a 129 | vector representing information for conditioning. 130 | mode (str, optional): how to incorporate the global conditioning vector in 131 | the extracted feature representation (can either be 'multiplication', 132 | 'addition', or 'stack', default: 'stack'). 133 | n_global_cond_features (int, optional): number of features in the global 134 | conditioning vector (i.e. output dimension for the MLP used to aggregate 135 | the stacked separate conditioning vectors). 136 | n_layers (int, optional): number of dense layers in the MLP used to get the 137 | global conditioning vector (default: 5). 138 | n_neurons (list of int or int or None, optional): number of neurons in each 139 | layer of the MLP. If a single int is provided, all layers will have that 140 | number of neurons, if `None`, interpolate linearly between n_in and n_out 141 | (default: None). 142 | activation (function, optional): activation function for hidden layers in the 143 | aggregation MLP (default: spk.nn.activations.shifted_softplus). 144 | in_key (str, optional): keyword to access the representation in the inputs 145 | dictionary, it is automatically inferred from the preprocessing layers, if 146 | at least one is given (default: 'representation'). 147 | out_key (str, optional): a string as key to the output dictionary (if set to 148 | 'None', the output will not be wrapped into a dictionary, default: 149 | 'representation') 150 | 151 | Returns: 152 | result: dictionary with predictions stored in result[out_key] 153 | """ 154 | 155 | def __init__(self, 156 | layers, 157 | mode='stack', 158 | n_global_cond_features=128, 159 | n_layers=5, 160 | n_neurons=None, 161 | activation=spk.nn.activations.shifted_softplus, 162 | in_key='representation', 163 | out_key='representation'): 164 | 165 | super(RepresentationConditioning, self).__init__() 166 | 167 | if type(layers) not in [list, nn.ModuleList]: 168 | layers = [layers] 169 | if type(layers) == list: 170 | layers = nn.ModuleList(layers) 171 | self.layers = layers 172 | self.mode = mode 173 | self.in_key = in_key 174 | self.out_key = out_key 175 | self.n_global_cond_features = n_global_cond_features 176 | 177 | self.derivative = None # don't compute derivative w.r.t. inputs 178 | 179 | # set number of additional features 180 | self.n_additional_features = 0 181 | if self.mode == 'stack': 182 | self.n_additional_features = self.n_global_cond_features 183 | 184 | # compute number of inputs to the MLP processing stacked conditioning vectors 185 | n_in = 0 186 | for layer in self.layers: 187 | n_in += layer.n_out 188 | n_out = n_global_cond_features 189 | 190 | # initialize MLP processing stacked conditioning vectors 191 | if n_neurons is None: 192 | # linearly interpolate between n_in and n_out 193 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1]) 194 | self.cond_mlp = MLP(n_in, n_out, n_neurons, n_layers, activation) 195 | 196 | def forward(self, inputs): 197 | """ 198 | Update representation in the inputs according to conditioning information and 199 | return empty dictionary since no proper network output is computed in this 200 | module. 201 | 202 | Args: 203 | inputs (dict of torch.Tensor): batch of input values. 204 | Returns: 205 | dict: An empty dictionary. 206 | """ 207 | # get (pre-processed) representation 208 | if isinstance(inputs[self.in_key], tuple): 209 | repr = inputs[self.in_key][0] 210 | else: 211 | repr = inputs[self.in_key] 212 | 213 | # get mask that (potentially) hides conditional information 214 | _size = [1, len(self.layers)] + [1 for _ in repr.size()[1:]] 215 | if '_cond_mask' in inputs: 216 | cond_mask = inputs['_cond_mask'] 217 | cond_mask = cond_mask.reshape([cond_mask.shape[0]] + _size[1:]) 218 | else: 219 | cond_mask = torch.ones(_size, dtype=repr.dtype, device=repr.device) 220 | 221 | # get conditioning information vectors from layers and include them in 222 | # representation 223 | cond_vecs = [] 224 | for i, layer in enumerate(self.layers): 225 | cond_vecs += [cond_mask[:, i] * layer(inputs)] 226 | 227 | cond_vecs = torch.cat(cond_vecs, dim=-1) 228 | final_cond_vec = self.cond_mlp(cond_vecs) 229 | 230 | if self.mode == 'addition': 231 | repr = repr + final_cond_vec 232 | elif self.mode == 'multiplication': 233 | repr = repr * final_cond_vec 234 | elif self.mode == 'stack': 235 | repr = torch.cat([repr, final_cond_vec.expand(*repr.size()[:-1], -1)], -1) 236 | 237 | inputs.update({self.out_key: repr}) 238 | 239 | return {} 240 | 241 | 242 | ### METRICS ### 243 | class KLDivergence(Metric): 244 | r""" 245 | Metric for mean KL-Divergence. 246 | 247 | Args: 248 | target (str, optional): name of target property (default: '_labels') 249 | model_output (list of int or list of str, optional): indices or keys to unpack 250 | the desired output from the model in case of multiple outputs, e.g. 251 | ['x', 'y'] to get output['x']['y'] (default: 'y'). 252 | name (str, optional): name used in logging for this metric. If set to `None`, 253 | `KLD_[target]` will be used (default: None). 254 | mask (str, optional): key for a mask in the examined batch which hides 255 | irrelevant output values. If 'None' is provided, no mask will be applied 256 | (default: None). 257 | inverse_mask (bool, optional): whether the mask needs to be inverted prior to 258 | application (default: False). 259 | """ 260 | 261 | def __init__(self, target='_labels', model_output='y', name=None, 262 | mask=None, inverse_mask=False): 263 | name = 'KLD_' + target if name is None else name 264 | super(KLDivergence, self).__init__(name) 265 | self.target = target 266 | self.model_output = model_output 267 | self.loss = 0. 268 | self.n_entries = 0. 269 | self.mask_str = mask 270 | self.inverse_mask = inverse_mask 271 | 272 | def reset(self): 273 | self.loss = 0. 274 | self.n_entries = 0. 275 | 276 | def add_batch(self, batch, result): 277 | # extract true labels 278 | y = batch[self.target] 279 | 280 | # extract predictions 281 | yp = result 282 | if self.model_output is not None: 283 | if isinstance(self.model_output, list): 284 | for key in self.model_output: 285 | yp = yp[key] 286 | else: 287 | yp = yp[self.model_output] 288 | 289 | # normalize output 290 | log_yp = F.log_softmax(yp, -1) 291 | 292 | # apply KL divergence formula entry-wise 293 | loss = F.kl_div(log_yp, y, reduction='none') 294 | 295 | # sum over last dimension to get KL divergence per distribution 296 | loss = torch.sum(loss, -1) 297 | 298 | # apply mask to filter padded dimensions 299 | if self.mask_str is not None: 300 | atom_mask = batch[self.mask_str] 301 | if self.inverse_mask: 302 | atom_mask = 1.-atom_mask 303 | loss = torch.where(atom_mask > 0, loss, torch.zeros_like(loss)) 304 | n_entries = torch.sum(atom_mask > 0) 305 | else: 306 | n_entries = torch.prod(torch.tensor(loss.size())) 307 | 308 | # calculate loss and n_entries 309 | self.n_entries += n_entries.detach().cpu().data.numpy() 310 | self.loss += torch.sum(loss).detach().cpu().data.numpy() 311 | 312 | def aggregate(self): 313 | return self.loss / max(self.n_entries, 1.) 314 | 315 | 316 | ### PRE- AND POST-PROCESSING LAYERS ### 317 | class EmbeddingMultiplication(nn.Module): 318 | r""" 319 | Layer that multiplies embeddings of given types with the representation. 320 | 321 | Args: 322 | embedding (torch.nn.Embedding instance): the embedding layer used to embed atom 323 | types. 324 | in_key_types (str, optional): the keyword to obtain types for embedding from 325 | inputs. 326 | in_key_representation (str, optional): the keyword to obtain the representation 327 | from inputs. 328 | out_key (str, optional): the keyword used to store the calculated product in 329 | the inputs dictionary. 330 | """ 331 | 332 | def __init__(self, embedding, in_key_types='_next_types', 333 | in_key_representation='representation', 334 | out_key='preprocessed_representation'): 335 | super(EmbeddingMultiplication, self).__init__() 336 | self.embedding = embedding 337 | self.in_key_types = in_key_types 338 | self.in_key_representation = in_key_representation 339 | self.out_key = out_key 340 | 341 | def forward(self, inputs): 342 | """ 343 | Compute layer output. 344 | 345 | Args: 346 | inputs (dict of torch.Tensor): batch of input values containing the atomic 347 | numbers for embedding as well as the representation. 348 | Returns: 349 | torch.Tensor: layer output. 350 | """ 351 | # get types to embed from inputs 352 | types = inputs[self.in_key_types] 353 | st = types.size() 354 | 355 | # embed types 356 | if len(st) == 1: 357 | emb = self.embedding(types.view(st[0], 1)) 358 | elif len(st) == 2: 359 | emb = self.embedding(types.view(*st[:-1], 1, st[-1])) 360 | 361 | # get representation 362 | if isinstance(inputs[self.in_key_representation], tuple): 363 | repr = inputs[self.in_key_representation][0] 364 | else: 365 | repr = inputs[self.in_key_representation] 366 | if len(st) == 2: 367 | # if multiple types are provided per molecule, expand 368 | # dimensionality of representation 369 | repr = repr.view(*repr.size()[:-1], 1, repr.size()[-1]) 370 | 371 | # if representation is larger than the embedding, pad embedding with ones 372 | if repr.size()[-1] != emb.size()[-1]: 373 | _emb = torch.ones([*emb.size()[:-1], repr.size()[-1]], device=emb.device) 374 | _emb[..., :emb.size()[-1]] = emb 375 | emb = _emb 376 | 377 | # multiply embedded types with representation 378 | features = repr * emb 379 | 380 | # store result in input dictionary 381 | inputs.update({self.out_key: features}) 382 | 383 | return inputs 384 | 385 | 386 | class NormalizeAndAggregate(nn.Module): 387 | r""" 388 | Layer that normalizes and aggregates given input along specifiable axes. 389 | 390 | Args: 391 | normalize (bool, optional): set True to normalize the input (default: True). 392 | normalization_axis (int, optional): axis along which normalization is applied 393 | (default: -1). 394 | normalization_mode (str, optional): which normalization to apply (currently 395 | only 'logsoftmax' is supported, default: 'logsoftmax'). 396 | aggregate (bool, optional): set True to aggregate the input (default: True). 397 | aggregation_axis (int, optional): axis along which aggregation is applied 398 | (default: -1). 399 | aggregation_mode (str, optional): which aggregation to apply (currently 'sum' 400 | and 'mean' are supported, default: 'sum'). 401 | keepdim (bool, optional): set True to keep the number of dimensions after 402 | aggregation (default: True). 403 | in_key_mask (str, optional): key to extract a mask from the inputs dictionary, 404 | which hides values during aggregation (default: None). 405 | squeeze (bool, optional): whether to squeeze the input before applying 406 | normalization (default: False). 407 | 408 | Returns: 409 | torch.Tensor: input after normalization and aggregation along specified axes. 410 | """ 411 | 412 | def __init__(self, normalize=True, normalization_axis=-1, 413 | normalization_mode='logsoftmax', aggregate=True, 414 | aggregation_axis=-1, aggregation_mode='sum', keepdim=True, 415 | mask=None, squeeze=False): 416 | 417 | super(NormalizeAndAggregate, self).__init__() 418 | 419 | if normalize: 420 | if normalization_mode.lower() == 'logsoftmax': 421 | self.normalization = nn.LogSoftmax(normalization_axis) 422 | else: 423 | self.normalization = None 424 | 425 | if aggregate: 426 | if aggregation_mode.lower() == 'sum': 427 | self.aggregation =\ 428 | spk.nn.base.Aggregate(aggregation_axis, mean=False, 429 | keepdim=keepdim) 430 | elif aggregation_mode.lower() == 'mean': 431 | self.aggregation =\ 432 | spk.nn.base.Aggregate(aggregation_axis, mean=True, 433 | keepdim=keepdim) 434 | else: 435 | self.aggregation = None 436 | 437 | self.mask = mask 438 | self.squeeze = squeeze 439 | 440 | def forward(self, inputs, result): 441 | """ 442 | Compute layer output. 443 | 444 | Args: 445 | inputs (dict of torch.Tensor): batch of input values containing the mask 446 | result (torch.Tensor): batch of result values to which normalization and 447 | aggregation is applied 448 | Returns: 449 | torch.Tensor: normalized and aggregated result. 450 | """ 451 | 452 | res = result 453 | 454 | if self.squeeze: 455 | res = torch.squeeze(res) 456 | 457 | if self.normalization is not None: 458 | res = self.normalization(res) 459 | 460 | if self.aggregation is not None: 461 | if self.mask is not None: 462 | mask = inputs[self.mask] 463 | else: 464 | mask = None 465 | res = self.aggregation(res, mask) 466 | 467 | return res 468 | 469 | 470 | class AtomCompositionEmbedding(nn.Module): 471 | r""" 472 | Layer that embeds all atom types in a molecule and aggregates them into a single 473 | representation of the composition using a fully connected MLP. 474 | 475 | Args: 476 | embedding (torch.nn.Embedding instance): an embedding layer used to embed atom 477 | types separately. 478 | n_out (int, optional): number of features in the final, global embedding (i.e. 479 | output dimension for the MLP used to aggregate the separate, stacked atom 480 | type embeddings). 481 | n_layers (int, optional): number of dense layers used to get the global 482 | embedding (default: 5). 483 | n_neurons (list of int or int or None, optional): number of neurons in each 484 | layer of the aggregation MLP. If a single int is provided, all layers will 485 | have that number of neurons, if `None`, interpolate linearly between n_in 486 | and n_out (default: None). 487 | activation (function, optional): activation function for hidden layers in the 488 | aggregation MLP (default: spk.nn.activations.shifted_softplus). 489 | type_weighting (str, optional): how to weight the individual atom type 490 | embeddings (choose from 'absolute' to multiply each embedding with the 491 | absolute number of atoms of that type, 'relative' to multiply with the 492 | fraction of atoms of that type, and 'existence' to multiply with one if the 493 | type is present in the composition and zero otherwise, default: 'absolute') 494 | in_key_composition (str, optional): the keyword to obtain the global 495 | composition of molecules (i.e. a list of all atom types, default: 496 | 'composition'). 497 | n_types (int, optional): total number of available atom types (default: 5). 498 | """ 499 | 500 | def __init__(self, 501 | embedding, 502 | n_out=128, 503 | n_layers=5, 504 | n_neurons=None, 505 | activation=spk.nn.activations.shifted_softplus, 506 | type_weighting='exact', 507 | in_key_composition='composition', 508 | n_types=5, 509 | skip_h=True): 510 | 511 | super(AtomCompositionEmbedding, self).__init__() 512 | 513 | self.embedding = embedding 514 | self.in_key_composition = in_key_composition 515 | self.type_weighting = type_weighting 516 | self.n_types = n_types 517 | self.skip_h = skip_h 518 | if self.skip_h: 519 | self.n_types -= 1 520 | self.n_out = n_out 521 | 522 | # compute number of features in stacked embeddings 523 | n_in = self.n_types * self.embedding.embedding_dim 524 | 525 | if n_neurons is None: 526 | # linearly interpolate between n_in and n_out 527 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1]) 528 | self.aggregation_mlp = MLP(n_in, n_out, n_neurons, n_layers, activation) 529 | 530 | def forward(self, inputs): 531 | """ 532 | Compute layer output. 533 | 534 | Args: 535 | inputs (dict of torch.Tensor): batch of input values containing the atomic 536 | numbers for embedding as well as the representation. 537 | Returns: 538 | torch.Tensor: batch of vectors representing the global composition of 539 | each molecule. 540 | """ 541 | # get composition to embed from inputs 542 | compositions = inputs[self.in_key_composition][..., None] 543 | if self.skip_h: 544 | embeded_types = self.embedding(inputs['_all_types'][0, 1:-1])[None, ...] 545 | else: 546 | embeded_types = self.embedding(inputs['_all_types'][0, :-1])[None, ...] 547 | 548 | # get global representation 549 | if self.type_weighting == 'relative': 550 | compositions = compositions/torch.sum(compositions, dim=-2, keepdim=True) 551 | elif self.type_weighting == 'existence': 552 | compositions = (compositions > 0).float() 553 | 554 | # multiply embedding with (weighted) composition 555 | embedding = embeded_types * compositions 556 | 557 | # aggregate embeddings to global representation 558 | sizes = embedding.size() 559 | embedding = embedding.view([*sizes[:-2], 1, sizes[-2]*sizes[-1]]) # stack 560 | embedding = self.aggregation_mlp(embedding) # aggregate 561 | 562 | return embedding 563 | 564 | 565 | class FingerprintEmbedding(nn.Module): 566 | r""" 567 | Layers that map the fingerprint of a molecule to a feature vector used for 568 | conditioning. 569 | 570 | Args: 571 | n_in (int): number of inputs (bits in the fingerprint). 572 | n_out (str): number of features in the embedding. 573 | n_layers (int, optional): number of dense layers used to embed the fingerprint 574 | (default: 5). 575 | n_neurons (list of int or int or None, optional): number of neurons in each 576 | layer of the output network. If a single int is provided, all layers will 577 | have that number of neurons, if `None`, interpolate linearly between n_in 578 | and n_out (default: None). 579 | in_key_fingerprint (str, optional): the keyword to obtain the fingerprint 580 | (default: 'fingerprint'). 581 | activation (function, optional): activation function for hidden layers 582 | (default: spk.nn.activations.shifted_softplus). 583 | """ 584 | 585 | def __init__(self, n_in, n_out, n_layers=5, n_neurons=None, 586 | in_key_fingerprint='fingerprint', 587 | activation=spk.nn.activations.shifted_softplus): 588 | 589 | super(FingerprintEmbedding, self).__init__() 590 | 591 | self.in_key_fingerprint = in_key_fingerprint 592 | self.n_in = n_in 593 | self.n_out = n_out 594 | 595 | if n_neurons is None: 596 | # linearly interpolate between n_in and n_out 597 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1]) 598 | self.out_net = MLP(n_in, n_out, n_neurons, n_layers, activation) 599 | 600 | def forward(self, inputs): 601 | """ 602 | Compute layer output. 603 | 604 | Args: 605 | inputs (dict of torch.Tensor): batch of input values containing the 606 | fingerprints. 607 | Returns: 608 | torch.Tensor: batch of vectors representing the fingerprint of each 609 | molecule. 610 | """ 611 | fingerprints = inputs[self.in_key_fingerprint] 612 | 613 | return self.out_net(fingerprints)[:, None, :] 614 | 615 | 616 | class PropertyEmbedding(nn.Module): 617 | r""" 618 | Layers that map the property (e.g. HOMO-LUMO gap, electronic spatial extent etc.) 619 | of a molecule to a feature vector used for conditioning. Properties are first 620 | expanded using Gaussian basis functions before being processed by a fully 621 | connected MLP. 622 | 623 | Args: 624 | n_in (int): number of inputs (Gaussians used for expansion of the property). 625 | n_out (int): number of features in the embedding. 626 | in_key_property (str): the keyword to obtain the property. 627 | start (float): center of first Gaussian function, :math:`\mu_0` for expansion. 628 | stop (float): center of last Gaussian function, :math:`\mu_{N_g}` for expansion 629 | (the remaining centers will be placed linearly spaced between start and 630 | stop). 631 | n_layers (int, optional): number of dense layers used to embed the property 632 | (default: 5). 633 | n_neurons (list of int or int or None, optional): number of neurons in each 634 | layer of the output network. If a single int is provided, all layers will 635 | have that number of neurons, if `None`, interpolate linearly between n_in 636 | and n_out (default: None). 637 | activation (function, optional): activation function for hidden layers 638 | (default: spk.nn.activations.shifted_softplus). 639 | trainable_gaussians (bool, optional): if True, widths and offset of Gaussian 640 | functions for expansion are adjusted during training process (default: 641 | False). 642 | widths (float, optional): width value of Gaussian functions for expansion 643 | (provide None to set the width to the distance between two centers 644 | :math:`\mu`, default: None). 645 | """ 646 | 647 | def __init__(self, n_in, n_out, in_key_property, start, stop, n_layers=5, 648 | n_neurons=None, activation=spk.nn.activations.shifted_softplus, 649 | trainable_gaussians=False, width=None, no_expansion=False): 650 | 651 | super(PropertyEmbedding, self).__init__() 652 | 653 | self.in_key_property = in_key_property 654 | self.n_in = n_in 655 | self.n_out = n_out 656 | if not no_expansion: 657 | self.expansion_net = GaussianExpansion(start, stop, self.n_in, 658 | trainable_gaussians, width) 659 | else: 660 | self.expansion_net = None 661 | 662 | if n_neurons is None: 663 | # linearly interpolate between n_in and n_out 664 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1]) 665 | self.out_net = MLP(n_in, n_out, n_neurons, n_layers, activation) 666 | 667 | def forward(self, inputs): 668 | """ 669 | Compute layer output. 670 | 671 | Args: 672 | inputs (dict of torch.Tensor): batch of input values containing the 673 | fingerprints. 674 | Returns: 675 | torch.Tensor: batch of vectors representing the fingerprint of each 676 | molecule. 677 | """ 678 | property = inputs[self.in_key_property] 679 | if self.expansion_net is None: 680 | expanded = property 681 | else: 682 | expanded = self.expansion_net(property) 683 | 684 | return self.out_net(expanded)[:, None, :] 685 | 686 | 687 | ### MISC 688 | class GaussianExpansion(nn.Module): 689 | r"""Expansion layer using a set of Gaussian functions. 690 | 691 | Args: 692 | start (float): center of first Gaussian function, :math:`\mu_0`. 693 | stop (float): center of last Gaussian function, :math:`\mu_{N_g}`. 694 | n_gaussians (int, optional): total number of Gaussian functions, :math:`N_g` 695 | (default: 50). 696 | trainable (bool, optional): if True, widths and offset of Gaussian functions 697 | are adjusted during training process (default: False). 698 | widths (float, optional): width value of Gaussian functions (provide None to 699 | set the width to the distance between two centers :math:`\mu`, default: 700 | None). 701 | 702 | """ 703 | 704 | def __init__(self, start, stop, n_gaussians=50, trainable=False, 705 | width=None): 706 | super(GaussianExpansion, self).__init__() 707 | # compute offset and width of Gaussian functions 708 | offset = torch.linspace(start, stop, n_gaussians) 709 | if width is None: 710 | widths = torch.FloatTensor((offset[1] - offset[0]) * 711 | torch.ones_like(offset)) 712 | else: 713 | widths = torch.FloatTensor(width * torch.ones_like(offset)) 714 | if trainable: 715 | self.widths = nn.Parameter(widths) 716 | self.offsets = nn.Parameter(offset) 717 | else: 718 | self.register_buffer("widths", widths) 719 | self.register_buffer("offsets", offset) 720 | 721 | def forward(self, property): 722 | """Compute expanded gaussian property values. 723 | 724 | Args: 725 | property (torch.Tensor): property values of (N_b x 1) shape. 726 | 727 | Returns: 728 | torch.Tensor: layer output of (N_b x N_g) shape. 729 | 730 | """ 731 | # compute width of Gaussian functions (using an overlap of 1 STDDEV) 732 | coeff = -0.5 / torch.pow(self.widths, 2)[None, :] 733 | # Use advanced indexing to compute the individual components 734 | diff = property - self.offsets[None, :] 735 | # compute expanded property values 736 | return torch.exp(coeff * torch.pow(diff, 2)) 737 | -------------------------------------------------------------------------------- /preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import argparse 3 | import sys 4 | import time 5 | import numpy as np 6 | import logging 7 | from ase.db import connect 8 | from scipy.spatial.distance import pdist 9 | from utility_classes import ConnectivityCompressor, Molecule 10 | from multiprocessing import Process, Queue 11 | from pathlib import Path 12 | 13 | 14 | def get_parser(): 15 | """ Setup parser for command line arguments """ 16 | main_parser = argparse.ArgumentParser() 17 | main_parser.add_argument('datapath', help='Full path to dataset (e.g. ' 18 | '/home/qm9.db)') 19 | main_parser.add_argument('--valence_list', 20 | default=[1, 1, 6, 4, 7, 3, 8, 2, 9, 1], type=int, 21 | nargs='+', 22 | help='The valence of atom types in the form ' 23 | '[type1 valence type2 valence ...] ' 24 | '(default: %(default)s)') 25 | main_parser.add_argument('--n_threads', type=int, default=16, 26 | help='Number of extra threads used while ' 27 | 'processing the data') 28 | main_parser.add_argument('--n_mols_per_thread', type=int, default=100, 29 | help='Number of molecules processed by each ' 30 | 'thread in one iteration') 31 | return main_parser 32 | 33 | 34 | def is_disconnected(connectivity): 35 | ''' 36 | Assess whether all atoms of a molecule are connected using a connectivity matrix 37 | 38 | Args: 39 | connectivity (numpy.ndarray): matrix (n_atoms x n_atoms) indicating bonds 40 | between atoms 41 | 42 | Returns 43 | bool: True if the molecule consists of at least two disconnected graphs, 44 | False if all atoms are connected by some path 45 | ''' 46 | con_mat = connectivity 47 | seen, queue = {0}, collections.deque([0]) # start at node (atom) 0 48 | while queue: 49 | vertex = queue.popleft() 50 | # iterate over (bonded) neighbors of current node 51 | for node in np.argwhere(con_mat[vertex] > 0).flatten(): 52 | # add node to queue and list of seen nodes if it has not been seen before 53 | if node not in seen: 54 | seen.add(node) 55 | queue.append(node) 56 | # if the seen nodes do not include all nodes, there are disconnected parts 57 | return seen != {*range(len(con_mat))} 58 | 59 | 60 | def get_count_statistics(mol=None, get_stat_heads=False): 61 | ''' 62 | Collects atom, bond, and ring count statistics of a provided molecule 63 | 64 | Args: 65 | mol (utility_classes.Molecule): Molecule to be examined 66 | get_stat_heads (bool, optional): set True to only return the headers of 67 | gathered statistics (default: False) 68 | 69 | Returns: 70 | numpy.ndarray: (n_statistics x 1) array containing the gathered statistics. Use 71 | get_stat_heads parameter to obtain the corresponding row headers (where RX 72 | describes number of X-membered rings and CXC indicates the number of 73 | carbon-carbon bonds of order X etc.). 74 | ''' 75 | stat_heads = ['n_atoms', 'C', 'N', 'O', 'F', 'H', 'H1C', 'H1N', 76 | 'H1O', 'C1C', 'C2C', 'C3C', 'C1N', 'C2N', 'C3N', 'C1O', 77 | 'C2O', 'C1F', 'N1N', 'N2N', 'N1O', 'N2O', 'N1F', 'O1O', 78 | 'O1F', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R>8'] 79 | if get_stat_heads: 80 | return stat_heads 81 | if mol is None: 82 | return None 83 | key_idx_dict = dict(zip(stat_heads, range(len(stat_heads)))) 84 | stats = np.zeros((len(stat_heads), 1)) 85 | # process all bonds and store statistics about bond and ring counts 86 | bond_stats = mol.get_bond_stats() 87 | for key, value in bond_stats.items(): 88 | if key in key_idx_dict: 89 | idx = key_idx_dict[key] 90 | stats[idx, 0] = value 91 | # store simple statistics about number of atoms 92 | stats[key_idx_dict['n_atoms'], 0] = mol.n_atoms 93 | for key in ['C', 'N', 'O', 'F', 'H']: 94 | idx = key_idx_dict[key] 95 | charge = mol.type_charges[key] 96 | if charge in mol._unique_numbers: 97 | stats[idx, 0] = np.sum(mol.numbers == charge) 98 | return stats 99 | 100 | 101 | def preprocess_molecules(mol_idcs, source_db, valence, 102 | precompute_distances=True, precompute_fingerprint=False, 103 | remove_invalid=True, invalid_list=None, print_progress=False): 104 | ''' 105 | Checks the validity of selected molecules and collects atom, bond, 106 | and ring count statistics for the valid structures. Molecules are classified as 107 | invalid if they consist of disconnected parts or fail a valence check, where the 108 | valency constraints of all atoms in a molecule have to be satisfied (e.g. carbon 109 | has four bonds, nitrogen has three bonds etc.) 110 | 111 | Args: 112 | mol_idcs (array): the indices of molecules from the source database that 113 | shall be examined 114 | source_db (str): full path to the source database (in ase.db sqlite format) 115 | valence (array): an array where the i-th entry contains the valency 116 | constraint of atoms with atomic charge i (e.g. a valency of 4 at array 117 | position 6 representing carbon) 118 | precompute_distances (bool, optional): if True, the pairwise distances between 119 | atoms in each molecule are computed and stored in the database (default: 120 | True) 121 | precompute_fingerprint (bool, optional): if True, the fingerprint of each 122 | molecule is computed and stored in the database (default: False) 123 | remove_invalid (bool, optional): if True, molecules that do not pass the 124 | valency or connectivity checks (or are on the invalid_list) are removed from 125 | the new database (default: True) 126 | invalid_list (list of int, optional): precomputed list containing indices of 127 | molecules that are marked as invalid (because they did not pass the 128 | valency or connectivity checks in earlier runs, default: None) 129 | print_progress (bool, optional): set True to print the progress in percent 130 | (default: False) 131 | 132 | Returns 133 | list of ase.Atoms: list of all valid molecules 134 | list of dict: list of corresponding dictionaries with data of each molecule 135 | numpy.ndarray: (n_statistics x n_valid_molecules) matrix with atom, bond, 136 | and ring count statistics 137 | list of int: list with indices of molecules that failed the valency check 138 | list of int: list with indices of molecules that consist of disconnected parts 139 | int: number of molecules processed 140 | ''' 141 | # initial setup 142 | count = 0 # count the number of invalid molecules 143 | disc = [] # store indices of disconnected molecules 144 | inval = [] # store indices of invalid molecules 145 | data_list = [] # store data fields of molecules for new db 146 | mols = [] # store molecules (as ase.Atoms objects) 147 | compressor = ConnectivityCompressor() # (de)compress sparse connectivity matrices 148 | stats = np.empty((len(get_count_statistics(get_stat_heads=True)), 0)) 149 | n_all = len(mol_idcs) 150 | 151 | with connect(source_db) as source_db: 152 | # iterate over provided indices 153 | for i in mol_idcs: 154 | i = int(i) 155 | # skip molecule if present in invalid_list and remove_invalid is True 156 | if remove_invalid and invalid_list is not None: 157 | if i in invalid_list: 158 | continue 159 | # get molecule from database 160 | row = source_db.get(i + 1) 161 | data = row.data 162 | at = row.toatoms() 163 | # get positions and atomic numbers 164 | pos = at.positions 165 | numbers = at.numbers 166 | # center positions (using center of mass) 167 | pos = pos - at.get_center_of_mass() 168 | # order atoms by distance to center of mass 169 | center_dists = np.sqrt(np.maximum(np.sum(pos ** 2, axis=1), 0)) 170 | idcs_sorted = np.argsort(center_dists) 171 | pos = pos[idcs_sorted] 172 | numbers = numbers[idcs_sorted] 173 | # update positions and atomic numbers accordingly in Atoms object 174 | at.positions = pos 175 | at.numbers = numbers 176 | # instantiate utility_classes.Molecule object 177 | mol = Molecule(pos, numbers) 178 | # get connectivity matrix (detecting bond orders with Open Babel) 179 | con_mat = mol.get_connectivity() 180 | # stop if molecule is disconnected (and therefore invalid) 181 | if remove_invalid: 182 | if is_disconnected(con_mat): 183 | count += 1 184 | disc += [i] 185 | continue 186 | 187 | # check if valency constraints of all atoms in molecule are satisfied: 188 | # since the detection of bond orders for the connectivity matrix with Open 189 | # Babel is unreliable for certain cases (e.g. some aromatic rings) we 190 | # try to fix it manually (with heuristics) or by reshuffling the atom 191 | # order (as the bond order detection of Open Babel is sensitive to the 192 | # order of atoms) 193 | nums = numbers 194 | random_ord = np.arange(len(numbers)) 195 | for _ in range(10): # try 10 times before dismissing as invalid 196 | if np.all(np.sum(con_mat, axis=0) == valence[nums]): 197 | # valency is correct -> mark as valid and stop check 198 | val = True 199 | break 200 | else: 201 | # try to fix bond orders using heuristics 202 | val = False 203 | con_mat = mol.get_fixed_connectivity() 204 | if np.all(np.sum(con_mat, axis=0) == valence[nums]): 205 | # valency is now correct -> mark as valid and stop check 206 | val = True 207 | break 208 | # shuffle atom order before checking valency again 209 | random_ord = np.random.permutation(range(len(pos))) 210 | mol = Molecule(pos[random_ord], numbers[random_ord]) 211 | con_mat = mol.get_connectivity() 212 | nums = numbers[random_ord] 213 | if remove_invalid: 214 | if not val: 215 | # stop if molecule is invalid (it failed the repeated valence checks) 216 | count += 1 217 | inval += [i] 218 | continue 219 | 220 | if precompute_distances: 221 | # calculate pairwise distances of atoms and store them in data 222 | dists = pdist(pos)[:, None] 223 | data.update({'dists': dists}) 224 | if precompute_fingerprint: 225 | fp = np.array(mol.get_fp().fp, dtype=np.uint32) 226 | data.update({'fingerprint': fp}) 227 | 228 | # store compressed connectivity matrix in data 229 | rand_ord_rev = np.argsort(random_ord) 230 | con_mat = con_mat[rand_ord_rev][:, rand_ord_rev] 231 | data.update( 232 | {'con_mat': compressor.compress(con_mat)}) 233 | 234 | # update atom, bond, and ring count statistics 235 | stats = np.hstack((stats, get_count_statistics(mol=mol))) 236 | 237 | # add results to the lists 238 | mols += [at] 239 | data_list += [data] 240 | 241 | # print progress if desired 242 | if print_progress: 243 | if i % 100 == 0: 244 | print('\033[K', end='\r', flush=True) 245 | print(f'{100 * (i + 1) / n_all:.2f}%', end='\r', flush=True) 246 | 247 | return mols, data_list, stats, inval, disc, count 248 | 249 | 250 | def _processing_worker(q_in, q_out, task): 251 | ''' 252 | Simple worker function that repeatedly fulfills a task using transmitted input and 253 | sends back the results until a stop signal is received. Can be used as target in 254 | a multiprocessing.Process object. 255 | 256 | Args: 257 | q_in (multiprocessing.Queue): queue to receive a list with data. The first 258 | entry signals whether worker can stop and the remaining entries are used as 259 | input arguments to the task function 260 | q_out (multiprocessing.Queue): queue to send results from task back 261 | task (callable function): function that is called using the received data 262 | ''' 263 | while True: 264 | data = q_in.get(True) # receive data 265 | if data[0]: # stop if stop signal is received 266 | break 267 | results = task(*data[1:]) # fulfill task with received data 268 | q_out.put(results) # send back results 269 | 270 | 271 | def _submit_jobs(qs_out, count, chunk_size, n_all, working_flag, 272 | n_per_thread): 273 | ''' 274 | Function that submits a job to preprocess molecules to every provided worker. 275 | 276 | Args: 277 | qs_out (list of multiprocessing.Queue): queues used to send data to workers (one 278 | queue per worker) 279 | count (int): index of the earliest, not yet preprocessed molecule in the db 280 | chunk_size (int): number of molecules to be divided amongst workers 281 | n_all (int): total number of molecules in the db 282 | working_flag (array): flags indicating whether workers are running 283 | n_per_thread (int): number of molecules to be given to each thread 284 | 285 | Returns: 286 | numpy.ndarray: array with flags indicating whether workers got 287 | a job 288 | int: index of the new earliest, not yet preprocessed molecule in 289 | the db (after the submitted preprocessing jobs have been done) 290 | ''' 291 | # calculate indices of molecules that shall be preprocessed by workers 292 | idcs = np.arange(count, min(n_all, count + chunk_size)) 293 | start = 0 294 | for i, q in enumerate(qs_out): 295 | if start >= len(idcs): 296 | # stop if no more indices are left to submit 297 | break 298 | end = start + n_per_thread 299 | q.put((False, idcs[start:end])) # submit indices (and signal to not stop) 300 | working_flag[i] = 1 # set flag that current worker got a job 301 | start = end 302 | new_count = count + len(idcs) 303 | return working_flag, new_count 304 | 305 | 306 | def preprocess_dataset(datapath, valence_list, n_threads, n_mols_per_thread=100, 307 | logging_print=True, new_db_path=None, precompute_distances=True, 308 | precompute_fingerprint=False, remove_invalid=True, 309 | invalid_list=None): 310 | ''' 311 | Pre-processes all molecules of a dataset using the provided valency information. 312 | Multi-threading is used to speed up the process. 313 | Along with a new database containing the pre-processed molecules, a 314 | "input_db_invalid.txt" file holding the indices of removed molecules (which 315 | do not pass the valence or connectivity checks, omitted if remove_invalid is False) 316 | and a "new_db_statistics.npz" file (containing atom, bond, and ring count statistics 317 | for all molecules in the new database) are stored. 318 | 319 | Args: 320 | datapath (str): full path to dataset (ase.db database) 321 | valence_list (list): the valence of atom types in the form 322 | [type1 valence type2 valence ...] 323 | n_threads (int): number of threads used (0 for no extra threads) 324 | n_mols_per_thread (int, optional): number of molecules processed by each 325 | thread at each iteration (default: 100) 326 | logging_print (bool, optional): set True to show output with logging.info 327 | instead of standard printing (default: True) 328 | new_db_path (str, optional): full path to new database where pre-processed 329 | molecules shall be stored (None to simply append "gen" to the name in 330 | datapath, default: None) 331 | precompute_distances (bool, optional): if True, the pairwise distances between 332 | atoms in each molecule are computed and stored in the database (default: 333 | True) 334 | precompute_fingerprint (bool, optional): if True, the fingerprint of each 335 | molecule is computed and stored in the database (default: False) 336 | remove_invalid (bool, optional): if True, molecules that do not pass the 337 | valency or connectivity check are removed from the new database (note 338 | that a precomputed list of invalid molecules determined using the code in 339 | this file is fetched from our repository if possible, default: True) 340 | invalid_list (list of int, optional): precomputed list containing indices of 341 | molecules that are marked as invalid (because they did not pass the 342 | valency or connectivity checks in earlier runs, default: None) 343 | ''' 344 | # convert paths 345 | datapath = Path(datapath) 346 | if new_db_path is None: 347 | new_db_path = datapath.parent / (datapath.stem + 'gen.db') 348 | else: 349 | new_db_path = Path(new_db_path) 350 | 351 | # compute array where the valency constraint of atom type i is stored at entry i 352 | max_type = max(valence_list[::2]) 353 | valence = np.zeros(max_type + 1, dtype=int) 354 | valence[valence_list[::2]] = valence_list[1::2] 355 | 356 | def _print(x, end='\n', flush=False): 357 | if logging_print: 358 | logging.info(x) 359 | else: 360 | print(x, end=end, flush=flush) 361 | 362 | with connect(datapath) as db: 363 | n_all = db.count() 364 | if n_all == 0: 365 | _print('No molecules found in data base!') 366 | sys.exit(0) 367 | _print('\nPre-processing data...') 368 | if logging_print: 369 | _print(f'Processed: 0 / {n_all}...') 370 | else: 371 | _print(f'0.00%', end='', flush=True) 372 | 373 | # initial setup 374 | n_iterations = 0 375 | chunk_size = n_threads * n_mols_per_thread 376 | current = 0 377 | count = 0 # count number of discarded (invalid etc.) molecules 378 | disc = [] 379 | inval = [] 380 | stats = np.empty((len(get_count_statistics(get_stat_heads=True)), 0)) 381 | working_flag = np.zeros(n_threads, dtype=bool) 382 | start_time = time.time() 383 | if invalid_list is not None and remove_invalid: 384 | invalid_list = {*invalid_list} 385 | n_inval = len(invalid_list) 386 | else: 387 | n_inval = 0 388 | 389 | with connect(new_db_path) as new_db: 390 | 391 | if precompute_fingerprint: 392 | # store the system's byte order of unsigned integers (for the fingerprints) 393 | new_db.metadata = \ 394 | {'fingerprint_format': '>u4' if sys.byteorder == 'big' else '= 1: 397 | # set up threads and queues 398 | threads = [] 399 | qs_in = [] 400 | qs_out = [] 401 | for i in range(n_threads): 402 | qs_in += [Queue(1)] 403 | qs_out += [Queue(1)] 404 | threads += \ 405 | [Process(target=_processing_worker, 406 | name=str(i), 407 | args=(qs_out[-1], 408 | qs_in[-1], 409 | lambda x: 410 | preprocess_molecules(x, 411 | datapath, 412 | valence, 413 | precompute_distances, 414 | precompute_fingerprint, 415 | remove_invalid, 416 | invalid_list)))] 417 | threads[-1].start() 418 | 419 | # submit first round of jobs 420 | working_flag, current = \ 421 | _submit_jobs(qs_out, current, chunk_size, n_all, 422 | working_flag, n_mols_per_thread) 423 | 424 | while np.any(working_flag == 1): 425 | n_iterations += 1 426 | 427 | # initialize new iteration 428 | results = [] 429 | 430 | # gather results 431 | for i, q in enumerate(qs_in): 432 | if working_flag[i]: 433 | results += [q.get()] 434 | working_flag[i] = 0 435 | 436 | # submit new jobs 437 | working_flag, current_new = \ 438 | _submit_jobs(qs_out, current, chunk_size, n_all, working_flag, 439 | n_mols_per_thread) 440 | 441 | # store gathered results 442 | for res in results: 443 | mols, data_list, _stats, _inval, _disc, _c = res 444 | for (at, data) in zip(mols, data_list): 445 | new_db.write(at, data=data) 446 | stats = np.hstack((stats, _stats)) 447 | inval += _inval 448 | disc += _disc 449 | count += _c 450 | 451 | # print progress 452 | if logging_print and n_iterations % 10 == 0: 453 | _print(f'Processed: {current:6d} / {n_all}...') 454 | elif not logging_print: 455 | _print('\033[K', end='\r', flush=True) 456 | _print(f'{100 * current / n_all:.2f}%', end='\r', 457 | flush=True) 458 | current = current_new # update current position in database 459 | 460 | # stop worker threads and join 461 | for i, q_out in enumerate(qs_out): 462 | q_out.put((True,)) 463 | threads[i].join() 464 | threads[i].terminate() 465 | if logging_print: 466 | _print(f'Processed: {n_all} / {n_all}...') 467 | 468 | else: 469 | results = preprocess_molecules(range(n_all), datapath, valence, 470 | precompute_distances, 471 | precompute_fingerprint, remove_invalid, 472 | invalid_list, print_progress=True) 473 | mols, data_list, stats, inval, disc, count = results 474 | for (at, data) in zip(mols, data_list): 475 | new_db.write(at, data=data) 476 | 477 | if not logging_print: 478 | _print('\033[K', end='\n', flush=True) 479 | _print(f'... successfully validated {n_all - count - n_inval} data ' 480 | f'points!', flush=True) 481 | if invalid_list is not None: 482 | _print(f'{n_inval} structures were removed because they are on the ' 483 | f'pre-computed list of invalid molecules!', flush=True) 484 | if len(disc)+len(inval) > 0: 485 | _print(f'CAUTION: Could not validate {len(disc)+len(inval)} additional ' 486 | f'molecules. These were also removed and their indices are ' 487 | f'appended to the list of invalid molecules stored at ' 488 | f'{datapath.parent / (datapath.stem + f"_invalid.txt")}', 489 | flush=True) 490 | np.savetxt(datapath.parent / (datapath.stem + f'_invalid.txt'), 491 | np.append(np.sort(list(invalid_list)), np.sort(inval + disc)), 492 | fmt='%d') 493 | elif remove_invalid: 494 | _print(f'Identified {len(disc)} disconnected structures, and {len(inval)} ' 495 | f'structures with invalid valence!', flush=True) 496 | np.savetxt(datapath.parent / (datapath.stem + f'_invalid.txt'), 497 | np.sort(inval + disc), fmt='%d') 498 | _print('\nCompressing and storing statistics with numpy...') 499 | np.savez_compressed(new_db_path.parent/(new_db_path.stem+f'_statistics.npz'), 500 | stats=stats, 501 | stat_heads=get_count_statistics(get_stat_heads=True)) 502 | 503 | end_time = time.time() - start_time 504 | m, s = divmod(end_time, 60) 505 | h, m = divmod(m, 60) 506 | h, m, s = int(h), int(m), int(s) 507 | _print(f'Done! Pre-processing needed {h:d}h{m:02d}m{s:02d}s.') 508 | 509 | 510 | if __name__ == '__main__': 511 | parser = get_parser() 512 | args = parser.parse_args() 513 | preprocess_dataset(**vars(args)) 514 | -------------------------------------------------------------------------------- /published_data/README.md: -------------------------------------------------------------------------------- 1 | # Generated molecules and pretrained models 2 | 3 | Here we provide links to the molecules generated with cG-SchNet in the studies reported in the paper. This allows for reproduction of the shown graphs and statistics, as well as further analysis of the obtained molecules. Furthermore, we provide two pretrained cG-SchNet models that were used in our experiments, which enables sampling of additional molecules. 4 | 5 | ## Data bases with generated molecules 6 | A zip-file containing the generated molecules can be found under [DOI 10.14279/depositonce-14977](http://dx.doi.org/10.14279/depositonce-14977). 7 | It includes five folders with molecules generated by five cG-SchNet models conditioned on different combinations of properties. 8 | They appear in the same order as in the paper, i.e. the first model is conditioned on isotropic polarizability, the second is conditioned on fingerprints, the third is conditioned on HOMO-LUMO gap and atomic composition, the fourth is conditioned on relative atomic energy and atomic composition, and the fifth is conditioned on HOMO-LUMO gap and relative atomic energy. Each folder contains several data bases following the naming convention _\\_\.db_, where _\_ lists the conditioning target values (separated with underscores) and _\_ is either _generated_ for raw, generated structures or _relaxed_ for the relaxed counterparts (wherever we computed them). For example _/4\_comp\_relenergy/c7o2h10\_-0.1\_relaxed.db_ contains relaxed molecules that were generated by the fourth cG-SchNet model (conditioned on atomic composition and relative atomic energy) using the composition c7o2h10 and a relative atomic energy of -0.1 eV as targets during sampling. Relaxation was carried out at the same level of theory used for the QM9 training data. Please refer to the publication for details on the procedure. 9 | 10 | For help with the content of each data base, please call the script _data\_base\_help.py_ which will print the metadata of a selected data base and help to get started: 11 | 12 | ``` 13 | python ./cG-SchNet/published_data/data_base_help.py 14 | ``` 15 | 16 | For example, the output for the data base _/4\_comp\_relenergy/c7o2h10\_-0.1\_relaxed.db_ is as following: 17 | 18 | ``` 19 | INFO for data base at 20 | =========================================================================================================== 21 | Contains 2719 C7O2H10 isomers generated by cG-SchNet and subsequently relaxed with DFT calculations. 22 | All reported energy values are in eV and a systematic offset between the ORCA reference calculations and the QM9 reference values was approximated and compensated. For reference, the uncompensated energy values directly obtained with ORCA are also provided. 23 | The corresponding raw, generated version of each structure can be found in the data base of generated molecules using the provided "gen_idx" (caution, it starts from 0 whereas ase loads data bases counting from 1). 24 | The field "rmsd" holds the root mean square deviation between the atom positions of the raw, generated structure and the relaxed equilibrium molecule in Å (it includes hydrogen atoms). 25 | The field "changed" is 0 if the connectivity of the molecule before and after relaxation is identical (i.e. no bonds were broken or newly formed) and 1 if it did change. 26 | The field "known_relaxed" is 0, 3, or 6 to mark novel isomers, novel stereo-isomers, and unseen isomers (i.e. isomers resembling test data), respectively. 27 | Originally, all 3349 unique and valid C7O2H10 isomers among 100k molecules generated by cG-SchNet were chosen for relaxation. 7 of these did not converge to a valid configuration and 630 converged to equilibrium configurations already covered by other generated isomers (i.e. they were duplicate structures) and were therefore removed. 28 | =========================================================================================================== 29 | 30 | For example, here is the data stored with the first three molecules: 31 | 0: {'gen_idx': 0, 'computed_relative_atomic_energy': -0.11081359089212128, 'computed_energy_U0': -11512.699587841627, 'computed_energy_U0_uncompensated': -11512.578122651728, 'rmsd': 0.2492194704871389, 'changed': 0, 'known_relaxed': 3} 32 | 1: {'gen_idx': 1, 'computed_relative_atomic_energy': -0.13234655336327705, 'computed_energy_U0': -11513.108714128579, 'computed_energy_U0_uncompensated': -11512.98724893868, 'rmsd': 0.48381233235411153, 'changed': 0, 'known_relaxed': 3} 33 | 2: {'gen_idx': 2, 'computed_relative_atomic_energy': -0.11825264882338615, 'computed_energy_U0': -11512.84092994232, 'computed_energy_U0_uncompensated': -11512.719464752421, 'rmsd': 0.21804191287675742, 'changed': 0, 'known_relaxed': 3} 34 | 35 | You can load and access the molecules and accompanying data by connecting to the data base with ASE, e.g. using the following python code snippet: 36 | from ase.db import connect 37 | with connect() as con: 38 | row = con.get(1) # load the first molecule, 1-based indexing 39 | R = row.positions # positions of atoms as 3d coordinates 40 | Z = row.numbers # list of atomic numbers 41 | data = row.data # dictionary of data stored with the molecule 42 | 43 | You can visualize the molecules in the data base with ASE from the command line by calling: 44 | ase gui 45 | ``` 46 | 47 | Note that references to molecules from the QM9 data set always correspond to the indices _after_ removing the invalid molecules listed in [the file of invalid structures](https://github.com/atomistic-machine-learning/cG-SchNet/blob/main/splits/qm9_invalid.txt). These structures were automatically removed if QM9 was downloaded with the data script in this repository (e.g. by starting [model training](https://github.com/atomistic-machine-learning/cG-SchNet#training-a-model)). The resulting data base with corresponding indices can be found in your data directory: ```/qm9gen.db```. 48 | 49 | ## Pretrained models 50 | A zip-file containing two pretrained cG-SchNet models can be found under [DOI 10.14279/depositonce-14978](http://dx.doi.org/10.14279/depositonce-14978). The archive consists of two folders, where _comp\_relenergy_ hosts the model that was conditioned on atomic composition and relative atomic energy and used for the study described in Fig. 4 in the paper. The other model was conditioned on the HOMO-LUMO gap and relative atomic energy and used in the study described in Fig. 5 in the paper. 51 | 52 | In order to generate molecules with the pretrained models, simply extract the folders into your model directory and adapt the call for generating molecules described in the [main readme](https://github.com/atomistic-machine-learning/cG-SchNet#generating-molecules) accordingly. For example, you can generate 20k molecules with the composition c7o2h10 and a relative atomic energy of -0.1 eV as targets with: 53 | 54 | ``` 55 | python ./cG-SchNet/gschnet_cond_script.py generate gschnet ./models/comp_relenergy/ 20000 --conditioning "composition 10 7 0 2 0; n_atoms 19; relative_atomic_energy -0.1" --cuda 56 | ``` 57 | 58 | Note that the model takes a string with three conditions as input to the --conditioning argument: the number of atoms of each type in the order h c n o f, the total number of atoms, and the relative atomic energy value, each separated with a semicolon. 59 | Similarly, you can generate 20k molecules with a HOMO-LUMO gap of 4 eV and relative atomic energy of -0.2 eV as targets with the other model: 60 | 61 | ``` 62 | python ./cG-SchNet/gschnet_cond_script.py generate gschnet ./models/gap_relenergy/ 20000 --conditioning "gap 4.0; relative_atomic_energy -0.2" --cuda 63 | ``` 64 | 65 | The second model takes two conditions, the gap value and the energy value, as targets. 66 | For more details on the generation of molecules and subsequent filtering, please refer to the [main readme](https://github.com/atomistic-machine-learning/cG-SchNet#filtering-and-analysis-of-generated-molecules). 67 | -------------------------------------------------------------------------------- /published_data/data_base_help.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from pathlib import Path 4 | from ase.db import connect 5 | 6 | 7 | def get_parser(): 8 | """ Setup parser for command line arguments """ 9 | main_parser = argparse.ArgumentParser() 10 | main_parser.add_argument('data_base_path', type=str, 11 | help='Path to data base (.db file) with molecules.') 12 | 13 | return main_parser 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = get_parser() 18 | args = parser.parse_args() 19 | 20 | path = Path(args.data_base_path).resolve() 21 | if not path.exists(): 22 | print(f'Argument error! There is no data base at "{path}"!') 23 | sys.exit(0) 24 | 25 | with connect(args.data_base_path) as con: 26 | if con.count() == 0: 27 | print(f'Error: The data base at "{path}" is empty!') 28 | sys.exit(0) 29 | elif 'info' not in con.metadata: 30 | print(f'Error: The metadata of data base at "{path}" does not contain the field "info"!') 31 | print(f'However, this is the data stored for the first molecule:\n{con.get(1).data}') 32 | sys.exit(0) 33 | print(f'\nINFO for data base at "{path}"') 34 | print('===========================================================================================================') 35 | print(con.metadata['info']) 36 | print('===========================================================================================================') 37 | print('\nFor example, here is the data stored with the first three molecules:') 38 | for i in range(3): 39 | print(f'{i}: {con.get(i+1).data}') 40 | print('\nYou can load and access the molecules and accompanying data by connecting to the data base with ASE, e.g. using the following python code snippet:') 41 | print(f'from ase.db import connect') 42 | print(f'with connect({path}) as con:\n', 43 | f'\trow = con.get(1) # load the first molecule, 1-based indexing\n', 44 | f'\tR = row.positions # positions of atoms as 3d coordinates\n', 45 | f'\tZ = row.numbers # list of atomic numbers\n', 46 | f'\tdata = row.data # dictionary of data stored with the molecule\n') 47 | print(f'You can visualize the molecules in the data base with ASE from the command line by calling:') 48 | print(f'ase gui {path}') 49 | 50 | -------------------------------------------------------------------------------- /qm9_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import shutil 5 | import tarfile 6 | import tempfile 7 | from pathlib import Path 8 | from urllib import request as request 9 | from urllib.error import HTTPError, URLError 10 | from base64 import b64encode, b64decode 11 | 12 | import numpy as np 13 | import torch 14 | from ase.db import connect 15 | from ase.io.extxyz import read_xyz 16 | from ase.units import Debye, Bohr, Hartree, eV 17 | 18 | from schnetpack import Properties 19 | from schnetpack.datasets import DownloadableAtomsData 20 | from utility_classes import ConnectivityCompressor 21 | from preprocess_dataset import preprocess_dataset 22 | 23 | 24 | class QM9gen(DownloadableAtomsData): 25 | """ QM9 benchmark dataset for organic molecules with up to nine non-hydrogen atoms 26 | from {C, O, N, F}. 27 | 28 | This class adds convenience functions to download QM9 from figshare, 29 | pre-process the data such that it can be used for moleculec generation with the 30 | G-SchNet model, and load the data into pytorch. 31 | 32 | Args: 33 | path (str): path to directory containing qm9 database 34 | subset (list, optional): indices of subset, set to None for entire dataset 35 | (default: None). 36 | download (bool, optional): enable downloading if qm9 database does not 37 | exists (default: True) 38 | precompute_distances (bool, optional): if True and the pre-processed 39 | database does not yet exist, the pairwise distances of atoms in the 40 | dataset's molecules will be computed during pre-processing and stored in 41 | the database (increases storage demand of the dataset but decreases 42 | computational cost during training as otherwise the distances will be 43 | computed once in every epoch, default: True) 44 | remove_invalid (bool, optional): if True QM9 molecules that do not pass the 45 | valence check will be removed from the training data (note 1: the 46 | validity is per default inferred from a pre-computed list in our 47 | repository but will be assessed locally if the download fails, 48 | note2: only works if the pre-processed database does not yet exist, 49 | default: True) 50 | 51 | References: 52 | .. [#qm9_1] https://ndownloader.figshare.com/files/3195404 53 | """ 54 | 55 | # general settings for the dataset 56 | available_atom_types = [1, 6, 7, 8, 9] # all atom types found in the dataset 57 | atom_types_valence = [1, 4, 3, 2, 1] # valence constraints of the atom types 58 | radial_limits = [0.9, 1.7] # minimum and maximum distance between neighboring atoms 59 | 60 | # properties 61 | A = 'rotational_constant_A' 62 | B = 'rotational_constant_B' 63 | C = 'rotational_constant_C' 64 | mu = 'dipole_moment' 65 | alpha = 'isotropic_polarizability' 66 | homo = 'homo' 67 | lumo = 'lumo' 68 | gap = 'gap' 69 | r2 = 'electronic_spatial_extent' 70 | zpve = 'zpve' 71 | U0 = 'energy_U0' 72 | U = 'energy_U' 73 | H = 'enthalpy_H' 74 | G = 'free_energy' 75 | Cv = 'heat_capacity' 76 | 77 | properties = [ 78 | A, B, C, mu, alpha, 79 | homo, lumo, gap, r2, zpve, 80 | U0, U, H, G, Cv, 'n_atoms', 'relative_atomic_energy' 81 | ] 82 | 83 | units = [1., 1., 1., Debye, Bohr ** 3, 84 | Hartree, Hartree, Hartree, 85 | Bohr ** 2, Hartree, 86 | Hartree, Hartree, Hartree, 87 | Hartree, 1., 88 | ] 89 | 90 | units_dict = dict(zip(properties, units)) 91 | 92 | connectivity_compressor = ConnectivityCompressor() 93 | 94 | def __init__(self, path, subset=None, download=True, precompute_distances=True, 95 | remove_invalid=True, load_additionally=None): 96 | self.path = path 97 | self.dbpath = os.path.join(self.path, f'qm9gen.db') 98 | self.precompute_distances = precompute_distances 99 | self.remove_invalid = remove_invalid 100 | self.load_additionally = [] if load_additionally is None else load_additionally 101 | self.db_metadata = None 102 | # hard coded weights for regression of energy per atom from the concentration of atoms (i.e. the composition divided by the absolute number of atoms) 103 | self.energy_regression = lambda x: x.dot(np.array([-1032.61411992, -2052.26704777, -2506.01057452, -3063.2081989, -3733.79421864])) + 1016.2017264092275 104 | 105 | super().__init__(self.dbpath, subset=subset, 106 | available_properties=self.properties, 107 | units=self.units, download=download) 108 | 109 | with connect(self.dbpath) as db: 110 | if db.count() <= 0: 111 | logging.error('Error: Data base is empty, please provide a path ' 112 | 'to a proper data base or an empty path where the ' 113 | 'data can be downloaded to!') 114 | raise FileExistsError() 115 | if self.precompute_distances and 'dists' not in db.get(1).data: 116 | logging.info('Caution: Existing data base does not contain ' 117 | 'pre-computed distances, distances will be computed ' 118 | 'on the fly in every epoch.') 119 | if 'fingerprint' in self.load_additionally: 120 | if 'fingerprint' not in db.get(1).data: 121 | logging.error('Error: Fingerprints not found in the provided data ' 122 | 'base, please provide another path to the correct ' 123 | 'data base or an empty directory where a new data ' 124 | 'base with pre-computed fingerprints can be ' 125 | 'downloaded to.') 126 | raise FileExistsError() 127 | else: 128 | self.db_metadata = db.metadata 129 | for prop in self.load_additionally: 130 | if prop not in db.get(1).data and prop not in self.properties: 131 | logging.error(f'Error: Unknown property ({prop}) requested for ' 132 | f'conditioning. Cannot obtain that property from ' 133 | f'the database!') 134 | raise FileExistsError() 135 | 136 | def create_subset(self, idx): 137 | """ 138 | Returns a new dataset that only consists of provided indices. 139 | 140 | Args: 141 | idx (numpy.ndarray): subset indices 142 | 143 | Returns: 144 | schnetpack.data.AtomsData: dataset with subset of original data 145 | """ 146 | idx = np.array(idx) 147 | subidx = idx if self.subset is None or len(idx) == 0 \ 148 | else np.array(self.subset)[idx] 149 | return type(self)(self.path, subidx, download=False, 150 | load_additionally=self.load_additionally) 151 | 152 | def get_properties(self, idx): 153 | _idx = self._subset_index(idx) 154 | with connect(self.dbpath) as db: 155 | row = db.get(_idx + 1) 156 | at = row.toatoms() 157 | 158 | # extract/calculate structure 159 | properties = {} 160 | properties[Properties.Z] = torch.LongTensor(at.numbers.astype(np.int)) 161 | positions = at.positions.astype(np.float32) 162 | positions -= at.get_center_of_mass() # center positions 163 | properties[Properties.R] = torch.FloatTensor(positions) 164 | properties[Properties.cell] = torch.FloatTensor(at.cell.astype(np.float32)) 165 | 166 | # recover connectivity matrix from compressed format 167 | con_mat = self.connectivity_compressor.decompress(row.data['con_mat']) 168 | # save in dictionary 169 | properties['_con_mat'] = torch.FloatTensor(con_mat.astype(np.float32)) 170 | 171 | # extract pre-computed distances (if they exist) 172 | if 'dists' in row.data: 173 | properties['dists'] = row.data['dists'] 174 | 175 | # extract additional information 176 | for add_prop in self.load_additionally: 177 | if add_prop == 'fingerprint': 178 | fp = np.array(row.data[add_prop], 179 | dtype=self.db_metadata['fingerprint_format']) 180 | properties[add_prop] = torch.FloatTensor( 181 | np.unpackbits(fp.view(np.uint8), bitorder='little')) 182 | elif add_prop == 'n_atoms': 183 | properties['n_atoms'] = torch.FloatTensor([len(at.numbers)]) 184 | elif add_prop == 'relative_atomic_energy': 185 | types = at.numbers.astype(np.int) 186 | composition = np.array([np.sum(types == 1), 187 | np.sum(types == 6), 188 | np.sum(types == 7), 189 | np.sum(types == 8), 190 | np.sum(types == 9)], 191 | dtype=np.float32) 192 | concentration = composition/np.sum(composition) 193 | energy = row.data['energy_U0'] 194 | energy_per_atom = energy/len(types) 195 | relative_atomic_energy = energy_per_atom - self.energy_regression(concentration) 196 | properties[add_prop] = torch.FloatTensor([relative_atomic_energy]) 197 | else: 198 | properties[add_prop] = torch.FloatTensor([row.data[add_prop]]) 199 | 200 | # get atom environment 201 | nbh_idx, offsets = self.environment_provider.get_environment(at) 202 | # store neighbors, cell, and index 203 | properties[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) 204 | properties[Properties.cell_offset] = torch.FloatTensor( 205 | offsets.astype(np.float32)) 206 | properties["_idx"] = torch.LongTensor(np.array([idx], dtype=np.int)) 207 | 208 | return at, properties 209 | 210 | def _download(self): 211 | works = True 212 | if not os.path.exists(self.dbpath): 213 | qm9_path = os.path.join(self.path, f'qm9.db') 214 | if not os.path.exists(qm9_path): 215 | works = works and self._load_data() 216 | works = works and self._preprocess_qm9() 217 | return works 218 | 219 | def _load_data(self): 220 | logging.info('Downloading GDB-9 data...') 221 | tmpdir = tempfile.mkdtemp('gdb9') 222 | tar_path = os.path.join(tmpdir, 'gdb9.tar.gz') 223 | raw_path = os.path.join(tmpdir, 'gdb9_xyz') 224 | url = 'https://ndownloader.figshare.com/files/3195389' 225 | 226 | try: 227 | request.urlretrieve(url, tar_path) 228 | logging.info('Done.') 229 | except HTTPError as e: 230 | logging.error('HTTP Error:', e.code, url) 231 | return False 232 | except URLError as e: 233 | logging.error('URL Error:', e.reason, url) 234 | return False 235 | 236 | logging.info('Extracting data from tar file...') 237 | tar = tarfile.open(tar_path) 238 | tar.extractall(raw_path) 239 | tar.close() 240 | logging.info('Done.') 241 | 242 | logging.info('Parsing xyz files...') 243 | with connect(os.path.join(self.path, 'qm9.db')) as db: 244 | ordered_files = sorted(os.listdir(raw_path), 245 | key=lambda x: (int(re.sub('\D', '', x)), x)) 246 | for i, xyzfile in enumerate(ordered_files): 247 | xyzfile = os.path.join(raw_path, xyzfile) 248 | 249 | if (i + 1) % 10000 == 0: 250 | logging.info('Parsed: {:6d} / 133885'.format(i + 1)) 251 | properties = {} 252 | tmp = os.path.join(tmpdir, 'tmp.xyz') 253 | 254 | with open(xyzfile, 'r') as f: 255 | lines = f.readlines() 256 | l = lines[1].split()[2:] 257 | for pn, p in zip(self.properties, l): 258 | properties[pn] = float(p) * self.units[pn] 259 | with open(tmp, "wt") as fout: 260 | for line in lines: 261 | fout.write(line.replace('*^', 'e')) 262 | 263 | with open(tmp, 'r') as f: 264 | ats = list(read_xyz(f, 0))[0] 265 | db.write(ats, data=properties) 266 | logging.info('Done.') 267 | 268 | shutil.rmtree(tmpdir) 269 | 270 | return True 271 | 272 | def _preprocess_qm9(self): 273 | # try to download pre-computed list of invalid molecules 274 | raw_path = os.path.join(self.path, 'qm9_invalid.txt') 275 | if os.path.exists(raw_path): 276 | logging.info(f'Found existing list with indices of molecules in QM9 that are invalid at "{raw_path}".' 277 | f' Please manually delete the file and restart training if you want to use the default list instead.') 278 | invalid_list = np.loadtxt(raw_path) 279 | else: 280 | logging.info('Downloading pre-computed list of invalid QM9 molecules...') 281 | # url = 'https://github.com/atomistic-machine-learning/cG-SchNet/blob/main/splits/' \ 282 | # 'qm9_invalid.txt?raw=true' 283 | try: 284 | url = Path(__file__).parent.resolve() / 'splits/qm9_invalid.txt' 285 | request.urlretrieve(url.as_uri(), raw_path) 286 | logging.info('Done.') 287 | invalid_list = np.loadtxt(raw_path) 288 | except HTTPError as e: 289 | logging.error('HTTP Error:', e.code, url) 290 | logging.info('CAUTION: Could not download pre-computed list, will assess ' 291 | 'validity during pre-processing.') 292 | invalid_list = None 293 | except URLError as e: 294 | logging.error('URL Error:', e.reason, url) 295 | logging.info('CAUTION: Could not download pre-computed list, will assess ' 296 | 'validity during pre-processing.') 297 | invalid_list = None 298 | except ValueError as e: 299 | logging.error('Value Error:', e) 300 | logging.info('CAUTION: Could not download pre-computed list, will assess ' 301 | 'validity during pre-processing.') 302 | invalid_list = None 303 | # check validity of molecules and store connectivity matrices and interatomic 304 | # distances in database as a pre-processing step 305 | qm9_db = os.path.join(self.path, f'qm9.db') 306 | valence_list = \ 307 | np.array([self.available_atom_types, self.atom_types_valence]).flatten('F') 308 | precompute_fingerprint = 'fingerprint' in self.load_additionally 309 | preprocess_dataset(datapath=qm9_db, valence_list=list(valence_list), 310 | n_threads=8, n_mols_per_thread=125, logging_print=True, 311 | new_db_path=self.dbpath, 312 | precompute_distances=self.precompute_distances, 313 | precompute_fingerprint=precompute_fingerprint, 314 | remove_invalid=self.remove_invalid, 315 | invalid_list=invalid_list) 316 | return True 317 | 318 | def get_available_properties(self, available_properties): 319 | # we don't use properties other than stored connectivity matrices (and 320 | # distances, if they were precomputed) so we skip this part 321 | return available_properties 322 | 323 | def create_splits(self, num_train=None, num_val=None, split_file=None): 324 | """ 325 | Splits the dataset into train/validation/test splits, writes split to 326 | an npz file and returns subsets. Either the sizes of training and 327 | validation split or an existing split file with split indices have to 328 | be supplied. The remaining data will be used in the test dataset. 329 | Args: 330 | num_train (int): number of training examples 331 | num_val (int): number of validation examples 332 | split_file (str): Path to split file. If file exists, splits will 333 | be loaded. Otherwise, a new file will be created 334 | where the generated split is stored. 335 | Returns: 336 | schnetpack.data.AtomsData: training dataset 337 | schnetpack.data.AtomsData: validation dataset 338 | schnetpack.data.AtomsData: test dataset 339 | """ 340 | invalid_file_path = os.path.join(self.path, f"qm9_invalid.txt") 341 | if not os.path.exists(invalid_file_path): 342 | raise ValueError(f"Cannot find required file with indices of QM9 molecules that are invalid at {invalid_file_path}!") 343 | removed_idx = {*np.loadtxt(invalid_file_path).astype(int)} 344 | if split_file is not None and os.path.exists(split_file): 345 | S = np.load(split_file) 346 | train_idx = S["train_idx"].tolist() 347 | val_idx = S["val_idx"].tolist() 348 | test_idx = S["test_idx"].tolist() 349 | invalid_idx = {*S["invalid_idx"]} 350 | else: 351 | if num_train is None or num_val is None: 352 | raise ValueError( 353 | "You have to supply either split sizes (num_train /" 354 | + " num_val) or an npz file with splits." 355 | ) 356 | 357 | assert num_train + num_val <= len( 358 | self 359 | ), "Dataset is smaller than num_train + num_val!" 360 | 361 | num_train = num_train if num_train > 1 else num_train * len(self) 362 | num_val = num_val if num_val > 1 else num_val * len(self) 363 | num_train = int(num_train) 364 | num_val = int(num_val) 365 | 366 | idx = np.random.permutation(len(self)) 367 | train_idx = idx[:num_train].tolist() 368 | val_idx = idx[num_train : num_train + num_val].tolist() 369 | test_idx = idx[num_train + num_val :].tolist() 370 | invalid_idx = removed_idx 371 | 372 | if split_file is not None: 373 | np.savez( 374 | split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, 375 | invalid_idx=sorted(list(invalid_idx)) 376 | ) 377 | 378 | if len(removed_idx) != len(invalid_idx) or len(removed_idx.difference(invalid_idx)) != 0: 379 | raise ValueError(f"Mismatch between the data base used to generate the provided split file and your local database. " 380 | + f"Please specify an empty data directory to re-download QM9 and try again.") 381 | train = self.create_subset(train_idx) 382 | val = self.create_subset(val_idx) 383 | test = self.create_subset(test_idx) 384 | return train, val, test 385 | -------------------------------------------------------------------------------- /splits/1_polarizability_split.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/1_polarizability_split.npz -------------------------------------------------------------------------------- /splits/2_fingerprint_split.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/2_fingerprint_split.npz -------------------------------------------------------------------------------- /splits/3_gap_comp_split.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/3_gap_comp_split.npz -------------------------------------------------------------------------------- /splits/4_comp_relenergy_split.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/4_comp_relenergy_split.npz -------------------------------------------------------------------------------- /splits/5_gap_relenergy_split.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/5_gap_relenergy_split.npz -------------------------------------------------------------------------------- /splits/qm9_invalid.txt: -------------------------------------------------------------------------------- 1 | 270 2 | 281 3 | 1116 4 | 1641 5 | 1647 6 | 1669 7 | 3837 8 | 3894 9 | 3898 10 | 4009 11 | 4014 12 | 4806 13 | 4870 14 | 4871 15 | 5033 16 | 5556 17 | 5598 18 | 5803 19 | 6032 20 | 6058 21 | 6074 22 | 6169 23 | 6329 24 | 6345 25 | 6682 26 | 6725 27 | 6732 28 | 7325 29 | 7334 30 | 7374 31 | 7375 32 | 7409 33 | 7416 34 | 7418 35 | 7428 36 | 7601 37 | 7653 38 | 7656 39 | 8397 40 | 9985 41 | 10015 42 | 10275 43 | 10294 44 | 10348 45 | 10784 46 | 11387 47 | 13480 48 | 13804 49 | 13832 50 | 14120 51 | 14317 52 | 15511 53 | 17807 54 | 17808 55 | 18307 56 | 20389 57 | 20397 58 | 20481 59 | 20489 60 | 20641 61 | 20678 62 | 20682 63 | 20691 64 | 21118 65 | 21126 66 | 21164 67 | 21363 68 | 21463 69 | 21520 70 | 21533 71 | 21534 72 | 21554 73 | 21566 74 | 21610 75 | 21619 76 | 21702 77 | 21705 78 | 21716 79 | 21717 80 | 21724 81 | 21740 82 | 21747 83 | 21757 84 | 21763 85 | 21837 86 | 21856 87 | 21858 88 | 21869 89 | 21874 90 | 21967 91 | 21968 92 | 21969 93 | 21970 94 | 21971 95 | 21972 96 | 21973 97 | 21974 98 | 21975 99 | 21976 100 | 21977 101 | 21978 102 | 21979 103 | 21980 104 | 21981 105 | 21982 106 | 21983 107 | 21984 108 | 21985 109 | 21986 110 | 21987 111 | 22008 112 | 22053 113 | 22089 114 | 22092 115 | 22096 116 | 22110 117 | 22115 118 | 22194 119 | 22202 120 | 22231 121 | 22237 122 | 22248 123 | 22264 124 | 22451 125 | 22459 126 | 22464 127 | 22468 128 | 22470 129 | 22471 130 | 22481 131 | 22494 132 | 22497 133 | 22498 134 | 22503 135 | 22542 136 | 22680 137 | 22685 138 | 22971 139 | 22973 140 | 22979 141 | 23148 142 | 23149 143 | 23371 144 | 23792 145 | 23798 146 | 23818 147 | 23821 148 | 23827 149 | 25530 150 | 25764 151 | 25811 152 | 25828 153 | 25829 154 | 25859 155 | 25868 156 | 25892 157 | 25914 158 | 26121 159 | 26152 160 | 26153 161 | 26186 162 | 26228 163 | 26229 164 | 26538 165 | 27271 166 | 27293 167 | 27322 168 | 27335 169 | 27388 170 | 27860 171 | 28082 172 | 28250 173 | 28383 174 | 28401 175 | 29149 176 | 29167 177 | 29539 178 | 29557 179 | 29563 180 | 30525 181 | 30526 182 | 30528 183 | 30529 184 | 30537 185 | 30539 186 | 30545 187 | 30546 188 | 30548 189 | 30550 190 | 30551 191 | 30705 192 | 30712 193 | 30760 194 | 30761 195 | 30762 196 | 30786 197 | 30787 198 | 30797 199 | 30901 200 | 30902 201 | 30903 202 | 30993 203 | 30994 204 | 30995 205 | 30999 206 | 31012 207 | 31106 208 | 31108 209 | 31109 210 | 31110 211 | 31111 212 | 31170 213 | 31502 214 | 31598 215 | 32413 216 | 32464 217 | 32759 218 | 32813 219 | 32865 220 | 32884 221 | 32941 222 | 32942 223 | 33399 224 | 36994 225 | 36995 226 | 37991 227 | 38082 228 | 42423 229 | 43212 230 | 43242 231 | 43474 232 | 43519 233 | 45540 234 | 45544 235 | 45545 236 | 45926 237 | 46610 238 | 49722 239 | 50308 240 | 50449 241 | 50619 242 | 50735 243 | 51245 244 | 51246 245 | 52007 246 | 53819 247 | 53820 248 | 53844 249 | 53891 250 | 53895 251 | 53938 252 | 53940 253 | 53942 254 | 53943 255 | 53953 256 | 54077 257 | 54078 258 | 54101 259 | 54118 260 | 54123 261 | 54125 262 | 54228 263 | 54243 264 | 54295 265 | 54383 266 | 54386 267 | 54399 268 | 54408 269 | 54409 270 | 54411 271 | 54421 272 | 54447 273 | 54448 274 | 54486 275 | 54537 276 | 54568 277 | 54581 278 | 54610 279 | 54614 280 | 54617 281 | 54618 282 | 54623 283 | 54628 284 | 54656 285 | 54690 286 | 54691 287 | 54765 288 | 54793 289 | 54794 290 | 54795 291 | 54810 292 | 54873 293 | 54895 294 | 54899 295 | 54903 296 | 54993 297 | 55144 298 | 55145 299 | 55186 300 | 55189 301 | 55266 302 | 55399 303 | 55407 304 | 55409 305 | 55437 306 | 55449 307 | 55475 308 | 55476 309 | 55478 310 | 55483 311 | 55498 312 | 55517 313 | 55557 314 | 55609 315 | 55610 316 | 55618 317 | 55619 318 | 55620 319 | 55700 320 | 55702 321 | 55790 322 | 55909 323 | 55943 324 | 56015 325 | 56054 326 | 56071 327 | 56240 328 | 56342 329 | 56343 330 | 57735 331 | 57736 332 | 57944 333 | 58280 334 | 58612 335 | 58613 336 | 58981 337 | 59826 338 | 59848 339 | 59965 340 | 59976 341 | 60659 342 | 60717 343 | 60779 344 | 61434 345 | 61439 346 | 61450 347 | 62028 348 | 62083 349 | 66510 350 | 66602 351 | 66603 352 | 71535 353 | 72316 354 | 72318 355 | 74136 356 | 74175 357 | 74199 358 | 74201 359 | 74240 360 | 74241 361 | 74242 362 | 74312 363 | 75052 364 | 75169 365 | 76134 366 | 76135 367 | 76142 368 | 76371 369 | 76372 370 | 76379 371 | 76393 372 | 76394 373 | 76396 374 | 77141 375 | 77459 376 | 80207 377 | 80594 378 | 80596 379 | 81048 380 | 81053 381 | 81056 382 | 81566 383 | 81567 384 | 81572 385 | 81577 386 | 81578 387 | 81579 388 | 81580 389 | 82081 390 | 83400 391 | 83410 392 | 83413 393 | 83414 394 | 83416 395 | 84309 396 | 84799 397 | 85156 398 | 85354 399 | 85487 400 | 85779 401 | 85951 402 | 85961 403 | 86562 404 | 86587 405 | 86635 406 | 86738 407 | 86741 408 | 87034 409 | 87036 410 | 89621 411 | 89625 412 | 89627 413 | 90286 414 | 90692 415 | 90693 416 | 90695 417 | 91152 418 | 91257 419 | 91258 420 | 91518 421 | 92759 422 | 93323 423 | 93346 424 | 93566 425 | 93571 426 | 93940 427 | 93941 428 | 93985 429 | 93987 430 | 93996 431 | 94181 432 | 94603 433 | 94605 434 | 95437 435 | 96611 436 | 96612 437 | 96636 438 | 96637 439 | 96639 440 | 96678 441 | 97115 442 | 97259 443 | 97324 444 | 97357 445 | 97362 446 | 97454 447 | 97457 448 | 97475 449 | 97528 450 | 97529 451 | 98010 452 | 98232 453 | 98233 454 | 98234 455 | 99224 456 | 99716 457 | 99725 458 | 99727 459 | 99730 460 | 99732 461 | 99744 462 | 99808 463 | 100075 464 | 100091 465 | 100442 466 | 100456 467 | 100514 468 | 100518 469 | 100625 470 | 100626 471 | 100709 472 | 100733 473 | 101806 474 | 101940 475 | 102014 476 | 102130 477 | 102224 478 | 102627 479 | 102633 480 | 102793 481 | 102795 482 | 102796 483 | 103797 484 | 103798 485 | 103812 486 | 103820 487 | 104600 488 | 104601 489 | 105193 490 | 105210 491 | 105214 492 | 105578 493 | 108409 494 | 108890 495 | 110173 496 | 112229 497 | 112337 498 | 112354 499 | 112496 500 | 112945 501 | 112946 502 | 112954 503 | 112989 504 | 113156 505 | 113160 506 | 113173 507 | 113174 508 | 113175 509 | 113183 510 | 115697 511 | 115698 512 | 116536 513 | 116638 514 | 116798 515 | 116943 516 | 117294 517 | 117522 518 | 117629 519 | 117642 520 | 118440 521 | 118447 522 | 119757 523 | 120430 524 | 120722 525 | 121012 526 | 121588 527 | 121595 528 | 121599 529 | 121610 530 | 121612 531 | 121779 532 | 121863 533 | 121881 534 | 122766 535 | 123125 536 | 123128 537 | 123544 538 | 123567 539 | 123588 540 | 123592 541 | 123615 542 | 123619 543 | 123629 544 | 123641 545 | 123654 546 | 123673 547 | 123685 548 | 123698 549 | 123901 550 | 123907 551 | 123964 552 | 123997 553 | 124017 554 | 124033 555 | 124121 556 | 124204 557 | 124221 558 | 124249 559 | 124709 560 | 124711 561 | 124713 562 | 124721 563 | 124722 564 | 124723 565 | 124730 566 | 124731 567 | 124736 568 | 124934 569 | 125054 570 | 125099 571 | 125275 572 | 125283 573 | 125360 574 | 125388 575 | 125470 576 | 125618 577 | 125629 578 | 125758 579 | 125792 580 | 125904 581 | 125916 582 | 126007 583 | 126024 584 | 126080 585 | 126088 586 | 126092 587 | 126291 588 | 126346 589 | 126350 590 | 126359 591 | 126864 592 | 126872 593 | 127082 594 | 127323 595 | 127355 596 | 127394 597 | 127406 598 | 127542 599 | 127605 600 | 127633 601 | 127777 602 | 127838 603 | 127892 604 | 127893 605 | 127894 606 | 128141 607 | 128142 608 | 128146 609 | 128170 610 | 128182 611 | 128194 612 | 128251 613 | 128259 614 | 128391 615 | 128393 616 | 128396 617 | 128406 618 | 128417 619 | 128421 620 | 128498 621 | 128527 622 | 128528 623 | 128557 624 | 128567 625 | 128618 626 | 128626 627 | 128932 628 | 128947 629 | 129099 630 | 129105 631 | 129113 632 | 129135 633 | 129136 634 | 129144 635 | 129145 636 | 129146 637 | 129148 638 | 129149 639 | 129150 640 | 129155 641 | 129156 642 | 129158 643 | 129169 644 | 129174 645 | 129176 646 | 129181 647 | 129242 648 | 129249 649 | 129316 650 | 129335 651 | 129336 652 | 129339 653 | 129392 654 | 129400 655 | 129405 656 | 129409 657 | 129410 658 | 129411 659 | 129577 660 | 129578 661 | 129580 662 | 129653 663 | 129735 664 | 129859 665 | 129867 666 | 129914 667 | 129939 668 | 129993 669 | 129995 670 | 129996 671 | 129998 672 | 130006 673 | 130008 674 | 130035 675 | 130037 676 | 130120 677 | 130181 678 | 130296 679 | 130335 680 | 130336 681 | 130337 682 | 130338 683 | 130344 684 | 130345 685 | 130354 686 | 130355 687 | 130356 688 | 130357 689 | 130365 690 | 130369 691 | 130373 692 | 130376 693 | 130381 694 | 130382 695 | 130384 696 | 130385 697 | 130386 698 | 130387 699 | 130392 700 | 130403 701 | 130405 702 | 130406 703 | 130415 704 | 130423 705 | 130434 706 | 130437 707 | 130439 708 | 130440 709 | 130449 710 | 130452 711 | 130453 712 | 130462 713 | 130466 714 | 130469 715 | 130475 716 | 130479 717 | 130530 718 | 130536 719 | 130537 720 | 130582 721 | 130583 722 | 130587 723 | 130591 724 | 130602 725 | 130619 726 | 130629 727 | 130634 728 | 130661 729 | 130663 730 | 130664 731 | 130665 732 | 130666 733 | 130668 734 | 130669 735 | 130679 736 | 130683 737 | 130685 738 | 130691 739 | 130740 740 | 130746 741 | 130793 742 | 130860 743 | 130878 744 | 130882 745 | 130918 746 | 131091 747 | 131164 748 | 131199 749 | 131224 750 | 131513 751 | 131541 752 | 131554 753 | 131658 754 | 131693 755 | 131695 756 | 131704 757 | 131881 758 | 131882 759 | 131883 760 | 131884 761 | 131885 762 | 131886 763 | 131887 764 | 131888 765 | 131889 766 | 131890 767 | 131891 768 | 131892 769 | 131893 770 | 131894 771 | 131895 772 | 131896 773 | 131897 774 | 131898 775 | 131899 776 | 131900 777 | 131901 778 | 131902 779 | 131903 780 | 131904 781 | 131905 782 | 131906 783 | 131907 784 | 131908 785 | 131909 786 | 131910 787 | 131911 788 | 131912 789 | 131913 790 | 131914 791 | 131915 792 | 131916 793 | 131917 794 | 131918 795 | 131919 796 | 131920 797 | 131921 798 | 131922 799 | 131923 800 | 131924 801 | 131925 802 | 131926 803 | 131927 804 | 131928 805 | 131929 806 | 131930 807 | 131931 808 | 131932 809 | 131933 810 | 131934 811 | 131935 812 | 131936 813 | 131937 814 | 131938 815 | 131939 816 | 131940 817 | 131941 818 | 131942 819 | 131943 820 | 131944 821 | 131945 822 | 131946 823 | 131947 824 | 131948 825 | 131949 826 | 131950 827 | 131951 828 | 131952 829 | 131953 830 | 131954 831 | 131955 832 | 131956 833 | 131957 834 | 131958 835 | 131959 836 | 131960 837 | 131961 838 | 131962 839 | 131963 840 | 131964 841 | 131965 842 | 131966 843 | 131967 844 | 131968 845 | 131969 846 | 131970 847 | 131971 848 | 131972 849 | 131973 850 | 131974 851 | 131975 852 | 131976 853 | 131977 854 | 131978 855 | 131979 856 | 131980 857 | 131981 858 | 131982 859 | 131983 860 | 131984 861 | 131985 862 | 131986 863 | 131987 864 | 131988 865 | 131989 866 | 131990 867 | 131991 868 | 131992 869 | 131993 870 | 131994 871 | 131995 872 | 131996 873 | 131997 874 | 131998 875 | 131999 876 | 132000 877 | 132001 878 | 132071 879 | 132883 880 | 133142 881 | 133166 882 | 133167 883 | 133262 884 | 133273 885 | 133310 886 | 133336 887 | 133337 888 | 133395 889 | 133396 890 | 133402 891 | 133403 892 | 133812 893 | 133815 894 | 133819 895 | 133821 896 | 133825 897 | 133827 898 | 133828 899 | 133831 900 | 133832 901 | 133833 902 | 133839 903 | 133842 904 | 133843 905 | 133844 906 | 133845 907 | 133846 908 | 133848 909 | 133850 910 | 133851 911 | 133853 912 | 133857 913 | 133863 914 | 133864 915 | 133865 916 | -------------------------------------------------------------------------------- /utility_classes.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import re 3 | import numpy as np 4 | import openbabel as ob 5 | import pybel 6 | from rdkit import Chem 7 | from multiprocessing import Process 8 | from scipy.spatial.distance import squareform 9 | 10 | 11 | class Molecule: 12 | ''' 13 | Molecule class that allows to get statistics such as the connectivity matrix, 14 | molecular fingerprint, canonical smiles representation, or ring count given 15 | positions of atoms and their atomic numbers. Currently supports molecules made of 16 | carbon, nitrogen, oxygen, fluorine, and hydrogen (such as in the QM9 benchmark 17 | dataset). Mainly relies on routines from Open Babel and RdKit. 18 | 19 | Args: 20 | pos (numpy.ndarray): positions of atoms in euclidean space (n_atoms x 3) 21 | atomic_numbers (numpy.ndarray): list with nuclear charge/type of each atom 22 | (e.g. 1 for hydrogens, 6 for carbons etc.). 23 | connectivity_matrix (numpy.ndarray, optional): optionally, a pre-calculated 24 | connectivity matrix (n_atoms x n_atoms) containing the bond order between 25 | atom pairs can be provided (default: None). 26 | store_positions (bool, optional): set True to store the positions of atoms in 27 | self.positions (only for convenience, not needed for computations, default: 28 | False). 29 | ''' 30 | 31 | type_infos = {1: {'name': 'H', 32 | 'n_bonds': 1}, 33 | 6: {'name': 'C', 34 | 'n_bonds': 4}, 35 | 7: {'name': 'N', 36 | 'n_bonds': 3}, 37 | 8: {'name': 'O', 38 | 'n_bonds': 2}, 39 | 9: {'name': 'F', 40 | 'n_bonds': 1}, 41 | } 42 | type_charges = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9} 43 | 44 | def __init__(self, pos, atomic_numbers, connectivity_matrix=None, 45 | store_positions=False): 46 | # was causing problems with other scripts if imported outside of this class 47 | # import openbabel as ob 48 | # import pybel 49 | # from rdkit import Chem 50 | 51 | # set comparison metrics to None (will be computed just in time) 52 | self._fp = None 53 | self._fp_bits = None 54 | self._can = None 55 | self._mirror_can = None 56 | self._inchi_key = None 57 | self._bond_stats = None 58 | self._fixed_connectivity = False 59 | self._row_indices = {} 60 | self._obmol = None 61 | self._rings = None 62 | self._n_atoms_per_type = None 63 | self._connectivity = connectivity_matrix 64 | 65 | # set statistics 66 | self.n_atoms = len(pos) 67 | self.numbers = atomic_numbers 68 | self._unique_numbers = {*self.numbers} # set for fast query 69 | self.positions = pos 70 | if not store_positions: 71 | self._obmol = self.get_obmol() # create obmol before removing pos 72 | self.positions = None 73 | 74 | def sanity_check(self): 75 | ''' 76 | Check whether the sum of valence of all atoms can be divided by 2. 77 | 78 | Returns: 79 | bool: True if the test is passed, False otherwise 80 | ''' 81 | count = 0 82 | for atom in self.numbers: 83 | count += self.type_infos[atom]['n_bonds'] 84 | if count % 2 == 0: 85 | return True 86 | else: 87 | return False 88 | 89 | def get_obmol(self): 90 | ''' 91 | Retrieve the underlying Open Babel OBMol object. 92 | 93 | Returns: 94 | OBMol object: Open Babel OBMol representation 95 | ''' 96 | if self._obmol is None: 97 | if self.positions is None: 98 | print('Error, cannot create obmol without positions!') 99 | return 100 | if self.numbers is None: 101 | print('Error, cannot create obmol without atomic numbers!') 102 | return 103 | # use openbabel to infer bonds and bond order: 104 | obmol = ob.OBMol() 105 | obmol.BeginModify() 106 | 107 | # set positions and atomic numbers of all atoms in the molecule 108 | for p, n in zip(self.positions, self.numbers): 109 | obatom = obmol.NewAtom() 110 | obatom.SetAtomicNum(int(n)) 111 | obatom.SetVector(*p.tolist()) 112 | 113 | # infer bonds and bond order 114 | obmol.ConnectTheDots() 115 | obmol.PerceiveBondOrders() 116 | 117 | obmol.EndModify() 118 | self._obmol = obmol 119 | return self._obmol 120 | 121 | def get_fp(self): 122 | ''' 123 | Retrieve the molecular fingerprint (the path-based FP2 from Open Babel is used, 124 | which means that paths of length up to 7 are considered). 125 | 126 | Returns: 127 | pybel.Fingerprint object: moleculer fingerprint (use "fp1 | fp2" to 128 | calculate the Tanimoto coefficient of two fingerprints) 129 | ''' 130 | if self._fp is None: 131 | # calculate fingerprint 132 | self._fp = pybel.Molecule(self.get_obmol()).calcfp() 133 | return self._fp 134 | 135 | def get_fp_bits(self): 136 | ''' 137 | Retrieve the bits set in the molecular fingerprint. 138 | 139 | Returns: 140 | Set of int: object containing the bits set in the molecular fingerprint 141 | ''' 142 | if self._fp_bits is None: 143 | self._fp_bits = {*self.get_fp().bits} 144 | return self._fp_bits 145 | 146 | def get_can(self): 147 | ''' 148 | Retrieve the canonical SMILES representation of the molecule. 149 | 150 | Returns: 151 | String: canonical SMILES string 152 | ''' 153 | if self._can is None: 154 | # calculate canonical SMILES 155 | self._can = pybel.Molecule(self.get_obmol()).write('can') 156 | return self._can 157 | 158 | def get_mirror_can(self): 159 | ''' 160 | Retrieve the canonical SMILES representation of the mirrored molecule (the 161 | z-coordinates are flipped). 162 | 163 | Returns: 164 | String: canonical SMILES string of the mirrored molecule 165 | ''' 166 | if self._mirror_can is None: 167 | # calculate canonical SMILES of mirrored molecule 168 | self._flip_z() # flip z to mirror molecule using x-y plane 169 | self._mirror_can = pybel.Molecule(self.get_obmol()).write('can') 170 | self._flip_z() # undo mirroring 171 | return self._mirror_can 172 | 173 | def get_inchi_key(self): 174 | ''' 175 | Retrieve the InChI-key of the molecule. 176 | 177 | Returns: 178 | String: InChI-key 179 | ''' 180 | if self._inchi_key is None: 181 | # calculate inchi key 182 | self._inchi_key = pybel.Molecule(self.get_obmol()).\ 183 | write('inchikey') 184 | return self._inchi_key 185 | 186 | def _flip_z(self): 187 | ''' 188 | Flips the z-coordinates of atom positions (to get a mirrored version of the 189 | molecule). 190 | ''' 191 | if self._obmol is None: 192 | self.get_obmol() 193 | for atom in ob.OBMolAtomIter(self._obmol): 194 | x, y, z = atom.x(), atom.y(), atom.z() 195 | atom.SetVector(x, y, -z) 196 | self._obmol.ConnectTheDots() 197 | self._obmol.PerceiveBondOrders() 198 | 199 | def get_connectivity(self): 200 | ''' 201 | Retrieve the connectivity matrix of the molecule. 202 | 203 | Returns: 204 | numpy.ndarray: (n_atoms x n_atoms) array containing the pairwise bond orders 205 | between atoms (0 for no bond). 206 | ''' 207 | if self._connectivity is None: 208 | # get connectivity matrix 209 | connectivity = np.zeros((self.n_atoms, len(self.numbers))) 210 | for atom in ob.OBMolAtomIter(self.get_obmol()): 211 | index = atom.GetIdx() - 1 212 | # loop over all neighbors of atom 213 | for neighbor in ob.OBAtomAtomIter(atom): 214 | idx = neighbor.GetIdx() - 1 215 | bond_order = neighbor.GetBond(atom).GetBO() 216 | #print(f'{index}-{idx}: {bond_order}') 217 | # do not count bonds between two hydrogen atoms 218 | if (self.numbers[index] == 1 and self.numbers[idx] == 1 219 | and bond_order > 0): 220 | bond_order = 0 221 | connectivity[index, idx] = bond_order 222 | self._connectivity = connectivity 223 | return self._connectivity 224 | 225 | def get_ring_counts(self): 226 | ''' 227 | Retrieve a list containing the sizes of rings in the symmetric smallest set 228 | of smallest rings (S-SSSR from RdKit) in the molecule (e.g. [5, 6, 5] for two 229 | rings of size 5 and one ring of size 6). 230 | 231 | Returns: 232 | List of int: list with ring sizes 233 | ''' 234 | if self._rings is None: 235 | # calculate symmetric SSSR with RdKit using the canonical smiles 236 | # representation as input 237 | can = self.get_can() 238 | mol = Chem.MolFromSmiles(can) 239 | if mol is not None: 240 | ssr = Chem.GetSymmSSSR(mol) 241 | self._rings = [len(ssr[i]) for i in range(len(ssr))] 242 | else: 243 | self._rings = [] # cannot count rings 244 | return self._rings 245 | 246 | def get_n_atoms_per_type(self): 247 | ''' 248 | Retrieve the number of atoms in the molecule per type. 249 | 250 | Returns: 251 | numpy.ndarray: number of atoms in the molecule per type, where the order 252 | corresponds to the order specified in Molecule.type_infos 253 | ''' 254 | if self._n_atoms_per_type is None: 255 | _types = np.array(list(self.type_infos.keys()), dtype=int) 256 | self._n_atoms_per_type =\ 257 | np.bincount(self.numbers, minlength=np.max(_types)+1)[_types] 258 | return self._n_atoms_per_type 259 | 260 | def remove_unpicklable_attributes(self, restorable=True): 261 | ''' 262 | Some attributes of the class cannot be processed by pickle. This method 263 | allows to remove these attributes prior to pickling. 264 | 265 | Args: 266 | restorable (bool, optional): Set True to allow restoring the deleted 267 | attributes later on (default: True) 268 | ''' 269 | # set attributes which are not picklable (SwigPyObjects) to None 270 | if restorable and self.positions is None and self._obmol is not None: 271 | # store positions to allow restoring obmol object later on 272 | pos = [atom.coords for atom in pybel.Molecule(self._obmol).atoms] 273 | self.positions = np.array(pos) 274 | self._obmol = None 275 | self._fp = None 276 | 277 | def tanimoto_similarity(self, other_mol, use_bits=True): 278 | ''' 279 | Get the Tanimoto (fingerprint) similarity to another molecule. 280 | 281 | Args: 282 | other_mol (Molecule or pybel.Fingerprint/list of bits set): 283 | representation of the second molecule (if it is not a Molecule object, 284 | it needs to be a pybel.Fingerprint if use_bits is False and a list of bits 285 | set in the fingerprint if use_bits is True). 286 | use_bits (bool, optional): set True to calculate Tanimoto similarity 287 | from bits set in the fingerprint (default: True) 288 | 289 | Returns: 290 | float: Tanimoto similarity to the other molecule 291 | ''' 292 | if use_bits: 293 | a = self.get_fp_bits() 294 | b = other_mol.get_fp_bits() if isinstance(other_mol, Molecule) \ 295 | else other_mol 296 | n_equal = len(a.intersection(b)) 297 | if len(a) + len(b) == 0: # edge case with no set bits 298 | return 1. 299 | return n_equal / (len(a)+len(b)-n_equal) 300 | else: 301 | fp_other = other_mol.get_fp() if isinstance(other_mol, Molecule)\ 302 | else other_mol 303 | return self.get_fp() | fp_other 304 | 305 | def _update_bond_orders(self, idc_lists): 306 | ''' 307 | Updates the bond orders in the underlying OBMol object. 308 | 309 | Args: 310 | idc_lists (list of list of int): nested list containing bonds, i.e. pairs 311 | of row indices (list1) and column indices (list2) which shall be updated 312 | ''' 313 | con_mat = self.get_connectivity() 314 | self._obmol.BeginModify() 315 | for i in range(len(idc_lists[0])): 316 | idx1 = idc_lists[0][i] 317 | idx2 = idc_lists[1][i] 318 | obbond = self._obmol.GetBond(int(idx1+1), int(idx2+1)) 319 | obbond.SetBO(int(con_mat[idx1, idx2])) 320 | self._obmol.EndModify() 321 | 322 | # reset fingerprints etc 323 | self._fp = None 324 | self._can = None 325 | self._mirror_can = None 326 | self._inchi_key = None 327 | 328 | def get_fixed_connectivity(self, recursive_call=False): 329 | ''' 330 | Attempts to fix the connectivity matrix using some heuristics (as some valid 331 | QM9 molecules do not pass the valency check using the connectivity matrix 332 | obtained with Open Babel, which seems to have problems with assigning correct 333 | bond orders to aromatic rings containing Nitrogen). 334 | 335 | Args: 336 | recursive_call (bool, do not set True): flag that indicates a recursive 337 | call (used internally, do not set to True) 338 | 339 | Returns: 340 | numpy.ndarray: (n_atoms x n_atoms) array containing the pairwise bond orders 341 | between atoms (0 for no bond) after the attempted fix. 342 | ''' 343 | 344 | # if fix has already been attempted, return the connectivity matrix 345 | if self._fixed_connectivity: 346 | return self._connectivity 347 | 348 | # define helpers: 349 | # increases bond order between two atoms in connectivity matrix 350 | def increase_bond(con_mat, idx1, idx2): 351 | con_mat[idx1, idx2] += 1 352 | con_mat[idx2, idx1] += 1 353 | return con_mat 354 | 355 | # decreases bond order between two atoms in connectivity matrix 356 | def decrease_bond(con_mat, idx1, idx2): 357 | con_mat[idx1, idx2] -= 1 358 | con_mat[idx2, idx1] -= 1 359 | return con_mat 360 | 361 | # returns only the rows of the connectivity matrix corresponding to atoms of 362 | # certain types (and the indices of these atoms) 363 | def get_typewise_connectivity(con_mat, types): 364 | idcs = [] 365 | for type in types: 366 | idcs += list(self._get_row_idcs(type)) 367 | return con_mat[idcs], np.array(idcs).astype(int) 368 | 369 | # store old connectivity matrix for later comparison 370 | old_mat = self.get_connectivity().copy() 371 | 372 | # get connectivity matrix and find indices of N and C atoms 373 | con_mat = self.get_connectivity() 374 | if 6 not in self._unique_numbers and 7 not in self._unique_numbers: 375 | # do not attempt fixing if there is no carbon and no nitrogen 376 | return con_mat 377 | N_mat, N_idcs = get_typewise_connectivity(con_mat, [7]) 378 | C_mat, C_idcs = get_typewise_connectivity(con_mat, [6]) 379 | NC_idcs = np.hstack((N_idcs, C_idcs)) # indices of all N and C atoms 380 | NC_valences = self._get_valences()[NC_idcs] # array with valency constraints 381 | 382 | # return connectivity if valency constraints of N and C atoms are already met 383 | if np.all(np.sum(con_mat[NC_idcs], axis=1) == NC_valences): 384 | return con_mat 385 | 386 | # if a C or N atom is "overcharged" (total bond order too high) we decrease 387 | # double to single bonds between N-N or N-C until it is not overcharged anymore 388 | # (e.g. C=N=C -> C=N-C) 389 | if 7 in self._unique_numbers: # only necessary if molecule contains N 390 | for cur in NC_idcs: 391 | type = self.numbers[cur] 392 | if np.sum(con_mat[cur]) <= self.type_infos[type]['n_bonds']: 393 | continue 394 | if type == 6: # for carbon look only at nitrogen neighbors 395 | neighbors = self._get_neighbors(cur, types=[7], strength=2) 396 | else: 397 | neighbors = self._get_neighbors(cur, types=[6, 7], 398 | strength=2) 399 | for neighbor in neighbors: 400 | con_mat = decrease_bond(con_mat, cur, neighbor) 401 | self._connectivity = con_mat 402 | if np.sum(con_mat[cur]) == \ 403 | self.type_infos[type]['n_bonds']: 404 | break 405 | 406 | # get updated partial connectivity matrices for N and C 407 | N_mat, _ = get_typewise_connectivity(con_mat, [7]) 408 | C_mat, _ = get_typewise_connectivity(con_mat, [6]) 409 | 410 | # increase total number of bonds by transferring the strength of a 411 | # double C-N bond to two neighboring bonds, if the involved atoms 412 | # are not yet saturated (e.g. H2C-H2C=N-H2C -> H2C=H2C-N=H2C) 413 | if (np.sum(N_mat) < len(N_idcs) * 3 or np.sum(C_mat) < len(C_idcs) * 4) \ 414 | and 7 in self._unique_numbers: 415 | for cur in NC_idcs: 416 | type = self.numbers[cur] 417 | if sum(con_mat[cur]) >= self.type_infos[type]['n_bonds']: 418 | continue 419 | CN_nbors = self._get_CN_neighbors(cur, order_1=1, order_2=2) 420 | for nbor_1, nbor_2 in CN_nbors: 421 | nbor_2_nbors = np.where(con_mat[nbor_2] == 1)[0] 422 | for nbor_2_nbor in nbor_2_nbors: 423 | nbor_2_nbor_type = self.numbers[nbor_2_nbor] 424 | if (np.sum(con_mat[nbor_2_nbor]) < 425 | self.type_infos[nbor_2_nbor_type]['n_bonds']): 426 | con_mat = increase_bond(con_mat, cur, nbor_1) 427 | con_mat = increase_bond(con_mat, nbor_2, nbor_2_nbor) 428 | con_mat = decrease_bond(con_mat, nbor_1, nbor_2) 429 | self._connectivity = con_mat 430 | 431 | # increase total number of bonds by transferring the strength of a 432 | # triple C-C bond to two neighboring bonds, if the involved atoms 433 | # are not yet saturated 434 | if (np.sum(N_mat) < len(N_idcs) * 3 or np.sum(C_mat) < len(C_idcs) * 4): 435 | for cur in NC_idcs: 436 | type = self.numbers[cur] 437 | if sum(con_mat[cur]) >= self.type_infos[type]['n_bonds']: 438 | continue 439 | CC_nbors = self._get_CN_neighbors(cur, order_1=1, order_2=3, 440 | only_CC=True, include_CCC=True) 441 | for nbor_1, nbor_2 in CC_nbors: 442 | nbor_2_nbors = np.where(con_mat[nbor_2] == 1)[0] 443 | for nbor_2_nbor in nbor_2_nbors: 444 | nbor_2_nbor_type = self.numbers[nbor_2_nbor] 445 | if (np.sum(con_mat[nbor_2_nbor]) < 446 | self.type_infos[nbor_2_nbor_type]['n_bonds']): 447 | con_mat = increase_bond(con_mat, cur, nbor_1) 448 | con_mat = increase_bond(con_mat, nbor_2, nbor_2_nbor) 449 | con_mat = decrease_bond(con_mat, nbor_1, nbor_2) 450 | self._connectivity = con_mat 451 | 452 | # increase bond strength between two undercharged neighbors C-N, 453 | # C-C or N-N (e.g HN-CH2 -> HN=CH2, starting from those atoms with least 454 | # available neighbors if there are multiple undercharged neighbors) 455 | undercharged_pairs = True 456 | while (undercharged_pairs): 457 | NC_charges = np.sum(con_mat[NC_idcs], axis=1) 458 | undercharged = NC_idcs[np.where(NC_charges < NC_valences)[0]] 459 | partial_con_mat = con_mat[undercharged][:, undercharged] 460 | # if non of the undercharged atoms are neighbors, stop 461 | if np.sum(partial_con_mat) == 0: 462 | break 463 | # sort by number of undercharged neighbors 464 | n_nbors = np.sum(partial_con_mat > 0, axis=0) 465 | # mask indices with zero undercharged neighbors to ignore them when sorting 466 | n_nbors[np.where(n_nbors == 0)[0]] = 1000 467 | cur = np.argmin(n_nbors) 468 | cur_nbor = np.where(partial_con_mat[cur] > 0)[0][0] 469 | con_mat = increase_bond(con_mat, undercharged[cur], undercharged[cur_nbor]) 470 | self._connectivity = con_mat 471 | 472 | # if the molecule still is not valid, try to flip double bonds if an atom 473 | # forms a double bond and has at least one other neighbor that has too few bonds 474 | # (e.g. C-N=C -> C=N-C) and repeat above heuristics with a recursive call of 475 | # this function 476 | if not recursive_call and \ 477 | not np.all(np.sum(con_mat[NC_idcs], axis=1) == NC_valences): 478 | changed = False 479 | candidates = np.where(np.any(con_mat[NC_idcs][:, NC_idcs] == 2, axis=0))[0] 480 | for cand in NC_idcs[candidates]: 481 | if np.sum(con_mat[cand, NC_idcs] == 2) == 0: 482 | continue 483 | NC_charges = np.sum(con_mat[NC_idcs], axis=1) 484 | undercharged = NC_charges < NC_valences 485 | uc_neighbors = np.logical_and(con_mat[cand, NC_idcs] == 1, undercharged) 486 | if np.any(uc_neighbors): 487 | uc_neighbor = NC_idcs[np.where(uc_neighbors)[0][0]] 488 | oc_neighbor = NC_idcs[ 489 | np.where(con_mat[cand, NC_idcs] == 2)[0][0]] 490 | con_mat = increase_bond(con_mat, cand, uc_neighbor) 491 | con_mat = decrease_bond(con_mat, cand, oc_neighbor) 492 | self._connectivity = con_mat 493 | changed = True 494 | if changed: 495 | self._connectivity = self.get_fixed_connectivity( 496 | recursive_call=True) 497 | 498 | # store that fixing the connectivity matrix has already been attempted 499 | if not recursive_call: 500 | self._fixed_connectivity = True 501 | if np.any(old_mat != self._connectivity): 502 | # update bond orders in underlying OBMol object (where they changed) 503 | self._update_bond_orders(np.where(old_mat != self._connectivity)) 504 | 505 | return self._connectivity 506 | 507 | def _get_valences(self): 508 | ''' 509 | Retrieve the valency constraints of all atoms in the molecule. 510 | 511 | Returns: 512 | numpy.ndarray: valency constraints (one per atom) 513 | ''' 514 | valence = [] 515 | for atom in self.numbers: 516 | valence += [self.type_infos[atom]['n_bonds']] 517 | return np.array(valence) 518 | 519 | def _get_CN_neighbors(self, idx, order_1=1, order_2=1, 520 | only_CC=False, include_CCC=False): 521 | ''' 522 | For a focus atom of type K returns indices of atoms C (carbon) and N (nitrogen) 523 | on two-step paths of the form K-C-N (and K-C-C only if K is nitrogen since one 524 | atom needs to be nitrogen, except if include_CCC is set True). The required 525 | bond order on the path can be constraint using the parameters order_1 and 526 | order_2. 527 | 528 | Args: 529 | idx (int): the index of the focus atom from which paths are examined 530 | order_1 (int, optional): the minimum bond order that neighbors must share 531 | with the focus atom in order to be added to the results (allows to 532 | constrain the paths to e.g. K=C-N or K=C-C with order_1=2, default: 1) 533 | order_2 (int, optional): the minimum bond order that second degree 534 | neighbors must share with the first degree neighbors in order to be 535 | added to the results (allows to constrain the paths to e.g. K-C=N or 536 | K-C=C with order_2=2, default: 1) 537 | only_CC (bool, optional): include only K-C-C paths (i.e. ignore K-C-N, 538 | default: False) 539 | include_CCC (bool, optional): also include K-C-C paths if K is carbon 540 | (default: False) 541 | 542 | Returns: 543 | list of lists: list1[i] contains an index of a direct neighbor of the 544 | focus atom and list2[i] contains the index of a second neighbor on the 545 | i-th identified two-step path 546 | ''' 547 | con_mat = self.get_connectivity() 548 | nbors = con_mat[idx] >= order_1 549 | C_nbors = np.where(np.logical_and(self.numbers == 6, nbors))[0] 550 | type = self.numbers[idx] 551 | # mask types to exclude idx from neighborhood 552 | _numbers = self.numbers.copy() 553 | _numbers[idx] = 0 554 | CN_nbors = [] 555 | CC_nbors = [] 556 | if not only_CC: 557 | CN_nbors = np.where(np.logical_and(_numbers == 7, 558 | con_mat[C_nbors] >= order_2)) 559 | CN_nbors = [(C_nbors[CN_nbors[0][i]], CN_nbors[1][i]) 560 | for i in range(len(CN_nbors[0]))] 561 | if type == 7 or include_CCC: # for N atoms, also add C-C neighbors 562 | CC_nbors = np.where(np.logical_and(_numbers == 6, 563 | con_mat[C_nbors] >= order_2)) 564 | CC_nbors = [(C_nbors[CC_nbors[0][i]], CC_nbors[1][i]) 565 | for i in range(len(CC_nbors[0]))] 566 | return CN_nbors + CC_nbors 567 | 568 | def _get_neighbors(self, idx, types=None, strength=1): 569 | ''' 570 | Retrieve the indices of neighbors of an atom. 571 | 572 | Args: 573 | idx (int): index of the atom 574 | types (list of int, optional): restrict the returned neighbors to 575 | contain only atoms of the specified types (set None to apply no type 576 | filter, default: None) 577 | strength (int, optional): restrict the returned neighbors to contain 578 | only atoms with a certain minimal bond order to the atom at idx 579 | (default: 1) 580 | 581 | Returns: 582 | list of int: indices of all neighbors that meet the requirements 583 | ''' 584 | con_mat = self.get_connectivity() 585 | neighbors = con_mat[idx] >= strength 586 | if types is not None: 587 | type_arr = np.zeros(len(neighbors)).astype(bool) 588 | for type in types: 589 | type_arr = np.logical_or(type_arr, self.numbers == type) 590 | return np.where(np.logical_and(neighbors, type_arr))[0] 591 | 592 | def get_bond_stats(self): 593 | ''' 594 | Retrieve the bond and ring count of the molecule. The bond count is 595 | calculated for every pair of types (e.g. C1N are all single bonds between 596 | carbon and nitrogen atoms in the molecule, C2N are all double bonds between 597 | such atoms etc.). The ring count is provided for rings from size 3 to 8 (R3, 598 | R4, ..., R8) and for rings greater than size eight (R>8). 599 | 600 | Returns: 601 | dict (str->int): bond and ring counts 602 | ''' 603 | if self._bond_stats is None: 604 | 605 | # 1st analyze bonds 606 | unique_types = np.sort(list(self._unique_numbers)) 607 | # get connectivity and read bonds from matrix 608 | con_mat = self.get_connectivity() 609 | d = {} 610 | for i, type1 in enumerate(unique_types): 611 | row_idcs = self._get_row_idcs(type1) 612 | n_bonds1 = self.type_infos[type1]['n_bonds'] 613 | for type2 in unique_types[i:]: 614 | col_idcs = self._get_row_idcs(type2) 615 | n_bonds2 = self.type_infos[type2]['n_bonds'] 616 | max_bond_strength = min(n_bonds1, n_bonds2) 617 | if n_bonds1 == n_bonds2: # exclude small trivial molecules 618 | max_bond_strength -= 1 619 | for n in range(1, max_bond_strength + 1): 620 | id = self.type_infos[type1]['name'] + str(n) + \ 621 | self.type_infos[type2]['name'] 622 | d[id] = np.sum(con_mat[row_idcs][:, col_idcs] == n) 623 | if type1 == type2: 624 | d[id] = int(d[id]/2) # remove twice counted bonds 625 | 626 | # 2nd analyze rings 627 | ring_counts = self.get_ring_counts() 628 | if len(ring_counts) > 0: 629 | ring_counts = np.bincount(np.array(ring_counts)) 630 | n_bigger_8 = 0 631 | for i in np.nonzero(ring_counts)[0]: 632 | if i < 9: 633 | d[f'R{i}'] = ring_counts[i] 634 | else: 635 | n_bigger_8 += ring_counts[i] 636 | if n_bigger_8 > 0: 637 | d[f'R>8'] = n_bigger_8 638 | self._bond_stats = d 639 | 640 | return self._bond_stats 641 | 642 | def _get_row_idcs(self, type): 643 | ''' 644 | Retrieve the indices of all atoms in the molecule corresponding to a selected 645 | type. 646 | 647 | Args: 648 | type (int): the atom type (atomic number, e.g. 6 for carbon) 649 | 650 | Returns: 651 | list of int: indices of all atoms with the selected type 652 | ''' 653 | if type not in self._row_indices: 654 | self._row_indices[type] = np.where(self.numbers == type)[0] 655 | return self._row_indices[type] 656 | 657 | 658 | class ConnectivityCompressor(): 659 | ''' 660 | Utility class that provides methods to compress and decompress connectivity 661 | matrices. 662 | ''' 663 | 664 | def __init__(self): 665 | pass 666 | 667 | def compress(self, connectivity_matrix): 668 | ''' 669 | Compresses a single connectivity matrix. 670 | 671 | Args: 672 | connectivity_matrix (numpy.ndarray): array (n_atoms x n_atoms) 673 | containing the bond orders of bonds between atoms of a molecule 674 | 675 | Returns: 676 | dict (str/int->int): the length of the non-redundant connectivity 677 | matrix (list with upper triangular part) and the indices of that list for 678 | bond orders > 0 679 | ''' 680 | smaller = squareform(connectivity_matrix) # get list of upper triangular part 681 | d = {'n_entries': len(smaller)} # store length of list 682 | for i in np.unique(smaller): # store indices per bond order > 0 683 | if i > 0: 684 | d[float(i)] = np.where(smaller == i)[0] 685 | return d 686 | 687 | def decompress(self, idcs_dict): 688 | ''' 689 | Retrieve the full (n_atoms x n_atoms) connectivity matrix from compressed 690 | format. 691 | 692 | Args: 693 | idcs_dict (dict str/int->int): compressed connectivity matrix 694 | (obtained with the compress method) 695 | 696 | Returns: 697 | numpy.ndarray: full connectivity matrix as an array of shape (n_atoms x 698 | n_atoms) 699 | ''' 700 | n_entries = idcs_dict['n_entries'] 701 | con_mat = np.zeros(n_entries) 702 | for i in idcs_dict: 703 | if i != 'n_entries': 704 | con_mat[idcs_dict[i]] = float(i) 705 | return squareform(con_mat) 706 | 707 | def compress_batch(self, connectivity_batch): 708 | ''' 709 | Compress a batch of connectivity matrices. 710 | 711 | Args: 712 | connectivity_batch (list of numpy.ndarray): list of connectivity matrices 713 | 714 | Returns: 715 | list of dict: batch of compressed connectivity matrices (see compress) 716 | ''' 717 | dict_list = [] 718 | for matrix in connectivity_batch: 719 | dict_list += [self.compress(matrix)] 720 | return dict_list 721 | 722 | def decompress_batch(self, idcs_dict_batch): 723 | ''' 724 | Retrieve a list of full connectivity matrices from a batch of compressed 725 | connectivity matrices. 726 | 727 | Args: 728 | idcs_dict_batch (list of dict): list with compressed connectivity 729 | matrices 730 | 731 | Return: 732 | list numpy.ndarray: batch of full connectivity matrices (see decompress) 733 | ''' 734 | matrix_list = [] 735 | for idcs_dict in idcs_dict_batch: 736 | matrix_list += [self.decompress(idcs_dict)] 737 | return matrix_list 738 | 739 | 740 | class IndexProvider(): 741 | ''' 742 | Class which allows to filter a large set of molecules for desired structures 743 | according to provided statistics. The filtering is done using a selection string 744 | of the general format 'Statistics_nameDelimiterOperatorTarget_value' 745 | (e.g. 'C,>8' to filter for all molecules with more than eight carbon atoms where 746 | 'C' is the statistic counting the number of carbon atoms in a molecule, ',' is the 747 | delimiter, '>' is the operator, and '8' is the target value). 748 | 749 | Args: 750 | statistics (numpy.ndarray): 751 | statistics of all molecules where columns correspond to molecules and rows 752 | correspond to available statistics (n_statistics x n_molecules) 753 | row_headlines (numpy.ndarray): 754 | the names of the statistics stored in each row (e.g. 'F' for the number of 755 | fluorine atoms or 'R5' for the number of rings of size 5) 756 | default_filter (str, optional): 757 | the default behaviour of the filter if no operator and target value are 758 | given (e.g. filtering for 'F' will give all molecules with at least 1 759 | fluorine atom if default_filter='>0' or all molecules with exactly 2 760 | fluorine atoms if default_filter='==2', default: '>0') 761 | delimiter (str, optional): 762 | the delimiter used to separate names of statistics from the operator and 763 | target value in the selection strings (default: ',') 764 | ''' 765 | 766 | # dictionary mapping strings of available operators to corresponding function: 767 | op_dict = {'<': operator.lt, 768 | '<=': operator.le, 769 | '==': operator.eq, 770 | '=': operator.eq, 771 | '!=': operator.ne, 772 | '>': operator.gt, 773 | '>=': operator.ge} 774 | 775 | rel_re = re.compile('<=|<|={1,2}|!=|>=|>') # regular expression for operators 776 | num_re = re.compile('[\-]*[0-9]+[.]*[0-9]*') # regular expression for target values 777 | 778 | def __init__(self, statistics, row_headlines, default_filter='>0', delimiter=','): 779 | self.statistics = np.array(statistics) 780 | self.headlines = list(row_headlines) 781 | self.default_relation = self.rel_re.search(default_filter).group(0) 782 | self.default_number = float(self.num_re.search(default_filter).group(0)) 783 | self.delimiter = delimiter 784 | 785 | def get_selected(self, selection_str, idcs=None): 786 | ''' 787 | Retrieve the indices of all molecules which fulfill the selection criteria. 788 | The selection string is of the general format 789 | 'Statistics_nameDelimiterOperatorTarget_value' (e.g. 'C,>8' to filter for all 790 | molecules with more than eight carbon atoms where 'C' is the statistic counting 791 | the number of carbon atoms in a molecule, ',' is the delimiter, '>' is the 792 | operator, and '8' is the target value). 793 | 794 | The following operators are available: 795 | '<' 796 | '<=' 797 | '==' 798 | '!=' 799 | '>=' 800 | '>' 801 | 802 | The target value can be any positive or negative integer or float value. 803 | 804 | Multiple statistics can be summed using '+' (e.g. 'F+N,=0' gives all 805 | molecules that have no fluorine and no nitrogen atoms). 806 | 807 | Multiple filters can be concatenated using '&' (e.g. 'H,>8&C,=5' gives all 808 | molecules that have more than 8 hydrogen atoms and exactly 5 carbon atoms). 809 | 810 | Args: 811 | selection_str (str): string describing the criterion(s) for filtering (build 812 | as described above) 813 | idcs (numpy.ndarray, optional): if provided, only this subset of all 814 | molecules is filtered for structures fulfilling the selection criteria 815 | 816 | Returns: 817 | list of int: indices of all the molecules in the dataset that fulfill the 818 | selection criterion(s) 819 | ''' 820 | 821 | delimiter = self.delimiter 822 | if idcs is None: 823 | idcs = np.arange(len(self.statistics[0])) # take all to begin with 824 | criterions = selection_str.split('&') # split criteria 825 | for criterion in criterions: 826 | rel_strs = criterion.split(delimiter) 827 | 828 | # add multiple statistics if specified 829 | heads = rel_strs[0].split('+') 830 | statistics = self.statistics[self.headlines.index(heads[0])][idcs] 831 | for head in heads[1:]: 832 | statistics += self.statistics[self.headlines.index(head)][idcs] 833 | 834 | if len(rel_strs) == 1: 835 | relation = self.op_dict[self.default_relation]( 836 | statistics, self.default_number) 837 | elif len(rel_strs) == 2: 838 | rel = self.rel_re.search(rel_strs[1]).group(0) 839 | num = float(self.num_re.search(rel_strs[1]).group(0)) 840 | relation = self.op_dict[rel](statistics, num) 841 | new_idcs = np.where(relation)[0] 842 | idcs = idcs[new_idcs] 843 | 844 | return idcs 845 | 846 | 847 | class ProcessQ(Process): 848 | ''' 849 | Multiprocessing.Process class that runs a provided function using provided 850 | (keyword) arguments and puts the result into a provided Multiprocessing.Queue 851 | object (such that the result of the function can easily be obtained by the host 852 | process). 853 | 854 | Args: 855 | queue (Multiprocessing.Queue): the queue into which the results of running 856 | the target function will be put (the object in the queue will be a tuple 857 | containing the provided name as first entry and the function return as 858 | second entry). 859 | name (str): name of the object (is returned as first value in the tuple put 860 | into the queue. 861 | target (callable object): the function that is executed in the process's run 862 | method 863 | args (list of any): sequential arguments target is called with 864 | kwargs (dict (str->any)): keyword arguments target is called with 865 | ''' 866 | 867 | def __init__(self, queue, name=None, target=None, args=(), kwargs={}): 868 | super(ProcessQ, self).__init__(None, target, name, args, kwargs) 869 | self._name = name 870 | self._q = queue 871 | self._target = target 872 | self._args = args 873 | self._kwargs = kwargs 874 | 875 | def run(self): 876 | ''' 877 | Method representing the process's activity. 878 | 879 | Invokes the callable object passed as the target argument, if any, with 880 | sequential and keyword arguments taken from the args and kwargs arguments, 881 | respectively. Puts the string passed as name argument and the returned result 882 | of the callable object into the queue as (name, result). 883 | ''' 884 | if self._target is not None: 885 | res = (self.name, self._target(*self._args, **self._kwargs)) 886 | self._q.put(res) 887 | 888 | --------------------------------------------------------------------------------