├── .gitignore
├── LICENSE
├── README.md
├── conditioning_specifications
├── 1_polarizability.json
├── 2_fingerprint.json
├── 3_gap_comp.json
├── 4_comp_relenergy.json
└── 5_gap_relenergy.json
├── display_molecules.py
├── filter_generated.py
├── gschnet_cond_script.py
├── images
└── concept_results_scheme.png
├── nn_classes.py
├── preprocess_dataset.py
├── published_data
├── README.md
└── data_base_help.py
├── qm9_data.py
├── splits
├── 1_polarizability_split.npz
├── 2_fingerprint_split.npz
├── 3_gap_comp_split.npz
├── 4_comp_relenergy_split.npz
├── 5_gap_relenergy_split.npz
└── qm9_invalid.txt
├── utility_classes.py
└── utility_functions.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea
6 | *.DS_Store
7 |
8 | # test data
9 | src/sacred_scripts/data/*
10 | src/sacred_scripts/models/*
11 | src/sacred_scripts/experiments/*
12 | src/scripts/data/
13 | src/scripts/training/
14 |
15 | docs/tutorials/*.db
16 | docs/tutorials/*.xyz
17 | docs/tutorials/qm9tut
18 |
19 | # C extensions
20 | *.so
21 |
22 |
23 | # Distribution / packaging
24 | .Python
25 | env/
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | .hypothesis/
62 | .pytest_cache
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # dotenv
98 | .env
99 |
100 | # virtualenv
101 | .venv
102 | venv/
103 | ENV/
104 |
105 | # Spyder project settings
106 | .spyderproject
107 | .spyproject
108 |
109 | # Rope project settings
110 | .ropeproject
111 |
112 | # mkdocs documentation
113 | /site
114 |
115 | # mypy
116 | .mypy_cache/
117 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Niklas Gebauer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ⚠️ **_Disclaimer: This repository is deprecated and only meant for reproduction of the published results of cG-SchNet on QM9. If you want to use custom data sets or build on top of our model, please refer to the [up-to-date implementation](https://github.com/atomistic-machine-learning/schnetpack-gschnet)._**
2 |
3 | # cG-SchNet - A Conditional Generative Neural Network for 3d Molecules
4 |
5 |
6 |
7 | Implementation of cG-SchNet - a conditional generative neural network for 3d molecular structures - accompanying the paper [_"Inverse design of 3d molecular structures with conditional generative neural networks"_](https://www.nature.com/articles/s41467-022-28526-y).
8 | If you are using cG-SchNet in your research, please cite the corresponding publication:
9 |
10 | N.W.A. Gebauer, M. Gastegger, S.S.P. Hessmann, K.-R. Müller, and K.T. Schütt.
11 | Inverse design of 3d molecular structures with conditional generative neural networks.
12 | Nature Communications 13, 973 (2022). https://doi.org/10.1038/s41467-022-28526-y
13 |
14 |
15 | @Article{gebauer2022inverse,
16 | author={Gebauer, Niklas W. A. and Gastegger, Michael and Hessmann, Stefaan S. P. and M{\"u}ller, Klaus-Robert and Sch{\"u}tt, Kristof T.},
17 | title={Inverse design of 3d molecular structures with conditional generative neural networks},
18 | journal={Nature Communications},
19 | year={2022},
20 | volume={13},
21 | number={1},
22 | pages={973},
23 | issn={2041-1723},
24 | doi={10.1038/s41467-022-28526-y},
25 | url={https://doi.org/10.1038/s41467-022-28526-y}
26 | }
27 |
28 |
29 | The code provided in this repository allows to train cG-SchNet on the tasks presented in the paper using the QM9 data set which consists of approximately 130k small molecules with up to nine heavy atoms from fluorine, oxygen, nitrogen, and carbon.
30 |
31 | We also provide links to the molecules generated with cG-SchNet for the paper as well as two pretrained models used therein. For details, please refer to the folder [_published_data_](https://github.com/atomistic-machine-learning/cG-SchNet/tree/main/published_data).
32 |
33 | ### Requirements
34 | - schnetpack 0.3
35 | - pytorch >= 1.2
36 | - python >= 3.7
37 | - ASE >= 3.17.0
38 | - Open Babel 2.41
39 | - rdkit >= 2019.03.4.0
40 |
41 | In the following, we describe the setup using Anaconda.
42 |
43 | The following commands will create a new conda environment called _"cgschnet"_ and install all dependencies (tested on Ubuntu 18.04 and 20.04):
44 |
45 | conda create -n cgschnet python=3.7 pytorch=1.5.1 torchvision cudatoolkit=10.2 ase=3.19.0 openbabel=2.4.1 rdkit=2019.09.2.0 -c pytorch -c openbabel -c defaults -c conda-forge
46 | conda activate cgschnet
47 | pip install 'schnetpack==0.3'
48 |
49 | Replace _"cudatoolkit=10.2"_ with _"cpuonly"_ if you do not want to utilize a GPU for training/generation. However, we strongly recommend to use a GPU if available.
50 |
51 | To observe the training progress, install tensorboard:
52 |
53 | pip install tensorboard
54 |
55 | # Getting started
56 | Clone the repository into your folder of choice:
57 |
58 | git clone https://github.com/atomistic-machine-learning/cG-SchNet.git
59 |
60 |
61 | ### Training a model
62 | A model conditioned on the composition of molecules and their relative atomic energy with the same settings as described in the paper can be trained by running gschnet_cond_script.py with the following parameters:
63 |
64 | python ./cG-SchNet/gschnet_cond_script.py train gschnet ./data/ ./models/cgschnet/ --conditioning_json_path ./cG-SchNet/conditioning_specifications/4_comp_relenergy.json --split_path ./cG-SchNet/splits/4_comp_relenergy_split.npz --cuda
65 |
66 | The training data (QM9) is automatically downloaded and preprocessed if not present in ./data/ and the model will be stored in ./models/cgschnet/.
67 | We recommend to train on a GPU but you can remove _--cuda_ from the call to use the CPU instead. If your GPU has less than 16GB VRAM, you need to decrease the number of features (e.g. _--features 64_) or the depth of the network (e.g. _--interactions 6_).
68 | The conditioning network is specified in a json-file. We provide the settings used in experiments throughout the paper in ./cG-SchNet/conditioning_specifications. The exact data splits utilized in our trainings are provided in ./cG-SchNet/splits/. The conditioning network and corresponding split are numbered like the models in Supplementary Table S1, where we list the hyper-parameter settings. Simply replace the paths behind the _--conditioning_json_path_ and _--split_path_ arguments in order to train the desired model.
69 | If you want to train on a new data split, remove _--split_path_ from the call and instead use _--split x y_ to randomly create a new split with x training molecules, y validation molecules, and the remaining molecules in the test set (e.g. with x=50000 and y=5000).
70 |
71 | The training of the full model takes about 40 hours on a A100 GPU. For testing purposes, you can leave the training running for only a couple of epochs. An epoch should take 10-20 min, depending on your hardware. To observe the training progress, use TensorBoard:
72 |
73 | tensorboard --logdir=./models
74 |
75 |
76 | The logs will appear after the first epoch has completed.
77 |
78 | ### Generating molecules
79 | Running the script with the following arguments will generate 1000 molecules using the composition C7O2H10 and relative atomic energy of -0.1 as conditions using the trained model at ./model/cgschnet/ and store them in ./model/cgschnet/generated/generated.mol_dict:
80 |
81 | python ./cG-SchNet/gschnet_cond_script.py generate gschnet ./models/cgschnet/ 1000 --conditioning "composition 10 7 0 2 0; n_atoms 19; relative_atomic_energy -0.1" --cuda
82 |
83 | Remove _--cuda_ from the call if you want to run on the CPU. Add _--show_gen_ to display the molecules with ASE after generation. If you are running into problems due to small VRAM, decrease the size of mini-batches during generation (e.g. _--chunk_size 500_, default is 1000).
84 |
85 | Note that the conditions for sampling are provided as a string using the _--conditioning_ argument. You have to provide the name of the property (as given in the field _in\_key\_property_ in the conditioning specification json without trailing underscore) followed by the target value(s). Each property the model was conditioned on, i.e. as listed in one of the layers in the conditioning specification json, has to be provided. Otherwise, the target value will be set to zero which often is an inappropriate value and thus may lead to invalid molecules. The composition of molecules is given as the number of atoms of type h, c, n, o, f (in that particular order). When conditioning on fingerprints, the value can either be an index (which is used to load the fingerprint of the corresponding molecule in the ```./data/qm9gen.db``` data base) or a SMILES string.
86 |
87 | ### Filtering and analysis of generated molecules
88 | After generation, the generated molecules can be filtered for invalid and duplicate structures by running filter_generated.py:
89 |
90 | python ./cG-SchNet/filter_generated.py ./models/cgschnet/generated/generated.mol_dict --train_data_path ./data/qm9gen.db --model_path ./models/cgschnet
91 |
92 | The script will print its progress and the gathered results.
93 | The script checks the valency constraints (e.g. every hydrogen atom should have exactly one bond), the connectedness (i.e. all atoms in a molecule should be connected to each other via a path over bonds), and removes duplicates*. The remaining valid structures are stored in an sqlite database with ASE (at ./models/cgschnet/generated/generated_molecules.db) along with an .npz-file that records certain statistics (e.g. the number of rings of certain sizes, the number of single, double, and triple bonds, the index of the matching training/test data molecule etc. for each molecule).
94 |
95 | In order to match the generated structures to training/test data, the QM9 data set is required. If it hasn't been downloaded before, e.g. because you are using a [pretrained model](https://github.com/atomistic-machine-learning/cG-SchNet/blob/main/published_data/README.md#pretrained-models) for generation, you can initialize the download by starting to train a dummy model for zero epochs and deleting it afterwards as follows:
96 |
97 | python ./cG-SchNet/gschnet_cond_script.py train gschnet ./data/ ./models/_dummy/ --split 1 1 --max_epochs 0
98 | rm -r ./models/_dummy
99 |
100 | *_Please note that, as described in the paper, we use molecular fingerprints and canonical smiles representations to identify duplicates which means that different spatial conformers corresponding to the same canonical smiles string are tagged as duplicates and removed in the process. Add '--filters valence disconnected' to the call in order to not remove but keep identified duplicates in the created database._
101 |
102 | ### Displaying generated molecules
103 | After filtering, all generated molecules stored in the sqlite database can be displayed with ASE as follows:
104 |
105 | python ./cG-SchNet/display_molecules.py --data_path ./models/cgschnet/generated/generated_molecules.db
106 |
107 | # How does it work?
108 |
109 | cG-SchNet is an autoregressive neural network. It builds 3d molecules by placing one atom after another in 3d space. To this end, the joint distribution of all atoms is factorized into single steps, where the position and type of the new atom depends on the preceding atoms (Figure a). The model also processes conditions, i.e. values of target properties, which enable it to learn a conditional distribution of molecular structures. This distribution allows targeted sampling of molecules that are highly likely to exhibit specified conditions (see e.g. the distribution of the polarizability of molecules generated with cG-SchNet using five different target values in Figure b). The type and absolute position of new atoms are sampled successively, where the probability of the positions is apporximated from predicted pairwise distances to preceding atoms. In order to improve the accuracy of the approximation and steer the generation process, the network uses two auxiliary tokens, the focus and the origin. The new atom always has to be a neighbor of the focus and the origin marks the supposed center of mass of the final structure. A scheme explaining the generation procedure can be seen in Figure c. It uses 2d positional distributions for visualization purposes. For more details, please refer to the publication [_here_](https://www.nature.com/articles/s41467-022-28526-y).
110 |
111 | 
112 |
--------------------------------------------------------------------------------
/conditioning_specifications/1_polarizability.json:
--------------------------------------------------------------------------------
1 | {
2 | "stack": {
3 | "in_key": "representation",
4 | "layers": {
5 | "isotropic_polarizability": {
6 | "in_key_property": "_isotropic_polarizability",
7 | "n_in": 15,
8 | "n_layers": 3,
9 | "n_neurons": 64,
10 | "n_out": 64,
11 | "start": 5,
12 | "stop": 16
13 | }
14 | },
15 | "mode": "stack",
16 | "n_global_cond_features": 128,
17 | "n_layers": 5,
18 | "n_neurons": 128,
19 | "out_key": "representation"
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/conditioning_specifications/2_fingerprint.json:
--------------------------------------------------------------------------------
1 | {
2 | "stack": {
3 | "in_key": "representation",
4 | "layers": {
5 | "fingerprint": {
6 | "n_in": 1024,
7 | "n_out": 128,
8 | "n_layers": 3,
9 | "in_key_fingerprint": "_fingerprint"
10 | }
11 | },
12 | "mode": "stack",
13 | "n_global_cond_features": 128,
14 | "n_layers": 5,
15 | "n_neurons": 128,
16 | "out_key": "representation"
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/conditioning_specifications/3_gap_comp.json:
--------------------------------------------------------------------------------
1 | {
2 | "stack": {
3 | "in_key": "representation",
4 | "layers": {
5 | "composition": {
6 | "embedding": {
7 | "embedding_dim": 16,
8 | "num_embeddings": 20
9 | },
10 | "in_key_composition": "_composition",
11 | "n_layers": 3,
12 | "n_neurons": 64,
13 | "n_out": 64,
14 | "n_types": 5,
15 | "skip_h": false,
16 | "type_weighting": "relative"
17 | },
18 | "gap": {
19 | "in_key_property": "_gap",
20 | "n_in": 5,
21 | "n_layers": 3,
22 | "n_neurons": 64,
23 | "n_out": 64,
24 | "start": 2,
25 | "stop": 11
26 | },
27 | "n_atoms": {
28 | "in_key_property": "_n_atoms",
29 | "n_in": 5,
30 | "n_layers": 3,
31 | "n_neurons": 64,
32 | "n_out": 64,
33 | "start": 0,
34 | "stop": 35
35 | }
36 | },
37 | "mode": "stack",
38 | "n_global_cond_features": 128,
39 | "n_layers": 5,
40 | "n_neurons": 128,
41 | "out_key": "representation"
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/conditioning_specifications/4_comp_relenergy.json:
--------------------------------------------------------------------------------
1 | {
2 | "stack": {
3 | "in_key": "representation",
4 | "layers": {
5 | "composition": {
6 | "embedding": {
7 | "embedding_dim": 16,
8 | "num_embeddings": 20
9 | },
10 | "in_key_composition": "_composition",
11 | "n_layers": 3,
12 | "n_neurons": 64,
13 | "n_out": 64,
14 | "n_types": 5,
15 | "skip_h": false,
16 | "type_weighting": "relative"
17 | },
18 | "relative_atomic_energy": {
19 | "in_key_property": "_relative_atomic_energy",
20 | "n_in": 5,
21 | "n_layers": 3,
22 | "n_neurons": 64,
23 | "n_out": 64,
24 | "start": -0.2,
25 | "stop": 0.2
26 | },
27 | "n_atoms": {
28 | "in_key_property": "_n_atoms",
29 | "n_in": 5,
30 | "n_layers": 3,
31 | "n_neurons": 64,
32 | "n_out": 64,
33 | "start": 0,
34 | "stop": 35
35 | }
36 | },
37 | "mode": "stack",
38 | "n_global_cond_features": 128,
39 | "n_layers": 5,
40 | "n_neurons": 128,
41 | "out_key": "representation"
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/conditioning_specifications/5_gap_relenergy.json:
--------------------------------------------------------------------------------
1 | {
2 | "stack": {
3 | "in_key": "representation",
4 | "layers": {
5 | "gap": {
6 | "in_key_property": "_gap",
7 | "n_in": 5,
8 | "n_layers": 3,
9 | "n_neurons": 64,
10 | "n_out": 64,
11 | "start": 2,
12 | "stop": 11
13 | },
14 | "relative_atomic_energy": {
15 | "in_key_property": "_relative_atomic_energy",
16 | "n_in": 5,
17 | "n_layers": 3,
18 | "n_neurons": 64,
19 | "n_out": 64,
20 | "start": -0.2,
21 | "stop": 0.2
22 | }
23 | },
24 | "mode": "stack",
25 | "n_global_cond_features": 128,
26 | "n_layers": 5,
27 | "n_neurons": 128,
28 | "out_key": "representation"
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/display_molecules.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import subprocess
5 | import numpy as np
6 | import tempfile
7 |
8 | from ase.db import connect
9 | from ase.io import write
10 | from utility_classes import IndexProvider
11 |
12 |
13 | def get_parser():
14 | """ Setup parser for command line arguments """
15 | main_parser = argparse.ArgumentParser()
16 | main_parser.add_argument('--data_path', type=str, default=None,
17 | help='Path to database with filtered, generated molecules '
18 | '(.db format, needs to be provided if generated '
19 | 'molecules shall be displayed, default: %(default)s)')
20 | main_parser.add_argument('--train_data_path', type=str,
21 | help='Path to training data base (.db format, needs to be '
22 | 'provided if molecules from the training data set '
23 | 'shall be displayed, e.g. when using --train or '
24 | '--test, default: %(default)s)',
25 | default=None)
26 | main_parser.add_argument('--select', type=str, nargs='*',
27 | help='Selection strings that specify which molecules '
28 | 'shall be shown, if None all molecules from '
29 | 'data_path and/or train_data_path are shown, '
30 | 'providing multiple strings'
31 | ' will open multiple windows (one per string), '
32 | '(default: %(default)s). The selection string has '
33 | 'the general format "Property,OperatorTarget" (e.g. '
34 | '"C,>8"to filter for all molecules with more than '
35 | 'eight carbon atoms where "C" is the statistic '
36 | 'counting the number of carbon atoms in a molecule, '
37 | '">" is the operator, and "8" is the target value). '
38 | 'Multiple conditions can be combined to form one '
39 | 'selection string using "&" (e.g "C,>8&R5,>0" to '
40 | 'get all molecules with more than 8 carbon atoms '
41 | 'and at least 1 ring of size 5). Prepending '
42 | '"training" to the selection string will filter and '
43 | 'display molecules from the training data base '
44 | 'instead of generated molecules (e.g. "training C,>8"'
45 | '). An overview of the available properties for '
46 | 'molecules generated with G-SchNet trained on QM9 can'
47 | ' be found in the README.md.',
48 | default=None)
49 | main_parser.add_argument('--print_indices',
50 | help='For each provided selection print out the indices '
51 | 'of molecules that match the respective selection '
52 | 'string',
53 | action='store_true')
54 | main_parser.add_argument('--export_to_dir', type=str,
55 | help='Optionally, provide a path to an directory to which '
56 | 'indices of molecules matching the corresponding '
57 | 'query shall be written (one .npy-file (numpy) per '
58 | 'selection string, if None is provided, the '
59 | 'indices will not be exported, default: %(default)s)',
60 | default=None)
61 | main_parser.add_argument('--train',
62 | help='Display all generated molecules that match '
63 | 'structures used during training and the '
64 | 'corresponding molecules from the training data set.',
65 | action='store_true')
66 | main_parser.add_argument('--test',
67 | help='Display all generated molecules that match '
68 | 'held out test data structures and the '
69 | 'corresponding molecules from the training data set.',
70 | action='store_true')
71 | main_parser.add_argument('--novel',
72 | help='Display all generated molecules that match neither '
73 | 'structures used during training nor those held out '
74 | 'as test data.',
75 | action='store_true')
76 | main_parser.add_argument('--block',
77 | help='Make the call to ASE GUI blocking (such that the '
78 | 'script stops until the GUI window is closed).',
79 | action='store_true')
80 |
81 | return main_parser
82 |
83 |
84 | def view_ase(mols, name, block=False):
85 | '''
86 | Display a list of molecules using the ASE GUI.
87 |
88 | Args:
89 | mols (list of ase.Atoms): molecules as ase.Atoms objects
90 | name (str): the name that shall be displayed in the windows top bar
91 | block (bool, optional): whether the call to ase gui shall block or not block
92 | the script (default: False)
93 | '''
94 | dir = tempfile.mkdtemp('', 'generated_molecules_') # make temporary directory
95 | filename = os.path.join(dir, name) # path of temporary file
96 | format = 'traj' # use trajectory format for temporary file
97 | command = sys.executable + ' -m ase gui -b' # command to execute ase gui viewer
98 | write(filename, mols, format=format) # write molecules to temporary file
99 | # show molecules in ase gui and remove temporary file and directory afterwards
100 | if block:
101 | subprocess.call(command.split() + [filename])
102 | os.remove(filename)
103 | os.rmdir(dir)
104 | else:
105 | subprocess.Popen(command.split() + [filename])
106 | subprocess.Popen(['sleep 60; rm "{0}"'.format(filename)], shell=True)
107 | subprocess.Popen(['sleep 65; rmdir "{0}"'.format(dir)], shell=True)
108 |
109 |
110 | def print_indices(idcs, name='', per_line=10):
111 | '''
112 | Prints provided indices in a clean formatting.
113 |
114 | Args:
115 | idcs (list of int): indices that shall be printed
116 | name (str, optional): the selection string that was used to obtain the indices
117 | per_line (int, optional): the number of indices that are printed per line (
118 | default: 10)
119 | '''
120 | biggest_count = len(str(len(idcs)))
121 | format_count = f'>{biggest_count}d'
122 | biggest_original_idx = len(str(max(idcs)))
123 | new_line = '\n'
124 | format = f'>{biggest_original_idx}d'
125 | str_idcs = [f'{i+1:{format_count}}: {j:{format}} ' +
126 | (new_line if (i+1) % per_line == 0 else '')
127 | for i, j in enumerate(idcs)]
128 | print(f'\nAll {len(idcs)} indices for selection {name}:')
129 | print(''.join(str_idcs))
130 |
131 |
132 | if __name__ == '__main__':
133 | parser = get_parser()
134 | args = parser.parse_args()
135 |
136 | # make sure that at least one path was provided
137 | if args.data_path is None and args.train_data_path is None:
138 | print(f'\nPlease specify --data_path to display generated molecules or '
139 | f'--train_data_path to display training molecules (or both)!')
140 | sys.exit(0)
141 |
142 | # sort queries into those concerning generated structures and those concerning
143 | # training data molecules
144 | train_selections = []
145 | gen_selections = []
146 | if args.select is not None:
147 | for selection in args.select:
148 | if selection.startswith('training'):
149 | # put queries concerning training structures aside for later
150 | train_selections += [selection]
151 | else:
152 | gen_selections += [selection]
153 |
154 | # make sure that the required paths were provided
155 | if args.train or args.test:
156 | if args.data_path is None:
157 | print('\nYou need to specify --data_path (and optionally '
158 | '--train_data_path) if using --train or --test!')
159 | sys.exit(0)
160 | if args.novel:
161 | if args.data_path is None:
162 | print('\nYou need to specify --data_path if you want to display novel '
163 | 'molecules!')
164 | sys.exit(0)
165 | if len(gen_selections) > 0:
166 | if args.data_path is None:
167 | print(f'\nYou need to specify --data_path to process the selections '
168 | f'{gen_selections}!')
169 | sys.exit(0)
170 | if len(train_selections) > 0:
171 | if args.train_data_path is None:
172 | print(f'\nYou need to specify --train_data_path to process the selections '
173 | f'{train_selections}!')
174 | sys.exit(0)
175 |
176 | # check if statistics files are needed
177 | need_gen_stats = (len(gen_selections) > 0) or args.train or args.test or args.novel
178 | need_train_stats = (len(train_selections) > 0) or args.train or args.test
179 |
180 | # check if there is a database with generated molecules at the provided path
181 | # and load accompanying statistics file
182 | if args.data_path is not None:
183 | if not os.path.isfile(args.data_path):
184 | print(f'\nThe specified data path ({args.data_path}) is not a file! Please '
185 | f'specify a different data path.')
186 | raise FileNotFoundError
187 | elif need_gen_stats:
188 | stats_path = os.path.splitext(args.data_path)[0] + f'_statistics.npz'
189 | if not os.path.isfile(stats_path):
190 | print(f'\nCannot find statistics file belonging to {args.data_path} ('
191 | f'expected it at {stats_path}. Please make sure that the file '
192 | f'exists.')
193 | raise FileNotFoundError
194 | else:
195 | stats_dict = np.load(stats_path)
196 | index_provider = IndexProvider(stats_dict['stats'],
197 | stats_dict['stat_heads'])
198 |
199 | # check if there is a database with training molecules at the provided path
200 | # and load accompanying statistics file
201 | if args.train_data_path is not None:
202 | if not os.path.isfile(args.train_data_path):
203 | print(f'\nThe specified training data path ({args.train_data_path}) is '
204 | f'not a file! Please specify --train_data_path correctly.')
205 | raise FileNotFoundError
206 | elif need_train_stats:
207 | stats_path = os.path.splitext(args.train_data_path)[0] + f'_statistics.npz'
208 | if not os.path.isfile(stats_path) and len(train_selections) > 0:
209 | print(f'\nCannot find statistics file belonging to '
210 | f'{args.train_data_path} (expected it at {stats_path}. Please '
211 | f'make sure that the file exists.')
212 | raise FileNotFoundError
213 | else:
214 | train_stats_dict = np.load(stats_path)
215 | train_index_provider = IndexProvider(train_stats_dict['stats'],
216 | train_stats_dict['stat_heads'])
217 |
218 | # create folder(s) for export of indices if necessary
219 | if args.export_to_dir is not None:
220 | if not os.path.isdir(args.export_to_dir):
221 | print(f'\nDirectory {args.export_to_dir} does not exist, creating '
222 | f'it to store indices of molecules matching the queries!')
223 | os.makedirs(args.export_to_dir)
224 | else:
225 | print(f'\nWill store indices of molecules matching the queries at '
226 | f'{args.export_to_dir}!')
227 |
228 | # display all generated molecules if desired
229 | if (len(gen_selections) == 0) and not (args.train or args.test or args.novel) and\
230 | args.data_path is not None:
231 | with connect(args.data_path) as con:
232 | _ats = [con.get(int(idx) + 1).toatoms() for idx in range(con.count())]
233 | view_ase(_ats, 'all generated molecules', args.block)
234 |
235 | # display generated molecules matching selection strings
236 | if len(gen_selections) > 0:
237 | for selection in gen_selections:
238 | # display queries concerning generated molecules
239 | idcs = index_provider.get_selected(selection)
240 | if len(idcs) == 0:
241 | print(f'\nNo molecules match selection {selection}!')
242 | continue
243 | with connect(args.data_path) as con:
244 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs]
245 | if args.print_indices:
246 | print_indices(idcs, selection)
247 | view_ase(_ats, f'generated molecules ({selection})', args.block)
248 | if args.export_to_dir is not None:
249 | np.save(os.path.join(args.export_to_dir, selection), idcs)
250 |
251 | # display all training molecules if desired
252 | if (len(train_selections) == 0) and not (args.train or args.test) and \
253 | args.train_data_path is not None:
254 | with connect(args.train_data_path) as con:
255 | _ats = [con.get(int(idx) + 1).toatoms() for idx in range(con.count())]
256 | view_ase(_ats, 'all molecules in the training data set', args.block)
257 |
258 | # display training molecules matching selection strings
259 | if len(train_selections) > 0:
260 | # display training molecules that match the selection strings
261 | for selection in train_selections:
262 | _selection = selection.split()[1]
263 | stats_queries = []
264 | db_queries = []
265 | # sort into queries handled by looking into the statistics or the db
266 | for _sel_str in _selection.split('&'):
267 | prop = _sel_str.split(',')[0]
268 | if prop in train_stats_dict['stat_heads']:
269 | stats_queries += [_sel_str]
270 | elif len(prop.split('+')) > 0:
271 | found = True
272 | for p in prop.split('+'):
273 | if p not in train_stats_dict['stat_heads']:
274 | found = False
275 | break
276 | if found:
277 | stats_queries += [_sel_str]
278 | else:
279 | db_queries += [_sel_str]
280 | else:
281 | db_queries += [_sel_str]
282 | # process queries concerning the statistics
283 | if len(stats_queries) > 0:
284 | idcs = train_index_provider.get_selected('&'.join(stats_queries))
285 | else:
286 | idcs = range(connect(args.train_data_path).count())
287 | # process queries concerning the db entries
288 | if len(db_queries) > 0:
289 | with connect(args.train_data_path) as con:
290 | for query in db_queries:
291 | head, condition = query.split(',')
292 | if head not in con.get(1).data:
293 | print(f'Entry {head} not found for molecules in the '
294 | f'database, skipping query {query}.')
295 | continue
296 | else:
297 | op = train_index_provider.rel_re.search(condition).group(0)
298 | op = train_index_provider.op_dict[op] # extract operator
299 | num = float(train_index_provider.num_re.search(
300 | condition).group(0)) # extract numerical value
301 | remaining_idcs = []
302 | for idx in idcs:
303 | if op(con.get(int(idx)+1).data[head], num):
304 | remaining_idcs += [idx]
305 | idcs = remaining_idcs
306 | # extract molecules matching the query from db and display them
307 | if len(idcs) == 0:
308 | print(f'\nNo training molecules match selection {_selection}!')
309 | continue
310 | with connect(args.train_data_path) as con:
311 | _ats = [con.get(int(idx)+1).toatoms() for idx in idcs]
312 | if args.print_indices:
313 | print_indices(idcs, selection)
314 | view_ase(_ats, f'training data set molecules ({_selection})', args.block)
315 | if args.export_to_dir is not None:
316 | np.save(os.path.join(args.export_to_dir, selection), idcs)
317 |
318 | # display generated molecules that match structures used for training
319 | if args.train:
320 | idcs = index_provider.get_selected('known,>=1&known,<=2')
321 | if len(idcs) == 0:
322 | print(f'\nNo generated molecules found that match structures used '
323 | f'during training!')
324 | else:
325 | with connect(args.data_path) as con:
326 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs]
327 | if args.print_indices:
328 | print_indices(idcs, 'generated train')
329 | view_ase(_ats, f'generated molecules (matching train structures)',
330 | args.block)
331 | if args.export_to_dir is not None:
332 | np.save(os.path.join(args.export_to_dir, 'generated train'), idcs)
333 | # display corresponding training structures
334 | if args.train_data_path is not None:
335 | _row_idx = list(stats_dict['stat_heads']).index('equals')
336 | t_idcs = stats_dict['stats'][_row_idx, idcs].astype(int)
337 | with connect(args.train_data_path) as con:
338 | _ats = [con.get(int(idx) + 1).toatoms() for idx in t_idcs]
339 | if args.print_indices:
340 | print_indices(t_idcs, 'reference train')
341 | view_ase(_ats, f'training molecules (train structures)', args.block)
342 | if args.export_to_dir is not None:
343 | np.save(os.path.join(args.export_to_dir, 'reference train'), t_idcs)
344 |
345 | # display generated molecules that match held out test structures
346 | if args.test:
347 | idcs = index_provider.get_selected('known,==3')
348 | if len(idcs) == 0:
349 | print(f'\nNo generated molecules found that match held out test '
350 | f'structures!')
351 | else:
352 | with connect(args.data_path) as con:
353 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs]
354 | if args.print_indices:
355 | print_indices(idcs, 'generated test')
356 | view_ase(_ats, f'generated molecules (matching test structures)',
357 | args.block)
358 | if args.export_to_dir is not None:
359 | np.save(os.path.join(args.export_to_dir, 'generated test'), idcs)
360 | # display corresponding training structures
361 | if args.train_data_path is not None:
362 | _row_idx = list(stats_dict['stat_heads']).index('equals')
363 | t_idcs = stats_dict['stats'][_row_idx, idcs].astype(int)
364 | with connect(args.train_data_path) as con:
365 | _ats = [con.get(int(idx) + 1).toatoms() for idx in t_idcs]
366 | if args.print_indices:
367 | print_indices(t_idcs, 'reference test')
368 | view_ase(_ats, f'training molecules (test structures)', args.block)
369 | if args.export_to_dir is not None:
370 | np.save(os.path.join(args.export_to_dir, 'reference test'), t_idcs)
371 |
372 | # display generated molecules that are novel (i.e. that do not match held out
373 | # test structures or structures used during training)
374 | if args.novel:
375 | idcs = index_provider.get_selected('known,==0')
376 | if len(idcs) == 0:
377 | print(f'\nNo novel molecules found!')
378 | else:
379 | with connect(args.data_path) as con:
380 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs]
381 | if args.print_indices:
382 | print_indices(idcs, 'novel')
383 | view_ase(_ats, f'generated molecules (novel)', args.block)
384 | if args.export_to_dir is not None:
385 | np.save(os.path.join(args.export_to_dir, 'generated novel'), idcs)
386 |
--------------------------------------------------------------------------------
/gschnet_cond_script.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import pickle
5 | import time
6 | import json
7 | from shutil import copyfile, rmtree
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | from torch.optim import Adam
13 | from torch.utils.data.sampler import RandomSampler
14 | from ase import Atoms
15 | from ase.db import connect
16 | import ase.visualize as asv
17 |
18 | import schnetpack as spk
19 | from schnetpack.utils import count_params, to_json, read_from_json
20 | from schnetpack import Properties
21 | from schnetpack.datasets import DownloadableAtomsData
22 |
23 | from nn_classes import AtomwiseWithProcessing, EmbeddingMultiplication,\
24 | RepresentationConditioning, NormalizeAndAggregate, KLDivergence,\
25 | FingerprintEmbedding, AtomCompositionEmbedding, PropertyEmbedding
26 | from utility_functions import boolean_string, collate_atoms, generate_molecules, \
27 | update_dict, get_dict_count, get_composition
28 |
29 | # add your own dataset classes here:
30 | from qm9_data import QM9gen
31 | dataset_name_to_class_mapping = {'qm9': QM9gen}
32 |
33 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
34 |
35 |
36 | def get_parser():
37 | """ Setup parser for command line arguments """
38 | main_parser = argparse.ArgumentParser()
39 |
40 | ## command-specific
41 | cmd_parser = argparse.ArgumentParser(add_help=False)
42 | cmd_parser.add_argument('--cuda', help='Set flag to use GPU(s)',
43 | action='store_true')
44 | cmd_parser.add_argument('--parallel',
45 | help='Run data-parallel on all available GPUs '
46 | '(specify with environment variable'
47 | + ' CUDA_VISIBLE_DEVICES)',
48 | action='store_true')
49 | cmd_parser.add_argument('--batch_size', type=int,
50 | help='Mini-batch size for training and prediction '
51 | '(default: %(default)s)',
52 | default=5)
53 | cmd_parser.add_argument('--draw_random_samples', type=int, default=0,
54 | help='Only draw x generation steps per molecule '
55 | 'in each batch (if x=0, all generation '
56 | 'steps are included for each molecule,'
57 | 'default: %(default)s)')
58 | cmd_parser.add_argument('--checkpoint', type=int, default=-1,
59 | help='The checkpoint of the model that is going '
60 | 'to be loaded for evaluation or generation '
61 | '(set to -1 to load the best model '
62 | 'according to validation error, '
63 | 'default: %(default)s)')
64 | cmd_parser.add_argument('--precompute_distances', type=boolean_string,
65 | default='true',
66 | help='Store precomputed distances in the database '
67 | 'during pre-processing (caution, has no effect if '
68 | 'the dataset has already been downloaded, '
69 | 'pre-processed, and stored before, '
70 | 'default: %(default)s)')
71 |
72 | ## training
73 | train_parser = argparse.ArgumentParser(add_help=False,
74 | parents=[cmd_parser])
75 | train_parser.add_argument('datapath',
76 | help='Path / destination of dataset '\
77 | 'directory')
78 | train_parser.add_argument('modelpath',
79 | help='Destination for models and logs')
80 | train_parser.add_argument('--dataset_name', type=str, default='qm9',
81 | help=f'Name of the dataset used (choose from '
82 | f'{list(dataset_name_to_class_mapping.keys())}, '
83 | f'default: %(default)s)'),
84 | train_parser.add_argument('--subset_path', type=str,
85 | help='A path to a npy file containing indices '
86 | 'of a subset of the data set at datapath '
87 | '(default: %(default)s)',
88 | default=None)
89 | train_parser.add_argument('--seed', type=int, default=None,
90 | help='Set random seed for torch and numpy.')
91 | train_parser.add_argument('--overwrite',
92 | help='Remove previous model directory.',
93 | action='store_true')
94 | train_parser.add_argument('--pretrained_path',
95 | help='Start training from the pre-trained model at the '
96 | 'provided path (reset optimizer parameters such as '
97 | 'best loss and learning rate and create new split)',
98 | default=None)
99 | train_parser.add_argument('--split_path',
100 | help='Path/destination of npz with data splits',
101 | default=None)
102 | train_parser.add_argument('--split',
103 | help='Split into [train] [validation] and use '
104 | 'remaining for testing',
105 | type=int, nargs=2, default=[None, None])
106 | train_parser.add_argument('--max_epochs', type=int,
107 | help='Maximum number of training epochs '
108 | '(default: %(default)s)',
109 | default=500)
110 | train_parser.add_argument('--lr', type=float,
111 | help='Initial learning rate '
112 | '(default: %(default)s)',
113 | default=1e-4)
114 | train_parser.add_argument('--lr_patience', type=int,
115 | help='Epochs without improvement before reducing'
116 | ' the learning rate (default: %(default)s)',
117 | default=10)
118 | train_parser.add_argument('--lr_decay', type=float,
119 | help='Learning rate decay '
120 | '(default: %(default)s)',
121 | default=0.5)
122 | train_parser.add_argument('--lr_min', type=float,
123 | help='Minimal learning rate '
124 | '(default: %(default)s)',
125 | default=1e-6)
126 | train_parser.add_argument('--logger',
127 | help='Choose logger for training process '
128 | '(default: %(default)s)',
129 | choices=['csv', 'tensorboard'],
130 | default='tensorboard')
131 | train_parser.add_argument('--log_every_n_epochs', type=int,
132 | help='Log metrics every given number of epochs '
133 | '(default: %(default)s)',
134 | default=1)
135 | train_parser.add_argument('--checkpoint_every_n_epochs', type=int,
136 | help='Create checkpoint every given number of '
137 | 'epochs'
138 | '(default: %(default)s)',
139 | default=25)
140 | train_parser.add_argument('--label_width_factor', type=float,
141 | help='A factor that is multiplied with the '
142 | 'range between two distance bins in order '
143 | 'to determine the width of the Gaussians '
144 | 'used to obtain labels from distances '
145 | '(set to 0. to use one-hot '
146 | 'encodings of distances as labels, '
147 | 'default: %(default)s)',
148 | default=0.1)
149 | train_parser.add_argument('--conditioning_json_path', type=str,
150 | help='Path to .json-file with specification of layers '
151 | 'used for conditioning of the model, default: '
152 | '%(default)s)',
153 | default=None)
154 | train_parser.add_argument('--use_embeddings_for_type_predictions',
155 | help='Copy extracted features and multiply them with '
156 | 'embeddings of all possible types to obtain scores.',
157 | action='store_true')
158 | train_parser.add_argument('--share_embeddings',
159 | help='Share embedding layers in SchNet part and in '
160 | 'pre-processing before predicting distances and '
161 | 'types.',
162 | action='store_true')
163 |
164 | ## evaluation
165 | eval_parser = argparse.ArgumentParser(add_help=False, parents=[cmd_parser])
166 | eval_parser.add_argument('datapath', help='Path of dataset directory')
167 | eval_parser.add_argument('modelpath', help='Path of stored model')
168 | eval_parser.add_argument('--split',
169 | help='Evaluate trained model on given split',
170 | choices=['train', 'validation', 'test'],
171 | default=['test'], nargs='+')
172 |
173 | ## molecule generation
174 | gen_parser = argparse.ArgumentParser(add_help=False, parents=[cmd_parser])
175 | gen_parser.add_argument('modelpath', help='Path of stored model')
176 | gen_parser.add_argument('amount_gen', type=int,
177 | help='The amount of generated molecules')
178 | gen_parser.add_argument('--show_gen',
179 | help='Whether to open plots of generated '
180 | 'molecules for visual evaluation',
181 | action='store_true')
182 | gen_parser.add_argument('--chunk_size', type=int,
183 | help='The size of mini batches during generation '
184 | '(default: %(default)s)',
185 | default=1000)
186 | gen_parser.add_argument('--max_length', type=int,
187 | help='The maximum number of atoms per molecule '
188 | '(default: %(default)s)',
189 | default=35)
190 | gen_parser.add_argument('--folder_name', type=str,
191 | help='The name of the folder in which generated '
192 | 'molecules are stored (please note that the folder '
193 | 'is always inside the model directory and always '
194 | 'called "generated", but custom extensions may be '
195 | 'provided here, e.g. "--folder_name _10" will place '
196 | 'the generated molecules in a folder called '
197 | '"generated_10", default: %(default)s)',
198 | default='')
199 | gen_parser.add_argument('--file_name', type=str,
200 | help='The name of the file in which generated '
201 | 'molecules are stored (please note that '
202 | 'increasing numbers are appended to the file name '
203 | 'if it already exists and that the extension '
204 | '.mol_dict is automatically added to the chosen '
205 | 'file name, default: %(default)s)',
206 | default='generated')
207 | gen_parser.add_argument('--store_unfinished',
208 | help='Store molecules which have not been '
209 | 'finished after sampling max_length atoms',
210 | action='store_true')
211 | gen_parser.add_argument('--store_process',
212 | help='Store information needed to track the generation '
213 | 'process (i.e. current focus, predicted distributions,'
214 | ' sampled type etc.) in the .mol_dict file',
215 | action='store_true')
216 | gen_parser.add_argument('--print_file',
217 | help='Use to limit the printing if results are '
218 | 'written to a file instead of the console ('
219 | 'e.g. if running on a cluster)',
220 | action='store_true')
221 | gen_parser.add_argument('--temperature', type=float,
222 | help='The temperature T to use for sampling '
223 | '(default: %(default)s)',
224 | default=0.1)
225 | gen_parser.add_argument('--conditioning', type=str, default=None,
226 | help='Additional input for conditioning of molecule '
227 | 'generation. Write "property1 value1; property2 '
228 | 'value2; ..." to specify the additional information '
229 | '(where multiple values can be provided per property, '
230 | 'e.g. the atomic composition, default: None)')
231 |
232 | # model-specific parsers
233 | model_parser = argparse.ArgumentParser(add_help=False)
234 | model_parser.add_argument('--aggregation_mode', type=str, default='sum',
235 | choices=['sum', 'avg'],
236 | help=' (default: %(default)s)')
237 |
238 | ####### G-SchNet #######
239 | gschnet_parser = argparse.ArgumentParser(add_help=False,
240 | parents=[model_parser])
241 | gschnet_parser.add_argument('--features', type=int,
242 | help='Size of atom-wise representation '
243 | '(default: %(default)s)',
244 | default=128)
245 | gschnet_parser.add_argument('--interactions', type=int,
246 | help='Number of regular SchNet interaction '
247 | 'blocks (default: %(default)s)',
248 | default=9)
249 | gschnet_parser.add_argument('--dense_layers', type=int,
250 | help='Number of layers in the (atom-wise) dense '
251 | 'output networks to predict next type and '
252 | 'distances (default: %(default)s)',
253 | default=5)
254 | gschnet_parser.add_argument('--cutoff', type=float, default=10.,
255 | help='Cutoff radius of local environment '
256 | '(default: %(default)s)')
257 | gschnet_parser.add_argument('--num_gaussians', type=int, default=25,
258 | help='Number of Gaussians to expand distances '
259 | '(default: %(default)s)')
260 | gschnet_parser.add_argument('--max_distance', type=float, default=15.,
261 | help='Maximum distance covered by the discrete '
262 | 'distributions over distances learned by '
263 | 'the model '
264 | '(default: %(default)s)')
265 | gschnet_parser.add_argument('--num_distance_bins', type=int, default=300,
266 | help='Number of bins used in the discrete '
267 | 'distributions over distances learned by '
268 | 'the model(default: %(default)s)')
269 |
270 | ## setup subparser structure
271 | cmd_subparsers = main_parser.add_subparsers(dest='mode',
272 | help='Command-specific '
273 | 'arguments')
274 | cmd_subparsers.required = True
275 | subparser_train = cmd_subparsers.add_parser('train', help='Training help')
276 | subparser_eval = cmd_subparsers.add_parser('eval', help='Eval help')
277 | subparser_gen = cmd_subparsers.add_parser('generate', help='Generate help')
278 |
279 | train_subparsers = subparser_train.add_subparsers(dest='model',
280 | help='Model-specific '
281 | 'arguments')
282 | train_subparsers.required = True
283 | train_subparsers.add_parser('gschnet', help='G-SchNet help',
284 | parents=[train_parser, gschnet_parser])
285 |
286 | eval_subparsers = subparser_eval.add_subparsers(dest='model',
287 | help='Model-specific '
288 | 'arguments')
289 | eval_subparsers.required = True
290 | eval_subparsers.add_parser('gschnet', help='G-SchNet help',
291 | parents=[eval_parser, gschnet_parser])
292 |
293 | gen_subparsers = subparser_gen.add_subparsers(dest='model',
294 | help='Model-specific '
295 | 'arguments')
296 | gen_subparsers.required = True
297 | gen_subparsers.add_parser('gschnet', help='G-SchNet help',
298 | parents=[gen_parser, gschnet_parser])
299 |
300 | return main_parser
301 |
302 |
303 | def get_model(args, conditioning_specification, parallelize=False):
304 | # get information about the atom types available in the data set
305 | dataclass = dataset_name_to_class_mapping[args.dataset_name]
306 | num_types = len(dataclass.available_atom_types)
307 | max_type = max(dataclass.available_atom_types)
308 |
309 | # get SchNet layers for feature extraction
310 | representation =\
311 | spk.representation.SchNet(n_atom_basis=args.features,
312 | n_filters=args.features,
313 | n_interactions=args.interactions,
314 | cutoff=args.cutoff,
315 | n_gaussians=args.num_gaussians,
316 | max_z=max_type+3)
317 | if args.share_embeddings:
318 | emb_layers = representation.embedding
319 | else:
320 | emb_layers = nn.Embedding(max_type+3, args.features, padding_idx=0)
321 |
322 | # build layers for conditioning according to conditioning specification dictionary
323 | representation_conditioning_blocks = []
324 | n_features = args.features
325 | if conditioning_specification is not None:
326 | for block in conditioning_specification:
327 | layers = conditioning_specification[block]['layers']
328 | layers_list = []
329 | for nn_name in layers:
330 | if 'p' in layers[nn_name]:
331 | layers[nn_name].pop('p')
332 | layer_class = get_conditioning_nn(nn_name, dataclass)
333 | layer_arguments = layers[nn_name]
334 | layers_list += [layer_class(**layer_arguments)]
335 | conditioning_specification[block]['layers'] = layers_list
336 | representation_conditioning_blocks += \
337 | [RepresentationConditioning(**conditioning_specification[block])]
338 | n_features += representation_conditioning_blocks[-1].n_additional_features
339 |
340 | # get output layers for prediction of next atom type
341 | if args.use_embeddings_for_type_predictions:
342 | preprocess_type = \
343 | EmbeddingMultiplication(emb_layers,
344 | in_key_types='_all_types',
345 | in_key_representation='representation',
346 | out_key='preprocessed_representation')
347 | _n_out = 1
348 | else:
349 | preprocess_type = None
350 | _n_out = num_types + 1 # number of possible types + stop token
351 | postprocess_type = NormalizeAndAggregate(normalize=True,
352 | normalization_axis=-1,
353 | normalization_mode='logsoftmax',
354 | aggregate=True,
355 | aggregation_axis=-2,
356 | aggregation_mode='sum',
357 | keepdim=False,
358 | mask='_type_mask',
359 | squeeze=True)
360 | out_module_type = \
361 | AtomwiseWithProcessing(n_in=n_features,
362 | n_out=_n_out,
363 | n_layers=args.dense_layers,
364 | preprocess_layers=preprocess_type,
365 | postprocess_layers=postprocess_type,
366 | out_key='type_predictions')
367 |
368 | # get output layers for predictions of distances
369 | preprocess_dist = \
370 | EmbeddingMultiplication(emb_layers,
371 | in_key_types='_next_types',
372 | in_key_representation='representation',
373 | out_key='preprocessed_representation')
374 | out_module_dist = \
375 | AtomwiseWithProcessing(n_in=n_features,
376 | n_out=args.num_distance_bins,
377 | n_layers=args.dense_layers,
378 | preprocess_layers=preprocess_dist,
379 | out_key='distance_predictions')
380 |
381 | # combine layers into an atomistic model
382 | model = spk.atomistic.AtomisticModel(representation,
383 | representation_conditioning_blocks +
384 | [out_module_type, out_module_dist])
385 |
386 | if parallelize:
387 | model = nn.DataParallel(model)
388 |
389 | logging.info("The model you built has: %d parameters" %
390 | count_params(model))
391 |
392 | return model
393 |
394 |
395 | def train(args, model, train_loader, val_loader, device):
396 |
397 | # setup hooks and logging
398 | hooks = [
399 | spk.hooks.MaxEpochHook(args.max_epochs)
400 | ]
401 |
402 | # filter for trainable parameters
403 | trainable_params = filter(lambda p: p.requires_grad, model.parameters())
404 | # setup optimizer
405 | optimizer = Adam(trainable_params, lr=args.lr)
406 | schedule = spk.hooks.ReduceLROnPlateauHook(optimizer,
407 | patience=args.lr_patience,
408 | factor=args.lr_decay,
409 | min_lr=args.lr_min,
410 | window_length=1,
411 | stop_after_min=True)
412 | hooks.append(schedule)
413 |
414 | # set up metrics to log KL divergence on distributions of types and distances
415 | metrics = [KLDivergence(target='_type_labels',
416 | model_output='type_predictions',
417 | name='KLD_types'),
418 | KLDivergence(target='_labels',
419 | model_output='distance_predictions',
420 | mask='_dist_mask',
421 | name='KLD_dists')]
422 |
423 | if args.logger == 'csv':
424 | logger =\
425 | spk.hooks.CSVHook(os.path.join(args.modelpath, 'log'),
426 | metrics,
427 | every_n_epochs=args.log_every_n_epochs)
428 | hooks.append(logger)
429 | elif args.logger == 'tensorboard':
430 | logger =\
431 | spk.hooks.TensorboardHook(os.path.join(args.modelpath, 'log'),
432 | metrics,
433 | every_n_epochs=args.log_every_n_epochs)
434 | hooks.append(logger)
435 |
436 | norm_layer = nn.LogSoftmax(-1).to(device)
437 | loss_layer = nn.KLDivLoss(reduction='none').to(device)
438 |
439 | # setup loss function
440 | def loss(batch, result):
441 | # loss for type predictions (KLD)
442 | out_type = norm_layer(result['type_predictions'])
443 | loss_type = loss_layer(out_type, batch['_type_labels'])
444 | loss_type = torch.sum(loss_type, -1)
445 | loss_type = torch.mean(loss_type)
446 |
447 | # loss for distance predictions (KLD)
448 | mask_dist = batch['_dist_mask']
449 | N = torch.sum(mask_dist)
450 | out_dist = norm_layer(result['distance_predictions'])
451 | loss_dist = loss_layer(out_dist, batch['_labels'])
452 | loss_dist = torch.sum(loss_dist, -1)
453 | loss_dist = torch.sum(loss_dist * mask_dist) / torch.max(N, torch.ones_like(N))
454 |
455 | return loss_type + loss_dist
456 |
457 | # initialize trainer
458 | trainer = spk.train.Trainer(args.modelpath,
459 | model,
460 | loss,
461 | optimizer,
462 | train_loader,
463 | val_loader,
464 | hooks=hooks,
465 | checkpoint_interval=args.checkpoint_every_n_epochs,
466 | keep_n_checkpoints=3)
467 |
468 | # reset optimizer and hooks if starting from pre-trained model (e.g. for
469 | # fine-tuning)
470 | if args.pretrained_path is not None:
471 | logging.info('starting from pre-trained model...')
472 | # reset epoch and step
473 | trainer.epoch = 0
474 | trainer.step = 0
475 | trainer.best_loss = float('inf')
476 | # reset optimizer
477 | trainable_params = filter(lambda p: p.requires_grad, model.parameters())
478 | optimizer = Adam(trainable_params, lr=args.lr)
479 | trainer.optimizer = optimizer
480 | # reset scheduler
481 | schedule =\
482 | spk.hooks.ReduceLROnPlateauHook(optimizer,
483 | patience=args.lr_patience,
484 | factor=args.lr_decay,
485 | min_lr=args.lr_min,
486 | window_length=1,
487 | stop_after_min=True)
488 | trainer.hooks[1] = schedule
489 | # remove checkpoints of pre-trained model
490 | rmtree(os.path.join(args.modelpath, 'checkpoints'))
491 | os.makedirs(os.path.join(args.modelpath, 'checkpoints'))
492 | # store first checkpoint
493 | trainer.store_checkpoint()
494 |
495 | # start training
496 | trainer.train(device)
497 |
498 |
499 | def evaluate(args, model, train_loader, val_loader, test_loader, device):
500 | header = ['Subset', 'distances KLD', 'types KLD']
501 |
502 | metrics = [KLDivergence(target='_labels',
503 | model_output='distance_predictions',
504 | mask='_dist_mask'),
505 | KLDivergence(target='_type_labels',
506 | model_output='type_predictions')]
507 |
508 | results = []
509 | if 'train' in args.split:
510 | results.append(['training'] +
511 | ['%.5f' % i for i in
512 | evaluate_dataset(metrics, model,
513 | train_loader, device)])
514 |
515 | if 'validation' in args.split:
516 | results.append(['validation'] +
517 | ['%.5f' % i for i in
518 | evaluate_dataset(metrics, model,
519 | val_loader, device)])
520 |
521 | if 'test' in args.split:
522 | results.append(['test'] + ['%.5f' % i for i in evaluate_dataset(
523 | metrics, model, test_loader, device)])
524 |
525 | header = ','.join(header)
526 | results = np.array(results)
527 |
528 | np.savetxt(os.path.join(args.modelpath, 'evaluation.csv'), results,
529 | header=header, fmt='%s', delimiter=',')
530 |
531 |
532 | def evaluate_dataset(metrics, model, loader, device):
533 | for metric in metrics:
534 | metric.reset()
535 |
536 | for batch in loader:
537 | batch = {
538 | k: v.to(device)
539 | for k, v in batch.items()
540 | }
541 | result = model(batch)
542 |
543 | for metric in metrics:
544 | metric.add_batch(batch, result)
545 |
546 | results = [
547 | metric.aggregate() for metric in metrics
548 | ]
549 | return results
550 |
551 |
552 | def generate(args, train_args, model, device, conditioning_layer_list):
553 | # generate molecules (in chunks) and print progress
554 |
555 | dataclass = dataset_name_to_class_mapping[train_args.dataset_name]
556 | types = sorted(dataclass.available_atom_types) # retrieve available atom types
557 | all_types = types + [types[-1] + 1] # add stop token to list (largest type + 1)
558 | start_token = types[-1] + 2 # define start token (largest type + 2)
559 | amount = args.amount_gen
560 | chunk_size = args.chunk_size
561 | if chunk_size >= amount:
562 | chunk_size = amount
563 |
564 | # set parameters for printing progress
565 | if int(amount / 10.) < chunk_size:
566 | step = chunk_size
567 | else:
568 | step = int(amount / 10.)
569 | increase = lambda x, y: y + step if x >= y else y
570 | thresh = step
571 | if args.print_file:
572 | progress = lambda x, y: print(f'Generated {x}.', flush=True) \
573 | if x >= y else print('', end='', flush=True)
574 | else:
575 | progress = lambda x, y: print(f'\x1b[2K\rSuccessfully generated'
576 | f' {x}', end='', flush=True)
577 |
578 | # extract conditioning information
579 | conditioning = {}
580 | if args.conditioning is not None:
581 | conds = args.conditioning.split('; ')
582 | for cond_list in conds:
583 | cond_list = cond_list.split(' ')
584 | cond_name = cond_list[0]
585 | if cond_name not in conditioning_layer_list:
586 | logging.info(f'The provided model was not trained to condition on '
587 | f'{cond_name}! The condition will be ignored during '
588 | f'generation.')
589 | continue
590 | cond_vals = cond_list[1:]
591 | if cond_name == 'fingerprint':
592 | if not os.path.isfile('./data/qm9gen.db'):
593 | logging.error(f'could not find database with fingerprints at ./data/qm9gen.db!')
594 | logging.error(f'stopping generation!')
595 | return
596 | with connect('./data/qm9gen.db') as conn:
597 | if cond_vals[0].isdigit():
598 | if 'fingerprint_format' not in conn.metadata:
599 | logging.error(f'fingerprints not found in database!')
600 | logging.error(f'please re-download data when training a model with fingerprints as condions!')
601 | logging.error(f'stopping generation!')
602 | return
603 | fp = np.array(conn.get(int(cond_vals[0])+1).data['fingerprint'],
604 | dtype=conn.metadata['fingerprint_format'])
605 | else:
606 | import pybel
607 | fp = \
608 | np.array(pybel.readstring('smi', cond_vals[0]).calcfp().fp,
609 | dtype=conn.metadata['fingerprint_format'])
610 | from collections import Counter
611 | _ct = Counter(cond_vals[0].lower())
612 | conditioning['_composition'] = \
613 | np.array([_ct['c'], _ct['n'], _ct['o'], _ct['f']])
614 | conditioning['_fingerprint'] = \
615 | np.unpackbits(fp.view(np.uint8), bitorder='little')[None, ...]
616 | else:
617 | if not cond_vals[0].isdigit():
618 | conditioning['_' + cond_name] = np.array([cond_vals], dtype=float)
619 | else:
620 | conditioning['_' + cond_name] = np.array([cond_vals], dtype=int)
621 | cond_mask = np.ones((1, len(conditioning_layer_list)))
622 | for i, condition_layer in enumerate(conditioning_layer_list):
623 | # remove not provided conditional information from mask and provide dummy input
624 | if f'_{condition_layer}' not in conditioning:
625 | conditioning[f'_{condition_layer}'] = np.array([[0]], dtype=float)
626 | cond_mask[0, i] = 0
627 | conditioning[f'_cond_mask'] = cond_mask
628 |
629 | # generate
630 | generated = {}
631 | left = args.amount_gen
632 | done = 0
633 | start_time = time.time()
634 | with torch.no_grad():
635 | while left > 0:
636 | if left - chunk_size < 0:
637 | batch = left
638 | else:
639 | batch = chunk_size
640 | update_dict(generated,
641 | generate_molecules(
642 | batch,
643 | model,
644 | all_types=all_types,
645 | start_token=start_token,
646 | max_length=args.max_length,
647 | save_unfinished=args.store_unfinished,
648 | device=device,
649 | max_dist=train_args.max_distance,
650 | n_bins=train_args.num_distance_bins,
651 | radial_limits=dataclass.radial_limits,
652 | t=args.temperature,
653 | conditioning=conditioning,
654 | store_process=args.store_process
655 | )
656 | )
657 | left -= batch
658 | done += batch
659 | n = np.sum(get_dict_count(generated, args.max_length))
660 | progress(n, thresh)
661 | thresh = increase(n, thresh)
662 | print('')
663 | end_time = time.time() - start_time
664 | m, s = divmod(end_time, 60)
665 | h, m = divmod(m, 60)
666 | h, m, s = int(h), int(m), int(s)
667 | print(f'Time consumed: {h:d}:{m:02d}:{s:02d}')
668 |
669 | # sort keys in resulting dictionary
670 | generated = dict(sorted(generated.items()))
671 |
672 | # show generated molecules and print some statistics if desired
673 | if args.show_gen:
674 | ats = []
675 | n_total_atoms = 0
676 | n_molecules = 0
677 | for key in generated:
678 | n = 0
679 | for i in range(len(generated[key][Properties.Z])):
680 | at = Atoms(generated[key][Properties.Z][i],
681 | positions=generated[key][Properties.R][i])
682 | ats += [at]
683 | n += 1
684 | n_molecules += 1
685 | n_total_atoms += n * key
686 | asv.view(ats)
687 | print(f'Total number of atoms placed: {n_total_atoms} '
688 | f'(avg {n_total_atoms / n_molecules:.2f})', flush=True)
689 |
690 | return generated
691 |
692 |
693 | def prepare_conditioning(conditioning_specification):
694 | if conditioning_specification is None:
695 | return [], {}, []
696 | load_additionally = []
697 | layer_list = []
698 | conditioning_extractors = {}
699 | for block in conditioning_specification:
700 | layers = conditioning_specification[block]['layers']
701 | replace_dict_argument_recursively('args.features', args.features, layers)
702 | for layer_name in layers:
703 | layer_list += [layer_name]
704 | if 'p' in layers[layer_name]:
705 | p = layers[layer_name]['p']
706 | else:
707 | p = 1.0
708 | if layer_name == 'composition':
709 | all_types = [6, 7, 8, 9]
710 | if 'skip_h' in layers['composition']:
711 | if not layers['composition']['skip_h']:
712 | all_types = [1] + all_types
713 | conditioning_extractors.update(
714 | {'_composition': lambda x, prob=p:
715 | (torch.FloatTensor(get_composition(x, all_types)), prob)}
716 | )
717 | else:
718 | load_additionally += [layer_name]
719 | conditioning_extractors.update(
720 | {'_' + layer_name:
721 | lambda x, prob=p, n=layer_name: (x.pop(n), prob)})
722 |
723 | return load_additionally, conditioning_extractors, layer_list
724 |
725 |
726 | def get_conditioning_nn(name, dataclass):
727 | layers = None
728 | if name == 'fingerprint':
729 | layers = FingerprintEmbedding
730 | elif name == 'composition':
731 | layers = lambda embedding, **kwargs: \
732 | AtomCompositionEmbedding(nn.Embedding(**embedding), **kwargs)
733 | elif name in dataclass.properties:
734 | layers = PropertyEmbedding
735 | return layers
736 |
737 |
738 | def replace_dict_argument_recursively(argument, value, d):
739 | for key in d:
740 | if isinstance(d[key], dict):
741 | replace_dict_argument_recursively(argument, value, d[key])
742 | elif d[key] == argument:
743 | d[key] = value
744 |
745 |
746 | def main(args):
747 | # set device (cpu or gpu)
748 | device = torch.device('cuda' if args.cuda else 'cpu')
749 |
750 | # store (or load) arguments
751 | argparse_dict = vars(args)
752 | jsonpath = os.path.join(args.modelpath, 'args.json')
753 |
754 | if args.mode == 'train':
755 | # overwrite existing model if desired
756 | if args.overwrite and os.path.exists(args.modelpath):
757 | rmtree(args.modelpath)
758 | logging.info('existing model will be overwritten...')
759 |
760 | # create model directory if it does not exist
761 | if not os.path.exists(args.modelpath):
762 | os.makedirs(args.modelpath)
763 |
764 | # get latest checkpoint of pre-trained model if a path was provided
765 | if args.pretrained_path is not None:
766 | model_chkpt_path = os.path.join(args.modelpath, 'checkpoints')
767 | pretrained_chkpt_path = os.path.join(args.pretrained_path, 'checkpoints')
768 | if os.path.exists(model_chkpt_path) \
769 | and len(os.listdir(model_chkpt_path)) > 0:
770 | logging.info(f'found existing checkpoints in model directory '
771 | f'({model_chkpt_path}), please use --overwrite or choose '
772 | f'empty model directory to start from a pre-trained '
773 | f'model...')
774 | logging.warning(f'will ignore pre-trained model and start from latest '
775 | f'checkpoint at {model_chkpt_path}...')
776 | args.pretrained_path = None
777 | else:
778 | logging.info(f'fetching latest checkpoint from pre-trained model at '
779 | f'{pretrained_chkpt_path}...')
780 | if not os.path.exists(pretrained_chkpt_path):
781 | logging.warning(f'did not find checkpoints of pre-trained model, '
782 | f'will train from scratch...')
783 | args.pretrained_path = None
784 | else:
785 | chkpt_files = [f for f in os.listdir(pretrained_chkpt_path)
786 | if f.startswith("checkpoint")]
787 | if len(chkpt_files) == 0:
788 | logging.warning(f'did not find checkpoints of pre-trained '
789 | f'model, will train from scratch...')
790 | args.pretrained_path = None
791 | else:
792 | epoch = max([int(f.split(".")[0].split("-")[-1])
793 | for f in chkpt_files])
794 | chkpt = os.path.join(pretrained_chkpt_path,
795 | "checkpoint-" + str(epoch) + ".pth.tar")
796 | if not os.path.exists(model_chkpt_path):
797 | os.makedirs(model_chkpt_path)
798 | copyfile(chkpt, os.path.join(model_chkpt_path,
799 | f'checkpoint-{epoch}.pth.tar'))
800 |
801 | # store arguments for training in model directory
802 | to_json(jsonpath, argparse_dict)
803 | train_args = args
804 |
805 | # set seed
806 | spk.utils.set_random_seed(args.seed)
807 | else:
808 | # load arguments used for training from model directory
809 | train_args = read_from_json(jsonpath)
810 |
811 | # read in conditioning layers from file and set arguments accordingly
812 | conditioning_specification = None
813 | conditioning_json_path = None
814 | conditioning_specification_path = \
815 | os.path.join(args.modelpath, 'conditioning_specification.json')
816 | # look for existing specification in model directory
817 | if os.path.isfile(conditioning_specification_path):
818 | conditioning_json_path = conditioning_specification_path
819 | # else look for specification at path provided with the training arguments
820 | elif args.mode == 'train' and args.conditioning_json_path is not None:
821 | if os.path.isfile(args.conditioning_json_path):
822 | conditioning_json_path = args.conditioning_json_path
823 | else:
824 | logging.error(f'The provided conditioning specification file '
825 | f'({args.conditioning_json_path}) does not exist!')
826 | raise FileNotFoundError
827 | # read file
828 | if conditioning_json_path is not None:
829 | with open(conditioning_json_path) as handle:
830 | conditioning_specification = json.loads(handle.read())
831 | # process information
832 | load_additionally, conditioning_extractors, cond_layer_list = \
833 | prepare_conditioning(conditioning_specification)
834 | # store specification of conditioning layers in model folder
835 | if conditioning_specification is not None and \
836 | conditioning_json_path != conditioning_specification_path:
837 | to_json(conditioning_specification_path, conditioning_specification)
838 |
839 | # load data for training/evaluation
840 | if args.mode in ['train', 'eval']:
841 | # find correct data class
842 | assert train_args.dataset_name in dataset_name_to_class_mapping, \
843 | f'Could not find data class for dataset {train_args.dataset}. Please ' \
844 | f'specify a correct dataset name!'
845 | dataclass = dataset_name_to_class_mapping[train_args.dataset_name]
846 |
847 | # load the dataset
848 | logging.info(f'{train_args.dataset_name} will be loaded...')
849 | subset = None
850 | if train_args.subset_path is not None:
851 | logging.info(f'Using subset from {train_args.subset_path}')
852 | subset = np.load(train_args.subset_path)
853 | subset = [int(i) for i in subset]
854 | if issubclass(dataclass, DownloadableAtomsData):
855 | data = dataclass(args.datapath,
856 | subset=subset,
857 | precompute_distances=args.precompute_distances,
858 | download=True if args.mode == 'train' else False,
859 | load_additionally=load_additionally)
860 | else:
861 | data = dataclass(args.datapath,
862 | subset=subset,
863 | precompute_distances=args.precompute_distances,
864 | load_additionally=load_additionally)
865 |
866 | # splits the dataset in test, val, train sets
867 | split_path = os.path.join(args.modelpath, 'split.npz')
868 | if args.mode == 'train':
869 | if args.split_path is not None:
870 | copyfile(args.split_path, split_path)
871 |
872 | logging.info('create splits...')
873 | data_train, data_val, data_test = data.create_splits(*train_args.split,
874 | split_file=split_path)
875 |
876 | logging.info('load data...')
877 | types = sorted(dataclass.available_atom_types)
878 | max_type = types[-1]
879 | # set up collate function according to args
880 | collate = lambda x: \
881 | collate_atoms(x,
882 | all_types=types + [max_type+1],
883 | start_token=max_type+2,
884 | draw_samples=args.draw_random_samples,
885 | label_width_scaling=train_args.label_width_factor,
886 | max_dist=train_args.max_distance,
887 | n_bins=train_args.num_distance_bins,
888 | conditioning_extractors=conditioning_extractors)
889 |
890 | train_loader = spk.data.AtomsLoader(data_train, batch_size=args.batch_size,
891 | sampler=RandomSampler(data_train),
892 | num_workers=4, pin_memory=True,
893 | collate_fn=collate)
894 | val_loader = spk.data.AtomsLoader(data_val, batch_size=args.batch_size,
895 | num_workers=2, pin_memory=True,
896 | collate_fn=collate)
897 |
898 | # construct the model
899 | if args.mode == 'train' or args.checkpoint >= 0:
900 | model = get_model(train_args, conditioning_specification,
901 | parallelize=args.parallel).to(device)
902 | logging.info(f'running on {device}')
903 |
904 | # load model or checkpoint for evaluation or generation
905 | if args.mode in ['eval', 'generate']:
906 | if args.checkpoint < 0: # load best model
907 | logging.info(f'restoring best model')
908 | model = torch.load(os.path.join(args.modelpath, 'best_model')).to(device)
909 | else:
910 | logging.info(f'restoring checkpoint {args.checkpoint}')
911 | chkpt = os.path.join(args.modelpath, 'checkpoints',
912 | 'checkpoint-' + str(args.checkpoint) + '.pth.tar')
913 | state_dict = torch.load(chkpt)
914 | model.load_state_dict(state_dict['model'], strict=True)
915 |
916 | # execute training, evaluation, or generation
917 | if args.mode == 'train':
918 | logging.info("training...")
919 | train(args, model, train_loader, val_loader, device)
920 | logging.info("...training done!")
921 |
922 | elif args.mode == 'eval':
923 | logging.info("evaluating...")
924 | test_loader = spk.data.AtomsLoader(data_test,
925 | batch_size=args.batch_size,
926 | num_workers=2,
927 | pin_memory=True,
928 | collate_fn=collate)
929 | with torch.no_grad():
930 | evaluate(args, model, train_loader, val_loader, test_loader, device)
931 | logging.info("... done!")
932 |
933 | elif args.mode == 'generate':
934 | logging.info(f'generating {args.amount_gen} molecules...')
935 | generated = generate(args, train_args, model, device, cond_layer_list)
936 | gen_path = os.path.join(args.modelpath, f'generated{args.folder_name}/')
937 | if not os.path.exists(gen_path):
938 | os.makedirs(gen_path)
939 | # get untaken filename and store results
940 | file_name = os.path.join(gen_path, args.file_name)
941 | if os.path.isfile(file_name + '.mol_dict'):
942 | expand = 0
943 | while True:
944 | expand += 1
945 | new_file_name = file_name + '_' + str(expand)
946 | if os.path.isfile(new_file_name + '.mol_dict'):
947 | continue
948 | else:
949 | file_name = new_file_name
950 | break
951 | with open(file_name + '.mol_dict', 'wb') as f:
952 | pickle.dump(generated, f)
953 | logging.info('...done!')
954 | else:
955 | logging.info(f'Unknown mode: {args.mode}')
956 |
957 |
958 | if __name__ == '__main__':
959 | parser = get_parser()
960 | args = parser.parse_args()
961 | main(args)
962 |
--------------------------------------------------------------------------------
/images/concept_results_scheme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/images/concept_results_scheme.png
--------------------------------------------------------------------------------
/nn_classes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 | from collections import Iterable
6 |
7 | import schnetpack as spk
8 | from schnetpack.nn import MLP
9 | from schnetpack.metrics import Metric
10 |
11 |
12 | ### OUTPUT MODULE ###
13 | class AtomwiseWithProcessing(nn.Module):
14 | r"""
15 | Atom-wise dense layers that allow to use additional pre- and post-processing layers.
16 |
17 | Args:
18 | n_in (int, optional): input dimension of representation (default: 128)
19 | n_out (int, optional): output dimension (default: 1)
20 | n_layers (int, optional): number of atom-wise dense layers in output network
21 | (default: 5)
22 | n_neurons (list of int or int or None, optional): number of neurons in each
23 | layer of the output network. If a single int is provided, all layers will
24 | have that number of neurons, if `None`, interpolate linearly between n_in
25 | and n_out (default: None).
26 | activation (function, optional): activation function for hidden layers
27 | (default: spk.nn.activations.shifted_softplus).
28 | preprocess_layers (nn.Module, optional): a torch.nn.Module or list of Modules
29 | for preprocessing the representation given by the first part of the network
30 | (default: None).
31 | postprocess_layers (nn.Module, optional): a torch.nn.Module or list of Modules
32 | for postprocessing the output given by the second part of the network
33 | (default: None).
34 | in_key (str, optional): keyword to access the representation in the inputs
35 | dictionary, it is automatically inferred from the preprocessing layers, if
36 | at least one is given (default: 'representation').
37 | out_key (str, optional): a string as key to the output dictionary (if set to
38 | 'None', the output will not be wrapped into a dictionary, default: 'y')
39 |
40 | Returns:
41 | result: dictionary with predictions stored in result[out_key]
42 | """
43 |
44 | def __init__(self, n_in=128, n_out=1, n_layers=5, n_neurons=None,
45 | activation=spk.nn.activations.shifted_softplus,
46 | preprocess_layers=None, postprocess_layers=None,
47 | in_key='representation', out_key='y'):
48 |
49 | super(AtomwiseWithProcessing, self).__init__()
50 |
51 | self.n_in = n_in
52 | self.n_out = n_out
53 | self.n_layers = n_layers
54 | self.in_key = in_key
55 | self.out_key = out_key
56 |
57 | if isinstance(preprocess_layers, Iterable):
58 | self.preprocess_layers = nn.ModuleList(preprocess_layers)
59 | self.in_key = self.preprocess_layers[-1].out_key
60 | elif preprocess_layers is not None:
61 | self.preprocess_layers = preprocess_layers
62 | self.in_key = self.preprocess_layers.out_key
63 | else:
64 | self.preprocess_layers = None
65 |
66 | if isinstance(postprocess_layers, Iterable):
67 | self.postprocess_layers = nn.ModuleList(postprocess_layers)
68 | else:
69 | self.postprocess_layers = postprocess_layers
70 |
71 | if n_neurons is None:
72 | # linearly interpolate between n_in and n_out
73 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1])
74 | self.out_net = MLP(n_in, n_out, n_neurons, n_layers, activation)
75 |
76 | self.derivative = None # don't compute derivative w.r.t. inputs
77 |
78 | def forward(self, inputs):
79 | """
80 | Compute layer output and apply pre-/postprocessing if specified.
81 |
82 | Args:
83 | inputs (dict of torch.Tensor): batch of input values.
84 | Returns:
85 | torch.Tensor: layer output.
86 | """
87 | # apply pre-processing layers
88 | if self.preprocess_layers is not None:
89 | if isinstance(self.preprocess_layers, Iterable):
90 | for pre_layer in self.preprocess_layers:
91 | inputs = pre_layer(inputs)
92 | else:
93 | inputs = self.preprocess_layers(inputs)
94 |
95 | # get (pre-processed) representation
96 | if isinstance(inputs[self.in_key], tuple):
97 | repr = inputs[self.in_key][0]
98 | else:
99 | repr = inputs[self.in_key]
100 |
101 | # apply output network
102 | result = self.out_net(repr)
103 |
104 | # apply post-processing layers
105 | if self.postprocess_layers is not None:
106 | if isinstance(self.postprocess_layers, Iterable):
107 | for post_layer in self.postprocess_layers:
108 | result = post_layer(inputs, result)
109 | else:
110 | result = self.postprocess_layers(inputs, result)
111 |
112 | # use provided key to store result
113 | if self.out_key is not None:
114 | result = {self.out_key: result}
115 |
116 | return result
117 |
118 |
119 | class RepresentationConditioning(nn.Module):
120 | r"""
121 | Layer that allows to alter the extracted feature representations in order to
122 | condition generation. Takes multiple networks that provide conditioning
123 | information as vectors, stacks these vectors and processes them in a fully
124 | connected MLP to get a global conditioning vector that is incorporated into
125 | the extracted feature representation.
126 |
127 | Args:
128 | layers (nn.Module): a torch.nn.Module or list of Modules that each provide a
129 | vector representing information for conditioning.
130 | mode (str, optional): how to incorporate the global conditioning vector in
131 | the extracted feature representation (can either be 'multiplication',
132 | 'addition', or 'stack', default: 'stack').
133 | n_global_cond_features (int, optional): number of features in the global
134 | conditioning vector (i.e. output dimension for the MLP used to aggregate
135 | the stacked separate conditioning vectors).
136 | n_layers (int, optional): number of dense layers in the MLP used to get the
137 | global conditioning vector (default: 5).
138 | n_neurons (list of int or int or None, optional): number of neurons in each
139 | layer of the MLP. If a single int is provided, all layers will have that
140 | number of neurons, if `None`, interpolate linearly between n_in and n_out
141 | (default: None).
142 | activation (function, optional): activation function for hidden layers in the
143 | aggregation MLP (default: spk.nn.activations.shifted_softplus).
144 | in_key (str, optional): keyword to access the representation in the inputs
145 | dictionary, it is automatically inferred from the preprocessing layers, if
146 | at least one is given (default: 'representation').
147 | out_key (str, optional): a string as key to the output dictionary (if set to
148 | 'None', the output will not be wrapped into a dictionary, default:
149 | 'representation')
150 |
151 | Returns:
152 | result: dictionary with predictions stored in result[out_key]
153 | """
154 |
155 | def __init__(self,
156 | layers,
157 | mode='stack',
158 | n_global_cond_features=128,
159 | n_layers=5,
160 | n_neurons=None,
161 | activation=spk.nn.activations.shifted_softplus,
162 | in_key='representation',
163 | out_key='representation'):
164 |
165 | super(RepresentationConditioning, self).__init__()
166 |
167 | if type(layers) not in [list, nn.ModuleList]:
168 | layers = [layers]
169 | if type(layers) == list:
170 | layers = nn.ModuleList(layers)
171 | self.layers = layers
172 | self.mode = mode
173 | self.in_key = in_key
174 | self.out_key = out_key
175 | self.n_global_cond_features = n_global_cond_features
176 |
177 | self.derivative = None # don't compute derivative w.r.t. inputs
178 |
179 | # set number of additional features
180 | self.n_additional_features = 0
181 | if self.mode == 'stack':
182 | self.n_additional_features = self.n_global_cond_features
183 |
184 | # compute number of inputs to the MLP processing stacked conditioning vectors
185 | n_in = 0
186 | for layer in self.layers:
187 | n_in += layer.n_out
188 | n_out = n_global_cond_features
189 |
190 | # initialize MLP processing stacked conditioning vectors
191 | if n_neurons is None:
192 | # linearly interpolate between n_in and n_out
193 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1])
194 | self.cond_mlp = MLP(n_in, n_out, n_neurons, n_layers, activation)
195 |
196 | def forward(self, inputs):
197 | """
198 | Update representation in the inputs according to conditioning information and
199 | return empty dictionary since no proper network output is computed in this
200 | module.
201 |
202 | Args:
203 | inputs (dict of torch.Tensor): batch of input values.
204 | Returns:
205 | dict: An empty dictionary.
206 | """
207 | # get (pre-processed) representation
208 | if isinstance(inputs[self.in_key], tuple):
209 | repr = inputs[self.in_key][0]
210 | else:
211 | repr = inputs[self.in_key]
212 |
213 | # get mask that (potentially) hides conditional information
214 | _size = [1, len(self.layers)] + [1 for _ in repr.size()[1:]]
215 | if '_cond_mask' in inputs:
216 | cond_mask = inputs['_cond_mask']
217 | cond_mask = cond_mask.reshape([cond_mask.shape[0]] + _size[1:])
218 | else:
219 | cond_mask = torch.ones(_size, dtype=repr.dtype, device=repr.device)
220 |
221 | # get conditioning information vectors from layers and include them in
222 | # representation
223 | cond_vecs = []
224 | for i, layer in enumerate(self.layers):
225 | cond_vecs += [cond_mask[:, i] * layer(inputs)]
226 |
227 | cond_vecs = torch.cat(cond_vecs, dim=-1)
228 | final_cond_vec = self.cond_mlp(cond_vecs)
229 |
230 | if self.mode == 'addition':
231 | repr = repr + final_cond_vec
232 | elif self.mode == 'multiplication':
233 | repr = repr * final_cond_vec
234 | elif self.mode == 'stack':
235 | repr = torch.cat([repr, final_cond_vec.expand(*repr.size()[:-1], -1)], -1)
236 |
237 | inputs.update({self.out_key: repr})
238 |
239 | return {}
240 |
241 |
242 | ### METRICS ###
243 | class KLDivergence(Metric):
244 | r"""
245 | Metric for mean KL-Divergence.
246 |
247 | Args:
248 | target (str, optional): name of target property (default: '_labels')
249 | model_output (list of int or list of str, optional): indices or keys to unpack
250 | the desired output from the model in case of multiple outputs, e.g.
251 | ['x', 'y'] to get output['x']['y'] (default: 'y').
252 | name (str, optional): name used in logging for this metric. If set to `None`,
253 | `KLD_[target]` will be used (default: None).
254 | mask (str, optional): key for a mask in the examined batch which hides
255 | irrelevant output values. If 'None' is provided, no mask will be applied
256 | (default: None).
257 | inverse_mask (bool, optional): whether the mask needs to be inverted prior to
258 | application (default: False).
259 | """
260 |
261 | def __init__(self, target='_labels', model_output='y', name=None,
262 | mask=None, inverse_mask=False):
263 | name = 'KLD_' + target if name is None else name
264 | super(KLDivergence, self).__init__(name)
265 | self.target = target
266 | self.model_output = model_output
267 | self.loss = 0.
268 | self.n_entries = 0.
269 | self.mask_str = mask
270 | self.inverse_mask = inverse_mask
271 |
272 | def reset(self):
273 | self.loss = 0.
274 | self.n_entries = 0.
275 |
276 | def add_batch(self, batch, result):
277 | # extract true labels
278 | y = batch[self.target]
279 |
280 | # extract predictions
281 | yp = result
282 | if self.model_output is not None:
283 | if isinstance(self.model_output, list):
284 | for key in self.model_output:
285 | yp = yp[key]
286 | else:
287 | yp = yp[self.model_output]
288 |
289 | # normalize output
290 | log_yp = F.log_softmax(yp, -1)
291 |
292 | # apply KL divergence formula entry-wise
293 | loss = F.kl_div(log_yp, y, reduction='none')
294 |
295 | # sum over last dimension to get KL divergence per distribution
296 | loss = torch.sum(loss, -1)
297 |
298 | # apply mask to filter padded dimensions
299 | if self.mask_str is not None:
300 | atom_mask = batch[self.mask_str]
301 | if self.inverse_mask:
302 | atom_mask = 1.-atom_mask
303 | loss = torch.where(atom_mask > 0, loss, torch.zeros_like(loss))
304 | n_entries = torch.sum(atom_mask > 0)
305 | else:
306 | n_entries = torch.prod(torch.tensor(loss.size()))
307 |
308 | # calculate loss and n_entries
309 | self.n_entries += n_entries.detach().cpu().data.numpy()
310 | self.loss += torch.sum(loss).detach().cpu().data.numpy()
311 |
312 | def aggregate(self):
313 | return self.loss / max(self.n_entries, 1.)
314 |
315 |
316 | ### PRE- AND POST-PROCESSING LAYERS ###
317 | class EmbeddingMultiplication(nn.Module):
318 | r"""
319 | Layer that multiplies embeddings of given types with the representation.
320 |
321 | Args:
322 | embedding (torch.nn.Embedding instance): the embedding layer used to embed atom
323 | types.
324 | in_key_types (str, optional): the keyword to obtain types for embedding from
325 | inputs.
326 | in_key_representation (str, optional): the keyword to obtain the representation
327 | from inputs.
328 | out_key (str, optional): the keyword used to store the calculated product in
329 | the inputs dictionary.
330 | """
331 |
332 | def __init__(self, embedding, in_key_types='_next_types',
333 | in_key_representation='representation',
334 | out_key='preprocessed_representation'):
335 | super(EmbeddingMultiplication, self).__init__()
336 | self.embedding = embedding
337 | self.in_key_types = in_key_types
338 | self.in_key_representation = in_key_representation
339 | self.out_key = out_key
340 |
341 | def forward(self, inputs):
342 | """
343 | Compute layer output.
344 |
345 | Args:
346 | inputs (dict of torch.Tensor): batch of input values containing the atomic
347 | numbers for embedding as well as the representation.
348 | Returns:
349 | torch.Tensor: layer output.
350 | """
351 | # get types to embed from inputs
352 | types = inputs[self.in_key_types]
353 | st = types.size()
354 |
355 | # embed types
356 | if len(st) == 1:
357 | emb = self.embedding(types.view(st[0], 1))
358 | elif len(st) == 2:
359 | emb = self.embedding(types.view(*st[:-1], 1, st[-1]))
360 |
361 | # get representation
362 | if isinstance(inputs[self.in_key_representation], tuple):
363 | repr = inputs[self.in_key_representation][0]
364 | else:
365 | repr = inputs[self.in_key_representation]
366 | if len(st) == 2:
367 | # if multiple types are provided per molecule, expand
368 | # dimensionality of representation
369 | repr = repr.view(*repr.size()[:-1], 1, repr.size()[-1])
370 |
371 | # if representation is larger than the embedding, pad embedding with ones
372 | if repr.size()[-1] != emb.size()[-1]:
373 | _emb = torch.ones([*emb.size()[:-1], repr.size()[-1]], device=emb.device)
374 | _emb[..., :emb.size()[-1]] = emb
375 | emb = _emb
376 |
377 | # multiply embedded types with representation
378 | features = repr * emb
379 |
380 | # store result in input dictionary
381 | inputs.update({self.out_key: features})
382 |
383 | return inputs
384 |
385 |
386 | class NormalizeAndAggregate(nn.Module):
387 | r"""
388 | Layer that normalizes and aggregates given input along specifiable axes.
389 |
390 | Args:
391 | normalize (bool, optional): set True to normalize the input (default: True).
392 | normalization_axis (int, optional): axis along which normalization is applied
393 | (default: -1).
394 | normalization_mode (str, optional): which normalization to apply (currently
395 | only 'logsoftmax' is supported, default: 'logsoftmax').
396 | aggregate (bool, optional): set True to aggregate the input (default: True).
397 | aggregation_axis (int, optional): axis along which aggregation is applied
398 | (default: -1).
399 | aggregation_mode (str, optional): which aggregation to apply (currently 'sum'
400 | and 'mean' are supported, default: 'sum').
401 | keepdim (bool, optional): set True to keep the number of dimensions after
402 | aggregation (default: True).
403 | in_key_mask (str, optional): key to extract a mask from the inputs dictionary,
404 | which hides values during aggregation (default: None).
405 | squeeze (bool, optional): whether to squeeze the input before applying
406 | normalization (default: False).
407 |
408 | Returns:
409 | torch.Tensor: input after normalization and aggregation along specified axes.
410 | """
411 |
412 | def __init__(self, normalize=True, normalization_axis=-1,
413 | normalization_mode='logsoftmax', aggregate=True,
414 | aggregation_axis=-1, aggregation_mode='sum', keepdim=True,
415 | mask=None, squeeze=False):
416 |
417 | super(NormalizeAndAggregate, self).__init__()
418 |
419 | if normalize:
420 | if normalization_mode.lower() == 'logsoftmax':
421 | self.normalization = nn.LogSoftmax(normalization_axis)
422 | else:
423 | self.normalization = None
424 |
425 | if aggregate:
426 | if aggregation_mode.lower() == 'sum':
427 | self.aggregation =\
428 | spk.nn.base.Aggregate(aggregation_axis, mean=False,
429 | keepdim=keepdim)
430 | elif aggregation_mode.lower() == 'mean':
431 | self.aggregation =\
432 | spk.nn.base.Aggregate(aggregation_axis, mean=True,
433 | keepdim=keepdim)
434 | else:
435 | self.aggregation = None
436 |
437 | self.mask = mask
438 | self.squeeze = squeeze
439 |
440 | def forward(self, inputs, result):
441 | """
442 | Compute layer output.
443 |
444 | Args:
445 | inputs (dict of torch.Tensor): batch of input values containing the mask
446 | result (torch.Tensor): batch of result values to which normalization and
447 | aggregation is applied
448 | Returns:
449 | torch.Tensor: normalized and aggregated result.
450 | """
451 |
452 | res = result
453 |
454 | if self.squeeze:
455 | res = torch.squeeze(res)
456 |
457 | if self.normalization is not None:
458 | res = self.normalization(res)
459 |
460 | if self.aggregation is not None:
461 | if self.mask is not None:
462 | mask = inputs[self.mask]
463 | else:
464 | mask = None
465 | res = self.aggregation(res, mask)
466 |
467 | return res
468 |
469 |
470 | class AtomCompositionEmbedding(nn.Module):
471 | r"""
472 | Layer that embeds all atom types in a molecule and aggregates them into a single
473 | representation of the composition using a fully connected MLP.
474 |
475 | Args:
476 | embedding (torch.nn.Embedding instance): an embedding layer used to embed atom
477 | types separately.
478 | n_out (int, optional): number of features in the final, global embedding (i.e.
479 | output dimension for the MLP used to aggregate the separate, stacked atom
480 | type embeddings).
481 | n_layers (int, optional): number of dense layers used to get the global
482 | embedding (default: 5).
483 | n_neurons (list of int or int or None, optional): number of neurons in each
484 | layer of the aggregation MLP. If a single int is provided, all layers will
485 | have that number of neurons, if `None`, interpolate linearly between n_in
486 | and n_out (default: None).
487 | activation (function, optional): activation function for hidden layers in the
488 | aggregation MLP (default: spk.nn.activations.shifted_softplus).
489 | type_weighting (str, optional): how to weight the individual atom type
490 | embeddings (choose from 'absolute' to multiply each embedding with the
491 | absolute number of atoms of that type, 'relative' to multiply with the
492 | fraction of atoms of that type, and 'existence' to multiply with one if the
493 | type is present in the composition and zero otherwise, default: 'absolute')
494 | in_key_composition (str, optional): the keyword to obtain the global
495 | composition of molecules (i.e. a list of all atom types, default:
496 | 'composition').
497 | n_types (int, optional): total number of available atom types (default: 5).
498 | """
499 |
500 | def __init__(self,
501 | embedding,
502 | n_out=128,
503 | n_layers=5,
504 | n_neurons=None,
505 | activation=spk.nn.activations.shifted_softplus,
506 | type_weighting='exact',
507 | in_key_composition='composition',
508 | n_types=5,
509 | skip_h=True):
510 |
511 | super(AtomCompositionEmbedding, self).__init__()
512 |
513 | self.embedding = embedding
514 | self.in_key_composition = in_key_composition
515 | self.type_weighting = type_weighting
516 | self.n_types = n_types
517 | self.skip_h = skip_h
518 | if self.skip_h:
519 | self.n_types -= 1
520 | self.n_out = n_out
521 |
522 | # compute number of features in stacked embeddings
523 | n_in = self.n_types * self.embedding.embedding_dim
524 |
525 | if n_neurons is None:
526 | # linearly interpolate between n_in and n_out
527 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1])
528 | self.aggregation_mlp = MLP(n_in, n_out, n_neurons, n_layers, activation)
529 |
530 | def forward(self, inputs):
531 | """
532 | Compute layer output.
533 |
534 | Args:
535 | inputs (dict of torch.Tensor): batch of input values containing the atomic
536 | numbers for embedding as well as the representation.
537 | Returns:
538 | torch.Tensor: batch of vectors representing the global composition of
539 | each molecule.
540 | """
541 | # get composition to embed from inputs
542 | compositions = inputs[self.in_key_composition][..., None]
543 | if self.skip_h:
544 | embeded_types = self.embedding(inputs['_all_types'][0, 1:-1])[None, ...]
545 | else:
546 | embeded_types = self.embedding(inputs['_all_types'][0, :-1])[None, ...]
547 |
548 | # get global representation
549 | if self.type_weighting == 'relative':
550 | compositions = compositions/torch.sum(compositions, dim=-2, keepdim=True)
551 | elif self.type_weighting == 'existence':
552 | compositions = (compositions > 0).float()
553 |
554 | # multiply embedding with (weighted) composition
555 | embedding = embeded_types * compositions
556 |
557 | # aggregate embeddings to global representation
558 | sizes = embedding.size()
559 | embedding = embedding.view([*sizes[:-2], 1, sizes[-2]*sizes[-1]]) # stack
560 | embedding = self.aggregation_mlp(embedding) # aggregate
561 |
562 | return embedding
563 |
564 |
565 | class FingerprintEmbedding(nn.Module):
566 | r"""
567 | Layers that map the fingerprint of a molecule to a feature vector used for
568 | conditioning.
569 |
570 | Args:
571 | n_in (int): number of inputs (bits in the fingerprint).
572 | n_out (str): number of features in the embedding.
573 | n_layers (int, optional): number of dense layers used to embed the fingerprint
574 | (default: 5).
575 | n_neurons (list of int or int or None, optional): number of neurons in each
576 | layer of the output network. If a single int is provided, all layers will
577 | have that number of neurons, if `None`, interpolate linearly between n_in
578 | and n_out (default: None).
579 | in_key_fingerprint (str, optional): the keyword to obtain the fingerprint
580 | (default: 'fingerprint').
581 | activation (function, optional): activation function for hidden layers
582 | (default: spk.nn.activations.shifted_softplus).
583 | """
584 |
585 | def __init__(self, n_in, n_out, n_layers=5, n_neurons=None,
586 | in_key_fingerprint='fingerprint',
587 | activation=spk.nn.activations.shifted_softplus):
588 |
589 | super(FingerprintEmbedding, self).__init__()
590 |
591 | self.in_key_fingerprint = in_key_fingerprint
592 | self.n_in = n_in
593 | self.n_out = n_out
594 |
595 | if n_neurons is None:
596 | # linearly interpolate between n_in and n_out
597 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1])
598 | self.out_net = MLP(n_in, n_out, n_neurons, n_layers, activation)
599 |
600 | def forward(self, inputs):
601 | """
602 | Compute layer output.
603 |
604 | Args:
605 | inputs (dict of torch.Tensor): batch of input values containing the
606 | fingerprints.
607 | Returns:
608 | torch.Tensor: batch of vectors representing the fingerprint of each
609 | molecule.
610 | """
611 | fingerprints = inputs[self.in_key_fingerprint]
612 |
613 | return self.out_net(fingerprints)[:, None, :]
614 |
615 |
616 | class PropertyEmbedding(nn.Module):
617 | r"""
618 | Layers that map the property (e.g. HOMO-LUMO gap, electronic spatial extent etc.)
619 | of a molecule to a feature vector used for conditioning. Properties are first
620 | expanded using Gaussian basis functions before being processed by a fully
621 | connected MLP.
622 |
623 | Args:
624 | n_in (int): number of inputs (Gaussians used for expansion of the property).
625 | n_out (int): number of features in the embedding.
626 | in_key_property (str): the keyword to obtain the property.
627 | start (float): center of first Gaussian function, :math:`\mu_0` for expansion.
628 | stop (float): center of last Gaussian function, :math:`\mu_{N_g}` for expansion
629 | (the remaining centers will be placed linearly spaced between start and
630 | stop).
631 | n_layers (int, optional): number of dense layers used to embed the property
632 | (default: 5).
633 | n_neurons (list of int or int or None, optional): number of neurons in each
634 | layer of the output network. If a single int is provided, all layers will
635 | have that number of neurons, if `None`, interpolate linearly between n_in
636 | and n_out (default: None).
637 | activation (function, optional): activation function for hidden layers
638 | (default: spk.nn.activations.shifted_softplus).
639 | trainable_gaussians (bool, optional): if True, widths and offset of Gaussian
640 | functions for expansion are adjusted during training process (default:
641 | False).
642 | widths (float, optional): width value of Gaussian functions for expansion
643 | (provide None to set the width to the distance between two centers
644 | :math:`\mu`, default: None).
645 | """
646 |
647 | def __init__(self, n_in, n_out, in_key_property, start, stop, n_layers=5,
648 | n_neurons=None, activation=spk.nn.activations.shifted_softplus,
649 | trainable_gaussians=False, width=None, no_expansion=False):
650 |
651 | super(PropertyEmbedding, self).__init__()
652 |
653 | self.in_key_property = in_key_property
654 | self.n_in = n_in
655 | self.n_out = n_out
656 | if not no_expansion:
657 | self.expansion_net = GaussianExpansion(start, stop, self.n_in,
658 | trainable_gaussians, width)
659 | else:
660 | self.expansion_net = None
661 |
662 | if n_neurons is None:
663 | # linearly interpolate between n_in and n_out
664 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1])
665 | self.out_net = MLP(n_in, n_out, n_neurons, n_layers, activation)
666 |
667 | def forward(self, inputs):
668 | """
669 | Compute layer output.
670 |
671 | Args:
672 | inputs (dict of torch.Tensor): batch of input values containing the
673 | fingerprints.
674 | Returns:
675 | torch.Tensor: batch of vectors representing the fingerprint of each
676 | molecule.
677 | """
678 | property = inputs[self.in_key_property]
679 | if self.expansion_net is None:
680 | expanded = property
681 | else:
682 | expanded = self.expansion_net(property)
683 |
684 | return self.out_net(expanded)[:, None, :]
685 |
686 |
687 | ### MISC
688 | class GaussianExpansion(nn.Module):
689 | r"""Expansion layer using a set of Gaussian functions.
690 |
691 | Args:
692 | start (float): center of first Gaussian function, :math:`\mu_0`.
693 | stop (float): center of last Gaussian function, :math:`\mu_{N_g}`.
694 | n_gaussians (int, optional): total number of Gaussian functions, :math:`N_g`
695 | (default: 50).
696 | trainable (bool, optional): if True, widths and offset of Gaussian functions
697 | are adjusted during training process (default: False).
698 | widths (float, optional): width value of Gaussian functions (provide None to
699 | set the width to the distance between two centers :math:`\mu`, default:
700 | None).
701 |
702 | """
703 |
704 | def __init__(self, start, stop, n_gaussians=50, trainable=False,
705 | width=None):
706 | super(GaussianExpansion, self).__init__()
707 | # compute offset and width of Gaussian functions
708 | offset = torch.linspace(start, stop, n_gaussians)
709 | if width is None:
710 | widths = torch.FloatTensor((offset[1] - offset[0]) *
711 | torch.ones_like(offset))
712 | else:
713 | widths = torch.FloatTensor(width * torch.ones_like(offset))
714 | if trainable:
715 | self.widths = nn.Parameter(widths)
716 | self.offsets = nn.Parameter(offset)
717 | else:
718 | self.register_buffer("widths", widths)
719 | self.register_buffer("offsets", offset)
720 |
721 | def forward(self, property):
722 | """Compute expanded gaussian property values.
723 |
724 | Args:
725 | property (torch.Tensor): property values of (N_b x 1) shape.
726 |
727 | Returns:
728 | torch.Tensor: layer output of (N_b x N_g) shape.
729 |
730 | """
731 | # compute width of Gaussian functions (using an overlap of 1 STDDEV)
732 | coeff = -0.5 / torch.pow(self.widths, 2)[None, :]
733 | # Use advanced indexing to compute the individual components
734 | diff = property - self.offsets[None, :]
735 | # compute expanded property values
736 | return torch.exp(coeff * torch.pow(diff, 2))
737 |
--------------------------------------------------------------------------------
/preprocess_dataset.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import argparse
3 | import sys
4 | import time
5 | import numpy as np
6 | import logging
7 | from ase.db import connect
8 | from scipy.spatial.distance import pdist
9 | from utility_classes import ConnectivityCompressor, Molecule
10 | from multiprocessing import Process, Queue
11 | from pathlib import Path
12 |
13 |
14 | def get_parser():
15 | """ Setup parser for command line arguments """
16 | main_parser = argparse.ArgumentParser()
17 | main_parser.add_argument('datapath', help='Full path to dataset (e.g. '
18 | '/home/qm9.db)')
19 | main_parser.add_argument('--valence_list',
20 | default=[1, 1, 6, 4, 7, 3, 8, 2, 9, 1], type=int,
21 | nargs='+',
22 | help='The valence of atom types in the form '
23 | '[type1 valence type2 valence ...] '
24 | '(default: %(default)s)')
25 | main_parser.add_argument('--n_threads', type=int, default=16,
26 | help='Number of extra threads used while '
27 | 'processing the data')
28 | main_parser.add_argument('--n_mols_per_thread', type=int, default=100,
29 | help='Number of molecules processed by each '
30 | 'thread in one iteration')
31 | return main_parser
32 |
33 |
34 | def is_disconnected(connectivity):
35 | '''
36 | Assess whether all atoms of a molecule are connected using a connectivity matrix
37 |
38 | Args:
39 | connectivity (numpy.ndarray): matrix (n_atoms x n_atoms) indicating bonds
40 | between atoms
41 |
42 | Returns
43 | bool: True if the molecule consists of at least two disconnected graphs,
44 | False if all atoms are connected by some path
45 | '''
46 | con_mat = connectivity
47 | seen, queue = {0}, collections.deque([0]) # start at node (atom) 0
48 | while queue:
49 | vertex = queue.popleft()
50 | # iterate over (bonded) neighbors of current node
51 | for node in np.argwhere(con_mat[vertex] > 0).flatten():
52 | # add node to queue and list of seen nodes if it has not been seen before
53 | if node not in seen:
54 | seen.add(node)
55 | queue.append(node)
56 | # if the seen nodes do not include all nodes, there are disconnected parts
57 | return seen != {*range(len(con_mat))}
58 |
59 |
60 | def get_count_statistics(mol=None, get_stat_heads=False):
61 | '''
62 | Collects atom, bond, and ring count statistics of a provided molecule
63 |
64 | Args:
65 | mol (utility_classes.Molecule): Molecule to be examined
66 | get_stat_heads (bool, optional): set True to only return the headers of
67 | gathered statistics (default: False)
68 |
69 | Returns:
70 | numpy.ndarray: (n_statistics x 1) array containing the gathered statistics. Use
71 | get_stat_heads parameter to obtain the corresponding row headers (where RX
72 | describes number of X-membered rings and CXC indicates the number of
73 | carbon-carbon bonds of order X etc.).
74 | '''
75 | stat_heads = ['n_atoms', 'C', 'N', 'O', 'F', 'H', 'H1C', 'H1N',
76 | 'H1O', 'C1C', 'C2C', 'C3C', 'C1N', 'C2N', 'C3N', 'C1O',
77 | 'C2O', 'C1F', 'N1N', 'N2N', 'N1O', 'N2O', 'N1F', 'O1O',
78 | 'O1F', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R>8']
79 | if get_stat_heads:
80 | return stat_heads
81 | if mol is None:
82 | return None
83 | key_idx_dict = dict(zip(stat_heads, range(len(stat_heads))))
84 | stats = np.zeros((len(stat_heads), 1))
85 | # process all bonds and store statistics about bond and ring counts
86 | bond_stats = mol.get_bond_stats()
87 | for key, value in bond_stats.items():
88 | if key in key_idx_dict:
89 | idx = key_idx_dict[key]
90 | stats[idx, 0] = value
91 | # store simple statistics about number of atoms
92 | stats[key_idx_dict['n_atoms'], 0] = mol.n_atoms
93 | for key in ['C', 'N', 'O', 'F', 'H']:
94 | idx = key_idx_dict[key]
95 | charge = mol.type_charges[key]
96 | if charge in mol._unique_numbers:
97 | stats[idx, 0] = np.sum(mol.numbers == charge)
98 | return stats
99 |
100 |
101 | def preprocess_molecules(mol_idcs, source_db, valence,
102 | precompute_distances=True, precompute_fingerprint=False,
103 | remove_invalid=True, invalid_list=None, print_progress=False):
104 | '''
105 | Checks the validity of selected molecules and collects atom, bond,
106 | and ring count statistics for the valid structures. Molecules are classified as
107 | invalid if they consist of disconnected parts or fail a valence check, where the
108 | valency constraints of all atoms in a molecule have to be satisfied (e.g. carbon
109 | has four bonds, nitrogen has three bonds etc.)
110 |
111 | Args:
112 | mol_idcs (array): the indices of molecules from the source database that
113 | shall be examined
114 | source_db (str): full path to the source database (in ase.db sqlite format)
115 | valence (array): an array where the i-th entry contains the valency
116 | constraint of atoms with atomic charge i (e.g. a valency of 4 at array
117 | position 6 representing carbon)
118 | precompute_distances (bool, optional): if True, the pairwise distances between
119 | atoms in each molecule are computed and stored in the database (default:
120 | True)
121 | precompute_fingerprint (bool, optional): if True, the fingerprint of each
122 | molecule is computed and stored in the database (default: False)
123 | remove_invalid (bool, optional): if True, molecules that do not pass the
124 | valency or connectivity checks (or are on the invalid_list) are removed from
125 | the new database (default: True)
126 | invalid_list (list of int, optional): precomputed list containing indices of
127 | molecules that are marked as invalid (because they did not pass the
128 | valency or connectivity checks in earlier runs, default: None)
129 | print_progress (bool, optional): set True to print the progress in percent
130 | (default: False)
131 |
132 | Returns
133 | list of ase.Atoms: list of all valid molecules
134 | list of dict: list of corresponding dictionaries with data of each molecule
135 | numpy.ndarray: (n_statistics x n_valid_molecules) matrix with atom, bond,
136 | and ring count statistics
137 | list of int: list with indices of molecules that failed the valency check
138 | list of int: list with indices of molecules that consist of disconnected parts
139 | int: number of molecules processed
140 | '''
141 | # initial setup
142 | count = 0 # count the number of invalid molecules
143 | disc = [] # store indices of disconnected molecules
144 | inval = [] # store indices of invalid molecules
145 | data_list = [] # store data fields of molecules for new db
146 | mols = [] # store molecules (as ase.Atoms objects)
147 | compressor = ConnectivityCompressor() # (de)compress sparse connectivity matrices
148 | stats = np.empty((len(get_count_statistics(get_stat_heads=True)), 0))
149 | n_all = len(mol_idcs)
150 |
151 | with connect(source_db) as source_db:
152 | # iterate over provided indices
153 | for i in mol_idcs:
154 | i = int(i)
155 | # skip molecule if present in invalid_list and remove_invalid is True
156 | if remove_invalid and invalid_list is not None:
157 | if i in invalid_list:
158 | continue
159 | # get molecule from database
160 | row = source_db.get(i + 1)
161 | data = row.data
162 | at = row.toatoms()
163 | # get positions and atomic numbers
164 | pos = at.positions
165 | numbers = at.numbers
166 | # center positions (using center of mass)
167 | pos = pos - at.get_center_of_mass()
168 | # order atoms by distance to center of mass
169 | center_dists = np.sqrt(np.maximum(np.sum(pos ** 2, axis=1), 0))
170 | idcs_sorted = np.argsort(center_dists)
171 | pos = pos[idcs_sorted]
172 | numbers = numbers[idcs_sorted]
173 | # update positions and atomic numbers accordingly in Atoms object
174 | at.positions = pos
175 | at.numbers = numbers
176 | # instantiate utility_classes.Molecule object
177 | mol = Molecule(pos, numbers)
178 | # get connectivity matrix (detecting bond orders with Open Babel)
179 | con_mat = mol.get_connectivity()
180 | # stop if molecule is disconnected (and therefore invalid)
181 | if remove_invalid:
182 | if is_disconnected(con_mat):
183 | count += 1
184 | disc += [i]
185 | continue
186 |
187 | # check if valency constraints of all atoms in molecule are satisfied:
188 | # since the detection of bond orders for the connectivity matrix with Open
189 | # Babel is unreliable for certain cases (e.g. some aromatic rings) we
190 | # try to fix it manually (with heuristics) or by reshuffling the atom
191 | # order (as the bond order detection of Open Babel is sensitive to the
192 | # order of atoms)
193 | nums = numbers
194 | random_ord = np.arange(len(numbers))
195 | for _ in range(10): # try 10 times before dismissing as invalid
196 | if np.all(np.sum(con_mat, axis=0) == valence[nums]):
197 | # valency is correct -> mark as valid and stop check
198 | val = True
199 | break
200 | else:
201 | # try to fix bond orders using heuristics
202 | val = False
203 | con_mat = mol.get_fixed_connectivity()
204 | if np.all(np.sum(con_mat, axis=0) == valence[nums]):
205 | # valency is now correct -> mark as valid and stop check
206 | val = True
207 | break
208 | # shuffle atom order before checking valency again
209 | random_ord = np.random.permutation(range(len(pos)))
210 | mol = Molecule(pos[random_ord], numbers[random_ord])
211 | con_mat = mol.get_connectivity()
212 | nums = numbers[random_ord]
213 | if remove_invalid:
214 | if not val:
215 | # stop if molecule is invalid (it failed the repeated valence checks)
216 | count += 1
217 | inval += [i]
218 | continue
219 |
220 | if precompute_distances:
221 | # calculate pairwise distances of atoms and store them in data
222 | dists = pdist(pos)[:, None]
223 | data.update({'dists': dists})
224 | if precompute_fingerprint:
225 | fp = np.array(mol.get_fp().fp, dtype=np.uint32)
226 | data.update({'fingerprint': fp})
227 |
228 | # store compressed connectivity matrix in data
229 | rand_ord_rev = np.argsort(random_ord)
230 | con_mat = con_mat[rand_ord_rev][:, rand_ord_rev]
231 | data.update(
232 | {'con_mat': compressor.compress(con_mat)})
233 |
234 | # update atom, bond, and ring count statistics
235 | stats = np.hstack((stats, get_count_statistics(mol=mol)))
236 |
237 | # add results to the lists
238 | mols += [at]
239 | data_list += [data]
240 |
241 | # print progress if desired
242 | if print_progress:
243 | if i % 100 == 0:
244 | print('\033[K', end='\r', flush=True)
245 | print(f'{100 * (i + 1) / n_all:.2f}%', end='\r', flush=True)
246 |
247 | return mols, data_list, stats, inval, disc, count
248 |
249 |
250 | def _processing_worker(q_in, q_out, task):
251 | '''
252 | Simple worker function that repeatedly fulfills a task using transmitted input and
253 | sends back the results until a stop signal is received. Can be used as target in
254 | a multiprocessing.Process object.
255 |
256 | Args:
257 | q_in (multiprocessing.Queue): queue to receive a list with data. The first
258 | entry signals whether worker can stop and the remaining entries are used as
259 | input arguments to the task function
260 | q_out (multiprocessing.Queue): queue to send results from task back
261 | task (callable function): function that is called using the received data
262 | '''
263 | while True:
264 | data = q_in.get(True) # receive data
265 | if data[0]: # stop if stop signal is received
266 | break
267 | results = task(*data[1:]) # fulfill task with received data
268 | q_out.put(results) # send back results
269 |
270 |
271 | def _submit_jobs(qs_out, count, chunk_size, n_all, working_flag,
272 | n_per_thread):
273 | '''
274 | Function that submits a job to preprocess molecules to every provided worker.
275 |
276 | Args:
277 | qs_out (list of multiprocessing.Queue): queues used to send data to workers (one
278 | queue per worker)
279 | count (int): index of the earliest, not yet preprocessed molecule in the db
280 | chunk_size (int): number of molecules to be divided amongst workers
281 | n_all (int): total number of molecules in the db
282 | working_flag (array): flags indicating whether workers are running
283 | n_per_thread (int): number of molecules to be given to each thread
284 |
285 | Returns:
286 | numpy.ndarray: array with flags indicating whether workers got
287 | a job
288 | int: index of the new earliest, not yet preprocessed molecule in
289 | the db (after the submitted preprocessing jobs have been done)
290 | '''
291 | # calculate indices of molecules that shall be preprocessed by workers
292 | idcs = np.arange(count, min(n_all, count + chunk_size))
293 | start = 0
294 | for i, q in enumerate(qs_out):
295 | if start >= len(idcs):
296 | # stop if no more indices are left to submit
297 | break
298 | end = start + n_per_thread
299 | q.put((False, idcs[start:end])) # submit indices (and signal to not stop)
300 | working_flag[i] = 1 # set flag that current worker got a job
301 | start = end
302 | new_count = count + len(idcs)
303 | return working_flag, new_count
304 |
305 |
306 | def preprocess_dataset(datapath, valence_list, n_threads, n_mols_per_thread=100,
307 | logging_print=True, new_db_path=None, precompute_distances=True,
308 | precompute_fingerprint=False, remove_invalid=True,
309 | invalid_list=None):
310 | '''
311 | Pre-processes all molecules of a dataset using the provided valency information.
312 | Multi-threading is used to speed up the process.
313 | Along with a new database containing the pre-processed molecules, a
314 | "input_db_invalid.txt" file holding the indices of removed molecules (which
315 | do not pass the valence or connectivity checks, omitted if remove_invalid is False)
316 | and a "new_db_statistics.npz" file (containing atom, bond, and ring count statistics
317 | for all molecules in the new database) are stored.
318 |
319 | Args:
320 | datapath (str): full path to dataset (ase.db database)
321 | valence_list (list): the valence of atom types in the form
322 | [type1 valence type2 valence ...]
323 | n_threads (int): number of threads used (0 for no extra threads)
324 | n_mols_per_thread (int, optional): number of molecules processed by each
325 | thread at each iteration (default: 100)
326 | logging_print (bool, optional): set True to show output with logging.info
327 | instead of standard printing (default: True)
328 | new_db_path (str, optional): full path to new database where pre-processed
329 | molecules shall be stored (None to simply append "gen" to the name in
330 | datapath, default: None)
331 | precompute_distances (bool, optional): if True, the pairwise distances between
332 | atoms in each molecule are computed and stored in the database (default:
333 | True)
334 | precompute_fingerprint (bool, optional): if True, the fingerprint of each
335 | molecule is computed and stored in the database (default: False)
336 | remove_invalid (bool, optional): if True, molecules that do not pass the
337 | valency or connectivity check are removed from the new database (note
338 | that a precomputed list of invalid molecules determined using the code in
339 | this file is fetched from our repository if possible, default: True)
340 | invalid_list (list of int, optional): precomputed list containing indices of
341 | molecules that are marked as invalid (because they did not pass the
342 | valency or connectivity checks in earlier runs, default: None)
343 | '''
344 | # convert paths
345 | datapath = Path(datapath)
346 | if new_db_path is None:
347 | new_db_path = datapath.parent / (datapath.stem + 'gen.db')
348 | else:
349 | new_db_path = Path(new_db_path)
350 |
351 | # compute array where the valency constraint of atom type i is stored at entry i
352 | max_type = max(valence_list[::2])
353 | valence = np.zeros(max_type + 1, dtype=int)
354 | valence[valence_list[::2]] = valence_list[1::2]
355 |
356 | def _print(x, end='\n', flush=False):
357 | if logging_print:
358 | logging.info(x)
359 | else:
360 | print(x, end=end, flush=flush)
361 |
362 | with connect(datapath) as db:
363 | n_all = db.count()
364 | if n_all == 0:
365 | _print('No molecules found in data base!')
366 | sys.exit(0)
367 | _print('\nPre-processing data...')
368 | if logging_print:
369 | _print(f'Processed: 0 / {n_all}...')
370 | else:
371 | _print(f'0.00%', end='', flush=True)
372 |
373 | # initial setup
374 | n_iterations = 0
375 | chunk_size = n_threads * n_mols_per_thread
376 | current = 0
377 | count = 0 # count number of discarded (invalid etc.) molecules
378 | disc = []
379 | inval = []
380 | stats = np.empty((len(get_count_statistics(get_stat_heads=True)), 0))
381 | working_flag = np.zeros(n_threads, dtype=bool)
382 | start_time = time.time()
383 | if invalid_list is not None and remove_invalid:
384 | invalid_list = {*invalid_list}
385 | n_inval = len(invalid_list)
386 | else:
387 | n_inval = 0
388 |
389 | with connect(new_db_path) as new_db:
390 |
391 | if precompute_fingerprint:
392 | # store the system's byte order of unsigned integers (for the fingerprints)
393 | new_db.metadata = \
394 | {'fingerprint_format': '>u4' if sys.byteorder == 'big' else '= 1:
397 | # set up threads and queues
398 | threads = []
399 | qs_in = []
400 | qs_out = []
401 | for i in range(n_threads):
402 | qs_in += [Queue(1)]
403 | qs_out += [Queue(1)]
404 | threads += \
405 | [Process(target=_processing_worker,
406 | name=str(i),
407 | args=(qs_out[-1],
408 | qs_in[-1],
409 | lambda x:
410 | preprocess_molecules(x,
411 | datapath,
412 | valence,
413 | precompute_distances,
414 | precompute_fingerprint,
415 | remove_invalid,
416 | invalid_list)))]
417 | threads[-1].start()
418 |
419 | # submit first round of jobs
420 | working_flag, current = \
421 | _submit_jobs(qs_out, current, chunk_size, n_all,
422 | working_flag, n_mols_per_thread)
423 |
424 | while np.any(working_flag == 1):
425 | n_iterations += 1
426 |
427 | # initialize new iteration
428 | results = []
429 |
430 | # gather results
431 | for i, q in enumerate(qs_in):
432 | if working_flag[i]:
433 | results += [q.get()]
434 | working_flag[i] = 0
435 |
436 | # submit new jobs
437 | working_flag, current_new = \
438 | _submit_jobs(qs_out, current, chunk_size, n_all, working_flag,
439 | n_mols_per_thread)
440 |
441 | # store gathered results
442 | for res in results:
443 | mols, data_list, _stats, _inval, _disc, _c = res
444 | for (at, data) in zip(mols, data_list):
445 | new_db.write(at, data=data)
446 | stats = np.hstack((stats, _stats))
447 | inval += _inval
448 | disc += _disc
449 | count += _c
450 |
451 | # print progress
452 | if logging_print and n_iterations % 10 == 0:
453 | _print(f'Processed: {current:6d} / {n_all}...')
454 | elif not logging_print:
455 | _print('\033[K', end='\r', flush=True)
456 | _print(f'{100 * current / n_all:.2f}%', end='\r',
457 | flush=True)
458 | current = current_new # update current position in database
459 |
460 | # stop worker threads and join
461 | for i, q_out in enumerate(qs_out):
462 | q_out.put((True,))
463 | threads[i].join()
464 | threads[i].terminate()
465 | if logging_print:
466 | _print(f'Processed: {n_all} / {n_all}...')
467 |
468 | else:
469 | results = preprocess_molecules(range(n_all), datapath, valence,
470 | precompute_distances,
471 | precompute_fingerprint, remove_invalid,
472 | invalid_list, print_progress=True)
473 | mols, data_list, stats, inval, disc, count = results
474 | for (at, data) in zip(mols, data_list):
475 | new_db.write(at, data=data)
476 |
477 | if not logging_print:
478 | _print('\033[K', end='\n', flush=True)
479 | _print(f'... successfully validated {n_all - count - n_inval} data '
480 | f'points!', flush=True)
481 | if invalid_list is not None:
482 | _print(f'{n_inval} structures were removed because they are on the '
483 | f'pre-computed list of invalid molecules!', flush=True)
484 | if len(disc)+len(inval) > 0:
485 | _print(f'CAUTION: Could not validate {len(disc)+len(inval)} additional '
486 | f'molecules. These were also removed and their indices are '
487 | f'appended to the list of invalid molecules stored at '
488 | f'{datapath.parent / (datapath.stem + f"_invalid.txt")}',
489 | flush=True)
490 | np.savetxt(datapath.parent / (datapath.stem + f'_invalid.txt'),
491 | np.append(np.sort(list(invalid_list)), np.sort(inval + disc)),
492 | fmt='%d')
493 | elif remove_invalid:
494 | _print(f'Identified {len(disc)} disconnected structures, and {len(inval)} '
495 | f'structures with invalid valence!', flush=True)
496 | np.savetxt(datapath.parent / (datapath.stem + f'_invalid.txt'),
497 | np.sort(inval + disc), fmt='%d')
498 | _print('\nCompressing and storing statistics with numpy...')
499 | np.savez_compressed(new_db_path.parent/(new_db_path.stem+f'_statistics.npz'),
500 | stats=stats,
501 | stat_heads=get_count_statistics(get_stat_heads=True))
502 |
503 | end_time = time.time() - start_time
504 | m, s = divmod(end_time, 60)
505 | h, m = divmod(m, 60)
506 | h, m, s = int(h), int(m), int(s)
507 | _print(f'Done! Pre-processing needed {h:d}h{m:02d}m{s:02d}s.')
508 |
509 |
510 | if __name__ == '__main__':
511 | parser = get_parser()
512 | args = parser.parse_args()
513 | preprocess_dataset(**vars(args))
514 |
--------------------------------------------------------------------------------
/published_data/README.md:
--------------------------------------------------------------------------------
1 | # Generated molecules and pretrained models
2 |
3 | Here we provide links to the molecules generated with cG-SchNet in the studies reported in the paper. This allows for reproduction of the shown graphs and statistics, as well as further analysis of the obtained molecules. Furthermore, we provide two pretrained cG-SchNet models that were used in our experiments, which enables sampling of additional molecules.
4 |
5 | ## Data bases with generated molecules
6 | A zip-file containing the generated molecules can be found under [DOI 10.14279/depositonce-14977](http://dx.doi.org/10.14279/depositonce-14977).
7 | It includes five folders with molecules generated by five cG-SchNet models conditioned on different combinations of properties.
8 | They appear in the same order as in the paper, i.e. the first model is conditioned on isotropic polarizability, the second is conditioned on fingerprints, the third is conditioned on HOMO-LUMO gap and atomic composition, the fourth is conditioned on relative atomic energy and atomic composition, and the fifth is conditioned on HOMO-LUMO gap and relative atomic energy. Each folder contains several data bases following the naming convention _\\_\.db_, where _\_ lists the conditioning target values (separated with underscores) and _\_ is either _generated_ for raw, generated structures or _relaxed_ for the relaxed counterparts (wherever we computed them). For example _/4\_comp\_relenergy/c7o2h10\_-0.1\_relaxed.db_ contains relaxed molecules that were generated by the fourth cG-SchNet model (conditioned on atomic composition and relative atomic energy) using the composition c7o2h10 and a relative atomic energy of -0.1 eV as targets during sampling. Relaxation was carried out at the same level of theory used for the QM9 training data. Please refer to the publication for details on the procedure.
9 |
10 | For help with the content of each data base, please call the script _data\_base\_help.py_ which will print the metadata of a selected data base and help to get started:
11 |
12 | ```
13 | python ./cG-SchNet/published_data/data_base_help.py
14 | ```
15 |
16 | For example, the output for the data base _/4\_comp\_relenergy/c7o2h10\_-0.1\_relaxed.db_ is as following:
17 |
18 | ```
19 | INFO for data base at
20 | ===========================================================================================================
21 | Contains 2719 C7O2H10 isomers generated by cG-SchNet and subsequently relaxed with DFT calculations.
22 | All reported energy values are in eV and a systematic offset between the ORCA reference calculations and the QM9 reference values was approximated and compensated. For reference, the uncompensated energy values directly obtained with ORCA are also provided.
23 | The corresponding raw, generated version of each structure can be found in the data base of generated molecules using the provided "gen_idx" (caution, it starts from 0 whereas ase loads data bases counting from 1).
24 | The field "rmsd" holds the root mean square deviation between the atom positions of the raw, generated structure and the relaxed equilibrium molecule in Å (it includes hydrogen atoms).
25 | The field "changed" is 0 if the connectivity of the molecule before and after relaxation is identical (i.e. no bonds were broken or newly formed) and 1 if it did change.
26 | The field "known_relaxed" is 0, 3, or 6 to mark novel isomers, novel stereo-isomers, and unseen isomers (i.e. isomers resembling test data), respectively.
27 | Originally, all 3349 unique and valid C7O2H10 isomers among 100k molecules generated by cG-SchNet were chosen for relaxation. 7 of these did not converge to a valid configuration and 630 converged to equilibrium configurations already covered by other generated isomers (i.e. they were duplicate structures) and were therefore removed.
28 | ===========================================================================================================
29 |
30 | For example, here is the data stored with the first three molecules:
31 | 0: {'gen_idx': 0, 'computed_relative_atomic_energy': -0.11081359089212128, 'computed_energy_U0': -11512.699587841627, 'computed_energy_U0_uncompensated': -11512.578122651728, 'rmsd': 0.2492194704871389, 'changed': 0, 'known_relaxed': 3}
32 | 1: {'gen_idx': 1, 'computed_relative_atomic_energy': -0.13234655336327705, 'computed_energy_U0': -11513.108714128579, 'computed_energy_U0_uncompensated': -11512.98724893868, 'rmsd': 0.48381233235411153, 'changed': 0, 'known_relaxed': 3}
33 | 2: {'gen_idx': 2, 'computed_relative_atomic_energy': -0.11825264882338615, 'computed_energy_U0': -11512.84092994232, 'computed_energy_U0_uncompensated': -11512.719464752421, 'rmsd': 0.21804191287675742, 'changed': 0, 'known_relaxed': 3}
34 |
35 | You can load and access the molecules and accompanying data by connecting to the data base with ASE, e.g. using the following python code snippet:
36 | from ase.db import connect
37 | with connect() as con:
38 | row = con.get(1) # load the first molecule, 1-based indexing
39 | R = row.positions # positions of atoms as 3d coordinates
40 | Z = row.numbers # list of atomic numbers
41 | data = row.data # dictionary of data stored with the molecule
42 |
43 | You can visualize the molecules in the data base with ASE from the command line by calling:
44 | ase gui
45 | ```
46 |
47 | Note that references to molecules from the QM9 data set always correspond to the indices _after_ removing the invalid molecules listed in [the file of invalid structures](https://github.com/atomistic-machine-learning/cG-SchNet/blob/main/splits/qm9_invalid.txt). These structures were automatically removed if QM9 was downloaded with the data script in this repository (e.g. by starting [model training](https://github.com/atomistic-machine-learning/cG-SchNet#training-a-model)). The resulting data base with corresponding indices can be found in your data directory: ```/qm9gen.db```.
48 |
49 | ## Pretrained models
50 | A zip-file containing two pretrained cG-SchNet models can be found under [DOI 10.14279/depositonce-14978](http://dx.doi.org/10.14279/depositonce-14978). The archive consists of two folders, where _comp\_relenergy_ hosts the model that was conditioned on atomic composition and relative atomic energy and used for the study described in Fig. 4 in the paper. The other model was conditioned on the HOMO-LUMO gap and relative atomic energy and used in the study described in Fig. 5 in the paper.
51 |
52 | In order to generate molecules with the pretrained models, simply extract the folders into your model directory and adapt the call for generating molecules described in the [main readme](https://github.com/atomistic-machine-learning/cG-SchNet#generating-molecules) accordingly. For example, you can generate 20k molecules with the composition c7o2h10 and a relative atomic energy of -0.1 eV as targets with:
53 |
54 | ```
55 | python ./cG-SchNet/gschnet_cond_script.py generate gschnet ./models/comp_relenergy/ 20000 --conditioning "composition 10 7 0 2 0; n_atoms 19; relative_atomic_energy -0.1" --cuda
56 | ```
57 |
58 | Note that the model takes a string with three conditions as input to the --conditioning argument: the number of atoms of each type in the order h c n o f, the total number of atoms, and the relative atomic energy value, each separated with a semicolon.
59 | Similarly, you can generate 20k molecules with a HOMO-LUMO gap of 4 eV and relative atomic energy of -0.2 eV as targets with the other model:
60 |
61 | ```
62 | python ./cG-SchNet/gschnet_cond_script.py generate gschnet ./models/gap_relenergy/ 20000 --conditioning "gap 4.0; relative_atomic_energy -0.2" --cuda
63 | ```
64 |
65 | The second model takes two conditions, the gap value and the energy value, as targets.
66 | For more details on the generation of molecules and subsequent filtering, please refer to the [main readme](https://github.com/atomistic-machine-learning/cG-SchNet#filtering-and-analysis-of-generated-molecules).
67 |
--------------------------------------------------------------------------------
/published_data/data_base_help.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from pathlib import Path
4 | from ase.db import connect
5 |
6 |
7 | def get_parser():
8 | """ Setup parser for command line arguments """
9 | main_parser = argparse.ArgumentParser()
10 | main_parser.add_argument('data_base_path', type=str,
11 | help='Path to data base (.db file) with molecules.')
12 |
13 | return main_parser
14 |
15 |
16 | if __name__ == '__main__':
17 | parser = get_parser()
18 | args = parser.parse_args()
19 |
20 | path = Path(args.data_base_path).resolve()
21 | if not path.exists():
22 | print(f'Argument error! There is no data base at "{path}"!')
23 | sys.exit(0)
24 |
25 | with connect(args.data_base_path) as con:
26 | if con.count() == 0:
27 | print(f'Error: The data base at "{path}" is empty!')
28 | sys.exit(0)
29 | elif 'info' not in con.metadata:
30 | print(f'Error: The metadata of data base at "{path}" does not contain the field "info"!')
31 | print(f'However, this is the data stored for the first molecule:\n{con.get(1).data}')
32 | sys.exit(0)
33 | print(f'\nINFO for data base at "{path}"')
34 | print('===========================================================================================================')
35 | print(con.metadata['info'])
36 | print('===========================================================================================================')
37 | print('\nFor example, here is the data stored with the first three molecules:')
38 | for i in range(3):
39 | print(f'{i}: {con.get(i+1).data}')
40 | print('\nYou can load and access the molecules and accompanying data by connecting to the data base with ASE, e.g. using the following python code snippet:')
41 | print(f'from ase.db import connect')
42 | print(f'with connect({path}) as con:\n',
43 | f'\trow = con.get(1) # load the first molecule, 1-based indexing\n',
44 | f'\tR = row.positions # positions of atoms as 3d coordinates\n',
45 | f'\tZ = row.numbers # list of atomic numbers\n',
46 | f'\tdata = row.data # dictionary of data stored with the molecule\n')
47 | print(f'You can visualize the molecules in the data base with ASE from the command line by calling:')
48 | print(f'ase gui {path}')
49 |
50 |
--------------------------------------------------------------------------------
/qm9_data.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import re
4 | import shutil
5 | import tarfile
6 | import tempfile
7 | from pathlib import Path
8 | from urllib import request as request
9 | from urllib.error import HTTPError, URLError
10 | from base64 import b64encode, b64decode
11 |
12 | import numpy as np
13 | import torch
14 | from ase.db import connect
15 | from ase.io.extxyz import read_xyz
16 | from ase.units import Debye, Bohr, Hartree, eV
17 |
18 | from schnetpack import Properties
19 | from schnetpack.datasets import DownloadableAtomsData
20 | from utility_classes import ConnectivityCompressor
21 | from preprocess_dataset import preprocess_dataset
22 |
23 |
24 | class QM9gen(DownloadableAtomsData):
25 | """ QM9 benchmark dataset for organic molecules with up to nine non-hydrogen atoms
26 | from {C, O, N, F}.
27 |
28 | This class adds convenience functions to download QM9 from figshare,
29 | pre-process the data such that it can be used for moleculec generation with the
30 | G-SchNet model, and load the data into pytorch.
31 |
32 | Args:
33 | path (str): path to directory containing qm9 database
34 | subset (list, optional): indices of subset, set to None for entire dataset
35 | (default: None).
36 | download (bool, optional): enable downloading if qm9 database does not
37 | exists (default: True)
38 | precompute_distances (bool, optional): if True and the pre-processed
39 | database does not yet exist, the pairwise distances of atoms in the
40 | dataset's molecules will be computed during pre-processing and stored in
41 | the database (increases storage demand of the dataset but decreases
42 | computational cost during training as otherwise the distances will be
43 | computed once in every epoch, default: True)
44 | remove_invalid (bool, optional): if True QM9 molecules that do not pass the
45 | valence check will be removed from the training data (note 1: the
46 | validity is per default inferred from a pre-computed list in our
47 | repository but will be assessed locally if the download fails,
48 | note2: only works if the pre-processed database does not yet exist,
49 | default: True)
50 |
51 | References:
52 | .. [#qm9_1] https://ndownloader.figshare.com/files/3195404
53 | """
54 |
55 | # general settings for the dataset
56 | available_atom_types = [1, 6, 7, 8, 9] # all atom types found in the dataset
57 | atom_types_valence = [1, 4, 3, 2, 1] # valence constraints of the atom types
58 | radial_limits = [0.9, 1.7] # minimum and maximum distance between neighboring atoms
59 |
60 | # properties
61 | A = 'rotational_constant_A'
62 | B = 'rotational_constant_B'
63 | C = 'rotational_constant_C'
64 | mu = 'dipole_moment'
65 | alpha = 'isotropic_polarizability'
66 | homo = 'homo'
67 | lumo = 'lumo'
68 | gap = 'gap'
69 | r2 = 'electronic_spatial_extent'
70 | zpve = 'zpve'
71 | U0 = 'energy_U0'
72 | U = 'energy_U'
73 | H = 'enthalpy_H'
74 | G = 'free_energy'
75 | Cv = 'heat_capacity'
76 |
77 | properties = [
78 | A, B, C, mu, alpha,
79 | homo, lumo, gap, r2, zpve,
80 | U0, U, H, G, Cv, 'n_atoms', 'relative_atomic_energy'
81 | ]
82 |
83 | units = [1., 1., 1., Debye, Bohr ** 3,
84 | Hartree, Hartree, Hartree,
85 | Bohr ** 2, Hartree,
86 | Hartree, Hartree, Hartree,
87 | Hartree, 1.,
88 | ]
89 |
90 | units_dict = dict(zip(properties, units))
91 |
92 | connectivity_compressor = ConnectivityCompressor()
93 |
94 | def __init__(self, path, subset=None, download=True, precompute_distances=True,
95 | remove_invalid=True, load_additionally=None):
96 | self.path = path
97 | self.dbpath = os.path.join(self.path, f'qm9gen.db')
98 | self.precompute_distances = precompute_distances
99 | self.remove_invalid = remove_invalid
100 | self.load_additionally = [] if load_additionally is None else load_additionally
101 | self.db_metadata = None
102 | # hard coded weights for regression of energy per atom from the concentration of atoms (i.e. the composition divided by the absolute number of atoms)
103 | self.energy_regression = lambda x: x.dot(np.array([-1032.61411992, -2052.26704777, -2506.01057452, -3063.2081989, -3733.79421864])) + 1016.2017264092275
104 |
105 | super().__init__(self.dbpath, subset=subset,
106 | available_properties=self.properties,
107 | units=self.units, download=download)
108 |
109 | with connect(self.dbpath) as db:
110 | if db.count() <= 0:
111 | logging.error('Error: Data base is empty, please provide a path '
112 | 'to a proper data base or an empty path where the '
113 | 'data can be downloaded to!')
114 | raise FileExistsError()
115 | if self.precompute_distances and 'dists' not in db.get(1).data:
116 | logging.info('Caution: Existing data base does not contain '
117 | 'pre-computed distances, distances will be computed '
118 | 'on the fly in every epoch.')
119 | if 'fingerprint' in self.load_additionally:
120 | if 'fingerprint' not in db.get(1).data:
121 | logging.error('Error: Fingerprints not found in the provided data '
122 | 'base, please provide another path to the correct '
123 | 'data base or an empty directory where a new data '
124 | 'base with pre-computed fingerprints can be '
125 | 'downloaded to.')
126 | raise FileExistsError()
127 | else:
128 | self.db_metadata = db.metadata
129 | for prop in self.load_additionally:
130 | if prop not in db.get(1).data and prop not in self.properties:
131 | logging.error(f'Error: Unknown property ({prop}) requested for '
132 | f'conditioning. Cannot obtain that property from '
133 | f'the database!')
134 | raise FileExistsError()
135 |
136 | def create_subset(self, idx):
137 | """
138 | Returns a new dataset that only consists of provided indices.
139 |
140 | Args:
141 | idx (numpy.ndarray): subset indices
142 |
143 | Returns:
144 | schnetpack.data.AtomsData: dataset with subset of original data
145 | """
146 | idx = np.array(idx)
147 | subidx = idx if self.subset is None or len(idx) == 0 \
148 | else np.array(self.subset)[idx]
149 | return type(self)(self.path, subidx, download=False,
150 | load_additionally=self.load_additionally)
151 |
152 | def get_properties(self, idx):
153 | _idx = self._subset_index(idx)
154 | with connect(self.dbpath) as db:
155 | row = db.get(_idx + 1)
156 | at = row.toatoms()
157 |
158 | # extract/calculate structure
159 | properties = {}
160 | properties[Properties.Z] = torch.LongTensor(at.numbers.astype(np.int))
161 | positions = at.positions.astype(np.float32)
162 | positions -= at.get_center_of_mass() # center positions
163 | properties[Properties.R] = torch.FloatTensor(positions)
164 | properties[Properties.cell] = torch.FloatTensor(at.cell.astype(np.float32))
165 |
166 | # recover connectivity matrix from compressed format
167 | con_mat = self.connectivity_compressor.decompress(row.data['con_mat'])
168 | # save in dictionary
169 | properties['_con_mat'] = torch.FloatTensor(con_mat.astype(np.float32))
170 |
171 | # extract pre-computed distances (if they exist)
172 | if 'dists' in row.data:
173 | properties['dists'] = row.data['dists']
174 |
175 | # extract additional information
176 | for add_prop in self.load_additionally:
177 | if add_prop == 'fingerprint':
178 | fp = np.array(row.data[add_prop],
179 | dtype=self.db_metadata['fingerprint_format'])
180 | properties[add_prop] = torch.FloatTensor(
181 | np.unpackbits(fp.view(np.uint8), bitorder='little'))
182 | elif add_prop == 'n_atoms':
183 | properties['n_atoms'] = torch.FloatTensor([len(at.numbers)])
184 | elif add_prop == 'relative_atomic_energy':
185 | types = at.numbers.astype(np.int)
186 | composition = np.array([np.sum(types == 1),
187 | np.sum(types == 6),
188 | np.sum(types == 7),
189 | np.sum(types == 8),
190 | np.sum(types == 9)],
191 | dtype=np.float32)
192 | concentration = composition/np.sum(composition)
193 | energy = row.data['energy_U0']
194 | energy_per_atom = energy/len(types)
195 | relative_atomic_energy = energy_per_atom - self.energy_regression(concentration)
196 | properties[add_prop] = torch.FloatTensor([relative_atomic_energy])
197 | else:
198 | properties[add_prop] = torch.FloatTensor([row.data[add_prop]])
199 |
200 | # get atom environment
201 | nbh_idx, offsets = self.environment_provider.get_environment(at)
202 | # store neighbors, cell, and index
203 | properties[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int))
204 | properties[Properties.cell_offset] = torch.FloatTensor(
205 | offsets.astype(np.float32))
206 | properties["_idx"] = torch.LongTensor(np.array([idx], dtype=np.int))
207 |
208 | return at, properties
209 |
210 | def _download(self):
211 | works = True
212 | if not os.path.exists(self.dbpath):
213 | qm9_path = os.path.join(self.path, f'qm9.db')
214 | if not os.path.exists(qm9_path):
215 | works = works and self._load_data()
216 | works = works and self._preprocess_qm9()
217 | return works
218 |
219 | def _load_data(self):
220 | logging.info('Downloading GDB-9 data...')
221 | tmpdir = tempfile.mkdtemp('gdb9')
222 | tar_path = os.path.join(tmpdir, 'gdb9.tar.gz')
223 | raw_path = os.path.join(tmpdir, 'gdb9_xyz')
224 | url = 'https://ndownloader.figshare.com/files/3195389'
225 |
226 | try:
227 | request.urlretrieve(url, tar_path)
228 | logging.info('Done.')
229 | except HTTPError as e:
230 | logging.error('HTTP Error:', e.code, url)
231 | return False
232 | except URLError as e:
233 | logging.error('URL Error:', e.reason, url)
234 | return False
235 |
236 | logging.info('Extracting data from tar file...')
237 | tar = tarfile.open(tar_path)
238 | tar.extractall(raw_path)
239 | tar.close()
240 | logging.info('Done.')
241 |
242 | logging.info('Parsing xyz files...')
243 | with connect(os.path.join(self.path, 'qm9.db')) as db:
244 | ordered_files = sorted(os.listdir(raw_path),
245 | key=lambda x: (int(re.sub('\D', '', x)), x))
246 | for i, xyzfile in enumerate(ordered_files):
247 | xyzfile = os.path.join(raw_path, xyzfile)
248 |
249 | if (i + 1) % 10000 == 0:
250 | logging.info('Parsed: {:6d} / 133885'.format(i + 1))
251 | properties = {}
252 | tmp = os.path.join(tmpdir, 'tmp.xyz')
253 |
254 | with open(xyzfile, 'r') as f:
255 | lines = f.readlines()
256 | l = lines[1].split()[2:]
257 | for pn, p in zip(self.properties, l):
258 | properties[pn] = float(p) * self.units[pn]
259 | with open(tmp, "wt") as fout:
260 | for line in lines:
261 | fout.write(line.replace('*^', 'e'))
262 |
263 | with open(tmp, 'r') as f:
264 | ats = list(read_xyz(f, 0))[0]
265 | db.write(ats, data=properties)
266 | logging.info('Done.')
267 |
268 | shutil.rmtree(tmpdir)
269 |
270 | return True
271 |
272 | def _preprocess_qm9(self):
273 | # try to download pre-computed list of invalid molecules
274 | raw_path = os.path.join(self.path, 'qm9_invalid.txt')
275 | if os.path.exists(raw_path):
276 | logging.info(f'Found existing list with indices of molecules in QM9 that are invalid at "{raw_path}".'
277 | f' Please manually delete the file and restart training if you want to use the default list instead.')
278 | invalid_list = np.loadtxt(raw_path)
279 | else:
280 | logging.info('Downloading pre-computed list of invalid QM9 molecules...')
281 | # url = 'https://github.com/atomistic-machine-learning/cG-SchNet/blob/main/splits/' \
282 | # 'qm9_invalid.txt?raw=true'
283 | try:
284 | url = Path(__file__).parent.resolve() / 'splits/qm9_invalid.txt'
285 | request.urlretrieve(url.as_uri(), raw_path)
286 | logging.info('Done.')
287 | invalid_list = np.loadtxt(raw_path)
288 | except HTTPError as e:
289 | logging.error('HTTP Error:', e.code, url)
290 | logging.info('CAUTION: Could not download pre-computed list, will assess '
291 | 'validity during pre-processing.')
292 | invalid_list = None
293 | except URLError as e:
294 | logging.error('URL Error:', e.reason, url)
295 | logging.info('CAUTION: Could not download pre-computed list, will assess '
296 | 'validity during pre-processing.')
297 | invalid_list = None
298 | except ValueError as e:
299 | logging.error('Value Error:', e)
300 | logging.info('CAUTION: Could not download pre-computed list, will assess '
301 | 'validity during pre-processing.')
302 | invalid_list = None
303 | # check validity of molecules and store connectivity matrices and interatomic
304 | # distances in database as a pre-processing step
305 | qm9_db = os.path.join(self.path, f'qm9.db')
306 | valence_list = \
307 | np.array([self.available_atom_types, self.atom_types_valence]).flatten('F')
308 | precompute_fingerprint = 'fingerprint' in self.load_additionally
309 | preprocess_dataset(datapath=qm9_db, valence_list=list(valence_list),
310 | n_threads=8, n_mols_per_thread=125, logging_print=True,
311 | new_db_path=self.dbpath,
312 | precompute_distances=self.precompute_distances,
313 | precompute_fingerprint=precompute_fingerprint,
314 | remove_invalid=self.remove_invalid,
315 | invalid_list=invalid_list)
316 | return True
317 |
318 | def get_available_properties(self, available_properties):
319 | # we don't use properties other than stored connectivity matrices (and
320 | # distances, if they were precomputed) so we skip this part
321 | return available_properties
322 |
323 | def create_splits(self, num_train=None, num_val=None, split_file=None):
324 | """
325 | Splits the dataset into train/validation/test splits, writes split to
326 | an npz file and returns subsets. Either the sizes of training and
327 | validation split or an existing split file with split indices have to
328 | be supplied. The remaining data will be used in the test dataset.
329 | Args:
330 | num_train (int): number of training examples
331 | num_val (int): number of validation examples
332 | split_file (str): Path to split file. If file exists, splits will
333 | be loaded. Otherwise, a new file will be created
334 | where the generated split is stored.
335 | Returns:
336 | schnetpack.data.AtomsData: training dataset
337 | schnetpack.data.AtomsData: validation dataset
338 | schnetpack.data.AtomsData: test dataset
339 | """
340 | invalid_file_path = os.path.join(self.path, f"qm9_invalid.txt")
341 | if not os.path.exists(invalid_file_path):
342 | raise ValueError(f"Cannot find required file with indices of QM9 molecules that are invalid at {invalid_file_path}!")
343 | removed_idx = {*np.loadtxt(invalid_file_path).astype(int)}
344 | if split_file is not None and os.path.exists(split_file):
345 | S = np.load(split_file)
346 | train_idx = S["train_idx"].tolist()
347 | val_idx = S["val_idx"].tolist()
348 | test_idx = S["test_idx"].tolist()
349 | invalid_idx = {*S["invalid_idx"]}
350 | else:
351 | if num_train is None or num_val is None:
352 | raise ValueError(
353 | "You have to supply either split sizes (num_train /"
354 | + " num_val) or an npz file with splits."
355 | )
356 |
357 | assert num_train + num_val <= len(
358 | self
359 | ), "Dataset is smaller than num_train + num_val!"
360 |
361 | num_train = num_train if num_train > 1 else num_train * len(self)
362 | num_val = num_val if num_val > 1 else num_val * len(self)
363 | num_train = int(num_train)
364 | num_val = int(num_val)
365 |
366 | idx = np.random.permutation(len(self))
367 | train_idx = idx[:num_train].tolist()
368 | val_idx = idx[num_train : num_train + num_val].tolist()
369 | test_idx = idx[num_train + num_val :].tolist()
370 | invalid_idx = removed_idx
371 |
372 | if split_file is not None:
373 | np.savez(
374 | split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
375 | invalid_idx=sorted(list(invalid_idx))
376 | )
377 |
378 | if len(removed_idx) != len(invalid_idx) or len(removed_idx.difference(invalid_idx)) != 0:
379 | raise ValueError(f"Mismatch between the data base used to generate the provided split file and your local database. "
380 | + f"Please specify an empty data directory to re-download QM9 and try again.")
381 | train = self.create_subset(train_idx)
382 | val = self.create_subset(val_idx)
383 | test = self.create_subset(test_idx)
384 | return train, val, test
385 |
--------------------------------------------------------------------------------
/splits/1_polarizability_split.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/1_polarizability_split.npz
--------------------------------------------------------------------------------
/splits/2_fingerprint_split.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/2_fingerprint_split.npz
--------------------------------------------------------------------------------
/splits/3_gap_comp_split.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/3_gap_comp_split.npz
--------------------------------------------------------------------------------
/splits/4_comp_relenergy_split.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/4_comp_relenergy_split.npz
--------------------------------------------------------------------------------
/splits/5_gap_relenergy_split.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/cG-SchNet/d107fa20563db348c668fc2dd29d843010bd1174/splits/5_gap_relenergy_split.npz
--------------------------------------------------------------------------------
/splits/qm9_invalid.txt:
--------------------------------------------------------------------------------
1 | 270
2 | 281
3 | 1116
4 | 1641
5 | 1647
6 | 1669
7 | 3837
8 | 3894
9 | 3898
10 | 4009
11 | 4014
12 | 4806
13 | 4870
14 | 4871
15 | 5033
16 | 5556
17 | 5598
18 | 5803
19 | 6032
20 | 6058
21 | 6074
22 | 6169
23 | 6329
24 | 6345
25 | 6682
26 | 6725
27 | 6732
28 | 7325
29 | 7334
30 | 7374
31 | 7375
32 | 7409
33 | 7416
34 | 7418
35 | 7428
36 | 7601
37 | 7653
38 | 7656
39 | 8397
40 | 9985
41 | 10015
42 | 10275
43 | 10294
44 | 10348
45 | 10784
46 | 11387
47 | 13480
48 | 13804
49 | 13832
50 | 14120
51 | 14317
52 | 15511
53 | 17807
54 | 17808
55 | 18307
56 | 20389
57 | 20397
58 | 20481
59 | 20489
60 | 20641
61 | 20678
62 | 20682
63 | 20691
64 | 21118
65 | 21126
66 | 21164
67 | 21363
68 | 21463
69 | 21520
70 | 21533
71 | 21534
72 | 21554
73 | 21566
74 | 21610
75 | 21619
76 | 21702
77 | 21705
78 | 21716
79 | 21717
80 | 21724
81 | 21740
82 | 21747
83 | 21757
84 | 21763
85 | 21837
86 | 21856
87 | 21858
88 | 21869
89 | 21874
90 | 21967
91 | 21968
92 | 21969
93 | 21970
94 | 21971
95 | 21972
96 | 21973
97 | 21974
98 | 21975
99 | 21976
100 | 21977
101 | 21978
102 | 21979
103 | 21980
104 | 21981
105 | 21982
106 | 21983
107 | 21984
108 | 21985
109 | 21986
110 | 21987
111 | 22008
112 | 22053
113 | 22089
114 | 22092
115 | 22096
116 | 22110
117 | 22115
118 | 22194
119 | 22202
120 | 22231
121 | 22237
122 | 22248
123 | 22264
124 | 22451
125 | 22459
126 | 22464
127 | 22468
128 | 22470
129 | 22471
130 | 22481
131 | 22494
132 | 22497
133 | 22498
134 | 22503
135 | 22542
136 | 22680
137 | 22685
138 | 22971
139 | 22973
140 | 22979
141 | 23148
142 | 23149
143 | 23371
144 | 23792
145 | 23798
146 | 23818
147 | 23821
148 | 23827
149 | 25530
150 | 25764
151 | 25811
152 | 25828
153 | 25829
154 | 25859
155 | 25868
156 | 25892
157 | 25914
158 | 26121
159 | 26152
160 | 26153
161 | 26186
162 | 26228
163 | 26229
164 | 26538
165 | 27271
166 | 27293
167 | 27322
168 | 27335
169 | 27388
170 | 27860
171 | 28082
172 | 28250
173 | 28383
174 | 28401
175 | 29149
176 | 29167
177 | 29539
178 | 29557
179 | 29563
180 | 30525
181 | 30526
182 | 30528
183 | 30529
184 | 30537
185 | 30539
186 | 30545
187 | 30546
188 | 30548
189 | 30550
190 | 30551
191 | 30705
192 | 30712
193 | 30760
194 | 30761
195 | 30762
196 | 30786
197 | 30787
198 | 30797
199 | 30901
200 | 30902
201 | 30903
202 | 30993
203 | 30994
204 | 30995
205 | 30999
206 | 31012
207 | 31106
208 | 31108
209 | 31109
210 | 31110
211 | 31111
212 | 31170
213 | 31502
214 | 31598
215 | 32413
216 | 32464
217 | 32759
218 | 32813
219 | 32865
220 | 32884
221 | 32941
222 | 32942
223 | 33399
224 | 36994
225 | 36995
226 | 37991
227 | 38082
228 | 42423
229 | 43212
230 | 43242
231 | 43474
232 | 43519
233 | 45540
234 | 45544
235 | 45545
236 | 45926
237 | 46610
238 | 49722
239 | 50308
240 | 50449
241 | 50619
242 | 50735
243 | 51245
244 | 51246
245 | 52007
246 | 53819
247 | 53820
248 | 53844
249 | 53891
250 | 53895
251 | 53938
252 | 53940
253 | 53942
254 | 53943
255 | 53953
256 | 54077
257 | 54078
258 | 54101
259 | 54118
260 | 54123
261 | 54125
262 | 54228
263 | 54243
264 | 54295
265 | 54383
266 | 54386
267 | 54399
268 | 54408
269 | 54409
270 | 54411
271 | 54421
272 | 54447
273 | 54448
274 | 54486
275 | 54537
276 | 54568
277 | 54581
278 | 54610
279 | 54614
280 | 54617
281 | 54618
282 | 54623
283 | 54628
284 | 54656
285 | 54690
286 | 54691
287 | 54765
288 | 54793
289 | 54794
290 | 54795
291 | 54810
292 | 54873
293 | 54895
294 | 54899
295 | 54903
296 | 54993
297 | 55144
298 | 55145
299 | 55186
300 | 55189
301 | 55266
302 | 55399
303 | 55407
304 | 55409
305 | 55437
306 | 55449
307 | 55475
308 | 55476
309 | 55478
310 | 55483
311 | 55498
312 | 55517
313 | 55557
314 | 55609
315 | 55610
316 | 55618
317 | 55619
318 | 55620
319 | 55700
320 | 55702
321 | 55790
322 | 55909
323 | 55943
324 | 56015
325 | 56054
326 | 56071
327 | 56240
328 | 56342
329 | 56343
330 | 57735
331 | 57736
332 | 57944
333 | 58280
334 | 58612
335 | 58613
336 | 58981
337 | 59826
338 | 59848
339 | 59965
340 | 59976
341 | 60659
342 | 60717
343 | 60779
344 | 61434
345 | 61439
346 | 61450
347 | 62028
348 | 62083
349 | 66510
350 | 66602
351 | 66603
352 | 71535
353 | 72316
354 | 72318
355 | 74136
356 | 74175
357 | 74199
358 | 74201
359 | 74240
360 | 74241
361 | 74242
362 | 74312
363 | 75052
364 | 75169
365 | 76134
366 | 76135
367 | 76142
368 | 76371
369 | 76372
370 | 76379
371 | 76393
372 | 76394
373 | 76396
374 | 77141
375 | 77459
376 | 80207
377 | 80594
378 | 80596
379 | 81048
380 | 81053
381 | 81056
382 | 81566
383 | 81567
384 | 81572
385 | 81577
386 | 81578
387 | 81579
388 | 81580
389 | 82081
390 | 83400
391 | 83410
392 | 83413
393 | 83414
394 | 83416
395 | 84309
396 | 84799
397 | 85156
398 | 85354
399 | 85487
400 | 85779
401 | 85951
402 | 85961
403 | 86562
404 | 86587
405 | 86635
406 | 86738
407 | 86741
408 | 87034
409 | 87036
410 | 89621
411 | 89625
412 | 89627
413 | 90286
414 | 90692
415 | 90693
416 | 90695
417 | 91152
418 | 91257
419 | 91258
420 | 91518
421 | 92759
422 | 93323
423 | 93346
424 | 93566
425 | 93571
426 | 93940
427 | 93941
428 | 93985
429 | 93987
430 | 93996
431 | 94181
432 | 94603
433 | 94605
434 | 95437
435 | 96611
436 | 96612
437 | 96636
438 | 96637
439 | 96639
440 | 96678
441 | 97115
442 | 97259
443 | 97324
444 | 97357
445 | 97362
446 | 97454
447 | 97457
448 | 97475
449 | 97528
450 | 97529
451 | 98010
452 | 98232
453 | 98233
454 | 98234
455 | 99224
456 | 99716
457 | 99725
458 | 99727
459 | 99730
460 | 99732
461 | 99744
462 | 99808
463 | 100075
464 | 100091
465 | 100442
466 | 100456
467 | 100514
468 | 100518
469 | 100625
470 | 100626
471 | 100709
472 | 100733
473 | 101806
474 | 101940
475 | 102014
476 | 102130
477 | 102224
478 | 102627
479 | 102633
480 | 102793
481 | 102795
482 | 102796
483 | 103797
484 | 103798
485 | 103812
486 | 103820
487 | 104600
488 | 104601
489 | 105193
490 | 105210
491 | 105214
492 | 105578
493 | 108409
494 | 108890
495 | 110173
496 | 112229
497 | 112337
498 | 112354
499 | 112496
500 | 112945
501 | 112946
502 | 112954
503 | 112989
504 | 113156
505 | 113160
506 | 113173
507 | 113174
508 | 113175
509 | 113183
510 | 115697
511 | 115698
512 | 116536
513 | 116638
514 | 116798
515 | 116943
516 | 117294
517 | 117522
518 | 117629
519 | 117642
520 | 118440
521 | 118447
522 | 119757
523 | 120430
524 | 120722
525 | 121012
526 | 121588
527 | 121595
528 | 121599
529 | 121610
530 | 121612
531 | 121779
532 | 121863
533 | 121881
534 | 122766
535 | 123125
536 | 123128
537 | 123544
538 | 123567
539 | 123588
540 | 123592
541 | 123615
542 | 123619
543 | 123629
544 | 123641
545 | 123654
546 | 123673
547 | 123685
548 | 123698
549 | 123901
550 | 123907
551 | 123964
552 | 123997
553 | 124017
554 | 124033
555 | 124121
556 | 124204
557 | 124221
558 | 124249
559 | 124709
560 | 124711
561 | 124713
562 | 124721
563 | 124722
564 | 124723
565 | 124730
566 | 124731
567 | 124736
568 | 124934
569 | 125054
570 | 125099
571 | 125275
572 | 125283
573 | 125360
574 | 125388
575 | 125470
576 | 125618
577 | 125629
578 | 125758
579 | 125792
580 | 125904
581 | 125916
582 | 126007
583 | 126024
584 | 126080
585 | 126088
586 | 126092
587 | 126291
588 | 126346
589 | 126350
590 | 126359
591 | 126864
592 | 126872
593 | 127082
594 | 127323
595 | 127355
596 | 127394
597 | 127406
598 | 127542
599 | 127605
600 | 127633
601 | 127777
602 | 127838
603 | 127892
604 | 127893
605 | 127894
606 | 128141
607 | 128142
608 | 128146
609 | 128170
610 | 128182
611 | 128194
612 | 128251
613 | 128259
614 | 128391
615 | 128393
616 | 128396
617 | 128406
618 | 128417
619 | 128421
620 | 128498
621 | 128527
622 | 128528
623 | 128557
624 | 128567
625 | 128618
626 | 128626
627 | 128932
628 | 128947
629 | 129099
630 | 129105
631 | 129113
632 | 129135
633 | 129136
634 | 129144
635 | 129145
636 | 129146
637 | 129148
638 | 129149
639 | 129150
640 | 129155
641 | 129156
642 | 129158
643 | 129169
644 | 129174
645 | 129176
646 | 129181
647 | 129242
648 | 129249
649 | 129316
650 | 129335
651 | 129336
652 | 129339
653 | 129392
654 | 129400
655 | 129405
656 | 129409
657 | 129410
658 | 129411
659 | 129577
660 | 129578
661 | 129580
662 | 129653
663 | 129735
664 | 129859
665 | 129867
666 | 129914
667 | 129939
668 | 129993
669 | 129995
670 | 129996
671 | 129998
672 | 130006
673 | 130008
674 | 130035
675 | 130037
676 | 130120
677 | 130181
678 | 130296
679 | 130335
680 | 130336
681 | 130337
682 | 130338
683 | 130344
684 | 130345
685 | 130354
686 | 130355
687 | 130356
688 | 130357
689 | 130365
690 | 130369
691 | 130373
692 | 130376
693 | 130381
694 | 130382
695 | 130384
696 | 130385
697 | 130386
698 | 130387
699 | 130392
700 | 130403
701 | 130405
702 | 130406
703 | 130415
704 | 130423
705 | 130434
706 | 130437
707 | 130439
708 | 130440
709 | 130449
710 | 130452
711 | 130453
712 | 130462
713 | 130466
714 | 130469
715 | 130475
716 | 130479
717 | 130530
718 | 130536
719 | 130537
720 | 130582
721 | 130583
722 | 130587
723 | 130591
724 | 130602
725 | 130619
726 | 130629
727 | 130634
728 | 130661
729 | 130663
730 | 130664
731 | 130665
732 | 130666
733 | 130668
734 | 130669
735 | 130679
736 | 130683
737 | 130685
738 | 130691
739 | 130740
740 | 130746
741 | 130793
742 | 130860
743 | 130878
744 | 130882
745 | 130918
746 | 131091
747 | 131164
748 | 131199
749 | 131224
750 | 131513
751 | 131541
752 | 131554
753 | 131658
754 | 131693
755 | 131695
756 | 131704
757 | 131881
758 | 131882
759 | 131883
760 | 131884
761 | 131885
762 | 131886
763 | 131887
764 | 131888
765 | 131889
766 | 131890
767 | 131891
768 | 131892
769 | 131893
770 | 131894
771 | 131895
772 | 131896
773 | 131897
774 | 131898
775 | 131899
776 | 131900
777 | 131901
778 | 131902
779 | 131903
780 | 131904
781 | 131905
782 | 131906
783 | 131907
784 | 131908
785 | 131909
786 | 131910
787 | 131911
788 | 131912
789 | 131913
790 | 131914
791 | 131915
792 | 131916
793 | 131917
794 | 131918
795 | 131919
796 | 131920
797 | 131921
798 | 131922
799 | 131923
800 | 131924
801 | 131925
802 | 131926
803 | 131927
804 | 131928
805 | 131929
806 | 131930
807 | 131931
808 | 131932
809 | 131933
810 | 131934
811 | 131935
812 | 131936
813 | 131937
814 | 131938
815 | 131939
816 | 131940
817 | 131941
818 | 131942
819 | 131943
820 | 131944
821 | 131945
822 | 131946
823 | 131947
824 | 131948
825 | 131949
826 | 131950
827 | 131951
828 | 131952
829 | 131953
830 | 131954
831 | 131955
832 | 131956
833 | 131957
834 | 131958
835 | 131959
836 | 131960
837 | 131961
838 | 131962
839 | 131963
840 | 131964
841 | 131965
842 | 131966
843 | 131967
844 | 131968
845 | 131969
846 | 131970
847 | 131971
848 | 131972
849 | 131973
850 | 131974
851 | 131975
852 | 131976
853 | 131977
854 | 131978
855 | 131979
856 | 131980
857 | 131981
858 | 131982
859 | 131983
860 | 131984
861 | 131985
862 | 131986
863 | 131987
864 | 131988
865 | 131989
866 | 131990
867 | 131991
868 | 131992
869 | 131993
870 | 131994
871 | 131995
872 | 131996
873 | 131997
874 | 131998
875 | 131999
876 | 132000
877 | 132001
878 | 132071
879 | 132883
880 | 133142
881 | 133166
882 | 133167
883 | 133262
884 | 133273
885 | 133310
886 | 133336
887 | 133337
888 | 133395
889 | 133396
890 | 133402
891 | 133403
892 | 133812
893 | 133815
894 | 133819
895 | 133821
896 | 133825
897 | 133827
898 | 133828
899 | 133831
900 | 133832
901 | 133833
902 | 133839
903 | 133842
904 | 133843
905 | 133844
906 | 133845
907 | 133846
908 | 133848
909 | 133850
910 | 133851
911 | 133853
912 | 133857
913 | 133863
914 | 133864
915 | 133865
916 |
--------------------------------------------------------------------------------
/utility_classes.py:
--------------------------------------------------------------------------------
1 | import operator
2 | import re
3 | import numpy as np
4 | import openbabel as ob
5 | import pybel
6 | from rdkit import Chem
7 | from multiprocessing import Process
8 | from scipy.spatial.distance import squareform
9 |
10 |
11 | class Molecule:
12 | '''
13 | Molecule class that allows to get statistics such as the connectivity matrix,
14 | molecular fingerprint, canonical smiles representation, or ring count given
15 | positions of atoms and their atomic numbers. Currently supports molecules made of
16 | carbon, nitrogen, oxygen, fluorine, and hydrogen (such as in the QM9 benchmark
17 | dataset). Mainly relies on routines from Open Babel and RdKit.
18 |
19 | Args:
20 | pos (numpy.ndarray): positions of atoms in euclidean space (n_atoms x 3)
21 | atomic_numbers (numpy.ndarray): list with nuclear charge/type of each atom
22 | (e.g. 1 for hydrogens, 6 for carbons etc.).
23 | connectivity_matrix (numpy.ndarray, optional): optionally, a pre-calculated
24 | connectivity matrix (n_atoms x n_atoms) containing the bond order between
25 | atom pairs can be provided (default: None).
26 | store_positions (bool, optional): set True to store the positions of atoms in
27 | self.positions (only for convenience, not needed for computations, default:
28 | False).
29 | '''
30 |
31 | type_infos = {1: {'name': 'H',
32 | 'n_bonds': 1},
33 | 6: {'name': 'C',
34 | 'n_bonds': 4},
35 | 7: {'name': 'N',
36 | 'n_bonds': 3},
37 | 8: {'name': 'O',
38 | 'n_bonds': 2},
39 | 9: {'name': 'F',
40 | 'n_bonds': 1},
41 | }
42 | type_charges = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9}
43 |
44 | def __init__(self, pos, atomic_numbers, connectivity_matrix=None,
45 | store_positions=False):
46 | # was causing problems with other scripts if imported outside of this class
47 | # import openbabel as ob
48 | # import pybel
49 | # from rdkit import Chem
50 |
51 | # set comparison metrics to None (will be computed just in time)
52 | self._fp = None
53 | self._fp_bits = None
54 | self._can = None
55 | self._mirror_can = None
56 | self._inchi_key = None
57 | self._bond_stats = None
58 | self._fixed_connectivity = False
59 | self._row_indices = {}
60 | self._obmol = None
61 | self._rings = None
62 | self._n_atoms_per_type = None
63 | self._connectivity = connectivity_matrix
64 |
65 | # set statistics
66 | self.n_atoms = len(pos)
67 | self.numbers = atomic_numbers
68 | self._unique_numbers = {*self.numbers} # set for fast query
69 | self.positions = pos
70 | if not store_positions:
71 | self._obmol = self.get_obmol() # create obmol before removing pos
72 | self.positions = None
73 |
74 | def sanity_check(self):
75 | '''
76 | Check whether the sum of valence of all atoms can be divided by 2.
77 |
78 | Returns:
79 | bool: True if the test is passed, False otherwise
80 | '''
81 | count = 0
82 | for atom in self.numbers:
83 | count += self.type_infos[atom]['n_bonds']
84 | if count % 2 == 0:
85 | return True
86 | else:
87 | return False
88 |
89 | def get_obmol(self):
90 | '''
91 | Retrieve the underlying Open Babel OBMol object.
92 |
93 | Returns:
94 | OBMol object: Open Babel OBMol representation
95 | '''
96 | if self._obmol is None:
97 | if self.positions is None:
98 | print('Error, cannot create obmol without positions!')
99 | return
100 | if self.numbers is None:
101 | print('Error, cannot create obmol without atomic numbers!')
102 | return
103 | # use openbabel to infer bonds and bond order:
104 | obmol = ob.OBMol()
105 | obmol.BeginModify()
106 |
107 | # set positions and atomic numbers of all atoms in the molecule
108 | for p, n in zip(self.positions, self.numbers):
109 | obatom = obmol.NewAtom()
110 | obatom.SetAtomicNum(int(n))
111 | obatom.SetVector(*p.tolist())
112 |
113 | # infer bonds and bond order
114 | obmol.ConnectTheDots()
115 | obmol.PerceiveBondOrders()
116 |
117 | obmol.EndModify()
118 | self._obmol = obmol
119 | return self._obmol
120 |
121 | def get_fp(self):
122 | '''
123 | Retrieve the molecular fingerprint (the path-based FP2 from Open Babel is used,
124 | which means that paths of length up to 7 are considered).
125 |
126 | Returns:
127 | pybel.Fingerprint object: moleculer fingerprint (use "fp1 | fp2" to
128 | calculate the Tanimoto coefficient of two fingerprints)
129 | '''
130 | if self._fp is None:
131 | # calculate fingerprint
132 | self._fp = pybel.Molecule(self.get_obmol()).calcfp()
133 | return self._fp
134 |
135 | def get_fp_bits(self):
136 | '''
137 | Retrieve the bits set in the molecular fingerprint.
138 |
139 | Returns:
140 | Set of int: object containing the bits set in the molecular fingerprint
141 | '''
142 | if self._fp_bits is None:
143 | self._fp_bits = {*self.get_fp().bits}
144 | return self._fp_bits
145 |
146 | def get_can(self):
147 | '''
148 | Retrieve the canonical SMILES representation of the molecule.
149 |
150 | Returns:
151 | String: canonical SMILES string
152 | '''
153 | if self._can is None:
154 | # calculate canonical SMILES
155 | self._can = pybel.Molecule(self.get_obmol()).write('can')
156 | return self._can
157 |
158 | def get_mirror_can(self):
159 | '''
160 | Retrieve the canonical SMILES representation of the mirrored molecule (the
161 | z-coordinates are flipped).
162 |
163 | Returns:
164 | String: canonical SMILES string of the mirrored molecule
165 | '''
166 | if self._mirror_can is None:
167 | # calculate canonical SMILES of mirrored molecule
168 | self._flip_z() # flip z to mirror molecule using x-y plane
169 | self._mirror_can = pybel.Molecule(self.get_obmol()).write('can')
170 | self._flip_z() # undo mirroring
171 | return self._mirror_can
172 |
173 | def get_inchi_key(self):
174 | '''
175 | Retrieve the InChI-key of the molecule.
176 |
177 | Returns:
178 | String: InChI-key
179 | '''
180 | if self._inchi_key is None:
181 | # calculate inchi key
182 | self._inchi_key = pybel.Molecule(self.get_obmol()).\
183 | write('inchikey')
184 | return self._inchi_key
185 |
186 | def _flip_z(self):
187 | '''
188 | Flips the z-coordinates of atom positions (to get a mirrored version of the
189 | molecule).
190 | '''
191 | if self._obmol is None:
192 | self.get_obmol()
193 | for atom in ob.OBMolAtomIter(self._obmol):
194 | x, y, z = atom.x(), atom.y(), atom.z()
195 | atom.SetVector(x, y, -z)
196 | self._obmol.ConnectTheDots()
197 | self._obmol.PerceiveBondOrders()
198 |
199 | def get_connectivity(self):
200 | '''
201 | Retrieve the connectivity matrix of the molecule.
202 |
203 | Returns:
204 | numpy.ndarray: (n_atoms x n_atoms) array containing the pairwise bond orders
205 | between atoms (0 for no bond).
206 | '''
207 | if self._connectivity is None:
208 | # get connectivity matrix
209 | connectivity = np.zeros((self.n_atoms, len(self.numbers)))
210 | for atom in ob.OBMolAtomIter(self.get_obmol()):
211 | index = atom.GetIdx() - 1
212 | # loop over all neighbors of atom
213 | for neighbor in ob.OBAtomAtomIter(atom):
214 | idx = neighbor.GetIdx() - 1
215 | bond_order = neighbor.GetBond(atom).GetBO()
216 | #print(f'{index}-{idx}: {bond_order}')
217 | # do not count bonds between two hydrogen atoms
218 | if (self.numbers[index] == 1 and self.numbers[idx] == 1
219 | and bond_order > 0):
220 | bond_order = 0
221 | connectivity[index, idx] = bond_order
222 | self._connectivity = connectivity
223 | return self._connectivity
224 |
225 | def get_ring_counts(self):
226 | '''
227 | Retrieve a list containing the sizes of rings in the symmetric smallest set
228 | of smallest rings (S-SSSR from RdKit) in the molecule (e.g. [5, 6, 5] for two
229 | rings of size 5 and one ring of size 6).
230 |
231 | Returns:
232 | List of int: list with ring sizes
233 | '''
234 | if self._rings is None:
235 | # calculate symmetric SSSR with RdKit using the canonical smiles
236 | # representation as input
237 | can = self.get_can()
238 | mol = Chem.MolFromSmiles(can)
239 | if mol is not None:
240 | ssr = Chem.GetSymmSSSR(mol)
241 | self._rings = [len(ssr[i]) for i in range(len(ssr))]
242 | else:
243 | self._rings = [] # cannot count rings
244 | return self._rings
245 |
246 | def get_n_atoms_per_type(self):
247 | '''
248 | Retrieve the number of atoms in the molecule per type.
249 |
250 | Returns:
251 | numpy.ndarray: number of atoms in the molecule per type, where the order
252 | corresponds to the order specified in Molecule.type_infos
253 | '''
254 | if self._n_atoms_per_type is None:
255 | _types = np.array(list(self.type_infos.keys()), dtype=int)
256 | self._n_atoms_per_type =\
257 | np.bincount(self.numbers, minlength=np.max(_types)+1)[_types]
258 | return self._n_atoms_per_type
259 |
260 | def remove_unpicklable_attributes(self, restorable=True):
261 | '''
262 | Some attributes of the class cannot be processed by pickle. This method
263 | allows to remove these attributes prior to pickling.
264 |
265 | Args:
266 | restorable (bool, optional): Set True to allow restoring the deleted
267 | attributes later on (default: True)
268 | '''
269 | # set attributes which are not picklable (SwigPyObjects) to None
270 | if restorable and self.positions is None and self._obmol is not None:
271 | # store positions to allow restoring obmol object later on
272 | pos = [atom.coords for atom in pybel.Molecule(self._obmol).atoms]
273 | self.positions = np.array(pos)
274 | self._obmol = None
275 | self._fp = None
276 |
277 | def tanimoto_similarity(self, other_mol, use_bits=True):
278 | '''
279 | Get the Tanimoto (fingerprint) similarity to another molecule.
280 |
281 | Args:
282 | other_mol (Molecule or pybel.Fingerprint/list of bits set):
283 | representation of the second molecule (if it is not a Molecule object,
284 | it needs to be a pybel.Fingerprint if use_bits is False and a list of bits
285 | set in the fingerprint if use_bits is True).
286 | use_bits (bool, optional): set True to calculate Tanimoto similarity
287 | from bits set in the fingerprint (default: True)
288 |
289 | Returns:
290 | float: Tanimoto similarity to the other molecule
291 | '''
292 | if use_bits:
293 | a = self.get_fp_bits()
294 | b = other_mol.get_fp_bits() if isinstance(other_mol, Molecule) \
295 | else other_mol
296 | n_equal = len(a.intersection(b))
297 | if len(a) + len(b) == 0: # edge case with no set bits
298 | return 1.
299 | return n_equal / (len(a)+len(b)-n_equal)
300 | else:
301 | fp_other = other_mol.get_fp() if isinstance(other_mol, Molecule)\
302 | else other_mol
303 | return self.get_fp() | fp_other
304 |
305 | def _update_bond_orders(self, idc_lists):
306 | '''
307 | Updates the bond orders in the underlying OBMol object.
308 |
309 | Args:
310 | idc_lists (list of list of int): nested list containing bonds, i.e. pairs
311 | of row indices (list1) and column indices (list2) which shall be updated
312 | '''
313 | con_mat = self.get_connectivity()
314 | self._obmol.BeginModify()
315 | for i in range(len(idc_lists[0])):
316 | idx1 = idc_lists[0][i]
317 | idx2 = idc_lists[1][i]
318 | obbond = self._obmol.GetBond(int(idx1+1), int(idx2+1))
319 | obbond.SetBO(int(con_mat[idx1, idx2]))
320 | self._obmol.EndModify()
321 |
322 | # reset fingerprints etc
323 | self._fp = None
324 | self._can = None
325 | self._mirror_can = None
326 | self._inchi_key = None
327 |
328 | def get_fixed_connectivity(self, recursive_call=False):
329 | '''
330 | Attempts to fix the connectivity matrix using some heuristics (as some valid
331 | QM9 molecules do not pass the valency check using the connectivity matrix
332 | obtained with Open Babel, which seems to have problems with assigning correct
333 | bond orders to aromatic rings containing Nitrogen).
334 |
335 | Args:
336 | recursive_call (bool, do not set True): flag that indicates a recursive
337 | call (used internally, do not set to True)
338 |
339 | Returns:
340 | numpy.ndarray: (n_atoms x n_atoms) array containing the pairwise bond orders
341 | between atoms (0 for no bond) after the attempted fix.
342 | '''
343 |
344 | # if fix has already been attempted, return the connectivity matrix
345 | if self._fixed_connectivity:
346 | return self._connectivity
347 |
348 | # define helpers:
349 | # increases bond order between two atoms in connectivity matrix
350 | def increase_bond(con_mat, idx1, idx2):
351 | con_mat[idx1, idx2] += 1
352 | con_mat[idx2, idx1] += 1
353 | return con_mat
354 |
355 | # decreases bond order between two atoms in connectivity matrix
356 | def decrease_bond(con_mat, idx1, idx2):
357 | con_mat[idx1, idx2] -= 1
358 | con_mat[idx2, idx1] -= 1
359 | return con_mat
360 |
361 | # returns only the rows of the connectivity matrix corresponding to atoms of
362 | # certain types (and the indices of these atoms)
363 | def get_typewise_connectivity(con_mat, types):
364 | idcs = []
365 | for type in types:
366 | idcs += list(self._get_row_idcs(type))
367 | return con_mat[idcs], np.array(idcs).astype(int)
368 |
369 | # store old connectivity matrix for later comparison
370 | old_mat = self.get_connectivity().copy()
371 |
372 | # get connectivity matrix and find indices of N and C atoms
373 | con_mat = self.get_connectivity()
374 | if 6 not in self._unique_numbers and 7 not in self._unique_numbers:
375 | # do not attempt fixing if there is no carbon and no nitrogen
376 | return con_mat
377 | N_mat, N_idcs = get_typewise_connectivity(con_mat, [7])
378 | C_mat, C_idcs = get_typewise_connectivity(con_mat, [6])
379 | NC_idcs = np.hstack((N_idcs, C_idcs)) # indices of all N and C atoms
380 | NC_valences = self._get_valences()[NC_idcs] # array with valency constraints
381 |
382 | # return connectivity if valency constraints of N and C atoms are already met
383 | if np.all(np.sum(con_mat[NC_idcs], axis=1) == NC_valences):
384 | return con_mat
385 |
386 | # if a C or N atom is "overcharged" (total bond order too high) we decrease
387 | # double to single bonds between N-N or N-C until it is not overcharged anymore
388 | # (e.g. C=N=C -> C=N-C)
389 | if 7 in self._unique_numbers: # only necessary if molecule contains N
390 | for cur in NC_idcs:
391 | type = self.numbers[cur]
392 | if np.sum(con_mat[cur]) <= self.type_infos[type]['n_bonds']:
393 | continue
394 | if type == 6: # for carbon look only at nitrogen neighbors
395 | neighbors = self._get_neighbors(cur, types=[7], strength=2)
396 | else:
397 | neighbors = self._get_neighbors(cur, types=[6, 7],
398 | strength=2)
399 | for neighbor in neighbors:
400 | con_mat = decrease_bond(con_mat, cur, neighbor)
401 | self._connectivity = con_mat
402 | if np.sum(con_mat[cur]) == \
403 | self.type_infos[type]['n_bonds']:
404 | break
405 |
406 | # get updated partial connectivity matrices for N and C
407 | N_mat, _ = get_typewise_connectivity(con_mat, [7])
408 | C_mat, _ = get_typewise_connectivity(con_mat, [6])
409 |
410 | # increase total number of bonds by transferring the strength of a
411 | # double C-N bond to two neighboring bonds, if the involved atoms
412 | # are not yet saturated (e.g. H2C-H2C=N-H2C -> H2C=H2C-N=H2C)
413 | if (np.sum(N_mat) < len(N_idcs) * 3 or np.sum(C_mat) < len(C_idcs) * 4) \
414 | and 7 in self._unique_numbers:
415 | for cur in NC_idcs:
416 | type = self.numbers[cur]
417 | if sum(con_mat[cur]) >= self.type_infos[type]['n_bonds']:
418 | continue
419 | CN_nbors = self._get_CN_neighbors(cur, order_1=1, order_2=2)
420 | for nbor_1, nbor_2 in CN_nbors:
421 | nbor_2_nbors = np.where(con_mat[nbor_2] == 1)[0]
422 | for nbor_2_nbor in nbor_2_nbors:
423 | nbor_2_nbor_type = self.numbers[nbor_2_nbor]
424 | if (np.sum(con_mat[nbor_2_nbor]) <
425 | self.type_infos[nbor_2_nbor_type]['n_bonds']):
426 | con_mat = increase_bond(con_mat, cur, nbor_1)
427 | con_mat = increase_bond(con_mat, nbor_2, nbor_2_nbor)
428 | con_mat = decrease_bond(con_mat, nbor_1, nbor_2)
429 | self._connectivity = con_mat
430 |
431 | # increase total number of bonds by transferring the strength of a
432 | # triple C-C bond to two neighboring bonds, if the involved atoms
433 | # are not yet saturated
434 | if (np.sum(N_mat) < len(N_idcs) * 3 or np.sum(C_mat) < len(C_idcs) * 4):
435 | for cur in NC_idcs:
436 | type = self.numbers[cur]
437 | if sum(con_mat[cur]) >= self.type_infos[type]['n_bonds']:
438 | continue
439 | CC_nbors = self._get_CN_neighbors(cur, order_1=1, order_2=3,
440 | only_CC=True, include_CCC=True)
441 | for nbor_1, nbor_2 in CC_nbors:
442 | nbor_2_nbors = np.where(con_mat[nbor_2] == 1)[0]
443 | for nbor_2_nbor in nbor_2_nbors:
444 | nbor_2_nbor_type = self.numbers[nbor_2_nbor]
445 | if (np.sum(con_mat[nbor_2_nbor]) <
446 | self.type_infos[nbor_2_nbor_type]['n_bonds']):
447 | con_mat = increase_bond(con_mat, cur, nbor_1)
448 | con_mat = increase_bond(con_mat, nbor_2, nbor_2_nbor)
449 | con_mat = decrease_bond(con_mat, nbor_1, nbor_2)
450 | self._connectivity = con_mat
451 |
452 | # increase bond strength between two undercharged neighbors C-N,
453 | # C-C or N-N (e.g HN-CH2 -> HN=CH2, starting from those atoms with least
454 | # available neighbors if there are multiple undercharged neighbors)
455 | undercharged_pairs = True
456 | while (undercharged_pairs):
457 | NC_charges = np.sum(con_mat[NC_idcs], axis=1)
458 | undercharged = NC_idcs[np.where(NC_charges < NC_valences)[0]]
459 | partial_con_mat = con_mat[undercharged][:, undercharged]
460 | # if non of the undercharged atoms are neighbors, stop
461 | if np.sum(partial_con_mat) == 0:
462 | break
463 | # sort by number of undercharged neighbors
464 | n_nbors = np.sum(partial_con_mat > 0, axis=0)
465 | # mask indices with zero undercharged neighbors to ignore them when sorting
466 | n_nbors[np.where(n_nbors == 0)[0]] = 1000
467 | cur = np.argmin(n_nbors)
468 | cur_nbor = np.where(partial_con_mat[cur] > 0)[0][0]
469 | con_mat = increase_bond(con_mat, undercharged[cur], undercharged[cur_nbor])
470 | self._connectivity = con_mat
471 |
472 | # if the molecule still is not valid, try to flip double bonds if an atom
473 | # forms a double bond and has at least one other neighbor that has too few bonds
474 | # (e.g. C-N=C -> C=N-C) and repeat above heuristics with a recursive call of
475 | # this function
476 | if not recursive_call and \
477 | not np.all(np.sum(con_mat[NC_idcs], axis=1) == NC_valences):
478 | changed = False
479 | candidates = np.where(np.any(con_mat[NC_idcs][:, NC_idcs] == 2, axis=0))[0]
480 | for cand in NC_idcs[candidates]:
481 | if np.sum(con_mat[cand, NC_idcs] == 2) == 0:
482 | continue
483 | NC_charges = np.sum(con_mat[NC_idcs], axis=1)
484 | undercharged = NC_charges < NC_valences
485 | uc_neighbors = np.logical_and(con_mat[cand, NC_idcs] == 1, undercharged)
486 | if np.any(uc_neighbors):
487 | uc_neighbor = NC_idcs[np.where(uc_neighbors)[0][0]]
488 | oc_neighbor = NC_idcs[
489 | np.where(con_mat[cand, NC_idcs] == 2)[0][0]]
490 | con_mat = increase_bond(con_mat, cand, uc_neighbor)
491 | con_mat = decrease_bond(con_mat, cand, oc_neighbor)
492 | self._connectivity = con_mat
493 | changed = True
494 | if changed:
495 | self._connectivity = self.get_fixed_connectivity(
496 | recursive_call=True)
497 |
498 | # store that fixing the connectivity matrix has already been attempted
499 | if not recursive_call:
500 | self._fixed_connectivity = True
501 | if np.any(old_mat != self._connectivity):
502 | # update bond orders in underlying OBMol object (where they changed)
503 | self._update_bond_orders(np.where(old_mat != self._connectivity))
504 |
505 | return self._connectivity
506 |
507 | def _get_valences(self):
508 | '''
509 | Retrieve the valency constraints of all atoms in the molecule.
510 |
511 | Returns:
512 | numpy.ndarray: valency constraints (one per atom)
513 | '''
514 | valence = []
515 | for atom in self.numbers:
516 | valence += [self.type_infos[atom]['n_bonds']]
517 | return np.array(valence)
518 |
519 | def _get_CN_neighbors(self, idx, order_1=1, order_2=1,
520 | only_CC=False, include_CCC=False):
521 | '''
522 | For a focus atom of type K returns indices of atoms C (carbon) and N (nitrogen)
523 | on two-step paths of the form K-C-N (and K-C-C only if K is nitrogen since one
524 | atom needs to be nitrogen, except if include_CCC is set True). The required
525 | bond order on the path can be constraint using the parameters order_1 and
526 | order_2.
527 |
528 | Args:
529 | idx (int): the index of the focus atom from which paths are examined
530 | order_1 (int, optional): the minimum bond order that neighbors must share
531 | with the focus atom in order to be added to the results (allows to
532 | constrain the paths to e.g. K=C-N or K=C-C with order_1=2, default: 1)
533 | order_2 (int, optional): the minimum bond order that second degree
534 | neighbors must share with the first degree neighbors in order to be
535 | added to the results (allows to constrain the paths to e.g. K-C=N or
536 | K-C=C with order_2=2, default: 1)
537 | only_CC (bool, optional): include only K-C-C paths (i.e. ignore K-C-N,
538 | default: False)
539 | include_CCC (bool, optional): also include K-C-C paths if K is carbon
540 | (default: False)
541 |
542 | Returns:
543 | list of lists: list1[i] contains an index of a direct neighbor of the
544 | focus atom and list2[i] contains the index of a second neighbor on the
545 | i-th identified two-step path
546 | '''
547 | con_mat = self.get_connectivity()
548 | nbors = con_mat[idx] >= order_1
549 | C_nbors = np.where(np.logical_and(self.numbers == 6, nbors))[0]
550 | type = self.numbers[idx]
551 | # mask types to exclude idx from neighborhood
552 | _numbers = self.numbers.copy()
553 | _numbers[idx] = 0
554 | CN_nbors = []
555 | CC_nbors = []
556 | if not only_CC:
557 | CN_nbors = np.where(np.logical_and(_numbers == 7,
558 | con_mat[C_nbors] >= order_2))
559 | CN_nbors = [(C_nbors[CN_nbors[0][i]], CN_nbors[1][i])
560 | for i in range(len(CN_nbors[0]))]
561 | if type == 7 or include_CCC: # for N atoms, also add C-C neighbors
562 | CC_nbors = np.where(np.logical_and(_numbers == 6,
563 | con_mat[C_nbors] >= order_2))
564 | CC_nbors = [(C_nbors[CC_nbors[0][i]], CC_nbors[1][i])
565 | for i in range(len(CC_nbors[0]))]
566 | return CN_nbors + CC_nbors
567 |
568 | def _get_neighbors(self, idx, types=None, strength=1):
569 | '''
570 | Retrieve the indices of neighbors of an atom.
571 |
572 | Args:
573 | idx (int): index of the atom
574 | types (list of int, optional): restrict the returned neighbors to
575 | contain only atoms of the specified types (set None to apply no type
576 | filter, default: None)
577 | strength (int, optional): restrict the returned neighbors to contain
578 | only atoms with a certain minimal bond order to the atom at idx
579 | (default: 1)
580 |
581 | Returns:
582 | list of int: indices of all neighbors that meet the requirements
583 | '''
584 | con_mat = self.get_connectivity()
585 | neighbors = con_mat[idx] >= strength
586 | if types is not None:
587 | type_arr = np.zeros(len(neighbors)).astype(bool)
588 | for type in types:
589 | type_arr = np.logical_or(type_arr, self.numbers == type)
590 | return np.where(np.logical_and(neighbors, type_arr))[0]
591 |
592 | def get_bond_stats(self):
593 | '''
594 | Retrieve the bond and ring count of the molecule. The bond count is
595 | calculated for every pair of types (e.g. C1N are all single bonds between
596 | carbon and nitrogen atoms in the molecule, C2N are all double bonds between
597 | such atoms etc.). The ring count is provided for rings from size 3 to 8 (R3,
598 | R4, ..., R8) and for rings greater than size eight (R>8).
599 |
600 | Returns:
601 | dict (str->int): bond and ring counts
602 | '''
603 | if self._bond_stats is None:
604 |
605 | # 1st analyze bonds
606 | unique_types = np.sort(list(self._unique_numbers))
607 | # get connectivity and read bonds from matrix
608 | con_mat = self.get_connectivity()
609 | d = {}
610 | for i, type1 in enumerate(unique_types):
611 | row_idcs = self._get_row_idcs(type1)
612 | n_bonds1 = self.type_infos[type1]['n_bonds']
613 | for type2 in unique_types[i:]:
614 | col_idcs = self._get_row_idcs(type2)
615 | n_bonds2 = self.type_infos[type2]['n_bonds']
616 | max_bond_strength = min(n_bonds1, n_bonds2)
617 | if n_bonds1 == n_bonds2: # exclude small trivial molecules
618 | max_bond_strength -= 1
619 | for n in range(1, max_bond_strength + 1):
620 | id = self.type_infos[type1]['name'] + str(n) + \
621 | self.type_infos[type2]['name']
622 | d[id] = np.sum(con_mat[row_idcs][:, col_idcs] == n)
623 | if type1 == type2:
624 | d[id] = int(d[id]/2) # remove twice counted bonds
625 |
626 | # 2nd analyze rings
627 | ring_counts = self.get_ring_counts()
628 | if len(ring_counts) > 0:
629 | ring_counts = np.bincount(np.array(ring_counts))
630 | n_bigger_8 = 0
631 | for i in np.nonzero(ring_counts)[0]:
632 | if i < 9:
633 | d[f'R{i}'] = ring_counts[i]
634 | else:
635 | n_bigger_8 += ring_counts[i]
636 | if n_bigger_8 > 0:
637 | d[f'R>8'] = n_bigger_8
638 | self._bond_stats = d
639 |
640 | return self._bond_stats
641 |
642 | def _get_row_idcs(self, type):
643 | '''
644 | Retrieve the indices of all atoms in the molecule corresponding to a selected
645 | type.
646 |
647 | Args:
648 | type (int): the atom type (atomic number, e.g. 6 for carbon)
649 |
650 | Returns:
651 | list of int: indices of all atoms with the selected type
652 | '''
653 | if type not in self._row_indices:
654 | self._row_indices[type] = np.where(self.numbers == type)[0]
655 | return self._row_indices[type]
656 |
657 |
658 | class ConnectivityCompressor():
659 | '''
660 | Utility class that provides methods to compress and decompress connectivity
661 | matrices.
662 | '''
663 |
664 | def __init__(self):
665 | pass
666 |
667 | def compress(self, connectivity_matrix):
668 | '''
669 | Compresses a single connectivity matrix.
670 |
671 | Args:
672 | connectivity_matrix (numpy.ndarray): array (n_atoms x n_atoms)
673 | containing the bond orders of bonds between atoms of a molecule
674 |
675 | Returns:
676 | dict (str/int->int): the length of the non-redundant connectivity
677 | matrix (list with upper triangular part) and the indices of that list for
678 | bond orders > 0
679 | '''
680 | smaller = squareform(connectivity_matrix) # get list of upper triangular part
681 | d = {'n_entries': len(smaller)} # store length of list
682 | for i in np.unique(smaller): # store indices per bond order > 0
683 | if i > 0:
684 | d[float(i)] = np.where(smaller == i)[0]
685 | return d
686 |
687 | def decompress(self, idcs_dict):
688 | '''
689 | Retrieve the full (n_atoms x n_atoms) connectivity matrix from compressed
690 | format.
691 |
692 | Args:
693 | idcs_dict (dict str/int->int): compressed connectivity matrix
694 | (obtained with the compress method)
695 |
696 | Returns:
697 | numpy.ndarray: full connectivity matrix as an array of shape (n_atoms x
698 | n_atoms)
699 | '''
700 | n_entries = idcs_dict['n_entries']
701 | con_mat = np.zeros(n_entries)
702 | for i in idcs_dict:
703 | if i != 'n_entries':
704 | con_mat[idcs_dict[i]] = float(i)
705 | return squareform(con_mat)
706 |
707 | def compress_batch(self, connectivity_batch):
708 | '''
709 | Compress a batch of connectivity matrices.
710 |
711 | Args:
712 | connectivity_batch (list of numpy.ndarray): list of connectivity matrices
713 |
714 | Returns:
715 | list of dict: batch of compressed connectivity matrices (see compress)
716 | '''
717 | dict_list = []
718 | for matrix in connectivity_batch:
719 | dict_list += [self.compress(matrix)]
720 | return dict_list
721 |
722 | def decompress_batch(self, idcs_dict_batch):
723 | '''
724 | Retrieve a list of full connectivity matrices from a batch of compressed
725 | connectivity matrices.
726 |
727 | Args:
728 | idcs_dict_batch (list of dict): list with compressed connectivity
729 | matrices
730 |
731 | Return:
732 | list numpy.ndarray: batch of full connectivity matrices (see decompress)
733 | '''
734 | matrix_list = []
735 | for idcs_dict in idcs_dict_batch:
736 | matrix_list += [self.decompress(idcs_dict)]
737 | return matrix_list
738 |
739 |
740 | class IndexProvider():
741 | '''
742 | Class which allows to filter a large set of molecules for desired structures
743 | according to provided statistics. The filtering is done using a selection string
744 | of the general format 'Statistics_nameDelimiterOperatorTarget_value'
745 | (e.g. 'C,>8' to filter for all molecules with more than eight carbon atoms where
746 | 'C' is the statistic counting the number of carbon atoms in a molecule, ',' is the
747 | delimiter, '>' is the operator, and '8' is the target value).
748 |
749 | Args:
750 | statistics (numpy.ndarray):
751 | statistics of all molecules where columns correspond to molecules and rows
752 | correspond to available statistics (n_statistics x n_molecules)
753 | row_headlines (numpy.ndarray):
754 | the names of the statistics stored in each row (e.g. 'F' for the number of
755 | fluorine atoms or 'R5' for the number of rings of size 5)
756 | default_filter (str, optional):
757 | the default behaviour of the filter if no operator and target value are
758 | given (e.g. filtering for 'F' will give all molecules with at least 1
759 | fluorine atom if default_filter='>0' or all molecules with exactly 2
760 | fluorine atoms if default_filter='==2', default: '>0')
761 | delimiter (str, optional):
762 | the delimiter used to separate names of statistics from the operator and
763 | target value in the selection strings (default: ',')
764 | '''
765 |
766 | # dictionary mapping strings of available operators to corresponding function:
767 | op_dict = {'<': operator.lt,
768 | '<=': operator.le,
769 | '==': operator.eq,
770 | '=': operator.eq,
771 | '!=': operator.ne,
772 | '>': operator.gt,
773 | '>=': operator.ge}
774 |
775 | rel_re = re.compile('<=|<|={1,2}|!=|>=|>') # regular expression for operators
776 | num_re = re.compile('[\-]*[0-9]+[.]*[0-9]*') # regular expression for target values
777 |
778 | def __init__(self, statistics, row_headlines, default_filter='>0', delimiter=','):
779 | self.statistics = np.array(statistics)
780 | self.headlines = list(row_headlines)
781 | self.default_relation = self.rel_re.search(default_filter).group(0)
782 | self.default_number = float(self.num_re.search(default_filter).group(0))
783 | self.delimiter = delimiter
784 |
785 | def get_selected(self, selection_str, idcs=None):
786 | '''
787 | Retrieve the indices of all molecules which fulfill the selection criteria.
788 | The selection string is of the general format
789 | 'Statistics_nameDelimiterOperatorTarget_value' (e.g. 'C,>8' to filter for all
790 | molecules with more than eight carbon atoms where 'C' is the statistic counting
791 | the number of carbon atoms in a molecule, ',' is the delimiter, '>' is the
792 | operator, and '8' is the target value).
793 |
794 | The following operators are available:
795 | '<'
796 | '<='
797 | '=='
798 | '!='
799 | '>='
800 | '>'
801 |
802 | The target value can be any positive or negative integer or float value.
803 |
804 | Multiple statistics can be summed using '+' (e.g. 'F+N,=0' gives all
805 | molecules that have no fluorine and no nitrogen atoms).
806 |
807 | Multiple filters can be concatenated using '&' (e.g. 'H,>8&C,=5' gives all
808 | molecules that have more than 8 hydrogen atoms and exactly 5 carbon atoms).
809 |
810 | Args:
811 | selection_str (str): string describing the criterion(s) for filtering (build
812 | as described above)
813 | idcs (numpy.ndarray, optional): if provided, only this subset of all
814 | molecules is filtered for structures fulfilling the selection criteria
815 |
816 | Returns:
817 | list of int: indices of all the molecules in the dataset that fulfill the
818 | selection criterion(s)
819 | '''
820 |
821 | delimiter = self.delimiter
822 | if idcs is None:
823 | idcs = np.arange(len(self.statistics[0])) # take all to begin with
824 | criterions = selection_str.split('&') # split criteria
825 | for criterion in criterions:
826 | rel_strs = criterion.split(delimiter)
827 |
828 | # add multiple statistics if specified
829 | heads = rel_strs[0].split('+')
830 | statistics = self.statistics[self.headlines.index(heads[0])][idcs]
831 | for head in heads[1:]:
832 | statistics += self.statistics[self.headlines.index(head)][idcs]
833 |
834 | if len(rel_strs) == 1:
835 | relation = self.op_dict[self.default_relation](
836 | statistics, self.default_number)
837 | elif len(rel_strs) == 2:
838 | rel = self.rel_re.search(rel_strs[1]).group(0)
839 | num = float(self.num_re.search(rel_strs[1]).group(0))
840 | relation = self.op_dict[rel](statistics, num)
841 | new_idcs = np.where(relation)[0]
842 | idcs = idcs[new_idcs]
843 |
844 | return idcs
845 |
846 |
847 | class ProcessQ(Process):
848 | '''
849 | Multiprocessing.Process class that runs a provided function using provided
850 | (keyword) arguments and puts the result into a provided Multiprocessing.Queue
851 | object (such that the result of the function can easily be obtained by the host
852 | process).
853 |
854 | Args:
855 | queue (Multiprocessing.Queue): the queue into which the results of running
856 | the target function will be put (the object in the queue will be a tuple
857 | containing the provided name as first entry and the function return as
858 | second entry).
859 | name (str): name of the object (is returned as first value in the tuple put
860 | into the queue.
861 | target (callable object): the function that is executed in the process's run
862 | method
863 | args (list of any): sequential arguments target is called with
864 | kwargs (dict (str->any)): keyword arguments target is called with
865 | '''
866 |
867 | def __init__(self, queue, name=None, target=None, args=(), kwargs={}):
868 | super(ProcessQ, self).__init__(None, target, name, args, kwargs)
869 | self._name = name
870 | self._q = queue
871 | self._target = target
872 | self._args = args
873 | self._kwargs = kwargs
874 |
875 | def run(self):
876 | '''
877 | Method representing the process's activity.
878 |
879 | Invokes the callable object passed as the target argument, if any, with
880 | sequential and keyword arguments taken from the args and kwargs arguments,
881 | respectively. Puts the string passed as name argument and the returned result
882 | of the callable object into the queue as (name, result).
883 | '''
884 | if self._target is not None:
885 | res = (self.name, self._target(*self._args, **self._kwargs))
886 | self._q.put(res)
887 |
888 |
--------------------------------------------------------------------------------