├── LICENSE ├── README.md ├── setup.py └── src ├── msgnet ├── __init__.py ├── datahandler.py ├── dataloader.py ├── defaults.py ├── msgpassing.py ├── readout.py ├── train.py └── utilities.py └── scripts ├── get_matproj.py ├── get_oqmd.py ├── get_qm9.py ├── predict_with_model.py └── runner.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Peter Bjørn Jørgensen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Msgnet 2 | Tensorflow implementation of message passing neural networks for molecules and materials. 3 | The framework implements the [SchNet model](https://arxiv.org/abs/1712.06113) and its extension with edge update network [NMP-EDGE](https://arxiv.org/abs/1806.03146) as well as the model used in [Materials property prediction using symmetry-labeled graphs as atomic-position independent descriptors](https://arxiv.org/abs/1905.06048). 4 | 5 | Currently the implementation does not enable training with forces, but this might be implemented in the future. 6 | For a more full-fledged implementation of the SchNet model, see [schnetpack](https://github.com/atomistic-machine-learning/schnetpack). 7 | 8 | The main difference between `msgnet` and `schnetpack` is that `msgnet` follows a message passing architecture and can therefore be more flexible in some cases, e.g. it can be used to train on graphs rather than on structures with full spatial information. 9 | 10 | # Install 11 | Install the dependency 12 | - [Vorosym](https://github.com/peterbjorgensen/vorosym) 13 | 14 | Set the `datadir` variable in `src/msgnet/defaults.py` to a preferred path in which the datasets will be saved. 15 | 16 | Then run `python setup.py install` or `python setup.py install --user` to install the module. 17 | 18 | ## Install datasets 19 | Run the script `src/scripts/get_qm9.py` to download the QM9 dataset 20 | 21 | Run the script `src/scripts/get_matproj.py MATPROJ_API_KEY` to download the materials project dataset. You need to create a user and obtain an API key from [Materials Project](https://materialsproject.org/). 22 | 23 | Run the script `python2 src/scripts/get_oqmd.py` to convert the OQMD database into an ASE database. You need to manually download and install the [OQMD database](http://oqmd.org/) on your machine to run this script. 24 | The OQMD API is only compatible with Python 2, so after running the script you must manually move the `oqmd12.db` to the `datadir` set in `src/msgnet/defaults.py`. 25 | 26 | # Running the model 27 | To train the model used in the NMP-EDGE paper: 28 | 29 | `python runner.py --cutoff const 100 --readout sumscalar --num_passes 3 --update_edges --node_embedding_size 64 --dataset qm9 --edge_idx 0 --edge_expand 0.0,0.1,15.0 --learning_rate 5e-4 --target U0` 30 | 31 | To train the model on OQMD structures using the voronoi graph with symmetry labels: 32 | `python runner.py --fold 0 --cutoff voronoi 0.2 --readout avgscalar --num_passes 3 --node_embedding_size 256 --dataset oqmd12 --learning_rate 0.0001 --edge_idx 5 6 7 8 9 10 11 12 13 --update_edges` 33 | 34 | After the model is done training get the test set results by running 35 | `python predict_with_model.py --modelpath logs/path/to/model/model.ckpt-STEP.meta --output modeloutput.txt --split test` 36 | 37 | # Future Development 38 | The model is implemented such that it avoids any padding/masking. This is achieved by reshaping the variable length inputs into the first dimension of the tensors, which is usually the batch dimension. However, this means we can't use the conventional Tensorflow methods for handling datasets as streams. If the framework is still used in the future I am planning to convert it into a tensorflow keras model when the RaggedTensor implementation is fully supported. 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup, Extension 2 | 3 | setup( 4 | name="msgnet", 5 | version="0.0.1", 6 | description="Python library for implementation of message passing neural networks with example scripts", 7 | author="Peter Bjørn Jørgensen", 8 | author_email="peterbjorgensen@gmail.com", 9 | # url = 'https://docs.python.org/extending/building', 10 | package_dir={"": "src"}, 11 | packages=find_packages("src"), 12 | ) 13 | -------------------------------------------------------------------------------- /src/msgnet/__init__.py: -------------------------------------------------------------------------------- 1 | from msgnet import msgpassing 2 | from msgnet.msgpassing import MsgpassingNetwork 3 | from msgnet import utilities 4 | from msgnet import train 5 | from msgnet import readout 6 | from msgnet import defaults 7 | from msgnet import datahandler 8 | from msgnet import dataloader 9 | -------------------------------------------------------------------------------- /src/msgnet/datahandler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | import numpy as np 4 | import sklearn.model_selection 5 | 6 | 7 | def set_len_to_segments(set_len): 8 | return np.repeat(np.arange(len(set_len)), set_len) 9 | 10 | 11 | def get_folds(objects, folds): 12 | return [o for o in objects if o.fold in folds] 13 | 14 | 15 | class DataHandler: 16 | """DataHandler class used for handling and serving graph objects for training and testing""" 17 | 18 | def __init__(self, graph_objects, graph_targets=["total_energy"]): 19 | self.graph_objects = graph_objects 20 | self.graph_targets = graph_targets 21 | self.train_index_generator = self.idx_epoch_gen(len(graph_objects)) 22 | 23 | def get_train_batch(self, batch_size): 24 | rand_choice = itertools.islice(self.train_index_generator, batch_size) 25 | training_dict = self.list_to_matrices( 26 | [self.graph_objects[idx] for idx in rand_choice], 27 | graph_targets=self.graph_targets, 28 | ) 29 | self.modify_dict(training_dict) 30 | return training_dict 31 | 32 | def get_test_batches(self, batch_size): 33 | num_test_batches = int(math.ceil(len(self.graph_objects) / batch_size)) 34 | for batch_idx in range(num_test_batches): 35 | test_dict = self.list_to_matrices( 36 | self.graph_objects[ 37 | batch_idx * batch_size : (batch_idx + 1) * batch_size 38 | ], 39 | graph_targets=self.graph_targets, 40 | ) 41 | self.modify_dict(test_dict) 42 | yield test_dict 43 | 44 | def __len__(self): 45 | return len(self.graph_objects) 46 | 47 | @staticmethod 48 | def idx_epoch_gen(num_objects): 49 | while 1: 50 | for n in np.random.permutation(num_objects): 51 | yield n 52 | 53 | @staticmethod 54 | def list_to_matrices(graph_list, graph_targets=["total_energy"]): 55 | """list_to_matrices 56 | Convert list of FeatureGraph objects to dictionary with concatenated properties 57 | 58 | :param graph_list: 59 | :return: dictionary of stacked vectors and matrices 60 | """ 61 | nodes_created = 0 62 | all_nodes = [] 63 | all_conn = [] 64 | all_conn_offsets = [] 65 | all_edges = [] 66 | all_graph_targets = [] 67 | all_X = [] 68 | all_unitcells = [] 69 | set_len = [] 70 | edges_len = [] 71 | for gr in graph_list: 72 | nodes, conn, conn_offset, edges, X, unitcell = ( 73 | gr.nodes, 74 | gr.conns, 75 | gr.conns_offset, 76 | gr.edges, 77 | gr.positions, 78 | gr.unitcell, 79 | ) 80 | conn_shifted = np.copy(conn) + nodes_created 81 | all_nodes.append(nodes) 82 | all_conn.append(conn_shifted) 83 | all_conn_offsets.append(conn_offset) 84 | all_unitcells.append(unitcell) 85 | all_edges.append(edges) 86 | all_graph_targets.append(np.array([getattr(gr, t) for t in graph_targets])) 87 | all_X.append(X) 88 | nodes_created += nodes.shape[0] 89 | set_len.append(nodes.shape[0]) 90 | edges_len.append(edges.shape[0]) 91 | cat = lambda x: np.concatenate(x, axis=0) 92 | outdict = { 93 | "nodes": cat(all_nodes), 94 | "nodes_xyz": cat(all_X), 95 | "edges": cat(all_edges), 96 | "connections": cat(all_conn), 97 | "connections_offsets": cat(all_conn_offsets), 98 | "graph_targets": np.vstack(all_graph_targets), 99 | "set_lengths": np.array(set_len), 100 | "unitcells": np.stack(all_unitcells, axis=0), 101 | "edges_lengths": np.array(edges_len), 102 | } 103 | outdict["segments"] = set_len_to_segments(outdict["set_lengths"]) 104 | return outdict 105 | 106 | def get_normalization(self, per_atom=False): 107 | x_sum = np.zeros(len(self.graph_targets)) 108 | x_2 = np.zeros(len(self.graph_targets)) 109 | num_objects = 0 110 | for obj in self.graph_objects: 111 | for i, target in enumerate(self.graph_targets): 112 | x = getattr(obj, target) 113 | if per_atom: 114 | x = x / obj.nodes.shape[0] 115 | x_sum[i] += x 116 | x_2[i] += x ** 2.0 117 | num_objects += 1 118 | # Var(X) = E[X^2] - E[X]^2 119 | x_mean = x_sum / num_objects 120 | x_var = x_2 / num_objects - (x_mean) ** 2.0 121 | 122 | return x_mean, np.sqrt(x_var) 123 | 124 | def train_test_split( 125 | self, 126 | split_type=None, 127 | num_folds=None, 128 | test_fold=None, 129 | validation_size=None, 130 | test_size=None, 131 | deterministic=True, 132 | ): 133 | if split_type == "count" or split_type == "fraction": 134 | if deterministic: 135 | random_state = 21 136 | else: 137 | random_state = None 138 | if test_size > 0: 139 | train, test = sklearn.model_selection.train_test_split( 140 | self.graph_objects, test_size=test_size, random_state=random_state 141 | ) 142 | else: 143 | train = self.graph_objects 144 | test = [] 145 | elif split_type == "fold": 146 | assert test_fold < num_folds 147 | assert test_fold >= 0 148 | train_folds = [i for i in range(num_folds) if i != test_fold] 149 | train, test = ( 150 | get_folds(self.graph_objects, train_folds), 151 | get_folds(self.graph_objects, [test_fold]), 152 | ) 153 | else: 154 | raise ValueError("Unknown split type %s" % split_type) 155 | 156 | if validation_size: 157 | if deterministic: 158 | random_state = 47 159 | else: 160 | random_state = None 161 | train, validation = sklearn.model_selection.train_test_split( 162 | train, test_size=validation_size, random_state=random_state 163 | ) 164 | else: 165 | validation = [] 166 | 167 | return self.from_self(train), self.from_self(test), self.from_self(validation) 168 | 169 | def modify_dict(self, train_dict): 170 | pass 171 | 172 | def from_self(self, objects): 173 | return self.__class__(objects, self.graph_targets) 174 | 175 | 176 | class EdgeSelectDataHandler(DataHandler): 177 | """EdgeSelectDataHandler datahandler that selects a subset of the edge features""" 178 | 179 | def __init__(self, graph_objects, graph_targets, edge_input_idx): 180 | super().__init__(graph_objects, graph_targets) 181 | self.edge_input_idx = edge_input_idx 182 | 183 | def from_self(self, objects): 184 | return self.__class__(objects, self.graph_targets, self.edge_input_idx) 185 | 186 | def modify_dict(self, train_dict): 187 | all_edges = train_dict["edges"] 188 | input_edges = all_edges[:, self.edge_input_idx] 189 | train_dict["edges"] = input_edges 190 | 191 | 192 | class EdgeOutDataHandler(DataHandler): 193 | """EdgeOutDataHandler datahandler that allows training with edge targets""" 194 | 195 | def __init__(self, graph_objects, graph_targets, edge_target_idx, edge_input_idx): 196 | super().__init__(graph_objects, graph_targets) 197 | self.edge_target_idx = edge_target_idx 198 | self.edge_input_idx = edge_input_idx 199 | 200 | def from_self(self, objects): 201 | return self.__class__( 202 | objects, self.graph_targets, self.edge_target_idx, self.edge_input_idx 203 | ) 204 | 205 | def modify_dict(self, train_dict): 206 | all_edges = train_dict["edges"] 207 | target_edges = all_edges[:, self.edge_target_idx] 208 | input_edges = all_edges[:, self.edge_input_idx] 209 | train_dict["edges"] = input_edges 210 | train_dict["edges_targets"] = target_edges 211 | -------------------------------------------------------------------------------- /src/msgnet/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zlib 3 | import pickle 4 | import logging 5 | import gzip 6 | import tarfile 7 | import io 8 | import warnings 9 | import requests 10 | import numpy as np 11 | import ase 12 | import ase.db 13 | from ase.neighborlist import NeighborList 14 | from vorosym import voro_tessellate, graphdistance 15 | import msgnet 16 | 17 | 18 | class DataLoader: 19 | default_target = None 20 | default_datasplit_args = {"split_type": "fraction", "test_size": 0.1} 21 | 22 | def __init__(self): 23 | self.download_url = None 24 | self.download_dest = None 25 | self.cutoff_type = "const" 26 | self.cutoff_radius = 100.0 27 | self.self_interaction = False 28 | self.db_filter_query = None 29 | 30 | @property 31 | def final_dest(self): 32 | cutname = "%s-%.2f" % (self.cutoff_type, self.cutoff_radius) 33 | return "%s/%s_%s.pkz" % ( 34 | msgnet.defaults.datadir, 35 | self.__class__.__name__, 36 | cutname, 37 | ) 38 | 39 | def _download_data(self): 40 | response = requests.get(self.download_url) 41 | with open(self.download_dest, "wb") as f: 42 | for chunk in response.iter_content(): 43 | f.write(chunk) 44 | 45 | def _preprocess(self): 46 | graph_list = self.load_ase_data( 47 | db_path=self.download_dest, 48 | cutoff_type=self.cutoff_type, 49 | cutoff_radius=self.cutoff_radius, 50 | self_interaction=self.self_interaction, 51 | filter_query=self.db_filter_query, 52 | ) 53 | return graph_list 54 | 55 | def _save(self, obj_list): 56 | with tarfile.open(self.final_dest, "w") as tar: 57 | for number, obj in enumerate(obj_list): 58 | pbytes = pickle.dumps(obj) 59 | cbytes = zlib.compress(pbytes) 60 | fsize = len(cbytes) 61 | cbuf = io.BytesIO(cbytes) 62 | cbuf.seek(0) 63 | tarinfo = tarfile.TarInfo(name="%d" % number) 64 | tarinfo.size = fsize 65 | tar.addfile(tarinfo, cbuf) 66 | 67 | def _load_data(self): 68 | obj_list = [] 69 | with tarfile.open(self.final_dest, "r") as tar: 70 | for tarinfo in tar.getmembers(): 71 | buf = tar.extractfile(tarinfo) 72 | decomp = zlib.decompress(buf) 73 | obj_list.append(pickle.loads(decomp)) 74 | return obj_list 75 | 76 | @staticmethod 77 | def load_ase_data( 78 | db_path="oqmd_all_entries.db", 79 | dtype=float, 80 | cutoff_type="voronoi", 81 | cutoff_radius=2.0, 82 | filter_query=None, 83 | self_interaction=False, 84 | discard_unconnected=False, 85 | ): 86 | """load_ase_data 87 | Load atom structure data from ASE database 88 | 89 | :param db_path: path of the database to load 90 | :param dtype: dtype of returned numpy arrays 91 | :param cutoff_type: voronoi, const or coval 92 | :param cutoff_radius: cutoff radius of the sphere around each atom 93 | :param filter_query: query string or function to select a subset of database 94 | :param self_interaction: whether an atom includes itself as a neighbor (not only its images) 95 | :param discard_unconnected: whether to discard samples that ends up with no edges in the graph 96 | :return: list of FeatureGraph objects 97 | """ 98 | con = ase.db.connect(db_path) 99 | sel = filter_query 100 | 101 | for i, row in enumerate(select_wfilter(con, sel)): 102 | if i % 100 == 0: 103 | print("%010d " % i, sep="", end="\r") 104 | atoms = row.toatoms() 105 | if row.key_value_pairs: 106 | prop_dict = row.key_value_pairs 107 | else: 108 | prop_dict = row.data 109 | prop_dict["id"] = row.id 110 | try: 111 | graphobj = FeatureGraph( 112 | atoms, 113 | cutoff_type, 114 | cutoff_radius, 115 | lambda x: x, 116 | self_interaction=self_interaction, 117 | **prop_dict 118 | ) 119 | except RuntimeError: 120 | logging.error("Error during data conversion of row id %d", row.id) 121 | continue 122 | if discard_unconnected and (graphobj.conns.shape[0] == 0): 123 | logging.error("Discarding %i because no connections made %s", i, atoms) 124 | else: 125 | yield graphobj 126 | print("") 127 | 128 | def load(self): 129 | if not os.path.isfile(self.final_dest): 130 | logging.info("%s does not exist" % self.final_dest) 131 | if not os.path.isfile(self.download_dest): 132 | logging.info( 133 | "%s does not exist, downloading data..." % self.download_dest 134 | ) 135 | self._download_data() 136 | logging.info("Download complete") 137 | logging.info("Preprocessing") 138 | obj_list = self._preprocess() 139 | logging.info("Saving to %s" % self.final_dest) 140 | self._save(obj_list) 141 | del obj_list 142 | logging.info("Loading data") 143 | obj_list = self._load_data() 144 | logging.info("Data loaded") 145 | return obj_list 146 | 147 | 148 | class Oqmd12DataLoader(DataLoader): 149 | default_target = "delta_e" 150 | default_datasplit_args = { 151 | "split_type": "fold", 152 | "num_folds": 5, 153 | "validation_size": 5000, 154 | } 155 | 156 | def __init__(self, cutoff_type="voronoi", cutoff_radius=None): 157 | super().__init__() 158 | self.download_url = None 159 | self.download_dest = "%s/oqmd12.db" % (msgnet.defaults.datadir) 160 | self.cutoff_type = cutoff_type 161 | self.cutoff_radius = cutoff_radius 162 | self.self_interaction = False 163 | 164 | 165 | class Oqmd12MiniDataLoader(DataLoader): 166 | default_target = "delta_e" 167 | default_datasplit_args = { 168 | "split_type": "fold", 169 | "num_folds": 5, 170 | "validation_size": 500, 171 | } 172 | 173 | def __init__(self, cutoff_type="voronoi", cutoff_radius=None): 174 | super().__init__() 175 | self.download_url = None 176 | self.download_dest = "%s/oqmd12.db" % (msgnet.defaults.datadir) 177 | self.cutoff_type = cutoff_type 178 | self.cutoff_radius = cutoff_radius 179 | self.self_interaction = False 180 | self.db_filter_query = "id<1000" 181 | 182 | 183 | class MatprojDataLoader(DataLoader): 184 | default_target = "delta_e" 185 | default_datasplit_args = { 186 | "split_type": "fold", 187 | "num_folds": 5, 188 | "validation_size": 5000, 189 | } 190 | 191 | def __init__(self, cutoff_type="voronoi", cutoff_radius=None): 192 | super().__init__() 193 | self.download_url = None 194 | self.download_dest = "%s/matproj_2018.db" % (msgnet.defaults.datadir) 195 | self.cutoff_type = cutoff_type 196 | self.cutoff_radius = cutoff_radius 197 | self.self_interaction = False 198 | 199 | 200 | class Qm9DataLoader(DataLoader): 201 | default_target = "U0" 202 | default_datasplit_args = { 203 | "split_type": "count", 204 | "validation_size": 10000, 205 | "test_size": 133885 - 120000, 206 | } 207 | 208 | def __init__(self, cutoff_type="const", cutoff_radius=100.0): 209 | super().__init__() 210 | self.download_url = None 211 | self.download_dest = "%s/qm9.db" % (msgnet.defaults.datadir) 212 | self.cutoff_type = cutoff_type 213 | self.cutoff_radius = cutoff_radius 214 | self.self_interaction = False 215 | 216 | 217 | class FeatureGraph: 218 | def __init__( 219 | self, 220 | atoms_obj: ase.Atoms, 221 | cutoff_type, 222 | cutoff_radius, 223 | atom_to_node_fn, 224 | self_interaction=False, 225 | **kwargs 226 | ): 227 | self.atoms = atoms_obj 228 | 229 | if cutoff_type == "const": 230 | graph_tuple = self.atoms_to_graph_const_cutoff( 231 | self.atoms, 232 | cutoff_radius, 233 | atom_to_node_fn, 234 | self_interaction=self_interaction, 235 | ) 236 | self.edge_labels = ["distance"] 237 | elif cutoff_type == "coval": 238 | graph_tuple = self.atoms_to_graph_const_cutoff( 239 | self.atoms, 240 | cutoff_radius, 241 | atom_to_node_fn, 242 | self_interaction=self_interaction, 243 | cutoff_covalent=True, 244 | ) 245 | self.edge_labels = ["distance"] 246 | elif cutoff_type == "knearest": 247 | graph_tuple = self.atoms_to_graph_knearest( 248 | self.atoms, int(cutoff_radius), atom_to_node_fn 249 | ) 250 | self.edge_labels = ["distance"] 251 | elif cutoff_type == "voronoi": 252 | graph_tuple = self.atoms_to_graph_voronoi( 253 | self.atoms, 254 | atom_to_node_fn, 255 | cutoff_radius, 256 | symmetry_binarize_threshold=0.99, 257 | ) 258 | self.edge_labels = [ 259 | "distance", 260 | "distance_normalized", 261 | "area", 262 | "area_normalized", 263 | "solid_angle", 264 | "C2", 265 | "C3", 266 | "C4", 267 | "C6", 268 | "D1", 269 | "D2", 270 | "D3", 271 | "D4", 272 | "D6", 273 | ] 274 | else: 275 | raise ValueError("cutoff_type not valid, given: %s" % cutoff_type) 276 | 277 | self.nodes, self.positions, self.edges, self.conns, self.conns_offset, self.unitcell = ( 278 | graph_tuple 279 | ) 280 | 281 | for key, val in kwargs.items(): 282 | assert not hasattr(self, key), "Attribute %s is reserved" % key 283 | setattr(self, key, val) 284 | 285 | def remap_nodes(self, atom_to_node_fn): 286 | self.nodes = np.array( 287 | [atom_to_node_fn(n) for n in self.atoms.get_atomic_numbers()] 288 | ) 289 | 290 | @staticmethod 291 | def atoms_to_graph_voronoi( 292 | atoms: ase.atoms, 293 | atom_to_node_fn, 294 | min_solid_angle=None, 295 | symmetry_binarize_threshold=None, 296 | ): 297 | 298 | nodes = [] 299 | connections = [] 300 | connections_offset = [] 301 | edges = [] 302 | 303 | voronoi_cells = voro_tessellate(atoms) 304 | 305 | assert np.all( 306 | atoms.get_pbc() 307 | ), "Voronoi graph only supported for periodic structures" 308 | atom_numbers = atoms.get_atomic_numbers() 309 | atom_positions = atoms.get_positions(wrap=True) 310 | unitcell = atoms.get_cell() 311 | for ii in range(len(atoms)): 312 | nodes.append(atom_to_node_fn(atom_numbers[ii])) 313 | 314 | for cell in voronoi_cells: 315 | total_area = sum(face.area for face in cell.faces) 316 | total_weighted_bond_length = sum( 317 | face.distance * face.solid_angle / (4 * np.pi) for face in cell.faces 318 | ) 319 | # assert np.all(abs(atom_positions[i] - (cell.atom_pos - np.dot(cell.cell_offset, unitcell))) < 1e-4) 320 | for face in cell.faces: 321 | dist = face.distance 322 | area = face.area 323 | normed_area = face.area / total_area 324 | sangle = face.solid_angle 325 | if min_solid_angle and (sangle < min_solid_angle): 326 | continue 327 | connections.append([face.neighbor, cell.atom_idx]) # [from, to] 328 | connections_offset.append( 329 | np.vstack( 330 | (face.neighbor_offset - cell.cell_offset, np.zeros(3, float)) 331 | ) 332 | ) 333 | if symmetry_binarize_threshold: 334 | syms = [ 335 | float(x) 336 | for x in (face.symmetries > symmetry_binarize_threshold) 337 | ] 338 | else: 339 | syms = list(face.symmetries) 340 | edges.append( 341 | [dist, dist / total_weighted_bond_length, area, normed_area, sangle] 342 | + syms 343 | ) 344 | 345 | return ( 346 | np.array(nodes), 347 | atom_positions, 348 | np.array(edges), 349 | np.array(connections), 350 | np.stack(connections_offset, axis=0), 351 | unitcell, 352 | ) 353 | 354 | @staticmethod 355 | def atoms_to_graph_const_cutoff( 356 | atoms: ase.Atoms, 357 | cutoff, 358 | atom_to_node_fn, 359 | self_interaction=False, 360 | cutoff_covalent=False, 361 | ): 362 | 363 | atoms.wrap() 364 | atom_numbers = atoms.get_atomic_numbers() 365 | 366 | if cutoff_covalent: 367 | radii = ase.data.covalent_radii[atom_numbers] * cutoff 368 | else: 369 | radii = [cutoff] * len(atoms) 370 | neighborhood = NeighborList( 371 | radii, skin=0.0, self_interaction=self_interaction, bothways=True 372 | ) 373 | neighborhood.update(atoms) 374 | 375 | nodes = [] 376 | connections = [] 377 | connections_offset = [] 378 | edges = [] 379 | if np.any(atoms.get_pbc()): 380 | atom_positions = atoms.get_positions(wrap=True) 381 | else: 382 | atom_positions = atoms.get_positions(wrap=False) 383 | unitcell = atoms.get_cell() 384 | 385 | for ii in range(len(atoms)): 386 | nodes.append(atom_to_node_fn(atom_numbers[ii])) 387 | 388 | for ii in range(len(atoms)): 389 | neighbor_indices, offset = neighborhood.get_neighbors(ii) 390 | for jj, offs in zip(neighbor_indices, offset): 391 | ii_pos = atom_positions[ii] 392 | jj_pos = atom_positions[jj] + np.dot(offs, unitcell) 393 | dist_vec = ii_pos - jj_pos 394 | dist = np.sqrt(np.dot(dist_vec, dist_vec)) 395 | 396 | connections.append([jj, ii]) 397 | connections_offset.append(np.vstack((offs, np.zeros(3, float)))) 398 | edges.append([dist]) 399 | 400 | if len(edges) == 0: 401 | warnings.warn("Generated graph has zero edges") 402 | edges = np.zeros((0, 1)) 403 | connections = np.zeros((0, 2)) 404 | connections_offset = np.zeros((0, 2, 3)) 405 | else: 406 | connections_offset = np.stack(connections_offset, axis=0) 407 | 408 | return ( 409 | np.array(nodes), 410 | atom_positions, 411 | np.array(edges), 412 | np.array(connections), 413 | connections_offset, 414 | unitcell, 415 | ) 416 | 417 | @staticmethod 418 | def atoms_to_graph_knearest( 419 | atoms: ase.Atoms, num_neighbors, atom_to_node_fn, initial_radius=3.0 420 | ): 421 | 422 | atoms.wrap() 423 | atom_numbers = atoms.get_atomic_numbers() 424 | unitcell = atoms.get_cell() 425 | 426 | for multiplier in range(1, 11): 427 | if multiplier == 10: 428 | raise RuntimeError("Reached maximum radius") 429 | radii = [initial_radius * multiplier] * len(atoms) 430 | neighborhood = NeighborList( 431 | radii, skin=0.0, self_interaction=False, bothways=True 432 | ) 433 | neighborhood.update(atoms) 434 | 435 | nodes = [] 436 | connections = [] 437 | connections_offset = [] 438 | edges = [] 439 | if np.any(atoms.get_pbc()): 440 | atom_positions = atoms.get_positions(wrap=True) 441 | else: 442 | atom_positions = atoms.get_positions(wrap=False) 443 | keep_connections = [] 444 | keep_connections_offset = [] 445 | keep_edges = [] 446 | 447 | for ii in range(len(atoms)): 448 | nodes.append(atom_to_node_fn(atom_numbers[ii])) 449 | 450 | early_exit = False 451 | for ii in range(len(atoms)): 452 | this_edges = [] 453 | this_connections = [] 454 | this_connections_offset = [] 455 | neighbor_indices, offset = neighborhood.get_neighbors(ii) 456 | if len(neighbor_indices) < num_neighbors: 457 | # Not enough neigbors, so exit and increase radius 458 | early_exit = True 459 | break 460 | for jj, offs in zip(neighbor_indices, offset): 461 | ii_pos = atom_positions[ii] 462 | jj_pos = atom_positions[jj] + np.dot(offs, unitcell) 463 | dist_vec = ii_pos - jj_pos 464 | dist = np.sqrt(np.dot(dist_vec, dist_vec)) 465 | 466 | this_connections.append([jj, ii]) # from, to 467 | this_connections_offset.append( 468 | np.vstack((offs, np.zeros(3, float))) 469 | ) 470 | this_edges.append([dist]) 471 | edges.append(np.array(this_edges)) 472 | connections.append(np.array(this_connections)) 473 | connections_offset.append(np.stack(this_connections_offset, axis=0)) 474 | if early_exit: 475 | continue 476 | else: 477 | for e, c, o in zip(edges, connections, connections_offset): 478 | # Keep only num_neighbors closest indices 479 | keep_ind = np.argsort(e[:, 0])[0:num_neighbors] 480 | keep_edges.append(e[keep_ind]) 481 | keep_connections.append(c[keep_ind]) 482 | keep_connections_offset.append(o[keep_ind]) 483 | break 484 | return ( 485 | np.array(nodes), 486 | atom_positions, 487 | np.concatenate(keep_edges), 488 | np.concatenate(keep_connections), 489 | np.concatenate(keep_connections_offset), 490 | unitcell, 491 | ) 492 | 493 | 494 | def get_voro_adjacency(atoms: ase.Atoms, min_solid_angle=None): 495 | 496 | # Count how many times each atom is a neighbor of any other atom 497 | voronoi_cells = voro_tessellate(atoms) 498 | adjacency_count = np.zeros((len(atoms), len(atoms)), dtype=int) 499 | assert np.all( 500 | np.array([v.atom_idx for v in voronoi_cells]) == np.arange(len(atoms)) 501 | ) 502 | for cell in voronoi_cells: 503 | for face in cell.faces: 504 | if min_solid_angle and (face.solid_angle < min_solid_angle): 505 | continue 506 | adjacency_count[face.neighbor, cell.atom_idx] += 1 507 | 508 | # Convert adjacency count matrix to edge list 509 | edge_list = np.transpose(np.nonzero(adjacency_count)).astype(np.int32) 510 | graph_dist = graphdistance(edge_list, len(atoms)) 511 | return adjacency_count, graph_dist 512 | 513 | 514 | def select_wfilter(con, filterobj): 515 | if filterobj is None: 516 | for row in con.select(): 517 | yield row 518 | elif isinstance(filterobj, str): 519 | for row in con.select(filterobj): 520 | yield row 521 | else: 522 | for row in con.select(): 523 | if filterobj(row): 524 | yield row 525 | -------------------------------------------------------------------------------- /src/msgnet/defaults.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import layers 3 | 4 | datadir = "/tmp/" 5 | 6 | 7 | def sh_softplus(x): 8 | """sh_softplus 9 | shifted softplus function log(1+exp(x))-log(2) 10 | 11 | :param x: 12 | """ 13 | return tf.nn.softplus(x) - tf.log(2.0) 14 | 15 | 16 | def mlp(x, hidden_units, activation=tf.tanh, last_activation=tf.identity, **kwargs): 17 | var = x 18 | for i, num_units in enumerate(hidden_units[:-1]): 19 | var = layers.fully_connected(var, num_units, activation_fn=activation, **kwargs) 20 | 21 | var = layers.fully_connected( 22 | var, hidden_units[-1], activation_fn=last_activation, **kwargs 23 | ) 24 | return var 25 | 26 | 27 | nonlinearity = sh_softplus 28 | initializer = tf.variance_scaling_initializer( 29 | scale=1.0, mode="fan_avg", distribution="uniform" 30 | ) 31 | -------------------------------------------------------------------------------- /src/msgnet/msgpassing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import tensorflow as tf 6 | import os 7 | import msgnet 8 | from tensorflow.contrib import layers 9 | 10 | 11 | def compute_messages( 12 | nodes, 13 | conn, 14 | edges, 15 | message_fn, 16 | act_fn, 17 | include_receiver=True, 18 | include_sender=True, 19 | only_messages=False, 20 | mean_messages=True, 21 | ): 22 | """ 23 | :param nodes: (n_nodes, n_node_features) tensor of nodes, float32. 24 | :param conn: (n_edges, 2) tensor of indices indicating an edge between nodes at those indices, [from, to] int32. 25 | :param edges: (n_edges, n_edge_features) tensor of edge features, float32. 26 | :param message_fn: message function, will be called with two inputs with shapes (n_edges, K*n_node_features), (n_edges,n_edge_features), where K is 2 if include_receiver=True and 1 otherwise, and must return a tensor of size (n_edges, n_output) 27 | :param act_fn: A pointwise activation function applied after the sum. 28 | :param include_receiver: Include receiver node in computation of messages. 29 | :param include_sender: Include sender node in computation of messages. 30 | :param only_messages: Do not sum up messages 31 | :param mean_messages: If true compute average over messages (instead of sum) 32 | :return: (n_edges, n_output) if only_messages is True, otherwise (n_nodes, n_output) Sum of messages arriving at each node. 33 | """ 34 | n_nodes = tf.shape(nodes)[0] 35 | n_node_features = nodes.get_shape()[1].value 36 | n_edge_features = edges.get_shape()[1].value 37 | 38 | if include_receiver and include_sender: 39 | # Use both receiver and sender node features in message computation 40 | message_inputs = tf.gather(nodes, conn) # n_edges, 2, n_node_features 41 | reshaped = tf.reshape(message_inputs, (-1, 2 * n_node_features)) 42 | elif include_sender: # Only use sender node features (index=0) 43 | message_inputs = tf.gather(nodes, conn[:, 0]) # n_edges, n_node_features 44 | reshaped = message_inputs 45 | elif include_receiver: # Only use receiver node features (index=1) 46 | message_inputs = tf.gather(nodes, conn[:, 1]) # n_edges, n_node_features 47 | reshaped = message_inputs 48 | else: 49 | raise ValueError( 50 | "Messages must include at least one of sender and receiver nodes" 51 | ) 52 | messages = message_fn(reshaped, edges) # n_edges, n_output 53 | 54 | if only_messages: 55 | return messages 56 | 57 | idx_dest = conn[:, 1] 58 | if mean_messages: 59 | # tf.bincount not supported on GPU in TF 1.4, so do this instead 60 | count = tf.unsorted_segment_sum( 61 | tf.ones_like(idx_dest, dtype=tf.float32), idx_dest, n_nodes 62 | ) 63 | count = tf.maximum(count, 1) # Avoid division by zero 64 | msg_pool = tf.unsorted_segment_sum( 65 | messages, idx_dest, n_nodes 66 | ) / tf.expand_dims(count, -1) 67 | else: 68 | msg_pool = tf.unsorted_segment_sum(messages, idx_dest, n_nodes) 69 | return act_fn(msg_pool) 70 | 71 | 72 | def create_dtnn_msg_function(num_outputs, num_hidden_neurons, **kwargs): 73 | """create_dtnn_msg_function 74 | Creates the message function from Deep Tensor Neural Networks (DTNN) 75 | 76 | :param num_outputs: output dimension 77 | :param num_hidden_neurons: number of hidden units 78 | :param **kwargs: 79 | """ 80 | 81 | def func(nodes, edges): 82 | num_node_features = nodes.get_shape()[1].value 83 | num_edge_features = edges.get_shape()[1].value 84 | Wcf = tf.get_variable( 85 | "W_atom_c", 86 | (num_node_features, num_hidden_neurons), 87 | initializer=layers.xavier_initializer(False), 88 | ) 89 | bcf = tf.get_variable( 90 | "b_atom_c", (num_hidden_neurons,), initializer=tf.constant_initializer(0) 91 | ) 92 | Wdf = tf.get_variable( 93 | "W_dist", 94 | (num_edge_features, num_hidden_neurons), 95 | initializer=layers.xavier_initializer(False), 96 | ) 97 | bdf = tf.get_variable( 98 | "b_dist", (num_hidden_neurons,), initializer=tf.constant_initializer(0) 99 | ) 100 | Wfc = tf.get_variable( 101 | "W_hidden_to_c", 102 | (num_hidden_neurons, num_node_features), 103 | initializer=layers.xavier_initializer(False), 104 | ) 105 | 106 | term1 = tf.matmul(nodes, Wcf) + bcf 107 | term2 = tf.matmul(edges, Wdf) + bdf 108 | output = tf.tanh(tf.matmul(term1 * term2, Wfc)) 109 | return output 110 | 111 | return func 112 | 113 | 114 | def create_msg_function(num_outputs, **kwargs): 115 | """create_msg_function 116 | Creates the message function used in the SchNet model 117 | 118 | :param num_outputs: number of output units 119 | :param **kwargs: 120 | """ 121 | 122 | def func(nodes, edges): 123 | tf.add_to_collection("msg_input_nodes", nodes) 124 | tf.add_to_collection("msg_input_edges", edges) 125 | with tf.variable_scope("gates"): 126 | gates = msgnet.defaults.mlp( 127 | edges, 128 | [num_outputs, num_outputs], 129 | last_activation=msgnet.defaults.nonlinearity, 130 | activation=msgnet.defaults.nonlinearity, 131 | weights_initializer=msgnet.defaults.initializer, 132 | ) 133 | tf.add_to_collection("msg_gates", gates) 134 | with tf.variable_scope("pre"): 135 | pre = layers.fully_connected( 136 | nodes, 137 | num_outputs, 138 | activation_fn=tf.identity, 139 | weights_initializer=msgnet.defaults.initializer, 140 | biases_initializer=None, 141 | **kwargs 142 | ) 143 | tf.add_to_collection("msg_pregates", pre) 144 | output = pre * gates 145 | tf.add_to_collection("msg_outputs", output) 146 | return output 147 | 148 | return func 149 | 150 | 151 | def edge_update(node_states, edge_states): 152 | """edge_update 153 | 154 | :param node_states: Tensor of dimension [number of nodes, node embedding size] 155 | :param edge_states: Tensor of dimension [number of edges, edge embedding size] 156 | """ 157 | edge_states_len = int(edge_states.get_shape()[1]) 158 | nodes_states_len = int(node_states.get_shape()[1]) 159 | combined = tf.concat((node_states, edge_states), axis=1) 160 | new_edge = msgnet.defaults.mlp( 161 | combined, 162 | [nodes_states_len, nodes_states_len // 2], 163 | activation=msgnet.defaults.nonlinearity, 164 | weights_initializer=msgnet.defaults.initializer, 165 | ) 166 | return new_edge 167 | 168 | 169 | class MsgpassingNetwork: 170 | def __init__( 171 | self, 172 | n_node_features=1, 173 | n_edge_features=1, 174 | num_passes=3, 175 | embedding_shape=None, 176 | edge_feature_expand=None, 177 | msg_share_weights=False, 178 | use_edge_updates=False, 179 | readout_fn=None, 180 | edge_output_fn=None, 181 | avg_msg=False, 182 | target_mean=0.0, 183 | target_std=1.0, 184 | ): 185 | """__init__ 186 | 187 | :param n_node_features: Number of input node features 188 | :param n_edge_features: Number of inpute edge features 189 | :param num_passes: Number of interaction pases 190 | :param embedding_shape: Shape of the atomic element embedding e.g. (num_species, embedding_size) 191 | :param edge_feature_expand: List of tuples for expanding edge features [(start, step, end)] 192 | :param msg_share_weights: Share weights between the interaction layers 193 | :param use_edge_updates: If true also update edges between interaction passes 194 | :param readout_fn: An instance of the ReadoutFunction class 195 | :param edge_output_fn: An instace of the EdgeOutputFunction class 196 | :param avg_msg: If true interaction messages will be averaged rather than summed 197 | :param target_mean: Normalization constant used for training on appropriate scale 198 | :oaram target_std: Normalization constant used for training on appropriate scale 199 | """ 200 | 201 | # Symbolic input variables 202 | if embedding_shape is not None: 203 | self.sym_nodes = tf.placeholder(np.int32, shape=(None,), name="sym_nodes") 204 | else: 205 | self.sym_nodes = tf.placeholder( 206 | np.float32, shape=(None, n_node_features), name="sym_nodes" 207 | ) 208 | self.sym_edges = tf.placeholder( 209 | np.float32, shape=(None, n_edge_features), name="sym_edges" 210 | ) 211 | self.readout_fn = readout_fn 212 | self.edge_output_fn = edge_output_fn 213 | self.sym_conn = tf.placeholder(np.int32, shape=(None, 2), name="sym_conn") 214 | self.sym_segments = tf.placeholder( 215 | np.int32, shape=(None,), name="sym_segments_map" 216 | ) 217 | self.sym_set_len = tf.placeholder(np.int32, shape=(None,), name="sym_set_len") 218 | 219 | self.input_symbols = { 220 | "nodes": self.sym_nodes, 221 | "edges": self.sym_edges, 222 | "connections": self.sym_conn, 223 | "segments": self.sym_segments, 224 | "set_lengths": self.sym_set_len, 225 | } 226 | 227 | # Setup constants for normalizing/denormalizing graph level outputs 228 | self.sym_target_mean = tf.get_variable( 229 | "target_mean", 230 | dtype=tf.float32, 231 | shape=[], 232 | trainable=False, 233 | initializer=tf.constant_initializer(target_mean), 234 | ) 235 | self.sym_target_std = tf.get_variable( 236 | "target_std", 237 | dtype=tf.float32, 238 | shape=[], 239 | trainable=False, 240 | initializer=tf.constant_initializer(target_std), 241 | ) 242 | 243 | if edge_feature_expand is not None: 244 | init_edges = msgnet.utilities.gaussian_expansion( 245 | self.sym_edges, edge_feature_expand 246 | ) 247 | else: 248 | init_edges = self.sym_edges 249 | 250 | if embedding_shape is not None: 251 | # Setup embedding matrix 252 | stddev = np.sqrt(1.0 / np.sqrt(embedding_shape[1])) 253 | self.species_embedding = tf.Variable( 254 | initial_value=np.random.standard_normal(embedding_shape) * stddev, 255 | trainable=True, 256 | dtype=np.float32, 257 | name="species_embedding_matrix", 258 | ) 259 | hidden_state0 = tf.gather(self.species_embedding, self.sym_nodes) 260 | else: 261 | hidden_state0 = self.sym_nodes 262 | 263 | hidden_state = hidden_state0 264 | 265 | hidden_state_len = int(hidden_state.get_shape()[1]) 266 | 267 | # Setup edge update function 268 | if use_edge_updates: 269 | edge_msg_fn = edge_update 270 | edges = compute_messages( 271 | hidden_state, 272 | self.sym_conn, 273 | init_edges, 274 | edge_msg_fn, 275 | tf.identity, 276 | include_receiver=True, 277 | include_sender=True, 278 | only_messages=True, 279 | ) 280 | else: 281 | edges = init_edges 282 | 283 | # Setup interaction messages 284 | msg_fn = create_msg_function(hidden_state_len) 285 | act_fn = tf.identity 286 | for i in range(num_passes): 287 | if msg_share_weights: 288 | scope_suffix = "" 289 | reuse = i > 0 290 | else: 291 | scope_suffix = "%d" % i 292 | reuse = False 293 | with tf.variable_scope("msg" + scope_suffix, reuse=reuse): 294 | sum_msg = compute_messages( 295 | hidden_state, 296 | self.sym_conn, 297 | edges, 298 | msg_fn, 299 | act_fn, 300 | include_receiver=False, 301 | mean_messages=avg_msg, 302 | ) 303 | with tf.variable_scope("update" + scope_suffix, reuse=reuse): 304 | hidden_state += msgnet.defaults.mlp( 305 | sum_msg, 306 | [hidden_state_len, hidden_state_len], 307 | activation=msgnet.defaults.nonlinearity, 308 | weights_initializer=msgnet.defaults.initializer, 309 | ) 310 | with tf.variable_scope("edge_update" + scope_suffix, reuse=reuse): 311 | if use_edge_updates and (i < (num_passes - 1)): 312 | edges = compute_messages( 313 | hidden_state, 314 | self.sym_conn, 315 | edges, 316 | edge_msg_fn, 317 | tf.identity, 318 | include_receiver=True, 319 | include_sender=True, 320 | only_messages=True, 321 | ) 322 | 323 | nodes_out = tf.identity(hidden_state, name="nodes_out") 324 | 325 | self.nodes_out = nodes_out 326 | 327 | # Setup readout function 328 | with tf.variable_scope("readout_edge"): 329 | if self.edge_output_fn is not None: 330 | self.edge_out = edge_output_fn(edges) 331 | with tf.variable_scope("readout_graph"): 332 | if self.readout_fn is not None: 333 | graph_out = self.readout_fn(nodes_out, self.sym_segments) 334 | self.graph_out_normalized = tf.identity( 335 | graph_out, name="graph_out_normalized" 336 | ) 337 | 338 | # Denormalize graph_out for making predictions on original scale 339 | if self.readout_fn is not None: 340 | if self.readout_fn.is_sum: 341 | mean = self.sym_target_mean * tf.expand_dims( 342 | tf.cast(self.sym_set_len, tf.float32), -1 343 | ) 344 | else: 345 | mean = self.sym_target_mean 346 | self.graph_out = tf.add( 347 | graph_out * self.sym_target_std, mean, name="graph_out" 348 | ) 349 | 350 | self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=24, max_to_keep=3) 351 | 352 | def save(self, session, destination, global_step): 353 | self.saver.save(session, destination, global_step=global_step) 354 | 355 | def load(self, session, path): 356 | self.saver.restore(session, path) 357 | 358 | def get_nodes_out(self): 359 | return self.nodes_out 360 | 361 | def get_graph_out(self): 362 | if self.readout_fn is None: 363 | raise NotImplementedError("No readout function given") 364 | return self.graph_out 365 | 366 | def get_graph_out_normalized(self): 367 | if self.readout_fn is None: 368 | raise NotImplementedError("No readout function given") 369 | return self.graph_out_normalized 370 | 371 | def get_normalization(self): 372 | return self.sym_target_mean, self.sym_target_std 373 | 374 | def get_readout_function(self): 375 | return self.readout_fn 376 | 377 | def get_edges_out(self): 378 | if self.edge_output_fn is None: 379 | raise NotImplementedError("No edges output network given") 380 | return self.edge_out 381 | 382 | def get_input_symbols(self): 383 | return self.input_symbols 384 | -------------------------------------------------------------------------------- /src/msgnet/readout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import msgnet 4 | 5 | 6 | def set2set(nodes_in, segments, num_passes=2): 7 | """set2set [https://arxiv.org/abs/1511.06391] 8 | Based on implementation from DeepChem (PandeLab) 9 | 10 | :param nodes_in: (num_nodes, n_node_features) tensor 11 | :param segments: (num_nodes, ) tensor, places each node into (batch_idx,), (generated by set_len_to_segments function) 12 | :returns: (batch_size, output_size) tensor 13 | """ 14 | batch_size = segments[-1] + 1 15 | num_features = nodes_in.get_shape()[1].value 16 | assert num_features is not None 17 | bias_init = np.concatenate( 18 | ( 19 | np.zeros(num_features, dtype=np.float32), 20 | np.ones(num_features, dtype=np.float32), 21 | np.zeros(num_features, dtype=np.float32), 22 | np.zeros(num_features, dtype=np.float32), 23 | ) 24 | ) 25 | lstm_b = tf.get_variable("b_lstm", initializer=bias_init) 26 | lstm_U = tf.get_variable( 27 | "U_lstm", 28 | (2 * num_features, 4 * num_features), 29 | initializer=tf.contrib.layers.xavier_initializer(False), 30 | ) 31 | 32 | def lstm_step(h, c): 33 | z = tf.nn.xw_plus_b(h, lstm_U, lstm_b) 34 | i = tf.nn.sigmoid(z[:, :num_features]) 35 | f = tf.nn.sigmoid(z[:, num_features : 2 * num_features]) 36 | o = tf.nn.sigmoid(z[:, 2 * num_features : 3 * num_features]) 37 | z3 = z[:, 3 * num_features :] 38 | c_out = f * c + i * tf.nn.tanh(z3) 39 | h_out = o * tf.nn.tanh(c_out) 40 | 41 | return h_out, c_out 42 | 43 | q = tf.zeros((batch_size, num_features)) 44 | lstm_c = tf.zeros((batch_size, num_features)) 45 | 46 | mi = nodes_in 47 | for i in range(num_passes): 48 | e_node = tf.reduce_sum(mi * tf.gather(q, segments), -1) # (total_nodes, ) 49 | # Subtract maximum for numerical stability in softmax 50 | e_max = tf.segment_max( 51 | e_node, segments 52 | ) # (bs, ) # Does not work on GPU in TF 1.4 53 | e_max_e = tf.gather(e_max, segments) # (total_nodes, ) 54 | e_node -= e_max_e 55 | # softmax 56 | # Use unsorted_segment_sum because segment_sum does not work on GPU in TF 1.4 57 | a_node = tf.exp(e_node) / tf.gather( 58 | tf.unsorted_segment_sum(tf.exp(e_node), segments, segments[-1] + 1), 59 | segments, 60 | ) # (total_nodes,) 61 | a_node = tf.expand_dims(a_node, -1) # (total_nodes, 1) 62 | # Use unsorted_segment_sum because segment_sum does not work on GPU in TF 1.4 63 | r = tf.unsorted_segment_sum( 64 | mi * a_node, segments, segments[-1] + 1 65 | ) # (bs, num_features) 66 | q_star = tf.concat((q, r), axis=1) 67 | q, lstm_c = lstm_step(q_star, lstm_c) 68 | 69 | num_features = q_star.get_shape()[1].value 70 | return q_star 71 | 72 | 73 | class EdgeOutputFunction: 74 | def __init__(self, output_size=1): 75 | self.output_size = output_size 76 | 77 | def __call__(self, edges): 78 | edge_out = tf.identity( 79 | msgnet.defaults.mlp( 80 | edges, 81 | [edges.get_shape()[1].value, self.output_size], 82 | activation=msgnet.defaults.nonlinearity, 83 | ), 84 | "edge_out", 85 | ) 86 | return edge_out 87 | 88 | 89 | class ReadoutFunction: 90 | """ReadoutFunction 91 | Base class readout function """ 92 | 93 | is_sum = True 94 | 95 | def __init__(self, output_size=1): 96 | self.output_size = output_size 97 | 98 | def __call__(self, nodes, segments): 99 | raise NotImplementedError() 100 | 101 | 102 | class ReadoutAvgscalar(ReadoutFunction): 103 | is_sum = False 104 | 105 | def __call__(self, nodes, segments): 106 | nodes_size = int(nodes.get_shape()[1]) 107 | pre_sum_fn = lambda x: msgnet.defaults.mlp( 108 | x, 109 | [nodes_size // 2, self.output_size], 110 | activation=msgnet.defaults.nonlinearity, 111 | weights_initializer=msgnet.defaults.initializer, 112 | ) 113 | pre_sum = tf.identity(pre_sum_fn(nodes), name="node_contribution") 114 | tf.add_to_collection("node_contribution", pre_sum) 115 | graph_out = tf.segment_mean(pre_sum, segments) 116 | return graph_out 117 | 118 | 119 | class ReadoutSumscalar(ReadoutFunction): 120 | is_sum = True 121 | 122 | def __call__(self, nodes, segments): 123 | nodes_size = int(nodes.get_shape()[1]) 124 | pre_sum_fn = lambda x: msgnet.defaults.mlp( 125 | x, 126 | [nodes_size // 2, self.output_size], 127 | activation=msgnet.defaults.nonlinearity, 128 | weights_initializer=msgnet.defaults.initializer, 129 | ) 130 | pre_sum = tf.identity(pre_sum_fn(nodes), name="node_contribution") 131 | tf.add_to_collection("node_contribution", pre_sum) 132 | graph_out = tf.segment_sum(pre_sum, segments) 133 | return graph_out 134 | 135 | 136 | class ReadoutSumvector(ReadoutFunction): 137 | is_sum = True 138 | 139 | def __call__(self, nodes, segments): 140 | nodes_size = int(nodes.get_shape()[1]) 141 | pre_sum_fn = lambda x: msgnet.defaults.mlp( 142 | x, 143 | [nodes_size, nodes_size // 2], 144 | activation=msgnet.defaults.nonlinearity, 145 | weights_initializer=msgnet.defaults.initializer, 146 | ) 147 | post_sum_fn = lambda x: msgnet.defaults.mlp( 148 | x, 149 | [nodes_size // 2, self.output_size], 150 | activation=msgnet.defaults.nonlinearity, 151 | weights_initializer=msgnet.defaults.initializer, 152 | ) 153 | pre_sum = tf.identity(pre_sum_fn(nodes), name="node_contribution") 154 | tf.add_to_collection("node_contribution", pre_sum) 155 | post_sum = tf.segment_sum(pre_sum, segments) 156 | graph_out = post_sum_fn(post_sum) 157 | return graph_out 158 | 159 | 160 | class ReadoutAvgvector(ReadoutFunction): 161 | is_sum = True 162 | 163 | def __call__(self, nodes, segments): 164 | nodes_size = int(nodes.get_shape()[1]) 165 | pre_sum_fn = lambda x: msgnet.defaults.mlp( 166 | x, 167 | [nodes_size, nodes_size // 2], 168 | activation=msgnet.defaults.nonlinearity, 169 | weights_initializer=msgnet.defaults.initializer, 170 | ) 171 | post_sum_fn = lambda x: msgnet.defaults.mlp( 172 | x, 173 | [nodes_size // 2, self.output_size], 174 | activation=msgnet.defaults.nonlinearity, 175 | weights_initializer=msgnet.defaults.initializer, 176 | ) 177 | pre_sum = tf.identity(pre_sum_fn(nodes), name="node_contribution") 178 | tf.add_to_collection("node_contribution", pre_sum) 179 | post_sum = tf.segment_mean(pre_sum, segments) 180 | graph_out = post_sum_fn(post_sum) 181 | return graph_out 182 | 183 | 184 | class ReadoutSet2set(ReadoutFunction): 185 | is_sum = False 186 | 187 | def __call__(self, nodes, segments): 188 | nodes_size = int(nodes.get_shape()[1]) 189 | graph_out = set2set(nodes, segments, num_passes=2) 190 | graph_out = msgnet.defaults.mlp(graph_out, [nodes_size // 2, self.output_size]) 191 | return graph_out 192 | -------------------------------------------------------------------------------- /src/msgnet/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import msgnet 4 | 5 | 6 | class Trainer: 7 | def __init__(self, model, batchloader, initial_lr=1e-4, batch_size=32): 8 | self.model = model 9 | self.sym_learning_rate = tf.placeholder( 10 | tf.float32, shape=[], name="learning_rate" 11 | ) 12 | 13 | self.initial_lr = initial_lr 14 | 15 | self.input_symbols = self.setup_input_symbols() 16 | self.cost = self.setup_total_cost() 17 | self.train_op = self.setup_train_op() 18 | self.metric_tensors = self.setup_metrics() 19 | self.batchloader = batchloader 20 | self.batch_size = batch_size 21 | 22 | def get_learning_rate(self, step): 23 | learning_rate = self.initial_lr * (0.96 ** (step / 100000)) 24 | return learning_rate 25 | 26 | def setup_metrics(self): 27 | return {} 28 | 29 | def setup_input_symbols(self): 30 | input_symbols = self.model.get_input_symbols() 31 | return input_symbols 32 | 33 | def setup_train_op(self): 34 | optimizer = tf.train.AdamOptimizer(self.sym_learning_rate) 35 | gradients = optimizer.compute_gradients(self.cost) 36 | train_op = optimizer.apply_gradients(gradients, name="train_op") 37 | return train_op 38 | 39 | def setup_total_cost(self): 40 | raise NotImplementedError() 41 | 42 | def step(self, session, step): 43 | input_data = self.batchloader.get_train_batch(self.batch_size) 44 | feed_dict = {} 45 | for key in self.input_symbols.keys(): 46 | feed_dict[self.input_symbols[key]] = input_data[key] 47 | feed_dict[self.sym_learning_rate] = self.get_learning_rate(step) 48 | session.run([self.train_op], feed_dict=feed_dict) 49 | 50 | 51 | class GraphOutputTrainer(Trainer): 52 | def setup_input_symbols(self): 53 | input_symbols = self.model.get_input_symbols() 54 | output_size = self.model.get_graph_out().get_shape()[1].value 55 | 56 | self.sym_edge_targets = tf.placeholder( 57 | tf.float32, shape=(None, 1), name="sym_edge_targets" 58 | ) 59 | self.sym_graph_targets = tf.placeholder( 60 | tf.float32, shape=(None, output_size), name="sym_graph_targets" 61 | ) 62 | 63 | input_symbols.update({"graph_targets": self.sym_graph_targets}) 64 | 65 | return input_symbols 66 | 67 | def setup_metrics(self): 68 | graph_error = self.model.get_graph_out() - self.input_symbols["graph_targets"] 69 | 70 | metric_tensors = {"graph_error": graph_error} 71 | return metric_tensors 72 | 73 | def setup_total_cost(self): 74 | sym_graph_targets = self.input_symbols["graph_targets"] 75 | graph_cost = self.get_cost_graph_target(sym_graph_targets, self.model) 76 | total_cost = graph_cost 77 | return total_cost 78 | 79 | def evaluate_metrics(self, session, datahandler, prefix=""): 80 | target_mae = 0 81 | target_mse = 0 82 | num_graphs = 0 83 | for input_data in datahandler.get_test_batches(self.batch_size): 84 | feed_dict = {} 85 | for key in self.input_symbols.keys(): 86 | feed_dict[self.input_symbols[key]] = input_data[key] 87 | syms = [self.metric_tensors["graph_error"]] 88 | graph_error, = session.run(syms, feed_dict=feed_dict) 89 | target_mae += np.sum(np.abs(graph_error)) 90 | target_mse += np.sum(np.square(graph_error)) 91 | num_graphs += graph_error.shape[0] 92 | 93 | if prefix: 94 | prefix += "_" 95 | metrics = { 96 | prefix + "mae": target_mae / num_graphs, 97 | prefix + "rmse": np.sqrt(target_mse / num_graphs), 98 | } 99 | 100 | return metrics 101 | 102 | @staticmethod 103 | def get_cost_graph_target(sym_graph_target, model): 104 | target_mean, target_std = model.get_normalization() 105 | sym_set_len = model.get_input_symbols()["set_lengths"] 106 | target_normalizing = 1.0 / target_std 107 | if model.get_readout_function().is_sum: 108 | # When target is a sum of K numbers we normalize the target to zero mean and variance K 109 | graph_target_normalized = ( 110 | sym_graph_target 111 | - target_mean * tf.cast(tf.expand_dims(sym_set_len, -1), tf.float32) 112 | ) * target_normalizing 113 | else: 114 | # When target is an average normalize to zero mean and unit variance 115 | graph_target_normalized = ( 116 | sym_graph_target - target_mean 117 | ) * target_normalizing 118 | 119 | graph_cost = tf.reduce_mean( 120 | (model.get_graph_out_normalized() - graph_target_normalized) ** 2, 121 | name="graph_cost", 122 | ) 123 | 124 | return graph_cost 125 | 126 | 127 | class EdgeOutputTrainer(GraphOutputTrainer): 128 | def __init__( 129 | self, 130 | model, 131 | batchloader, 132 | edge_output_expand, 133 | initial_lr=1e-4, 134 | edge_cost_weight=0.5, 135 | ): 136 | self.edge_cost_weight = edge_cost_weight 137 | self.edge_output_expand = edge_output_expand 138 | 139 | super().__init__(model, batchloader, initial_lr=initial_lr) 140 | 141 | def setup_input_symbols(self): 142 | input_symbols = self.model.get_input_symbols() 143 | output_size = self.model.get_graph_out().get_shape()[1].value 144 | 145 | self.sym_edge_targets = tf.placeholder( 146 | tf.float32, shape=(None, 1), name="sym_edge_targets" 147 | ) 148 | self.sym_graph_targets = tf.placeholder( 149 | tf.float32, shape=(None, output_size), name="sym_graph_targets" 150 | ) 151 | 152 | input_symbols.update( 153 | { 154 | "edges_targets": self.sym_edge_targets, 155 | "graph_targets": self.sym_graph_targets, 156 | } 157 | ) 158 | 159 | return input_symbols 160 | 161 | def setup_metrics(self): 162 | graph_error = self.model.get_graph_out() - self.input_symbols["graph_targets"] 163 | sym_edge_targets = self.input_symbols["edges_targets"] 164 | edge_error = self.get_cost_edge_target(sym_edge_targets, self.model) 165 | 166 | metric_tensors = {"graph_error": graph_error, "edge_error": edge_error} 167 | return metric_tensors 168 | 169 | def setup_total_cost(self): 170 | sym_graph_targets = self.input_symbols["graph_targets"] 171 | sym_edge_targets = self.input_symbols["edges_targets"] 172 | graph_cost = self.get_cost_graph_target(sym_graph_targets, self.model) 173 | edge_cost = tf.reduce_mean( 174 | self.get_cost_edge_target(sym_edge_targets, self.model), name="edge_cost" 175 | ) 176 | total_cost = ( 177 | 1 - self.edge_cost_weight 178 | ) * graph_cost + self.edge_cost_weight * edge_cost 179 | return total_cost 180 | 181 | def evaluate_metrics(self, session, datahandler, prefix=""): 182 | target_mae = 0 183 | target_mse = 0 184 | edge_kl = 0 185 | num_graphs = 0 186 | num_edges = 0 187 | for input_data in datahandler.get_test_batches(self.batch_size): 188 | feed_dict = {} 189 | for key in self.input_symbols.keys(): 190 | feed_dict[self.input_symbols[key]] = input_data[key] 191 | syms = [ 192 | self.metric_tensors["graph_error"], 193 | self.metric_tensors["edge_error"], 194 | ] 195 | graph_error, edge_error = session.run(syms, feed_dict=feed_dict) 196 | edge_kl += np.sum(edge_error) 197 | target_mae += np.sum(np.abs(graph_error)) 198 | target_mse += np.sum(np.square(graph_error)) 199 | num_graphs += graph_error.shape[0] 200 | num_edges += edge_error.shape[0] 201 | 202 | if prefix: 203 | prefix += "_" 204 | metrics = { 205 | prefix + "mae": target_mae / num_graphs, 206 | prefix + "rmse": np.sqrt(target_mse / num_graphs), 207 | prefix + "kl": edge_kl / num_edges, 208 | } 209 | 210 | return metrics 211 | 212 | def get_cost_edge_target(self, sym_edge_target, model): 213 | edge_expanded = msgnet.utilities.gaussian_expansion( 214 | sym_edge_target, self.edge_output_expand 215 | ) 216 | edge_expanded_normalised = tf.divide( 217 | edge_expanded, tf.reduce_sum(edge_expanded, axis=1, keepdims=True) 218 | ) 219 | p_entropy = msgnet.utilities.entropy(edge_expanded_normalised) 220 | edge_cost = ( 221 | tf.nn.softmax_cross_entropy_with_logits_v2( 222 | labels=edge_expanded_normalised, logits=model.get_edges_out() 223 | ) 224 | - p_entropy 225 | ) 226 | 227 | return edge_cost 228 | -------------------------------------------------------------------------------- /src/msgnet/utilities.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import sklearn 4 | import sklearn.model_selection 5 | import math 6 | import os 7 | import scipy 8 | import tensorflow as tf 9 | import itertools 10 | 11 | 12 | def tf_compute_distances( 13 | positions, unitcells, connections, connections_offset, segments 14 | ): 15 | """tf_compute_distances 16 | 17 | :param positions: (N_atoms, 3) tensor 18 | :param unitcells: (N_structures, 3, 3) tensor 19 | :param connections: (N_edges, 2) tensor 20 | :param connections_offset: (N_edges, 2, 3) tensor 21 | :param segments: N_atoms tensor 22 | """ 23 | ## Compute absolute positions 24 | 25 | # Gather unit cells. We can assume that the unit cell is the same for sender and receiver 26 | unitcell_inds = tf.gather(segments, connections[:, 0]) # N_edges 27 | cells = tf.gather(unitcells, unitcell_inds) # N_edges, 3, 3 28 | offsets = tf.matmul(connections_offset, cells) # N_edges, 2, 3 29 | pos = tf.gather(positions, connections) # N_edges, 2, 3 30 | abs_pos = pos + offsets 31 | diffs = abs_pos[:, 0, :] - abs_pos[:, 1, :] # N_edges, 3 32 | dist = tf.sqrt(tf.reduce_sum(tf.square(diffs), axis=1, keepdims=True)) # N_edges, 3 33 | return dist 34 | 35 | 36 | def entropy(p): 37 | """entropy 38 | Compute entropy of normalised discrete probability distribution 39 | :param p: (batch_size, Np) tensor 40 | """ 41 | ent = tf.where(p > np.finfo(np.float32).eps, -p * tf.log(p), tf.zeros_like(p)) 42 | ent = tf.reduce_sum(ent, axis=1) 43 | return ent 44 | 45 | 46 | def gaussian_expansion(input_x, expand_params): 47 | """gaussian_expansion 48 | 49 | :param input_x: (num_edges, n_features) tensor 50 | :param expand_params: list of None or (start, step, stop) tuples 51 | :returns: (num_edges, ``ceil((stop - start)/step)``) tensor 52 | """ 53 | feat_list = tf.unstack(input_x, axis=1) 54 | expanded_list = [] 55 | for step_tuple, feat in itertools.zip_longest(expand_params, feat_list): 56 | assert feat is not None, "Too many expansion parameters given" 57 | if step_tuple: 58 | start, step, stop = step_tuple 59 | feat_expanded = tf.expand_dims(feat, axis=1) 60 | sigma = step 61 | mu = np.arange(start, stop, step) 62 | expanded_list.append( 63 | tf.exp(-((feat_expanded - mu) ** 2) / (2.0 * sigma ** 2)) 64 | ) 65 | else: 66 | expanded_list.append(tf.expand_dims(feat, 1)) 67 | return tf.concat(expanded_list, axis=1, name="expanded_edges") 68 | -------------------------------------------------------------------------------- /src/scripts/get_matproj.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/ python3 2 | import os 3 | import math 4 | import numpy as np 5 | import pymatgen 6 | import ase 7 | import ase.db 8 | import msgnet 9 | 10 | # Available properties can be found here 11 | # https://materialsproject.org/wiki/index.php/The_Materials_API#query 12 | 13 | 14 | def get_all_results(): 15 | with pymatgen.ext.matproj.MPRester(MATPROJ_API_KEY) as r: 16 | mp_ids = r.query({}, ["material_id"]) 17 | material_ids = [m["material_id"] for m in mp_ids] 18 | print(len(material_ids)) 19 | chunk_size = 1000 20 | sublists = [ 21 | material_ids[i : i + chunk_size] 22 | for i in range(0, len(material_ids), chunk_size) 23 | ] 24 | for i, sublist in enumerate(sublists): 25 | results = r.query( 26 | {"material_id": {"$in": sublist}}, 27 | [ 28 | "formation_energy_per_atom", 29 | "final_structure", 30 | "e_above_hull", 31 | "material_id", 32 | "icsd_id", 33 | "band_gap", 34 | "spacegroup", 35 | ], 36 | ) 37 | print("Downloaded %d/%d" % (i + 1, len(sublists))) 38 | for res in results: 39 | yield res 40 | 41 | 42 | def generate_folds(num_entries, num_splits, ceil_first=True, seed=42): 43 | """generate_folds 44 | 45 | :param num_entries: 46 | :param num_splits: 47 | :param ceil_first: if num_splits does not divide len(entry_list), 48 | first groups will be larger if ceil_first=True, otherwise they will be smaller 49 | """ 50 | rng = np.random.RandomState(seed) 51 | entries_left = num_entries 52 | fold_id = np.array([], dtype=int) 53 | for i in range(num_splits): 54 | if ceil_first: 55 | num_elements = math.ceil(entries_left / (num_splits - i)) 56 | else: 57 | num_elements = math.floor(entries_left / (num_splits - i)) 58 | print("fold %d: %d elements" % (i, num_elements)) 59 | entries_left -= num_elements 60 | fold_id = np.concatenate([fold_id, np.ones(num_elements, dtype=int) * i]) 61 | 62 | return rng.permutation(fold_id) 63 | 64 | 65 | def get_all_atoms(): 66 | all_atoms = [] 67 | is_common_arr = [] 68 | for res in get_all_results(): 69 | structure = res["final_structure"] 70 | delta_e = res["formation_energy_per_atom"] 71 | mp_id = res["material_id"] 72 | e_above_hull = res["e_above_hull"] 73 | icsd_id = res["icsd_id"] 74 | band_gap = res["band_gap"] 75 | sg_number = res["spacegroup"]["number"] 76 | sg_symbol = res["spacegroup"]["symbol"] 77 | sg_pointgroup = res["spacegroup"]["point_group"] 78 | sg_crystal_system = res["spacegroup"]["crystal_system"] 79 | 80 | cell = structure.lattice.matrix 81 | atomic_numbers = structure.atomic_numbers 82 | cart_coord = structure.cart_coords 83 | atoms = ase.Atoms( 84 | positions=cart_coord, 85 | numbers=atomic_numbers, 86 | cell=cell, 87 | pbc=[True, True, True], 88 | ) 89 | 90 | # Check if material contains noble gas 91 | numbers = set(atoms.get_atomic_numbers()) 92 | is_common = True 93 | for element in ["He", "Ne", "Ar", "Kr", "Xe"]: 94 | if ase.atom.atomic_numbers[element] in numbers: 95 | is_common = False 96 | break 97 | 98 | key_val_pairs = { 99 | "delta_e": delta_e, 100 | "mp_id": mp_id, 101 | "e_above_hull": e_above_hull, 102 | "band_gap": band_gap, 103 | "sg_number": sg_number, 104 | "sg_crystal_system": sg_crystal_system, 105 | } 106 | if icsd_id is not None: 107 | if isinstance(icsd_id, list): 108 | key_val_pairs["icsd_id"] = icsd_id[0] 109 | else: 110 | key_val_pairs["icsd_id"] = icsd_id 111 | 112 | all_atoms.append((atoms, key_val_pairs)) 113 | is_common_arr.append(is_common) 114 | return all_atoms, is_common_arr 115 | 116 | 117 | def main(): 118 | all_atoms, is_common_arr = get_all_atoms() 119 | common_atoms = [a for a, c in zip(all_atoms, is_common_arr) if c] 120 | uncommon_atoms = [a for a, c in zip(all_atoms, is_common_arr) if not c] 121 | 122 | fold_id = generate_folds(len(common_atoms), 5) 123 | 124 | print("Writing to DB") 125 | with ase.db.connect( 126 | os.path.join(msgnet.defaults.datadir, "matproj.db"), append=False 127 | ) as db: 128 | for atom_keyval, fold in zip(common_atoms, fold_id): 129 | atom = atom_keyval[0] 130 | keyval = atom_keyval[1] 131 | keyval["fold"] = int(fold) 132 | db.write(atom, key_value_pairs=keyval) 133 | for atom, keyval in uncommon_atoms: 134 | keyval["fold"] = "None" 135 | db.write(atom, key_value_pairs=keyval) 136 | 137 | 138 | if __name__ == "__main__": 139 | import sys 140 | 141 | try: 142 | MATPROJ_API_KEY = sys.argv[1] 143 | except IndexError: 144 | print("usage: python get_matproj.py MATPROJ_API_KEY") 145 | main() 146 | -------------------------------------------------------------------------------- /src/scripts/get_oqmd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/ python2 2 | from __future__ import print_function 3 | from __future__ import division 4 | import pdb 5 | import ase 6 | import ase.db 7 | import math 8 | import qmpy 9 | import sys 10 | import subprocess 11 | import numpy as np 12 | from django.db.models import F 13 | 14 | uniques = qmpy.Formation.objects.filter(entry__id=F("entry__duplicate_of__id")) 15 | slice_size = 20000 16 | db_path_tmp = "oqmd_tmp.db" 17 | db_path_final = "oqmd12.db" 18 | 19 | 20 | def write_slice(slice_id): 21 | if slice_id == 0: 22 | # We are writing a new database 23 | append = False 24 | else: 25 | # we are adding a slice to an existing database 26 | append = True 27 | with ase.db.connect(db_path_tmp, append=append) as ase_db: 28 | for i, formation in enumerate( 29 | uniques[(slice_id * slice_size) : ((slice_id + 1) * slice_size)] 30 | ): 31 | properties_dict = {} 32 | entry = formation.entry 33 | oqmd_id = entry.id 34 | # try: 35 | # spacegroup = entry.spacegroup.number 36 | # except AttributeError: 37 | # print("No spacegroup id for %d %d" % (i, oqmd_id)) 38 | # spacegroup = "None" 39 | try: 40 | properties_dict["prototype"] = entry.prototype.name 41 | except AttributeError: 42 | pass 43 | 44 | try: 45 | atomic_numbers = entry.structure.atomic_numbers 46 | cell = entry.structure.cell 47 | cart_coords = entry.structure.cartesian_coords 48 | 49 | atoms = ase.Atoms( 50 | positions=cart_coords, numbers=atomic_numbers, cell=cell, pbc=True 51 | ) 52 | except Exception as e: 53 | pdb.set_trace() 54 | 55 | properties_dict["delta_e"] = entry.energy 56 | properties_dict["oqmd_id"] = oqmd_id 57 | if entry.label: 58 | properties_dict["label"] = entry.label 59 | 60 | ase_db.write(atoms, **properties_dict) 61 | 62 | 63 | def generate_folds(num_entries, num_splits, ceil_first=True): 64 | """generate_folds 65 | 66 | :param num_entries: 67 | :param num_splits: 68 | :param ceil_first: if num_splits does not divide len(entry_list), 69 | first groups will be larger if ceil_first=True, otherwise they will be smaller 70 | """ 71 | entries_left = num_entries 72 | fold_id = np.array([], dtype=int) 73 | for i in range(num_splits): 74 | if ceil_first: 75 | num_elements = int(math.ceil(entries_left / (num_splits - i))) 76 | else: 77 | num_elements = int(math.floor(entries_left / (num_splits - i))) 78 | print("fold %d: %d elements" % (i, num_elements)) 79 | entries_left -= num_elements 80 | fold_id = np.concatenate([fold_id, np.ones(num_elements, dtype=int) * i]) 81 | 82 | return np.random.permutation(fold_id) 83 | 84 | 85 | def write_folds(input_db, output_db): 86 | # Set random seed 87 | np.random.seed(31) 88 | 89 | is_common_arr = [] 90 | is_icsd_arr = [] 91 | with ase.db.connect(input_db) as asedb: 92 | for row in asedb.select(): 93 | is_common = True 94 | if row.delta_e > 5.0: 95 | is_common = False 96 | numbers = set(row.numbers) 97 | for element in ["He", "Ne", "Ar", "Kr", "Xe"]: 98 | if ase.atom.atomic_numbers[element] in numbers: 99 | is_common = False 100 | break 101 | try: 102 | label = row.label 103 | is_icsd = bool("icsd" in label.lower()) 104 | except AttributeError: 105 | is_icsd = False 106 | is_common_arr.append(is_common) 107 | is_icsd_arr.append(is_icsd) 108 | 109 | is_common = np.array(is_common_arr, dtype=bool) 110 | is_icsd = np.array(is_icsd_arr, dtype=bool) 111 | folds = np.ones_like(is_common, dtype=int) * (-1) 112 | 113 | common_icsd = np.logical_and(is_common, is_icsd) 114 | common_nonicsd = np.logical_and(is_common, np.logical_not(is_icsd)) 115 | 116 | icsd_folds = generate_folds(np.count_nonzero(common_icsd), 5, ceil_first=True) 117 | others_folds = generate_folds( 118 | np.count_nonzero(common_nonicsd), 5, ceil_first=False 119 | ) 120 | folds[common_icsd] = icsd_folds 121 | folds[common_nonicsd] = others_folds 122 | 123 | with ase.db.connect(output_db, append=False) as foldsdb: 124 | for i, row in enumerate(asedb.select()): 125 | fold_id = folds[i] 126 | if fold_id < 0: 127 | fold_name = "None" 128 | else: 129 | fold_name = fold_id 130 | key_val_pairs = row.key_value_pairs 131 | key_val_pairs["fold"] = fold_name 132 | foldsdb.write(row.toatoms(), key_value_pairs=key_val_pairs) 133 | 134 | 135 | def main(): 136 | ## qmpy leaks memory (django cache?) so we need to process the database in slices 137 | try: 138 | slice_id = int(sys.argv[1]) 139 | is_master = False 140 | except IndexError: 141 | is_master = True 142 | 143 | if is_master: 144 | num_slices = int(math.ceil(float(len(uniques)) / float(slice_size))) 145 | for slice_id in range(num_slices): 146 | subprocess.check_call(["python2", "get_oqmd.py", str(slice_id)]) 147 | print("wrote %d/%d" % (slice_id + 1, num_slices)) 148 | print("writing folds") 149 | write_folds(db_path_tmp, db_path_final) 150 | else: 151 | write_slice(slice_id) 152 | return 0 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | -------------------------------------------------------------------------------- /src/scripts/get_qm9.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/ python3 2 | import os 3 | import numpy as np 4 | import ase 5 | import ase.db 6 | import ase.data 7 | import msgnet 8 | import tarfile 9 | import requests 10 | 11 | 12 | def string_convert(string): 13 | try: 14 | return int(string) 15 | except ValueError: 16 | pass 17 | try: 18 | return float(string) 19 | except ValueError: 20 | pass 21 | return string.strip() 22 | 23 | 24 | class Molecule: 25 | def __init__(self, num_atoms): 26 | self.z = np.zeros(num_atoms, dtype=np.int32) 27 | self.coord = np.zeros((num_atoms, 3)) 28 | 29 | 30 | def download(url, dest): 31 | response = requests.get(url) 32 | with open(dest, "wb") as f: 33 | for chunk in response.iter_content(): 34 | f.write(chunk) 35 | 36 | 37 | def tar_to_xyz(tarpath, dest): 38 | tar = tarfile.open(tarpath, mode="r:bz2") 39 | with open(os.path.join(msgnet.defaults.datadir, "qm9.xyz"), "wb") as f: 40 | for tarinfo in tar: 41 | tarf = tar.extractfile(tarinfo) 42 | f.write(tarf.read()) 43 | 44 | 45 | def load_xyz_file(filename): 46 | predefined_keys = """tag 47 | index 48 | A 49 | B 50 | C 51 | mu 52 | alpha 53 | homo 54 | lumo 55 | gap 56 | r2 57 | zpve 58 | U0 59 | U 60 | H 61 | G 62 | Cv""".split() 63 | STATE_READ_NUMBER = 0 64 | STATE_READ_COMMENT = 1 65 | STATE_READ_ENTRY = 2 66 | STATE_READ_FREQUENCY = 3 67 | STATE_READ_SMILES = 4 68 | STATE_READ_INCHI = 5 69 | STATE_FAILURE = 6 70 | 71 | state = STATE_READ_NUMBER 72 | entries_read = 0 73 | cur_desc = None 74 | 75 | with open(filename, "r") as f: 76 | for line_no, line in enumerate(f): 77 | try: 78 | if state == STATE_READ_NUMBER: 79 | entries_to_read = int(line) 80 | cur_desc = Molecule(entries_to_read) 81 | entries_read = 0 82 | state = STATE_READ_COMMENT 83 | elif state == STATE_READ_COMMENT: 84 | # Read comment as whitespace separated values 85 | for key, value in zip(predefined_keys, line.split()): 86 | if hasattr(cur_desc, key): 87 | raise KeyError( 88 | "Molecule already contains property %s" % key 89 | ) 90 | else: 91 | setattr(cur_desc, key.strip(), string_convert(value)) 92 | state = STATE_READ_ENTRY 93 | elif state == STATE_READ_ENTRY: 94 | parts = line.split() 95 | assert len(parts) == 5 96 | atom = parts[0] 97 | el_number = ase.data.atomic_numbers[atom] 98 | strat_parts = map(lambda x: x.replace("*^", "E"), parts[1:4]) 99 | floats = list(map(float, strat_parts)) 100 | cur_desc.coord[entries_read, :] = np.array(floats) 101 | cur_desc.z[entries_read] = el_number 102 | entries_read += 1 103 | if entries_read == cur_desc.z.size: 104 | state = STATE_READ_FREQUENCY 105 | elif state == STATE_READ_FREQUENCY: 106 | cur_desc.frequency = np.array( 107 | list(map(string_convert, line.split())) 108 | ) 109 | state = STATE_READ_SMILES 110 | elif state == STATE_READ_SMILES: 111 | cur_desc.smiles = line.split() 112 | state = STATE_READ_INCHI 113 | elif state == STATE_READ_INCHI: 114 | cur_desc.inchi = line.split() 115 | yield cur_desc 116 | state = STATE_READ_NUMBER 117 | elif state == STATE_FAILURE: 118 | entries_to_read = None 119 | try: 120 | entries_to_read = int(line) 121 | except: 122 | pass 123 | if entries_to_read is not None: 124 | print("Resuming parsing on line %d" % line_no) 125 | cur_desc = Molecule(entries_to_read) 126 | entries_read = 0 127 | state = STATE_READ_COMMENT 128 | else: 129 | raise Exception("Invalid state") 130 | except Exception as e: 131 | print("Exception occured on line %d: %s" % (line_no, str(e))) 132 | state = STATE_FAILURE 133 | 134 | 135 | def xyz_to_ase(filename, output_name): 136 | """ 137 | Convert xyz descriptors to ase database 138 | """ 139 | 140 | """ 141 | ========================================================================================================= 142 | Ele- ZPVE U (0 K) U (298.15 K) H (298.15 K) G (298.15 K) CV 143 | ment Hartree Hartree Hartree Hartree Hartree Cal/(Mol Kelvin) 144 | ========================================================================================================= 145 | H 0.000000 -0.500273 -0.498857 -0.497912 -0.510927 2.981 146 | C 0.000000 -37.846772 -37.845355 -37.844411 -37.861317 2.981 147 | N 0.000000 -54.583861 -54.582445 -54.581501 -54.598897 2.981 148 | O 0.000000 -75.064579 -75.063163 -75.062219 -75.079532 2.981 149 | F 0.000000 -99.718730 -99.717314 -99.716370 -99.733544 2.981 150 | ========================================================================================================= 151 | """ 152 | HARTREE_TO_EV = 27.21138602 153 | REFERENCE_DICT = { 154 | ase.data.atomic_numbers["H"]: { 155 | "U0": -0.500273, 156 | "U": -0.498857, 157 | "H": -0.497912, 158 | "G": -0.510927, 159 | }, 160 | ase.data.atomic_numbers["C"]: { 161 | "U0": -37.846772, 162 | "U": -37.845355, 163 | "H": -37.844411, 164 | "G": -37.861317, 165 | }, 166 | ase.data.atomic_numbers["N"]: { 167 | "U0": -54.583861, 168 | "U": -54.582445, 169 | "H": -54.581501, 170 | "G": -54.598897, 171 | }, 172 | ase.data.atomic_numbers["O"]: { 173 | "U0": -75.064579, 174 | "U": -75.063163, 175 | "H": -75.062219, 176 | "G": -75.079532, 177 | }, 178 | ase.data.atomic_numbers["F"]: { 179 | "U0": -99.718730, 180 | "U": -99.717314, 181 | "H": -99.716370, 182 | "G": -99.733544, 183 | }, 184 | } 185 | 186 | # Make a transposed dictionary such that first dimension is property 187 | REFERENCE_DICT_T = {} 188 | atom_nums = [ase.data.atomic_numbers[x] for x in ["H", "C", "N", "O", "F"]] 189 | for prop in ["U0", "U", "H", "G"]: 190 | prop_dict = dict(zip(atom_nums, [REFERENCE_DICT[at][prop] for at in atom_nums])) 191 | REFERENCE_DICT_T[prop] = prop_dict 192 | 193 | # List of tag, whether to convert hartree to eV 194 | keywords = [ 195 | ["tag", False], 196 | ["index", False], 197 | ["A", False], 198 | ["B", False], 199 | ["C", False], 200 | ["mu", False], 201 | ["alpha", False], 202 | ["homo", True], 203 | ["lumo", True], 204 | ["gap", True], 205 | ["r2", False], 206 | ["zpve", True], 207 | ["U0", True], 208 | ["U", True], 209 | ["H", True], 210 | ["G", True], 211 | ["Cv", False], 212 | ] 213 | # Load xyz file 214 | descriptors = load_xyz_file(filename) 215 | 216 | with ase.db.connect(output_name, append=False) as asedb: 217 | properties_dict = {} 218 | for desc in descriptors: 219 | # Convert attributes to dictionary and convert hartree to eV 220 | for key, convert in keywords: 221 | properties_dict[key] = getattr(desc, key) 222 | # Subtract reference energies for each atom 223 | if key in REFERENCE_DICT_T: 224 | for atom_num in desc.z: 225 | properties_dict[key] -= REFERENCE_DICT_T[key][atom_num] 226 | if convert: 227 | properties_dict[key] *= HARTREE_TO_EV 228 | atoms = ase.Atoms(numbers=desc.z, positions=desc.coord, pbc=False) 229 | asedb.write(atoms, data=properties_dict) 230 | 231 | 232 | if __name__ == "__main__": 233 | url = "https://ndownloader.figshare.com/files/3195389" 234 | filename = os.path.join(msgnet.defaults.datadir, "dsgdb9nsd.xyz.tar.bz2") 235 | xyz_name = os.path.join(msgnet.defaults.datadir, "qm9.xyz") 236 | final_dest = os.path.join(msgnet.defaults.datadir, "qm9.db") 237 | print("downloading dataset...") 238 | download(url, filename) 239 | print("extracting...") 240 | tar_to_xyz(filename, xyz_name) 241 | print("writing to ASE database...") 242 | xyz_to_ase(xyz_name, final_dest) 243 | print("done") 244 | -------------------------------------------------------------------------------- /src/scripts/predict_with_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import os 4 | import itertools 5 | import logging 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | import ase 10 | 11 | import msgnet 12 | import runner 13 | 14 | 15 | def get_arguments(): 16 | parser = argparse.ArgumentParser(description="Evaluate graph convolution network") 17 | parser.add_argument("--modelpath", type=str, default=None) 18 | parser.add_argument("--permutations", type=int, default=0) 19 | parser.add_argument("--dataset", type=str, default=None) 20 | parser.add_argument("--filter", type=str, default=None) 21 | parser.add_argument("--output", type=str, default=None) 22 | parser.add_argument( 23 | "--cutoff", 24 | type=runner.float_or_string, 25 | nargs="+", 26 | default=[], 27 | help="Cutoff method (voronoi, const or coval) followed by float", 28 | ) 29 | parser.add_argument( 30 | "--split", choices=["train", "test", "validation", "all"], default="all" 31 | ) 32 | return parser.parse_args() 33 | 34 | 35 | class SpeciesPermutationsDataHandler(msgnet.datahandler.EdgeSelectDataHandler): 36 | def __init__( 37 | self, 38 | graph_objects, 39 | graph_targets, 40 | edge_input_idx, 41 | replace_species=None, 42 | keep_species=[], 43 | ): 44 | super().__init__(graph_objects, graph_targets, edge_input_idx) 45 | self.replace_species = replace_species 46 | self.keep_species = keep_species 47 | 48 | def from_self(self, objects): 49 | self.__class__( 50 | objects, self.graph_targets, self.replace_species, self.keep_species 51 | ) 52 | 53 | def list_to_matrices(self, graph_list, graph_targets=["total_energy"]): 54 | """hack_list_to_matrices 55 | 56 | :param graph_list: 57 | :return: tuple of 58 | (nodes, conns, edges, node_targets, atom_pos, graph_target, set_len, edges_len) 59 | """ 60 | nodes_created = 0 61 | all_nodes = [] 62 | all_conn = [] 63 | all_conn_offsets = [] 64 | all_edges = [] 65 | all_graph_targets = [] 66 | all_X = [] 67 | all_unitcells = [] 68 | set_len = [] 69 | edges_len = [] 70 | for gr in graph_list: 71 | for rep_species in itertools.permutations(self.replace_species): 72 | nodes, conn, conn_offset, edges, X, unitcell = ( 73 | gr.nodes, 74 | gr.conns, 75 | gr.conns_offset, 76 | gr.edges, 77 | gr.positions, 78 | gr.unitcell, 79 | ) 80 | 81 | # Replace original species with given ones and ignore the keep_species 82 | to_be_replaced = set(nodes) 83 | to_be_replaced = to_be_replaced.difference(set(self.keep_species)) 84 | if len(to_be_replaced) == 1: 85 | rep = rep_species[0:1] 86 | else: 87 | rep = rep_species 88 | assert len(to_be_replaced) == len(rep_species) 89 | to_be_replaced = sorted(list(to_be_replaced)) 90 | newnodes = nodes.copy() 91 | for spec, replacement in zip(to_be_replaced, rep): 92 | newnodes[nodes == spec] = replacement 93 | nodes = newnodes 94 | 95 | conn_shifted = np.copy(conn) + nodes_created 96 | all_nodes.append(nodes) 97 | all_conn.append(conn_shifted) 98 | all_conn_offsets.append(conn_offset) 99 | all_unitcells.append(unitcell) 100 | all_edges.append(edges) 101 | all_graph_targets.append( 102 | np.array([getattr(gr, t) for t in graph_targets]) 103 | ) 104 | all_X.append(X) 105 | nodes_created += nodes.shape[0] 106 | set_len.append(nodes.shape[0]) 107 | edges_len.append(edges.shape[0]) 108 | cat = lambda x: np.concatenate(x, axis=0) 109 | outdict = { 110 | "nodes": cat(all_nodes), 111 | "nodes_xyz": cat(all_X), 112 | "edges": cat(all_edges), 113 | "connections": cat(all_conn), 114 | "connections_offsets": cat(all_conn_offsets), 115 | "graph_targets": np.vstack(all_graph_targets), 116 | "set_lengths": np.array(set_len), 117 | "unitcells": np.stack(all_unitcells, axis=0), 118 | "edges_lengths": np.array(edges_len), 119 | } 120 | outdict["segments"] = msgnet.datahandler.set_len_to_segments( 121 | outdict["set_lengths"] 122 | ) 123 | return outdict 124 | 125 | 126 | def species_count_filter(x, count): 127 | atm_nums = list(x.atoms.get_atomic_numbers()) 128 | set_atm_nums = set(atm_nums) 129 | ordered_set = sorted(list(set_atm_nums)) 130 | if not len(set_atm_nums) == count: 131 | return False 132 | return True 133 | 134 | 135 | def si_filter(x): 136 | atm_nums = list(x.atoms.get_atomic_numbers()) 137 | set_atm_nums = set(atm_nums) 138 | if len(set_atm_nums) > 1: 139 | return False 140 | if list(set_atm_nums)[0] == ase.data.atomic_numbers["Si"]: 141 | return True 142 | else: 143 | return False 144 | 145 | 146 | def cu_filter(x): 147 | atm_nums = list(x.atoms.get_atomic_numbers()) 148 | set_atm_nums = set(atm_nums) 149 | if len(set_atm_nums) > 1: 150 | return False 151 | if list(set_atm_nums)[0] == ase.data.atomic_numbers["Cu"]: 152 | return True 153 | else: 154 | return False 155 | 156 | 157 | def zn_filter(x): 158 | atm_nums = list(x.atoms.get_atomic_numbers()) 159 | set_atm_nums = set(atm_nums) 160 | if len(set_atm_nums) > 1: 161 | return False 162 | if list(set_atm_nums)[0] == ase.data.atomic_numbers["Zn"]: 163 | return True 164 | else: 165 | return False 166 | 167 | 168 | def prototype_filter(x, prototype="Si(oS16)"): 169 | return x.prototype == prototype 170 | 171 | 172 | def abse3_filter(x): 173 | atm_arr = x.atoms.get_atomic_numbers() 174 | atm_nums = list(atm_arr) 175 | set_atm_nums = set(atm_nums) 176 | if len(set_atm_nums) != 3: 177 | return False 178 | if ase.data.atomic_numbers["Se"] not in set_atm_nums: 179 | return False 180 | a, b = [c for c in list(set_atm_nums) if c != ase.data.atomic_numbers["Se"]] 181 | a_count = np.count_nonzero(a == atm_arr) 182 | b_count = np.count_nonzero(b == atm_arr) 183 | se_count = np.count_nonzero(ase.data.atomic_numbers["Se"] == atm_arr) 184 | if a_count == b_count and se_count / a_count == 3.0: 185 | return True 186 | else: 187 | return False 188 | 189 | 190 | def icsd_filter(x): 191 | try: 192 | return "icsd" in x.label.lower() 193 | except AttributeError: 194 | return False 195 | 196 | 197 | def unary_filter(x): 198 | return species_count_filter(x, 1) 199 | 200 | 201 | def binary_filter(x): 202 | return species_count_filter(x, 2) 203 | 204 | 205 | def ternary_filter(x): 206 | return species_count_filter(x, 3) 207 | 208 | 209 | def icsd_unary_filter(x): 210 | return unary_filter(x) and icsd_filter(x) 211 | 212 | 213 | def icsd_binary_filter(x): 214 | return binary_filter(x) and icsd_filter(x) 215 | 216 | 217 | def icsd_ternary_filter(x): 218 | return ternary_filter(x) and icsd_filter(x) 219 | 220 | 221 | def perovskite_filter(x): 222 | try: 223 | prototype = x.prototype 224 | except AttributeError: 225 | return False 226 | return prototype == "Perovskite" 227 | 228 | 229 | def cu_prototype_filter(x): 230 | try: 231 | prototype = x.prototype 232 | except AttributeError: 233 | return False 234 | return prototype == "Cu" 235 | 236 | 237 | def fe2p_filter(x): 238 | try: 239 | prototype = x.prototype 240 | except AttributeError: 241 | return False 242 | return prototype == "Fe2P" 243 | 244 | 245 | def main(): 246 | args = get_arguments() 247 | 248 | if args.filter: 249 | names = globals() 250 | filter_func = names[args.filter] 251 | else: 252 | filter_func = None 253 | 254 | metafile = args.modelpath 255 | checkpoint = metafile.replace(".meta", "") 256 | with open( 257 | os.path.join(os.path.dirname(metafile), "commandline_args.txt"), "r" 258 | ) as f: 259 | args_list = f.read().splitlines() 260 | runner_args = runner.get_arguments(args_list) 261 | 262 | if args.dataset: 263 | dataset = args.dataset 264 | else: 265 | dataset = runner_args.dataset 266 | if args.cutoff: 267 | cutoff = args.cutoff 268 | else: 269 | cutoff = runner_args.cutoff 270 | DataLoader = runner.get_dataloader_class(dataset) 271 | loader = DataLoader(cutoff_type=cutoff[0], cutoff_radius=cutoff[1]) 272 | graph_obj_list = loader.load() 273 | if filter_func: 274 | graph_obj_list = [g for g in graph_obj_list if filter_func(g)] 275 | if args.permutations: 276 | graph_obj_list = [ 277 | g for g in graph_obj_list if species_count_filter(g, args.permutations) 278 | ] 279 | if runner_args.target: 280 | target = runner_args.target 281 | else: 282 | target = loader.default_target 283 | 284 | if args.permutations: 285 | symbols = ["Ag", "C", "Na", "B", "Mg", "Cl"] 286 | else: 287 | symbols = ["Dummy"] 288 | with tf.Session() as sess: 289 | 290 | model = runner.get_model(runner_args) 291 | model.load(sess, checkpoint) 292 | 293 | for sym_i, symbol in enumerate(symbols): 294 | if args.permutations: 295 | if args.permutations == 1: 296 | replace_species = [ase.data.atomic_numbers[symbols[sym_i]]] 297 | elif args.permutations == 2: 298 | replace_species = [ 299 | ase.data.atomic_numbers[symbols[sym_i]], 300 | ase.data.atomic_numbers[symbols[sym_i - 1]], 301 | ] 302 | elif args.permutations == 3: 303 | replace_species = [ 304 | ase.data.atomic_numbers[symbols[sym_i]], 305 | ase.data.atomic_numbers[symbols[sym_i - 1]], 306 | ase.data.atomic_numbers[symbols[sym_i - 2]], 307 | ] 308 | else: 309 | raise Exception("Invalid number of species to replace") 310 | keep_species = [] 311 | datahandler = SpeciesPermutationsDataHandler( 312 | graph_obj_list, 313 | [target], 314 | runner_args.edge_idx, 315 | replace_species=replace_species, 316 | keep_species=keep_species, 317 | ) 318 | else: 319 | datahandler = msgnet.datahandler.EdgeSelectDataHandler( 320 | graph_obj_list, [target], runner_args.edge_idx 321 | ) 322 | 323 | if args.split in ["train", "test", "validation"]: 324 | if filter_func: 325 | datasplit_args = DataLoader.default_datasplit_args 326 | datasplit_args["validation_size"] = 0 327 | else: 328 | datasplit_args = DataLoader.default_datasplit_args 329 | splits = dict( 330 | zip( 331 | ["train", "test", "validation"], 332 | datahandler.train_test_split( 333 | **DataLoader.default_datasplit_args, 334 | test_fold=runner_args.fold 335 | ), 336 | ) 337 | ) 338 | datahandler = splits[args.split] 339 | 340 | target_values = np.array( 341 | [getattr(g, target) for g in datahandler.graph_objects] 342 | ) 343 | row_id = np.array([g.id for g in datahandler.graph_objects]) 344 | 345 | if args.permutations: 346 | repeats = math.factorial(args.permutations) 347 | target_values = np.repeat(target_values, repeats) 348 | row_id = np.repeat(row_id, repeats) 349 | 350 | model_predictions = [] 351 | print("computing predictions") 352 | for input_data in datahandler.get_test_batches(5): 353 | feed_dict = {} 354 | model_input_symbols = model.get_input_symbols() 355 | for key, val in model_input_symbols.items(): 356 | feed_dict[val] = input_data[key] 357 | graph_out, = sess.run([model.get_graph_out()], feed_dict=feed_dict) 358 | model_predictions.append(graph_out) 359 | 360 | model_predict = np.concatenate(model_predictions, axis=0).squeeze() 361 | 362 | if args.permutations: 363 | outpath = "%s_%s" % (args.output, symbol) 364 | else: 365 | outpath = args.output 366 | 367 | errors = model_predict - target_values 368 | mae = np.mean(np.abs(errors)) 369 | rmse = np.sqrt(np.mean(np.square(errors))) 370 | 371 | print("split=%s, num_samples=%d, mae=%s, rmse=%s" % (args.split, errors.shape[0], mae, rmse)) 372 | 373 | np.savetxt( 374 | outpath, np.stack((target_values, model_predict, row_id), axis=1) 375 | ) 376 | 377 | 378 | if __name__ == "__main__": 379 | logging.basicConfig( 380 | level=logging.DEBUG, 381 | format="%(asctime)s [%(levelname)-5.5s] %(message)s", 382 | handlers=[logging.StreamHandler()], 383 | ) 384 | 385 | main() 386 | -------------------------------------------------------------------------------- /src/scripts/runner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import timeit 6 | import ase 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | import msgnet 11 | 12 | 13 | def list_of_lists(string): 14 | """list_of_list 15 | Parse a string of the form "1,2,3 4,3,2 ..." 16 | 17 | :param string: 18 | """ 19 | if string: 20 | return [float(x) for x in string.strip().split(",")] 21 | else: 22 | return string 23 | 24 | 25 | def float_or_string(string): 26 | try: 27 | return float(string) 28 | except ValueError: 29 | return string 30 | 31 | 32 | def get_arguments(arg_list=None): 33 | parser = argparse.ArgumentParser( 34 | description="Train graph convolution network", fromfile_prefix_chars="@" 35 | ) 36 | parser.add_argument("--load_model", type=str, default=None) 37 | parser.add_argument("--fold", type=int, default=-1) 38 | parser.add_argument( 39 | "--cutoff", 40 | type=float_or_string, 41 | nargs="+", 42 | default=["voronoi"], 43 | help="Cutoff method (voronoi, const or coval) followed by float for const and coval", 44 | ) 45 | parser.add_argument( 46 | "--edge_idx", 47 | nargs="*", 48 | default=[], 49 | type=int, 50 | help="Space separated list of edge feature indices to use", 51 | ) 52 | parser.add_argument("--dataset", type=str, default="oqmd") 53 | parser.add_argument( 54 | "--edge_expand", 55 | nargs="*", 56 | type=list_of_lists, 57 | default=None, 58 | help="Space separated list of comma separated triplets start,step,stop for edge feature expansion", 59 | ) 60 | parser.add_argument("--msg_share_weights", action="store_true") 61 | parser.add_argument("--num_passes", type=int, default=3) 62 | parser.add_argument("--node_embedding_size", type=int, default=256) 63 | parser.add_argument("--update_edges", action="store_true") 64 | parser.add_argument( 65 | "--avg_msg", action="store_true", help="Average messages instead of sum" 66 | ) 67 | parser.add_argument("--readout", type=str, default="set2set") 68 | parser.add_argument("--target", type=str, default=None) 69 | parser.add_argument("--learning_rate", type=float, default=1e-3) 70 | 71 | return parser.parse_args(arg_list) 72 | 73 | 74 | def gen_prefix(namespace): 75 | prefix = [] 76 | argdict = vars(namespace) 77 | for key in [ 78 | "dataset", 79 | "cutoff", 80 | "edge_idx", 81 | "node_embedding_size", 82 | "msg_share_weights", 83 | "update_edges", 84 | "readout", 85 | "num_passes", 86 | "avg_msg", 87 | "fold", 88 | "target", 89 | "learning_rate", 90 | ]: 91 | if isinstance(argdict[key], list): 92 | val = "-".join([str(x) for x in argdict[key]]) 93 | else: 94 | val = str(argdict[key]) 95 | prefix.append(key[0] + val) 96 | return "_".join(prefix).replace(" ", "") 97 | 98 | 99 | def get_dataloader_class(data_name): 100 | names = dir(msgnet.dataloader) 101 | lower_names = [x.lower() for x in names] 102 | index = lower_names.index(data_name + "dataloader") 103 | DataLoader = getattr(msgnet.dataloader, names[index]) 104 | return DataLoader 105 | 106 | 107 | def get_readout_function(readout_name, output_size): 108 | names = dir(msgnet.readout) 109 | lower_names = [x.lower() for x in names] 110 | index = lower_names.index("readout" + readout_name) 111 | ReadoutFunctionClass = getattr(msgnet.readout, names[index]) 112 | return ReadoutFunctionClass(output_size) 113 | 114 | 115 | def main(args): 116 | DataLoader = get_dataloader_class(args.dataset) 117 | 118 | if args.target: 119 | target_name = args.target 120 | else: 121 | target_name = DataLoader.default_target 122 | 123 | graph_obj_list = DataLoader( 124 | cutoff_type=args.cutoff[0], cutoff_radius=args.cutoff[1] 125 | ).load() 126 | data_handler = msgnet.datahandler.EdgeSelectDataHandler( 127 | graph_obj_list, [target_name], args.edge_idx 128 | ) 129 | 130 | if not args.load_model: 131 | with open(logs_path + "commandline_args.txt", "w") as f: 132 | f.write("\n".join(sys.argv[1:])) 133 | 134 | train_obj, test_obj, val_obj = data_handler.train_test_split( 135 | test_fold=args.fold, **DataLoader.default_datasplit_args 136 | ) 137 | 138 | model = get_model(args, data_handler) 139 | 140 | train_model(logs_path, model, args, target_name, train_obj, test_obj, val_obj) 141 | 142 | 143 | def get_model(args, train_data_handler=None): 144 | readout_fn = get_readout_function(args.readout, 1) 145 | 146 | if train_data_handler: 147 | target_mean, target_std = train_data_handler.get_normalization( 148 | per_atom=readout_fn.is_sum 149 | ) 150 | else: 151 | target_mean = 0 152 | target_std = 1 153 | logging.debug("Target mean %f, target std = %f" % (target_mean, target_std)) 154 | 155 | net = msgnet.MsgpassingNetwork( 156 | n_node_features=1, 157 | n_edge_features=len(args.edge_idx), 158 | embedding_shape=(len(ase.data.chemical_symbols), args.node_embedding_size), 159 | edge_feature_expand=args.edge_expand, 160 | num_passes=args.num_passes, 161 | msg_share_weights=args.msg_share_weights, 162 | use_edge_updates=args.update_edges, 163 | readout_fn=readout_fn, 164 | avg_msg=args.avg_msg, 165 | target_mean=target_mean, 166 | target_std=target_std, 167 | ) 168 | return net 169 | 170 | 171 | def train_model(logs_path, model, args, target_name, train_obj, test_obj, val_obj=None): 172 | 173 | log_interval = len(train_obj) 174 | 175 | best_val_mae = np.inf 176 | best_val_step = 0 177 | 178 | # Write metadata for embedding visualisation 179 | with open(logs_path + "metadata.tsv", "w") as metaf: 180 | metaf.write("symbol\tnumber\t\n") 181 | for i, species in enumerate(ase.data.chemical_symbols): 182 | metaf.write("%s\t%d\n" % (species, i)) 183 | with open(logs_path + "projector_config.pbtxt", "w") as logcfg: 184 | logcfg.write("embeddings {\n") 185 | logcfg.write(" tensor_name: 'species_embedding_matrix'\n") 186 | logcfg.write(" metadata_path: 'metadata.tsv'\n") 187 | logcfg.write("}") 188 | 189 | start_time = timeit.default_timer() 190 | logging.info("Training") 191 | num_steps = int(1e7) 192 | 193 | trainer = msgnet.train.GraphOutputTrainer( 194 | model, train_obj, initial_lr=args.learning_rate 195 | ) 196 | 197 | with tf.Session() as sess: 198 | sess.run(tf.global_variables_initializer()) 199 | 200 | if args.load_model: 201 | if args.load_model.endswith(".meta"): 202 | checkpoint = args.load_model.replace(".meta", "") 203 | logging.info("loading model from %s", checkpoint) 204 | start_step = int(checkpoint.split("/")[-1].split("-")[-1]) 205 | model.load(sess, checkpoint) 206 | else: 207 | checkpoint = tf.train.get_checkpoint_state(args.load_model) 208 | logging.info("loading model from %s", checkpoint) 209 | start_step = int( 210 | checkpoint.model_checkpoint_path.split("/")[-1].split("-")[-1] 211 | ) 212 | model.load(sess, checkpoint.model_checkpoint_path) 213 | else: 214 | start_step = 0 215 | 216 | # Print shape of all trainable variables 217 | trainable_vars = tf.trainable_variables() 218 | for var, val in zip(trainable_vars, sess.run(trainable_vars)): 219 | logging.debug("%s %s", var.name, var.get_shape()) 220 | 221 | for update_step in range(start_step, num_steps): 222 | trainer.step(sess, update_step) 223 | 224 | if (update_step % log_interval == 0) or (update_step + 1) == num_steps: 225 | test_start_time = timeit.default_timer() 226 | 227 | # Evaluate training set 228 | train_metrics = trainer.evaluate_metrics( 229 | sess, train_obj, prefix="train" 230 | ) 231 | 232 | # Evaluate validation set 233 | if val_obj: 234 | val_metrics = trainer.evaluate_metrics(sess, val_obj, prefix="val") 235 | else: 236 | val_metrics = {} 237 | 238 | all_metrics = {**train_metrics, **val_metrics} 239 | metric_string = " ".join( 240 | ["%s=%f" % (key, val) for key, val in all_metrics.items()] 241 | ) 242 | 243 | end_time = timeit.default_timer() 244 | test_end_time = timeit.default_timer() 245 | logging.info( 246 | "t=%.1f (%.1f) %d %s lr=%f", 247 | end_time - start_time, 248 | test_end_time - test_start_time, 249 | update_step, 250 | metric_string, 251 | trainer.get_learning_rate(update_step), 252 | ) 253 | start_time = timeit.default_timer() 254 | 255 | # Do early stopping using validation data (if available) 256 | if val_obj: 257 | if all_metrics["val_mae"] < best_val_mae: 258 | model.save( 259 | sess, logs_path + "model.ckpt", global_step=update_step 260 | ) 261 | best_val_mae = all_metrics["val_mae"] 262 | best_val_step = update_step 263 | logging.info( 264 | "best_val_mae=%f, best_val_step=%d", 265 | best_val_mae, 266 | best_val_step, 267 | ) 268 | if (update_step - best_val_step) > 1e6: 269 | logging.info( 270 | "best_val_mae=%f, best_val_step=%d", 271 | best_val_mae, 272 | best_val_step, 273 | ) 274 | logging.info("No improvement in last 1e6 steps, stopping...") 275 | model.save( 276 | sess, logs_path + "model.ckpt", global_step=update_step 277 | ) 278 | return 279 | else: 280 | model.save(sess, logs_path + "model.ckpt", global_step=update_step) 281 | 282 | 283 | if __name__ == "__main__": 284 | args = get_arguments() 285 | 286 | logs_path = "logs/runner_%s/" % gen_prefix(args) 287 | os.makedirs(logs_path, exist_ok=True) 288 | logging.basicConfig( 289 | level=logging.DEBUG, 290 | format="%(asctime)s [%(levelname)-5.5s] %(message)s", 291 | handlers=[ 292 | logging.FileHandler(logs_path + "printlog.txt", mode="w"), 293 | logging.StreamHandler(), 294 | ], 295 | ) 296 | main(args) 297 | --------------------------------------------------------------------------------