├── .gitignore ├── LICENSE ├── README.md ├── dtnn ├── __init__.py ├── atoms.py ├── core.py ├── cost.py ├── data.py ├── datasets │ ├── __init__.py │ └── gdb9.py ├── layers.py ├── models │ ├── __init__.py │ └── dtnn.py ├── train.py └── utils.py └── examples ├── __init__.py ├── eval_dtnn_gdb9.py └── train_dtnn_gdb9.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .nfs* 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | .idea 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 atomistic-machine-learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Tensor Neural Networks 2 | 3 | The deep tensor neural network (DTNN) enables spatially and chemically resolved 4 | insights into quantum-mechanical observables of molecular systems. 5 | 6 | Requirements: 7 | - python 3.4 8 | - ASE 9 | - numpy 10 | - tensorflow (>=1.0) 11 | 12 | See the `examples` folder for scripts for training and evaluation of a DTNN 13 | model for predicting the total energy (U0) for the GDB-9 data set. 14 | The data set will be downloaded and converted automatically. 15 | 16 | Basic usage: 17 | 18 | python train_dtnn_gdb9.py -h 19 | 20 | 21 | If you use deep tensor neural networks in your research, please cite: 22 | 23 | *K.T. Schütt. F. Arbabzadah. S. Chmiela, K.-R. Müller, A. Tkatchenko. 24 | Quantum-chemical insights from deep tensor neural networks.* 25 | Nature Communications **8**. 13890 (2017) 26 | doi: [10.1038/ncomms13890](http://dx.doi.org/10.1038/ncomms13890) 27 | 28 | 29 | -------------------------------------------------------------------------------- /dtnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Calculator, Model 2 | from .data import LeapDataProvider, ASEDataProvider, split_ase_db 3 | from .cost import * 4 | from .train import early_stopping 5 | -------------------------------------------------------------------------------- /dtnn/atoms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def interatomic_distances(positions, cell, pbc, cutoff): 6 | with tf.variable_scope('distance'): 7 | # calculate heights 8 | 9 | # account for zero cell in case of no pbc 10 | c = tf.reduce_sum(tf.cast(pbc, tf.int32)) > 0 11 | icell = tf.cond(c, lambda: tf.matrix_inverse(cell), 12 | lambda: tf.eye(3)) 13 | height = 1. / tf.sqrt(tf.reduce_sum(tf.square(icell), 0)) 14 | 15 | extent = tf.where(tf.cast(pbc, tf.bool), 16 | tf.cast(tf.floor(cutoff / height), tf.int32), 17 | tf.cast(tf.zeros_like(height), tf.int32)) 18 | n_reps = tf.reduce_prod(2 * extent + 1) 19 | 20 | # replicate atoms 21 | r = tf.range(-extent[0], extent[0] + 1) 22 | v0 = tf.expand_dims(r, 1) 23 | v0 = tf.tile(v0, 24 | tf.stack(((2 * extent[1] + 1) * (2 * extent[2] + 1), 1))) 25 | v0 = tf.reshape(v0, tf.stack((n_reps, 1))) 26 | 27 | r = tf.range(-extent[1], extent[1] + 1) 28 | v1 = tf.expand_dims(r, 1) 29 | v1 = tf.tile(v1, tf.stack((2 * extent[2] + 1, 2 * extent[0] + 1))) 30 | v1 = tf.reshape(v1, tf.stack((n_reps, 1))) 31 | 32 | v2 = tf.expand_dims(tf.range(-extent[2], extent[2] + 1), 1) 33 | v2 = tf.tile(v2, 34 | tf.stack((1, (2 * extent[0] + 1) * (2 * extent[1] + 1)))) 35 | v2 = tf.reshape(v2, tf.stack((n_reps, 1))) 36 | 37 | v = tf.cast(tf.concat((v0, v1, v2), axis=1), tf.float32) 38 | offset = tf.matmul(v, cell) 39 | offset = tf.expand_dims(offset, 0) 40 | 41 | # add axes 42 | positions = tf.expand_dims(positions, 1) 43 | rpos = positions + offset 44 | rpos = tf.expand_dims(rpos, 0) 45 | positions = tf.expand_dims(positions, 1) 46 | 47 | euclid_dist = tf.sqrt( 48 | tf.reduce_sum(tf.square(positions - rpos), 49 | reduction_indices=3)) 50 | return euclid_dist 51 | 52 | 53 | def site_rdf(distances, cutoff, step, width, eps=1e-5, 54 | use_mean=False, lower_cutoff=None): 55 | with tf.variable_scope('srdf'): 56 | if lower_cutoff is None: 57 | vrange = cutoff 58 | else: 59 | vrange = cutoff - lower_cutoff 60 | distances = tf.expand_dims(distances, -1) 61 | n_centers = np.ceil(vrange / step) 62 | gap = vrange - n_centers * step 63 | n_centers = int(n_centers) 64 | 65 | if lower_cutoff is None: 66 | centers = tf.linspace(0., cutoff - gap, n_centers) 67 | else: 68 | centers = tf.linspace(lower_cutoff + 0.5 * gap, cutoff - 0.5 * gap, 69 | n_centers) 70 | centers = tf.reshape(centers, (1, 1, 1, -1)) 71 | 72 | gamma = -0.5 / width / step ** 2 73 | 74 | rdf = tf.exp(gamma * (distances - centers) ** 2) 75 | 76 | mask = tf.cast(distances >= eps, tf.float32) 77 | rdf *= mask 78 | rdf = tf.reduce_sum(rdf, 2) 79 | if use_mean: 80 | N = tf.reduce_sum(mask, 2) 81 | N = tf.maximum(N, 1) 82 | rdf /= N 83 | 84 | new_shape = [None, None, n_centers] 85 | rdf.set_shape(new_shape) 86 | 87 | return rdf 88 | -------------------------------------------------------------------------------- /dtnn/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | 5 | import tensorflow as tf 6 | from dtnn.utils import shape 7 | 8 | 9 | def batching(features, batch_size, num_batch_threads): 10 | in_names = list(features.keys()) 11 | in_list = [features[name] for name in in_names] 12 | out_shapes = [shape(inpt) for inpt in in_list] 13 | 14 | features = tf.train.batch( 15 | in_list, batch_size, dynamic_pad=True, 16 | shapes=out_shapes, num_threads=num_batch_threads 17 | ) 18 | features = dict(list(zip(in_names, features))) 19 | return features 20 | 21 | 22 | class Calculator(object): 23 | pass 24 | 25 | 26 | class Model(object): 27 | def __init__(self, model_dir, preprocessor_fcn=None, model_fcn=None, 28 | **config): 29 | self.reuse = None 30 | self.preprocessor_fcn = preprocessor_fcn 31 | self.model_fcn = model_fcn 32 | 33 | self.model_dir = model_dir 34 | self.config = config 35 | self.config_path = os.path.join(self.model_dir, 'config.npz') 36 | 37 | if not os.path.exists(self.model_dir): 38 | os.makedirs(self.model_dir) 39 | 40 | if os.path.exists(self.config_path): 41 | logging.warning('Config file exists in model directory. ' + 42 | 'Config arguments will be overwritten!') 43 | self.from_config(self.config_path) 44 | else: 45 | self.to_config(self.config_path) 46 | 47 | with tf.variable_scope(None, 48 | default_name=self.__class__.__name__) as scope: 49 | self.scope = scope 50 | self.saver = None 51 | 52 | def __getattr__(self, item): 53 | if item in list(self.config.keys()): 54 | return self.config[item] 55 | raise AttributeError 56 | 57 | def to_config(self, config_path): 58 | np.savez(config_path, **self.config) 59 | 60 | def from_config(self, config_path): 61 | cfg = np.load(config_path) 62 | for k, v in list(cfg.items()): 63 | if v.shape == (): 64 | v = v.item() 65 | self.config[k] = v 66 | 67 | def _preprocessor(self, features): 68 | if self.preprocessor_fcn is None: 69 | return features 70 | else: 71 | return self.preprocessor_fcn(features) 72 | 73 | def _model(self, features): 74 | if self.model_fcn is None: 75 | raise NotImplementedError 76 | else: 77 | return self.model_fcn(features) 78 | 79 | def init_model(self): 80 | pass 81 | 82 | def store(self, sess, iteration, name='best'): 83 | checkpoint_path = os.path.join(self.model_dir, name) 84 | if not os.path.exists(checkpoint_path): 85 | os.makedirs(checkpoint_path) 86 | 87 | if self.saver is None: 88 | raise ValueError('Saver is not initialized. ' + 89 | 'Build the model by calling `get_output`' + 90 | 'before storing it.') 91 | 92 | self.saver.save(sess, os.path.join(checkpoint_path, name), iteration) 93 | 94 | def restore(self, sess, name='best', iteration=None): 95 | checkpoint_path = os.path.join(self.model_dir, name) 96 | 97 | if not os.path.exists(checkpoint_path): 98 | return 0 99 | 100 | if self.saver is None: 101 | raise ValueError('Saver is not initialized. ' + 102 | 'Build the model by calling `get_output`' + 103 | 'before restoring it.') 104 | 105 | if iteration is None: 106 | chkpt = tf.train.latest_checkpoint(checkpoint_path) 107 | else: 108 | chkpt = os.path.join(checkpoint_path, name + '-' + str(iteration)) 109 | logging.info('Restoring ' + chkpt) 110 | 111 | self.saver.restore(sess, chkpt) 112 | start_iter = int(chkpt.split('-')[-1]) 113 | return start_iter 114 | 115 | def get_output(self, features, is_training, batch_size=None, 116 | num_batch_threads=1): 117 | with tf.variable_scope(self.scope, reuse=self.reuse): 118 | with tf.variable_scope('preprocessing'): 119 | features = self._preprocessor(features) 120 | 121 | with tf.variable_scope('batching'): 122 | if batch_size is None: 123 | features = { 124 | k: tf.expand_dims(v, 0) for k, v in 125 | list(features.items()) 126 | } 127 | else: 128 | in_names = list(features.keys()) 129 | in_list = [features[name] for name in in_names] 130 | out_shapes = [shape(inpt) for inpt in in_list] 131 | 132 | features = tf.train.batch( 133 | in_list, batch_size, dynamic_pad=True, 134 | shapes=out_shapes, num_threads=num_batch_threads 135 | ) 136 | features = dict(list(zip(in_names, features))) 137 | 138 | with tf.variable_scope('model'): 139 | self.init_model() 140 | features['is_training'] = is_training 141 | output = self._model(features) 142 | features.update(output) 143 | 144 | if self.saver is None: 145 | model_vars = [v for v in tf.global_variables() 146 | if v.name.startswith(self.scope.name)] 147 | var_names = [v.name[len(self.scope.name):] for v in model_vars] 148 | vdict = dict(list(zip(var_names, model_vars))) 149 | self.saver = tf.train.Saver(vdict) 150 | self.reuse = True 151 | return features 152 | -------------------------------------------------------------------------------- /dtnn/cost.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class CostFunction(object): 5 | def __init__(self, name): 6 | self.name = name 7 | 8 | def __call__(self, output): 9 | errors = self.calc_errors(output) 10 | return self.aggregate(errors) 11 | 12 | def calc_errors(self, output): 13 | raise NotImplementedError 14 | 15 | def aggregate(self, errors): 16 | return tf.reduce_mean(errors) 17 | 18 | 19 | class MeanAbsoluteError(CostFunction): 20 | def __init__(self, prediction, target, idx=None, name='MAE'): 21 | super(MeanAbsoluteError, self).__init__(name) 22 | self.prediction = prediction 23 | self.target = target 24 | self.idx = idx 25 | 26 | def calc_errors(self, output): 27 | tgt = output[self.target] 28 | pred = output[self.prediction] 29 | if self.idx is not None: 30 | tgt = tgt[:, self.idx] 31 | pred = pred[:, self.idx] 32 | return tf.abs(tgt - pred) 33 | 34 | 35 | class L2Loss(CostFunction): 36 | def __init__(self, prediction, target, idx=None, name='MSE'): 37 | super(L2Loss, self).__init__(name) 38 | self.prediction = prediction 39 | self.target = target 40 | self.idx = idx 41 | 42 | def calc_errors(self, output): 43 | tgt = output[self.target] 44 | pred = output[self.prediction] 45 | if self.idx is not None: 46 | tgt = tgt[:, self.idx] 47 | pred = pred[:, self.idx] 48 | return (tgt - pred) ** 2 49 | 50 | def aggregate(self, errors): 51 | return tf.reduce_sum(errors) 52 | 53 | 54 | class MeanSquaredError(CostFunction): 55 | def __init__(self, prediction, target, idx=None, name='MSE'): 56 | super(MeanSquaredError, self).__init__(name) 57 | self.prediction = prediction 58 | self.target = target 59 | self.idx = idx 60 | 61 | def calc_errors(self, output): 62 | tgt = output[self.target] 63 | pred = output[self.prediction] 64 | if self.idx is not None: 65 | tgt = tgt[:, self.idx] 66 | pred = pred[:, self.idx] 67 | return (tgt - pred) ** 2 68 | 69 | 70 | class RootMeanSquaredError(MeanSquaredError): 71 | def __init__(self, prediction, target, idx=None, name='RMSE'): 72 | super(RootMeanSquaredError, self).__init__(prediction, target, idx, 73 | name) 74 | 75 | def aggregate(self, errors): 76 | return tf.sqrt(tf.reduce_mean(errors)) 77 | 78 | 79 | class PAMeanAbsoluteError(CostFunction): 80 | def __init__(self, prediction, target, idx=None, name='MAE'): 81 | super(PAMeanAbsoluteError, self).__init__(name) 82 | self.prediction = prediction 83 | self.target = target 84 | self.idx = idx 85 | 86 | def calc_errors(self, output): 87 | Z = output['numbers'] 88 | N = tf.reduce_sum(tf.cast(tf.greater(Z, 0), tf.float32), 1) 89 | 90 | tgt = output[self.target] 91 | pred = output[self.prediction] 92 | if self.idx is not None: 93 | tgt = tgt[:, self.idx] 94 | pred = pred[:, self.idx] 95 | return tf.abs(tgt - pred) / N 96 | 97 | 98 | class PARmse(CostFunction): 99 | def __init__(self, prediction, target, idx=None, name='MSE'): 100 | super(PARmse, self).__init__(name) 101 | self.prediction = prediction 102 | self.target = target 103 | self.idx = idx 104 | 105 | def calc_errors(self, output): 106 | Z = output['numbers'] 107 | N = tf.reduce_sum(tf.cast(tf.greater(Z, 0), tf.float32), 1) 108 | 109 | tgt = output[self.target] 110 | pred = output[self.prediction] 111 | if self.idx is not None: 112 | tgt = tgt[:, self.idx] 113 | pred = pred[:, self.idx] 114 | return ((tgt - pred) / N) ** 2 115 | 116 | def aggregate(self, errors): 117 | return tf.sqrt(tf.reduce_mean(errors)) 118 | -------------------------------------------------------------------------------- /dtnn/data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import threading 4 | from random import shuffle 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from ase.db import connect 9 | 10 | 11 | def split_ase_db(asedb, dstdir, partitions, selection=None): 12 | partition_ids = list(partitions.keys()) 13 | partitions = np.array(list(partitions.values())) 14 | if len(partitions[partitions < -1]) > 1: 15 | raise ValueError( 16 | 'There must not be more than one partition of unknown size!') 17 | 18 | with connect(asedb) as con: 19 | ids = [] 20 | for row in con.select(selection=selection): 21 | ids.append(row.id) 22 | 23 | ids = np.random.permutation(ids) 24 | n_rows = len(ids) 25 | 26 | r = (0. < partitions) * (partitions < 1.) 27 | partitions[r] *= n_rows 28 | partitions = partitions.astype(np.int) 29 | 30 | if np.any(partitions < 0): 31 | remaining = n_rows - np.sum(partitions[partitions > 0]) 32 | partitions[partitions < 0] = remaining 33 | 34 | if len(partitions[partitions < 0]) > 1: 35 | raise ValueError( 36 | 'Size of the partitions has to be <= the number of atom rows!') 37 | 38 | if not os.path.exists(dstdir): 39 | os.makedirs(dstdir) 40 | else: 41 | raise ValueError('Split destination directory already exists:', 42 | dstdir) 43 | 44 | split_dict = {} 45 | with connect(asedb) as con: 46 | offset = 0 47 | if partition_ids is None: 48 | partition_ids = list(range(len(partitions))) 49 | for pid, p in zip(partition_ids, partitions): 50 | with connect(os.path.join(dstdir, pid + '.db')) as dstcon: 51 | print(offset, p) 52 | split_dict[pid] = ids[offset:offset + p] 53 | for i in ids[offset:offset + p]: 54 | row = con.get(int(i)) 55 | if hasattr(row, 'data'): 56 | data = row.data 57 | else: 58 | data = None 59 | dstcon.write(row.toatoms(), 60 | key_value_pairs=row.key_value_pairs, 61 | data=data) 62 | offset += p 63 | np.savez(os.path.join(dstdir, 'split_ids.npz'), **split_dict) 64 | 65 | 66 | data_feeds = [] 67 | 68 | 69 | def add_data_feed(data_feed): 70 | data_feeds.append(data_feed) 71 | 72 | 73 | def start_data_feeds(sess, coord): 74 | for df in data_feeds: 75 | df.create_threads(sess, coord, True, True) 76 | 77 | 78 | class LeapDataProvider(object): 79 | def __init__(self, batch_size=1): 80 | self._is_running = False 81 | self.batch_size = batch_size 82 | add_data_feed(self) 83 | 84 | @property 85 | def num_examples(self): 86 | raise NotImplementedError 87 | 88 | def get_features(self): 89 | raise NotImplementedError 90 | 91 | def get_property(self, pname): 92 | raise NotImplementedError 93 | 94 | def _run(self, sess, coord=None): 95 | raise NotImplementedError 96 | 97 | def create_threads(self, sess, coord=None, daemon=False, start=False): 98 | if self._is_running: 99 | return [] 100 | 101 | thread = threading.Thread(target=self._run, args=(sess, coord)) 102 | 103 | if daemon: 104 | thread.daemon = True 105 | if start: 106 | thread.start() 107 | 108 | self._is_running = True 109 | return [thread] 110 | 111 | 112 | class ASEDataProvider(LeapDataProvider): 113 | def __init__(self, asedb, kvp={}, data={}, batch_size=1, 114 | selection=None, shuffle=True, prefetch=False, 115 | block_size=150000, 116 | capacity=5000, num_epochs=np.Inf, floatX=np.float32): 117 | super(ASEDataProvider, self).__init__(batch_size) 118 | 119 | self.asedb = asedb 120 | self.prefetch = prefetch 121 | self.selection = selection 122 | self.block_size = block_size 123 | self.shuffle = shuffle 124 | self.kvp = kvp 125 | self.data = data 126 | self.floatX = floatX 127 | self.feat_names = ['numbers', 'positions', 'cell', 128 | 'pbc'] + list(kvp.keys()) + list(data.keys()) 129 | self.shapes = [(None,), (None, 3), (3, 3), 130 | (3,)] + list(kvp.values()) + list(data.values()) 131 | 132 | self.epoch = 0 133 | self.num_epochs = num_epochs 134 | self.n_rows = 0 135 | 136 | # initialize queue 137 | with connect(self.asedb) as con: 138 | row = list(con.select(self.selection, limit=1))[0] 139 | 140 | feats = self.convert_atoms(row) 141 | dtypes = [np.array(feat).dtype for feat in feats] 142 | self.queue = tf.FIFOQueue(capacity, dtypes) 143 | 144 | self.placeholders = [ 145 | tf.placeholder(dt, name=name) 146 | for dt, name in zip(dtypes, self.feat_names) 147 | ] 148 | self.enqueue_op = self.queue.enqueue(self.placeholders) 149 | self.dequeue_op = self.queue.dequeue() 150 | 151 | self.preprocs = [] 152 | 153 | def convert_atoms(self, row): 154 | numbers = row.get('numbers') 155 | positions = row.get('positions').astype(self.floatX) 156 | pbc = row.get('pbc') 157 | cell = row.get('cell').astype(self.floatX) 158 | features = [numbers, positions, cell, pbc] 159 | 160 | for k in list(self.kvp.keys()): 161 | f = row[k] 162 | if np.isscalar(f): 163 | f = np.array([f]) 164 | if f.dtype in [np.float16, np.float32, np.float64]: 165 | f = f.astype(self.floatX) 166 | features.append(f) 167 | for k in list(self.data.keys()): 168 | f = np.array(row.data[k]) 169 | if np.isscalar(f): 170 | f = np.array([f]) 171 | if f.dtype in [np.float16, np.float32, np.float64]: 172 | f = f.astype(self.floatX) 173 | features.append(f) 174 | return features 175 | 176 | def do_reload(self): 177 | with connect(self.asedb) as con: 178 | n_rows = con.count(self.selection) 179 | if self.n_rows != n_rows: 180 | self.n_rows = n_rows 181 | return True 182 | return False 183 | 184 | @property 185 | def num_examples(self): 186 | with connect(self.asedb) as con: 187 | n_rows = con.count(self.selection) 188 | return n_rows 189 | 190 | def iterate(self): 191 | # get data base size 192 | with connect(self.asedb) as con: 193 | n_rows = con.count(self.selection) 194 | if self.block_size is None: 195 | block_size = n_rows 196 | else: 197 | block_size = self.block_size 198 | n_blocks = int(np.ceil(n_rows / block_size)) 199 | 200 | # shuffling 201 | if self.shuffle: 202 | permutation = np.random.permutation(n_blocks) 203 | else: 204 | permutation = range(n_blocks) 205 | 206 | # iterate over blocks 207 | for i in permutation: 208 | # load block 209 | with connect(self.asedb) as con: 210 | rows = list( 211 | con.select(self.selection, limit=block_size, 212 | offset=i * block_size) 213 | ) 214 | 215 | # iterate over rows 216 | for row in rows: 217 | yield self.convert_atoms(row) 218 | self.epoch += 1 219 | 220 | def _run(self, sess, coord=None): 221 | while self.epoch < self.num_epochs: 222 | if self.prefetch: 223 | if self.do_reload(): 224 | data = [] 225 | with connect(self.asedb) as con: 226 | for row in con.select(self.selection): 227 | data.append(self.convert_atoms(row)) 228 | if self.shuffle: 229 | shuffle(data) 230 | else: 231 | data = self.iterate() 232 | 233 | for feats in data: 234 | fdict = dict(zip(self.placeholders, feats)) 235 | sess.run(self.enqueue_op, feed_dict=fdict) 236 | 237 | def get_features(self): 238 | feat_dict = {} 239 | for name, feat, shape in zip(self.feat_names, self.dequeue_op, 240 | self.shapes): 241 | feat.set_shape(shape) 242 | feat_dict[name] = feat 243 | 244 | for preproc in self.preprocs: 245 | preproc(feat_dict) 246 | return feat_dict 247 | 248 | def add_preprocessor(self, preproc): 249 | self.preprocs.append(preproc) 250 | 251 | def get_property(self, pname): 252 | props = [] 253 | with connect(self.asedb) as con: 254 | for rows in con.select(self.selection): 255 | try: 256 | p = rows[pname] 257 | except Exception: 258 | p = rows.data[pname] 259 | props.append(p) 260 | return props 261 | -------------------------------------------------------------------------------- /dtnn/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/dtnn/30ef997f69f5293ae1eee03ec24716d4f0f3ce18/dtnn/datasets/__init__.py -------------------------------------------------------------------------------- /dtnn/datasets/gdb9.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import tarfile 4 | import tempfile 5 | import urllib.request, urllib.parse, urllib.error 6 | from urllib.error import URLError, HTTPError 7 | 8 | import numpy as np 9 | from ase.db import connect 10 | from ase.io.extxyz import read_xyz 11 | from ase.units import Hartree, eV, Bohr, Ang 12 | 13 | 14 | def load_atomrefs(at_path): 15 | logging.info('Downloading GDB-9 atom references...') 16 | at_url = 'https://ndownloader.figshare.com/files/3195395' 17 | tmpdir = tempfile.mkdtemp('gdb9') 18 | tmp_path = os.path.join(tmpdir, 'atomrefs.txt') 19 | 20 | try: 21 | urllib.request.urlretrieve(at_url, tmp_path) 22 | logging.info("Done.") 23 | except HTTPError as e: 24 | logging.error("HTTP Error:", e.code, at_url) 25 | return False 26 | except URLError as e: 27 | logging.error("URL Error:", e.reason, at_url) 28 | return False 29 | 30 | atref = np.zeros((100, 6)) 31 | labels = ['zpve', 'U0', 'U', 'H', 'G', 'Cv'] 32 | with open(tmp_path) as f: 33 | lines = f.readlines() 34 | for z, l in zip([1, 6, 7, 8, 9], lines[5:10]): 35 | atref[z, 0] = float(l.split()[1]) 36 | atref[z, 1] = float(l.split()[2]) * Hartree / eV 37 | atref[z, 2] = float(l.split()[3]) * Hartree / eV 38 | atref[z, 3] = float(l.split()[4]) * Hartree / eV 39 | atref[z, 4] = float(l.split()[5]) * Hartree / eV 40 | atref[z, 5] = float(l.split()[6]) 41 | np.savez(at_path, atom_ref=atref, labels=labels) 42 | return True 43 | 44 | 45 | def load_data(dbpath): 46 | logging.info('Downloading GDB-9 data...') 47 | tmpdir = tempfile.mkdtemp('gdb9') 48 | tar_path = os.path.join(tmpdir, 'gdb9.tar.gz') 49 | raw_path = os.path.join(tmpdir, 'gdb9_xyz') 50 | url = 'https://ndownloader.figshare.com/files/3195389' 51 | 52 | try: 53 | urllib.request.urlretrieve(url, tar_path) 54 | logging.info("Done.") 55 | except HTTPError as e: 56 | logging.error("HTTP Error:", e.code, url) 57 | return False 58 | except URLError as e: 59 | logging.error("URL Error:", e.reason, url) 60 | return False 61 | 62 | tar = tarfile.open(tar_path) 63 | tar.extractall(raw_path) 64 | tar.close() 65 | 66 | prop_names = ['rcA', 'rcB', 'rcC', 'mu', 'alpha', 'homo', 'lumo', 67 | 'gap', 'r2', 'zpve', 'energy_U0', 'energy_U', 'enthalpy_H', 68 | 'free_G', 'Cv'] 69 | conversions = [1., 1., 1., 1., Bohr ** 3 / Ang ** 3, 70 | Hartree / eV, Hartree / eV, Hartree / eV, 71 | Bohr ** 2 / Ang ** 2, Hartree / eV, 72 | Hartree / eV, Hartree / eV, Hartree / eV, 73 | Hartree / eV, 1.] 74 | 75 | logging.info('Parse xyz files...') 76 | with connect(dbpath) as con: 77 | for i, xyzfile in enumerate(os.listdir(raw_path)): 78 | xyzfile = os.path.join(raw_path, xyzfile) 79 | 80 | if i % 10000 == 0: 81 | logging.info('Parsed: ' + str(i) + ' / 133885') 82 | properties = {} 83 | tmp = os.path.join(tmpdir, 'tmp.xyz') 84 | 85 | with open(xyzfile, 'r') as f: 86 | lines = f.readlines() 87 | l = lines[1].split()[2:] 88 | for pn, p, c in zip(prop_names, l, conversions): 89 | properties[pn] = float(p) * c 90 | with open(tmp, "wt") as fout: 91 | for line in lines: 92 | fout.write(line.replace('*^', 'e')) 93 | 94 | with open(tmp, 'r') as f: 95 | ats = list(read_xyz(f, 0))[0] 96 | 97 | con.write(ats, key_value_pairs=properties) 98 | logging.info('Done.') 99 | 100 | return True 101 | -------------------------------------------------------------------------------- /dtnn/layers.py: -------------------------------------------------------------------------------- 1 | from dtnn.utils import shape 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def glorot_uniform(shape, dtype, partition_info=None): 8 | if not dtype.is_floating: 9 | raise ValueError("Expected floating point type, got %s." % dtype) 10 | 11 | n_in = np.prod(shape[:-1]) 12 | n_out = shape[-1] 13 | 14 | r = tf.cast(tf.sqrt(6. / (n_in + n_out)), tf.float32) 15 | return tf.random_uniform(shape, -r, r, dtype=dtype) 16 | 17 | 18 | def reference_initializer(ref): 19 | def initializer(shape, dtype, partition_info=None): 20 | return tf.cast(tf.constant(np.reshape(ref, shape)), dtype) 21 | return initializer 22 | 23 | 24 | def dense(x, n_out, 25 | nonlinearity=None, 26 | use_bias=True, 27 | weight_init=glorot_uniform, 28 | bias_init=tf.constant_initializer(0.), 29 | trainable=True, 30 | scope=None, reuse=False, name='Dense'): 31 | x_shape = shape(x) 32 | ndims = len(x_shape) 33 | n_in = x_shape[-1] 34 | with tf.variable_scope(scope, default_name=name, values=[x], 35 | reuse=reuse) as scope: 36 | # reshape for broadcasting 37 | xr = tf.reshape(x, (-1, n_in)) 38 | 39 | W = tf.get_variable('W', shape=(n_in, n_out), 40 | initializer=weight_init, 41 | trainable=trainable) 42 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, W) 43 | tf.summary.histogram('W', W) 44 | 45 | y = tf.matmul(xr, W) 46 | 47 | if use_bias: 48 | b = tf.get_variable('b', shape=(n_out,), 49 | initializer=bias_init, 50 | trainable=trainable) 51 | tf.add_to_collection(tf.GraphKeys.BIASES, b) 52 | tf.summary.histogram('b', b) 53 | y += b 54 | 55 | if nonlinearity: 56 | y = nonlinearity(y) 57 | 58 | new_shape = tf.concat([tf.shape(x)[:ndims - 1], [n_out]], axis=0) 59 | y = tf.reshape(y, new_shape) 60 | 61 | new_dims = x_shape[:-1] + [n_out] 62 | y.set_shape(new_dims) 63 | tf.summary.histogram('activations', y) 64 | 65 | return y 66 | 67 | 68 | def embedding(indices, n_vocabulary, n_out, 69 | weight_init=glorot_uniform, 70 | reference=None, 71 | trainable=True, 72 | scope=None, reuse=False, name='Embedding'): 73 | if type(n_out) is int: 74 | n_out = (n_out,) 75 | with tf.variable_scope(scope, default_name=name, reuse=reuse) as scope: 76 | if reference is None: 77 | W = tf.get_variable('W', shape=(n_vocabulary,) + n_out, 78 | initializer=weight_init, 79 | trainable=trainable) 80 | else: 81 | W = tf.get_variable('W', shape=(n_vocabulary,) + n_out, 82 | initializer=reference_initializer(reference), 83 | trainable=trainable) 84 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, W) 85 | 86 | y = tf.nn.embedding_lookup(W, indices) 87 | return y 88 | 89 | 90 | def masked_reduce(x, mask=None, axes=None, 91 | reduce_op=tf.reduce_sum, 92 | keep_dims=False, 93 | scope=None, name='masked_reduce'): 94 | scope_vars = [x] 95 | if mask is not None: 96 | scope_vars.append(mask) 97 | 98 | with tf.variable_scope(scope, default_name=name, 99 | values=scope_vars) as scope: 100 | if mask is not None: 101 | mask = tf.cast(mask > 0, tf.float32) 102 | x *= mask 103 | 104 | y = reduce_op(x, axes, keep_dims) 105 | 106 | return y 107 | 108 | 109 | def masked_sum(x, mask=None, axes=None, 110 | keep_dims=False, 111 | scope=None, name='masked_sum'): 112 | return masked_reduce(x, mask, axes, tf.reduce_sum, 113 | keep_dims, scope, name) 114 | 115 | 116 | def masked_mean(x, mask=None, axes=None, 117 | keep_dims=False, 118 | scope=None, name='masked_mean'): 119 | if mask is None: 120 | mred = masked_reduce(x, mask, axes, tf.reduce_mean, 121 | keep_dims, scope, name) 122 | else: 123 | msum = masked_reduce(x, mask, axes, tf.reduce_sum, 124 | keep_dims, scope, name) 125 | mask = tf.cast(mask > 0, tf.float32) 126 | N = tf.reduce_sum(mask, axes, keep_dims) 127 | N = tf.maximum(N, 1) 128 | mred = msum / N 129 | return mred 130 | -------------------------------------------------------------------------------- /dtnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dtnn import DTNN 2 | -------------------------------------------------------------------------------- /dtnn/models/dtnn.py: -------------------------------------------------------------------------------- 1 | import dtnn.layers as L 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from dtnn.atoms import interatomic_distances, site_rdf 6 | from ..core import Model 7 | 8 | 9 | class DTNN(Model): 10 | """ 11 | Deep Tensor Neural Network (DTNN) 12 | 13 | DTNN receives molecular structures through a vector of atomic `numbers` 14 | and a matrix of atomic `positions` ensuring rotational and 15 | translational invariance by construction. 16 | Each atom is represented by a coefficient vector that 17 | is repeatedly refined by pairwise interactions with the surrounding atoms. 18 | 19 | For a detailed description, see [1]. 20 | 21 | :param str model_dir: path to location of the model 22 | :param int n_basis: number of basis functions describing an atom 23 | :param int n_factors: number of factors in tensor low-rank approximation 24 | :param int n_interactions: number of interaction passes 25 | :param float mu: mean energy per atom 26 | :param float std: std. dev. of energies per atom 27 | :param float cutoff: distance cutoff 28 | :param float rdf_spacing: gap between Gaussians in distance basis 29 | :param bool per_atom: `true` if predicted is normalized to the number 30 | of atoms 31 | :param ndarray atom_ref: array of reference energies of single atoms 32 | :param int max_atomic_number: the highest, occuring atomic number 33 | in the data 34 | 35 | References 36 | ---------- 37 | .. [1] K.T. Schütt. F. Arbabzadah. S. Chmiela, K.-R. Müller, A. Tkatchenko: 38 | Quantum-chemical Insights from Deep Tensor Neural Networks. 39 | Nature Communications 8. 13890 (2017) 40 | http://dx.doi.org/10.1038/ncomms13890 41 | """ 42 | 43 | def __init__(self, model_dir, 44 | n_basis=30, n_factors=60, n_interactions=3, 45 | mu=0.0, std=1.0, cutoff=20., rdf_spacing=0.2, 46 | per_atom=False, atom_ref=None, max_atomic_number=20): 47 | super(DTNN, self).__init__( 48 | model_dir, cutoff=cutoff, rdf_spacing=rdf_spacing, 49 | n_basis=n_basis, n_factors=n_factors, per_atom=per_atom, 50 | n_interactions=n_interactions, max_z=max_atomic_number, 51 | mu=mu, std=std, atom_ref=atom_ref 52 | ) 53 | 54 | def _preprocessor(self, features): 55 | positions = features['positions'] 56 | pbc = features['pbc'] 57 | cell = features['cell'] 58 | 59 | distances = interatomic_distances( 60 | positions, cell, pbc, self.cutoff 61 | ) 62 | 63 | features['srdf'] = site_rdf( 64 | distances, self.cutoff, self.rdf_spacing, 1. 65 | ) 66 | return features 67 | 68 | def _model(self, features): 69 | Z = features['numbers'] 70 | C = features['srdf'] 71 | 72 | # masking 73 | mask = tf.cast(tf.expand_dims(Z, 1) * tf.expand_dims(Z, 2), 74 | tf.float32) 75 | diag = tf.matrix_diag_part(mask) 76 | diag = tf.ones_like(diag) 77 | offdiag = 1 - tf.matrix_diag(diag) 78 | mask *= offdiag 79 | mask = tf.expand_dims(mask, -1) 80 | 81 | I = np.eye(self.max_z).astype(np.float32) 82 | ZZ = tf.nn.embedding_lookup(I, Z) 83 | r = tf.sqrt(1. / tf.sqrt(float(self.n_basis))) 84 | X = L.dense(ZZ, self.n_basis, use_bias=False, 85 | weight_init=tf.random_normal_initializer(stddev=r)) 86 | 87 | fC = L.dense(C, self.n_factors, use_bias=True) 88 | 89 | reuse = None 90 | for i in range(self.n_interactions): 91 | tmp = tf.expand_dims(X, 1) 92 | 93 | fX = L.dense(tmp, self.n_factors, use_bias=True, 94 | scope='in2fac', reuse=reuse) 95 | 96 | fVj = fX * fC 97 | 98 | Vj = L.dense(fVj, self.n_basis, use_bias=False, 99 | weight_init=tf.constant_initializer(0.0), 100 | nonlinearity=tf.nn.tanh, 101 | scope='fac2out', reuse=reuse) 102 | 103 | V = L.masked_sum(Vj, mask, axes=2) 104 | 105 | X += V 106 | reuse = True 107 | 108 | # output 109 | o1 = L.dense(X, self.n_basis // 2, nonlinearity=tf.nn.tanh) 110 | yi = L.dense(o1, 1, 111 | weight_init=tf.constant_initializer(0.0), 112 | use_bias=True) 113 | 114 | mu = tf.get_variable('mu', shape=(1,), 115 | initializer=L.reference_initializer(self.mu), 116 | trainable=False) 117 | std = tf.get_variable('std', shape=(1,), 118 | initializer=L.reference_initializer(self.std), 119 | trainable=False) 120 | yi = yi * std + mu 121 | 122 | if self.atom_ref is not None: 123 | E0i = L.embedding(Z, 100, 1, 124 | reference=self.atom_ref, trainable=False) 125 | yi += E0i 126 | 127 | atom_mask = tf.expand_dims(Z, -1) 128 | if self.per_atom: 129 | y = L.masked_mean(yi, atom_mask, axes=1) 130 | #E0 = L.masked_mean(E0i, atom_mask, axes=1) 131 | else: 132 | y = L.masked_sum(yi, atom_mask, axes=1) 133 | #E0 = L.masked_sum(E0i, atom_mask, axes=1) 134 | 135 | return {'y': y, 'y_i': yi} #, 'E0': E0} 136 | -------------------------------------------------------------------------------- /dtnn/train.py: -------------------------------------------------------------------------------- 1 | """ Training procedures for machine learning models """ 2 | 3 | import os 4 | import logging 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import dtnn 10 | 11 | 12 | def early_stopping(model, cost_fcn, optimizer, 13 | train_data, val_data, test_data=None, 14 | additional_cost_fcns=[], global_step=None, 15 | n_iterations=1000000, patience=float('inf'), 16 | checkpoint_interval=100000, summary_interval=1000, 17 | validation_interval=1000, coord=None, 18 | num_val_batches=1, num_test_batches=1, 19 | profile=False, session_config=None): 20 | """ 21 | Train model using early stopping with validation and test set. 22 | 23 | :param LeapModel model: The model to be trained. 24 | :param CostFunction cost_fcn: Cost function to be optimized. 25 | :param LeapDataProvider train_data: Training data provider. 26 | :param LeapDataProvider val_data: Validation data provider. 27 | :param LeapDataProvider test_data: 28 | Test data provider for estimating the error during training. (optional) 29 | :param tf.train.Optimizer optimizer: 30 | Tensorflow optimizer (e.g. SGD, Adam, ...) 31 | :param list(CostFunction) additional_cost_fcns: 32 | List of additional cost functions for monitoring 33 | :param tf.Variable global_step: 34 | Variable containing the global step. 35 | Pass if using learning rate decay (optional) 36 | :param int n_iterations: Number of optimizer steps. 37 | :param int patience: Stop after `patience` steps without improved 38 | validation cost. (optional) 39 | :param int checkpoint_interval: Save model with given frequency. 40 | :param int summary_interval: Store training summary with given frequency. 41 | :param int validation_interval: Validate model with given frequency. 42 | :param tf.train.Coordinator coord: Coordinator for threads. 43 | :param int num_val_batches: Iterate over `num_val_batches` 44 | number of batches for validation. 45 | :param int num_test_batches: Iterate over `num_val_batches` 46 | number of batches for testing. 47 | :param bool profile: If `true`, enable collection of runtime data in 48 | TensorFlow 49 | """ 50 | 51 | train_features = train_data.get_features() 52 | val_features = val_data.get_features() 53 | test_features = test_data.get_features() 54 | 55 | # retrieve model outputs 56 | train_output = model.get_output( 57 | train_features, 58 | is_training=True, 59 | batch_size=train_data.batch_size, 60 | num_batch_threads=4 61 | ) 62 | val_output = model.get_output( 63 | val_features, 64 | is_training=False, 65 | batch_size=val_data.batch_size, 66 | num_batch_threads=1 67 | ) 68 | test_output = model.get_output( 69 | test_features, 70 | is_training=False, 71 | batch_size=test_data.batch_size, 72 | num_batch_threads=1 73 | ) 74 | 75 | # assemble costs & summaries 76 | cost = cost_fcn(train_output) 77 | train_sums = [ 78 | tf.summary.scalar(add_cost_fcn.name, 79 | add_cost_fcn(train_output)) 80 | for add_cost_fcn in additional_cost_fcns 81 | ] 82 | train_sums.append(tf.summary.scalar('cost', cost)) 83 | train_summaries = tf.summary.merge_all() 84 | 85 | # validation 86 | val_errors = [cost_fcn(val_output)] 87 | val_errors += [add_cost_fcn.calc_errors(val_output) 88 | for add_cost_fcn in additional_cost_fcns] 89 | 90 | # test 91 | if test_output is not None: 92 | test_errors = [cost_fcn(test_output)] 93 | test_errors += [add_cost_fcn.calc_errors(test_output) 94 | for add_cost_fcn in additional_cost_fcns] 95 | else: 96 | test_errors = None 97 | 98 | # collect test & validation summaries 99 | errors = [tf.placeholder(tf.float32) for _ in val_errors] 100 | summaries = [ 101 | tf.summary.scalar(add_cost_fcn.name, 102 | add_cost_fcn.aggregate( 103 | err)) 104 | for add_cost_fcn, err in 105 | zip(additional_cost_fcns, errors) 106 | ] 107 | summaries = tf.summary.merge(summaries) 108 | 109 | # training ops 110 | if global_step is None: 111 | global_step = tf.Variable(0, name='global_step', trainable=False) 112 | train_op = optimizer.minimize(cost, global_step=global_step) 113 | init_op = tf.global_variables_initializer() 114 | 115 | # training loop 116 | best_error = np.Inf 117 | coord = tf.train.Coordinator() if coord is None else coord 118 | try: 119 | with tf.Session(config=session_config) as sess: 120 | # set up summary writers 121 | train_writer = tf.summary.FileWriter( 122 | os.path.join(model.model_dir, 'train'), sess.graph) 123 | val_writer = tf.summary.FileWriter( 124 | os.path.join(model.model_dir, 'validation')) 125 | test_writer = tf.summary.FileWriter( 126 | os.path.join(model.model_dir, 'test')) 127 | 128 | # initialize all variables 129 | sess.run(init_op) 130 | 131 | # setup Saver & restore if previous checkpoints are available 132 | chkpt_saver = tf.train.Saver() 133 | checkpoint_path = os.path.join(model.model_dir, 'chkpoints') 134 | if not os.path.exists(checkpoint_path): 135 | start_iter = 0 136 | os.makedirs(checkpoint_path) 137 | else: 138 | chkpt = tf.train.latest_checkpoint(checkpoint_path) 139 | chkpt_saver.restore(sess, chkpt) 140 | start_iter = int(chkpt.split('-')[-1]) 141 | chkpt = os.path.join(checkpoint_path, 'checkpoint') 142 | 143 | global_step.assign(start_iter).eval() 144 | 145 | dtnn.data.start_data_feeds(sess, coord) 146 | tf.train.start_queue_runners(sess=sess, coord=coord) 147 | logging.info('Starting at iteration ' + 148 | str(start_iter) + ' / ' + str(n_iterations)) 149 | last_best = start_iter 150 | for i in range(start_iter, n_iterations): 151 | if i % checkpoint_interval == 0: 152 | chkpt_saver.save(sess, chkpt, i) 153 | logging.info('Saved checkpoint at iteration %d / %d', 154 | i, n_iterations) 155 | 156 | if (i - last_best) > patience: 157 | logging.info('Out of patience.') 158 | break 159 | 160 | if i % summary_interval == 0: 161 | logging.debug('Store Summary.') 162 | if profile: 163 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 164 | run_metadata = tf.RunMetadata() 165 | _, train_sums = sess.run( 166 | [train_op, train_summaries], options=run_options, run_metadata=run_metadata) 167 | train_writer.add_run_metadata(run_metadata, 'step%d' % i) 168 | else: 169 | _, train_sums = sess.run([train_op, train_summaries]) 170 | train_writer.add_summary(train_sums, global_step=i) 171 | else: 172 | sess.run(train_op) 173 | 174 | if i % validation_interval == 0: 175 | val_costs = [] 176 | sums = [[] for _ in range(len(additional_cost_fcns))] 177 | for k in range(num_val_batches): 178 | results = sess.run(val_errors) 179 | val_costs.append(results[0]) 180 | for s, r in enumerate(results[1:]): 181 | sums[s].append(r) 182 | for s in range(len(sums)): 183 | sums[s] = np.vstack(sums[s]) 184 | val_cost = np.mean(val_costs) 185 | 186 | feed_dict = { 187 | err: vsum 188 | for err, vsum in zip(errors, sums) 189 | } 190 | val_sums = sess.run(summaries, feed_dict=feed_dict) 191 | val_writer.add_summary(val_sums, 192 | global_step=i) 193 | 194 | if val_cost < best_error: 195 | last_best = i 196 | best_error = val_cost 197 | 198 | test_costs = [] 199 | sums = [[] for _ in range(len(additional_cost_fcns))] 200 | for k in range(num_test_batches): 201 | results = sess.run(test_errors) 202 | test_costs.append(results[0]) 203 | for s, r in enumerate(results[1:]): 204 | sums[s].append(r) 205 | test_cost = np.mean(test_costs) 206 | for s in range(len(sums)): 207 | sums[s] = np.vstack(sums[s]) 208 | 209 | feed_dict = { 210 | err: vsum 211 | for err, vsum in zip(errors, sums) 212 | } 213 | test_sums = sess.run(summaries, feed_dict=feed_dict) 214 | 215 | test_writer.add_summary(test_sums, 216 | global_step=i) 217 | model.store(sess, i, 'best') 218 | logging.info( 219 | 'New best model at iteration %d /' + 220 | ' %d with loss %.2f', 221 | i, n_iterations, test_cost 222 | ) 223 | logging.info('Done') 224 | finally: 225 | logging.info('Saving chkpoint...') 226 | chkpt_saver.save(sess, chkpt, i) 227 | logging.info('Done.') 228 | 229 | logging.info('Stopping threads...') 230 | if coord is not None: 231 | coord.request_stop() 232 | logging.info('Done.') 233 | 234 | 235 | -------------------------------------------------------------------------------- /dtnn/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def shape(x): 6 | if isinstance(x, tf.Tensor): 7 | return x.get_shape().as_list() 8 | return np.shape(x) 9 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/dtnn/30ef997f69f5293ae1eee03ec24716d4f0f3ce18/examples/__init__.py -------------------------------------------------------------------------------- /examples/eval_dtnn_gdb9.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Example script for evaluation a DTNN to predict 4 | the total energy at 0K (U0) for the GDB-9 data. 5 | """ 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from ase.db import connect 12 | 13 | from dtnn.models import DTNN 14 | 15 | 16 | def evaluate(args): 17 | # define model inputs 18 | features = { 19 | 'numbers': tf.placeholder(tf.int64, shape=(None,)), 20 | 'positions': tf.placeholder(tf.float32, shape=(None, 3)), 21 | 'cell': np.eye(3).astype(np.float32), 22 | 'pbc': np.zeros((3,)).astype(np.int64) 23 | } 24 | 25 | # load model 26 | model = DTNN(args.model_dir) 27 | model_output = model.get_output(features, is_training=False) 28 | y = model_output['y'] 29 | 30 | with tf.Session() as sess: 31 | model.restore(sess) 32 | 33 | print('test_live.db') 34 | U0_live, U0_pred_live = predict( 35 | os.path.join(args.split_dir, 'test_live.db'), features, sess, y 36 | ) 37 | print('test.db') 38 | U0, U0_pred = predict( 39 | os.path.join(args.split_dir, 'test.db'), features, sess, y 40 | ) 41 | U0 += U0_live 42 | U0_pred += U0_pred_live 43 | U0 = np.vstack(U0) 44 | U0_pred = np.vstack(U0_pred) 45 | 46 | diff = U0 - U0_pred 47 | mae = np.mean(np.abs(diff)) 48 | rmse = np.sqrt(np.mean(diff ** 2)) 49 | print('MAE: %.3f eV, RMSE: %.3f eV' % (mae, rmse)) 50 | 51 | 52 | def predict(dbpath, features, sess, y): 53 | U0 = [] 54 | U0_pred = [] 55 | count = 0 56 | with connect(dbpath) as conn: 57 | n_structures = conn.count() 58 | for row in conn.select(): 59 | U0.append(row['U0']) 60 | 61 | at = row.toatoms() 62 | feed_dict = { 63 | features['numbers']: 64 | np.array(at.numbers).astype(np.int64), 65 | features['positions']: 66 | np.array(at.positions).astype(np.float32) 67 | } 68 | U0_p = sess.run(y, feed_dict=feed_dict) 69 | U0_pred.append(U0_p) 70 | if count % 1000 == 0: 71 | print(str(count) + ' / ' + str(n_structures)) 72 | count += 1 73 | return U0, U0_pred 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('split_dir', 79 | help='Path to directory with data splits' + 80 | ' ("test.db", "test_live.db")') 81 | parser.add_argument('model_dir', help='Path to model directory.') 82 | args = parser.parse_args() 83 | 84 | evaluate(args) 85 | -------------------------------------------------------------------------------- /examples/train_dtnn_gdb9.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Example script for training a DTNN to predict 4 | the total energy at 0K (U0) for the GDB-9 data. 5 | """ 6 | 7 | import os 8 | import sys 9 | import argparse 10 | import logging 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | 15 | import dtnn 16 | from dtnn.datasets.gdb9 import load_data, load_atomrefs 17 | from dtnn.models import DTNN 18 | 19 | logging.basicConfig( 20 | format='%(levelname)s - %(message)s', 21 | level=logging.INFO 22 | ) 23 | 24 | 25 | def prepare_data(dbpath, partitions, splitdst): 26 | if not os.path.exists(splitdst): 27 | logging.info('Partition data...') 28 | dtnn.split_ase_db(dbpath, splitdst, partitions) 29 | logging.info('Done.') 30 | 31 | train_data = dtnn.ASEDataProvider( 32 | os.path.join(splitdst, 'train.db'), 33 | kvp={'energy_U0': (1,)}, prefetch=False, shuffle=True 34 | ) 35 | val_data = dtnn.ASEDataProvider( 36 | os.path.join(splitdst, 'validation.db'), 37 | kvp={'energy_U0': (1,)}, prefetch=True, shuffle=False 38 | ) 39 | test_data = dtnn.ASEDataProvider( 40 | os.path.join(splitdst, 'test_live.db'), 41 | kvp={'energy_U0': (1,)}, prefetch=True, shuffle=False 42 | ) 43 | return train_data, val_data, test_data 44 | 45 | 46 | def main(args): 47 | n_iterations = 5000000 48 | 49 | dbpath = os.path.join(args.data_dir, 'gdb9.db') 50 | atom_reference = os.path.join(args.data_dir, 'atom_refs.npz') 51 | 52 | # load and partition data 53 | partitions = {'train': 49000, 'validation': 1000, 54 | 'test_live': 1000, 'test': -1} 55 | split_dst = os.path.join(args.output_dir, args.split_name) 56 | train_data, val_data, test_data = prepare_data(dbpath=dbpath, 57 | partitions=partitions, 58 | splitdst=split_dst) 59 | train_data.batch_size = 25 60 | val_data.batch_size = 100 61 | test_data.batch_size = 100 62 | num_val_batches = 10 63 | num_test_batches = 10 64 | 65 | # load atom energies 66 | atom_reference = np.load(atom_reference) 67 | e_atom = atom_reference['atom_ref'][:, 1:2] 68 | 69 | # calculate mean/std.dev. per atom 70 | U0 = np.array(train_data.get_property('energy_U0')) 71 | E = U0.reshape((-1, 1)) 72 | Z = train_data.get_property('numbers') 73 | E0 = np.vstack([np.sum(e_atom[np.array(z)], 0) for z in Z]).reshape((-1, 1)) 74 | N = np.array([len(z) for z in Z]).reshape((-1, 1)) 75 | E0n = (E - E0) / N.reshape((-1, 1)) 76 | mu = np.mean(E0n, axis=0) 77 | std = np.std(E0n, axis=0) 78 | 79 | logging.info('mu(E/N)=' + str(mu)) 80 | logging.info('std(E/N)=' + str(std)) 81 | 82 | # setup models 83 | mname = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, 84 | args.basis, 85 | args.factors, 86 | args.interactions, 87 | args.cutoff, 88 | args.split_name.split('/')[-1], 89 | args.name) 90 | if args.model == 'DTNN': 91 | model = DTNN(os.path.join(args.output_dir, mname), 92 | mu=mu, std=std, 93 | n_interactions=args.interactions, 94 | n_basis=args.basis, 95 | atom_ref=e_atom, 96 | n_factors=args.factors, 97 | cutoff=args.cutoff) 98 | 99 | # setup cost functions 100 | cost_fcn = dtnn.L2Loss(prediction='y', target='energy_U0') 101 | additional_cost_fcns = [ 102 | dtnn.MeanAbsoluteError(prediction='y', target='energy_U0', name='energy_U0_MAE'), 103 | dtnn.RootMeanSquaredError(prediction='y', target='energy_U0', name='energy_U0_RMSE'), 104 | dtnn.PAMeanAbsoluteError(prediction='y', target='energy_U0', 105 | name='energy_U0_MAE_atom'), 106 | dtnn.PARmse(prediction='y', target='energy_U0', name='energy_U0pN_RMSE_atom') 107 | ] 108 | 109 | # setup optimizer 110 | global_step = tf.Variable(0, name='global_step', trainable=False) 111 | lr = tf.train.exponential_decay(args.lr, global_step, 112 | 100000, 0.95) 113 | optimizer = tf.train.AdamOptimizer(lr) 114 | 115 | if args.half: 116 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.48) 117 | else: 118 | gpu_options = None 119 | 120 | session_config = tf.ConfigProto(gpu_options=gpu_options, 121 | intra_op_parallelism_threads=4) 122 | 123 | # train DTNN 124 | dtnn.early_stopping( 125 | model, cost_fcn, optimizer, 126 | train_data, val_data, test_data, 127 | additional_cost_fcns=additional_cost_fcns, 128 | n_iterations=n_iterations, 129 | global_step=global_step, 130 | num_test_batches=num_test_batches, 131 | num_val_batches=num_val_batches, 132 | session_config=session_config, 133 | validation_interval=2000, 134 | summary_interval=1000 135 | ) 136 | 137 | 138 | if __name__ == '__main__': 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('data_dir', help='Path to data (destination) directory.') 141 | parser.add_argument('output_dir', help='Output directory for model and training log.') 142 | parser.add_argument('--split_name', help='Name of data split.', 143 | default='split_1') 144 | parser.add_argument('--cutoff', type=float, help='Distance cutoff', 145 | default=20.) 146 | parser.add_argument('--interactions', type=int, help='Distance cutoff', 147 | default=3) 148 | parser.add_argument('--basis', type=int, help='Basis set size', 149 | default=30) 150 | parser.add_argument('--factors', type=int, help='Factor space size', 151 | default=60) 152 | parser.add_argument('--model', type=str, 153 | help='ML model name [DTNN, DTNNv2]', 154 | default='DTNN') 155 | parser.add_argument('--name', help='Name of run', 156 | default='') 157 | parser.add_argument('--lr', type=float, help='Learning rate', 158 | default=1e-3) 159 | parser.add_argument('--half', action='store_true', 160 | help='Only use half of the GPU memory') 161 | args = parser.parse_args() 162 | 163 | if not os.path.exists(args.output_dir): 164 | os.makedirs(args.output_dir) 165 | 166 | if not os.path.exists(args.data_dir): 167 | os.makedirs(args.data_dir) 168 | 169 | dbpath = os.path.join(args.data_dir, 'gdb9.db') 170 | atom_refs = os.path.join(args.data_dir, 'atom_refs.npz') 171 | 172 | # download data set (if needed) 173 | if not os.path.exists(dbpath): 174 | do_download = input( 175 | 'No database found at `' + dbpath + '`. ' + 176 | 'Should GDB-9 data be downloaded to that location? [y/N]') 177 | 178 | success = False 179 | if do_download == 'y': 180 | success = load_data(dbpath) 181 | 182 | if not success: 183 | logging.info('Aborting.') 184 | sys.exit() 185 | 186 | # download atom reference energies (if needed) 187 | if not os.path.exists(atom_refs): 188 | do_download = input( 189 | 'No atom reference file found at `' + atom_refs + '`. ' + 190 | 'Should GDB-9 atom references be downloaded to that location? [y/N]') 191 | 192 | success = False 193 | if do_download == 'y': 194 | success = load_atomrefs(atom_refs) 195 | 196 | if not success: 197 | logging.info('Aborting.') 198 | sys.exit() 199 | 200 | # start training procedure 201 | main(args) 202 | --------------------------------------------------------------------------------