├── .gitignore ├── CGVAE.py ├── GGNN_core.py ├── LICENSE ├── README.md ├── SECURITY.md ├── data ├── get_qm9.py ├── get_zinc.py ├── valid_idx_qm9.json └── valid_idx_zinc.json ├── data_augmentation.py ├── evaluate.py ├── get_sascorer.sh ├── install.sh ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /CGVAE.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python 2 | """ 3 | Usage: 4 | CGVAE.py [options] 5 | 6 | Options: 7 | -h --help Show this screen 8 | --dataset NAME Dataset name: zinc, qm9, cep 9 | --config-file FILE Hyperparameter configuration file path (in JSON format) 10 | --config CONFIG Hyperparameter configuration dictionary (in JSON format) 11 | --log_dir NAME log dir name 12 | --data_dir NAME data dir name 13 | --restore FILE File to restore weights from. 14 | --freeze-graph-model Freeze weights of graph model components 15 | """ 16 | from typing import Sequence, Any 17 | from docopt import docopt 18 | from collections import defaultdict, deque 19 | import numpy as np 20 | import tensorflow as tf 21 | import sys, traceback 22 | import pdb 23 | import json 24 | import os 25 | from GGNN_core import ChemModel 26 | import utils 27 | from utils import * 28 | import pickle 29 | import random 30 | from numpy import linalg as LA 31 | from rdkit import Chem 32 | from copy import deepcopy 33 | from rdkit.Chem import QED 34 | import os 35 | import time 36 | from data_augmentation import * 37 | 38 | ''' 39 | Comments provide the expected tensor shapes where helpful. 40 | 41 | Key to symbols in comments: 42 | --------------------------- 43 | [...]: a tensor 44 | ; ; : a list 45 | b: batch size 46 | e: number of edege types (3) 47 | es: maximum number of BFS transitions in this batch 48 | v: number of vertices per graph in this batch 49 | h: GNN hidden size 50 | ''' 51 | 52 | class DenseGGNNChemModel(ChemModel): 53 | def __init__(self, args): 54 | super().__init__(args) 55 | 56 | @classmethod 57 | def default_params(cls): 58 | params = dict(super().default_params()) 59 | params.update({ 60 | 'task_sample_ratios': {}, 61 | 'use_edge_bias': True, # whether use edge bias in gnn 62 | 63 | 'clamp_gradient_norm': 1.0, 64 | 'out_layer_dropout_keep_prob': 1.0, 65 | 66 | 'tie_fwd_bkwd': True, 67 | 'task_ids': [0], # id of property prediction 68 | 69 | 'random_seed': 0, # fixed for reproducibility 70 | 71 | 'batch_size': 8 if dataset=='zinc' or dataset=='cep' else 64, 72 | "qed_trade_off_lambda": 10, 73 | 'prior_learning_rate': 0.05, 74 | 'stop_criterion': 0.01, 75 | 'num_epochs': 3 if dataset=='zinc' or dataset=='cep' else 10, 76 | 'epoch_to_generate': 3 if dataset=='zinc' or dataset=='cep' else 10, 77 | 'number_of_generation': 30000, 78 | 'optimization_step': 0, 79 | 'maximum_distance': 50, 80 | "use_argmax_generation": False, # use random sampling or argmax during generation 81 | 'residual_connection_on': True, # whether residual connection is on 82 | 'residual_connections': { # For iteration i, specify list of layers whose output is added as an input 83 | 2: [0], 84 | 4: [0, 2], 85 | 6: [0, 2, 4], 86 | 8: [0, 2, 4, 6], 87 | 10: [0, 2, 4, 6, 8], 88 | 12: [0, 2, 4, 6, 8, 10], 89 | 14: [0, 2, 4, 6, 8, 10, 12], 90 | }, 91 | 'num_timesteps': 12, # gnn propagation step 92 | 'hidden_size': 100, 93 | "kl_trade_off_lambda": 0.3, # kl tradeoff 94 | 'learning_rate': 0.001, 95 | 'graph_state_dropout_keep_prob': 1, 96 | "compensate_num": 1, # how many atoms to be added during generation 97 | 98 | 'train_file': 'data/molecules_train_%s.json' % dataset, 99 | 'valid_file': 'data/molecules_valid_%s.json' % dataset, 100 | 101 | 'try_different_starting': True, 102 | "num_different_starting": 6, 103 | 104 | 'generation': False, # only for generation 105 | 'use_graph': True, # use gnn 106 | "label_one_hot": False, # one hot label or not 107 | "multi_bfs_path": False, # whether sample several BFS paths for each molecule 108 | "bfs_path_count": 30, 109 | "path_random_order": False, # False: canonical order, True: random order 110 | "sample_transition": False, # whether use transition sampling 111 | 'edge_weight_dropout_keep_prob': 1, 112 | 'check_overlap_edge': False, 113 | "truncate_distance": 10, 114 | }) 115 | 116 | return params 117 | 118 | def prepare_specific_graph_model(self) -> None: 119 | h_dim = self.params['hidden_size'] 120 | expanded_h_dim=self.params['hidden_size']+self.params['hidden_size'] + 1 # 1 for focus bit 121 | self.placeholders['graph_state_keep_prob'] = tf.placeholder(tf.float32, None, name='graph_state_keep_prob') 122 | self.placeholders['edge_weight_dropout_keep_prob'] = tf.placeholder(tf.float32, None, name='edge_weight_dropout_keep_prob') 123 | self.placeholders['initial_node_representation'] = tf.placeholder(tf.float32, 124 | [None, None, self.params['hidden_size']], 125 | name='node_features') # padded node symbols 126 | # mask out invalid node 127 | self.placeholders['node_mask'] = tf.placeholder(tf.float32, [None, None], name='node_mask') # [b x v] 128 | self.placeholders['num_vertices'] = tf.placeholder(tf.int32, ()) 129 | # adj for encoder 130 | self.placeholders['adjacency_matrix'] = tf.placeholder(tf.float32, 131 | [None, self.num_edge_types, None, None], name="adjacency_matrix") # [b, e, v, v] 132 | # labels for node symbol prediction 133 | self.placeholders['node_symbols'] = tf.placeholder(tf.float32, [None, None, self.params['num_symbols']]) # [b, v, edge_type] 134 | # node symbols used to enhance latent representations 135 | self.placeholders['latent_node_symbols'] = tf.placeholder(tf.float32, 136 | [None, None, self.params['hidden_size']], name='latent_node_symbol') # [b, v, h] 137 | # mask out cross entropies in decoder 138 | self.placeholders['iteration_mask']=tf.placeholder(tf.float32, [None, None]) # [b, es] 139 | # adj matrices used in decoder 140 | self.placeholders['incre_adj_mat']=tf.placeholder(tf.float32, [None, None, self.num_edge_types, None, None], name='incre_adj_mat') # [b, es, e, v, v] 141 | # distance 142 | self.placeholders['distance_to_others']=tf.placeholder(tf.int32, [None, None, None], name='distance_to_others') # [b, es,v] 143 | # maximum iteration number of this batch 144 | self.placeholders['max_iteration_num']=tf.placeholder(tf.int32, [], name='max_iteration_num') # number 145 | # node number in focus at each iteration step 146 | self.placeholders['node_sequence']=tf.placeholder(tf.float32, [None, None, None], name='node_sequence') # [b, es, v] 147 | # mask out invalid edge types at each iteration step 148 | self.placeholders['edge_type_masks']=tf.placeholder(tf.float32, [None, None, self.num_edge_types, None], name='edge_type_masks') # [b, es, e, v] 149 | # ground truth edge type labels at each iteration step 150 | self.placeholders['edge_type_labels']=tf.placeholder(tf.float32, [None, None, self.num_edge_types, None], name='edge_type_labels') # [b, es, e, v] 151 | # mask out invalid edge at each iteration step 152 | self.placeholders['edge_masks']=tf.placeholder(tf.float32, [None, None, None], name='edge_masks') # [b, es, v] 153 | # ground truth edge labels at each iteration step 154 | self.placeholders['edge_labels']=tf.placeholder(tf.float32, [None, None, None], name='edge_labels') # [b, es, v] 155 | # ground truth labels for whether it stops at each iteration step 156 | self.placeholders['local_stop']=tf.placeholder(tf.float32, [None, None], name='local_stop') # [b, es] 157 | # z_prior sampled from standard normal distribution 158 | self.placeholders['z_prior']=tf.placeholder(tf.float32, [None, None, self.params['hidden_size']], name='z_prior') # the prior of z sampled from normal distribution 159 | # put in front of kl latent loss 160 | self.placeholders['kl_trade_off_lambda']=tf.placeholder(tf.float32, [], name='kl_trade_off_lambda') # number 161 | # overlapped edge features 162 | self.placeholders['overlapped_edge_features']=tf.placeholder(tf.int32, [None, None, None], name='overlapped_edge_features') # [b, es, v] 163 | 164 | # weights for encoder and decoder GNN. 165 | if self.params["residual_connection_on"]: 166 | # weights for encoder and decoder GNN. Different weights for each iteration 167 | for scope in ['_encoder', '_decoder']: 168 | if scope == '_encoder': 169 | new_h_dim=h_dim 170 | else: 171 | new_h_dim=expanded_h_dim 172 | for iter_idx in range(self.params['num_timesteps']): 173 | with tf.variable_scope("gru_scope"+scope+str(iter_idx), reuse=False): 174 | self.weights['edge_weights'+scope+str(iter_idx)] = tf.Variable(glorot_init([self.num_edge_types, new_h_dim, new_h_dim])) 175 | if self.params['use_edge_bias']: 176 | self.weights['edge_biases'+scope+str(iter_idx)] = tf.Variable(np.zeros([self.num_edge_types, 1, new_h_dim]).astype(np.float32)) 177 | 178 | cell = tf.contrib.rnn.GRUCell(new_h_dim) 179 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, 180 | state_keep_prob=self.placeholders['graph_state_keep_prob']) 181 | self.weights['node_gru'+scope+str(iter_idx)] = cell 182 | else: 183 | for scope in ['_encoder', '_decoder']: 184 | if scope == '_encoder': 185 | new_h_dim=h_dim 186 | else: 187 | new_h_dim=expanded_h_dim 188 | self.weights['edge_weights'+scope] = tf.Variable(glorot_init([self.num_edge_types, new_h_dim, new_h_dim])) 189 | if self.params['use_edge_bias']: 190 | self.weights['edge_biases'+scope] = tf.Variable(np.zeros([self.num_edge_types, 1, new_h_dim]).astype(np.float32)) 191 | with tf.variable_scope("gru_scope"+scope): 192 | cell = tf.contrib.rnn.GRUCell(new_h_dim) 193 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, 194 | state_keep_prob=self.placeholders['graph_state_keep_prob']) 195 | self.weights['node_gru'+scope] = cell 196 | 197 | # weights for calculating mean and variance 198 | self.weights['mean_weights'] = tf.Variable(glorot_init([h_dim, h_dim])) 199 | self.weights['mean_biases'] = tf.Variable(np.zeros([1, h_dim]).astype(np.float32)) 200 | self.weights['variance_weights'] = tf.Variable(glorot_init([h_dim, h_dim])) 201 | self.weights['variance_biases'] = tf.Variable(np.zeros([1, h_dim]).astype(np.float32)) 202 | 203 | # The weights for generating nodel symbol logits 204 | self.weights['node_symbol_weights'] = tf.Variable(glorot_init([h_dim, self.params['num_symbols']])) 205 | self.weights['node_symbol_biases'] = tf.Variable(np.zeros([1, self.params['num_symbols']]).astype(np.float32)) 206 | 207 | feature_dimension=6*expanded_h_dim 208 | # record the total number of features 209 | self.params["feature_dimension"] = 6 210 | # weights for generating edge type logits 211 | for i in range(self.num_edge_types): 212 | self.weights['edge_type_%d' % i] = tf.Variable(glorot_init([feature_dimension, feature_dimension])) 213 | self.weights['edge_type_biases_%d' % i] = tf.Variable(np.zeros([1, feature_dimension]).astype(np.float32)) 214 | self.weights['edge_type_output_%d' % i] = tf.Variable(glorot_init([feature_dimension, 1])) 215 | # weights for generating edge logits 216 | self.weights['edge_iteration'] = tf.Variable(glorot_init([feature_dimension, feature_dimension])) 217 | self.weights['edge_iteration_biases'] = tf.Variable(np.zeros([1, feature_dimension]).astype(np.float32)) 218 | self.weights['edge_iteration_output'] = tf.Variable(glorot_init([feature_dimension, 1])) 219 | # Weights for the stop node 220 | self.weights["stop_node"] = tf.Variable(glorot_init([1, expanded_h_dim])) 221 | # Weight for distance embedding 222 | self.weights['distance_embedding'] = tf.Variable(glorot_init([self.params['maximum_distance'], expanded_h_dim])) 223 | # Weight for overlapped edge feature 224 | self.weights["overlapped_edge_weight"] = tf.Variable(glorot_init([2, expanded_h_dim])) 225 | # weights for linear projection on qed prediction input 226 | self.weights['qed_weights'] = tf.Variable(glorot_init([h_dim, h_dim])) 227 | self.weights['qed_biases'] = tf.Variable(np.zeros([1, h_dim]).astype(np.float32)) 228 | # use node embeddings 229 | self.weights["node_embedding"]= tf.Variable(glorot_init([self.params["num_symbols"], h_dim])) 230 | 231 | # graph state mask 232 | self.ops['graph_state_mask']= tf.expand_dims(self.placeholders['node_mask'], 2) 233 | 234 | # transform one hot vector to dense embedding vectors 235 | def get_node_embedding_state(self, one_hot_state): 236 | node_nums=tf.argmax(one_hot_state, axis=2) 237 | return tf.nn.embedding_lookup(self.weights["node_embedding"], node_nums) * self.ops['graph_state_mask'] 238 | 239 | def compute_final_node_representations_with_residual(self, h, adj, scope_name): # scope_name: _encoder or _decoder 240 | # h: initial representation, adj: adjacency matrix, different GNN parameters for encoder and decoder 241 | v = self.placeholders['num_vertices'] 242 | # _decoder uses a larger latent space because concat of symbol and latent representation 243 | if scope_name=="_decoder": 244 | h_dim = self.params['hidden_size'] + self.params['hidden_size'] + 1 245 | else: 246 | h_dim = self.params['hidden_size'] 247 | h = tf.reshape(h, [-1, h_dim]) # [b*v, h] 248 | # record all hidden states at each iteration 249 | all_hidden_states=[h] 250 | for iter_idx in range(self.params['num_timesteps']): 251 | with tf.variable_scope("gru_scope"+scope_name+str(iter_idx), reuse=None) as g_scope: 252 | for edge_type in range(self.num_edge_types): 253 | # the message passed from this vertice to other vertices 254 | m = tf.matmul(h, self.weights['edge_weights'+scope_name+str(iter_idx)][edge_type]) # [b*v, h] 255 | if self.params['use_edge_bias']: 256 | m += self.weights['edge_biases'+scope_name+str(iter_idx)][edge_type] # [b, v, h] 257 | m = tf.reshape(m, [-1, v, h_dim]) # [b, v, h] 258 | # collect the messages from other vertices to each vertice 259 | if edge_type == 0: 260 | acts = tf.matmul(adj[edge_type], m) 261 | else: 262 | acts += tf.matmul(adj[edge_type], m) 263 | # all messages collected for each node 264 | acts = tf.reshape(acts, [-1, h_dim]) # [b*v, h] 265 | # add residual connection here 266 | layer_residual_connections = self.params['residual_connections'].get(iter_idx) 267 | if layer_residual_connections is None: 268 | layer_residual_states = [] 269 | else: 270 | layer_residual_states = [all_hidden_states[residual_layer_idx] 271 | for residual_layer_idx in layer_residual_connections] 272 | # concat current hidden states with residual states 273 | acts= tf.concat([acts] + layer_residual_states, axis=1) # [b, (1+num residual connection)* h] 274 | 275 | # feed msg inputs and hidden states to GRU 276 | h = self.weights['node_gru'+scope_name+str(iter_idx)](acts, h)[1] # [b*v, h] 277 | # record the new hidden states 278 | all_hidden_states.append(h) 279 | last_h = tf.reshape(all_hidden_states[-1], [-1, v, h_dim]) 280 | return last_h 281 | 282 | def compute_final_node_representations_without_residual(self, h, adj, edge_weights, edge_biases, node_gru, gru_scope_name): 283 | # h: initial representation, adj: adjacency matrix, different GNN parameters for encoder and decoder 284 | v = self.placeholders['num_vertices'] 285 | if gru_scope_name=="gru_scope_decoder": 286 | h_dim = self.params['hidden_size'] + self.params['hidden_size'] 287 | else: 288 | h_dim = self.params['hidden_size'] 289 | h = tf.reshape(h, [-1, h_dim]) 290 | 291 | with tf.variable_scope(gru_scope_name) as scope: 292 | for i in range(self.params['num_timesteps']): 293 | if i > 0: 294 | tf.get_variable_scope().reuse_variables() 295 | for edge_type in range(self.num_edge_types): 296 | m = tf.matmul(h, tf.nn.dropout(edge_weights[edge_type], 297 | keep_prob=self.placeholders['edge_weight_dropout_keep_prob'])) # [b*v, h] 298 | if self.params['use_edge_bias']: 299 | m += edge_biases[edge_type] # [b, v, h] 300 | m = tf.reshape(m, [-1, v, h_dim]) # [b, v, h] 301 | if edge_type == 0: 302 | acts = tf.matmul(adj[edge_type], m) 303 | else: 304 | acts += tf.matmul(adj[edge_type], m) 305 | acts = tf.reshape(acts, [-1, h_dim]) # [b*v, h] 306 | h = node_gru(acts, h)[1] # [b*v, h] 307 | last_h = tf.reshape(h, [-1, v, h_dim]) 308 | return last_h 309 | 310 | def compute_mean_and_logvariance(self): 311 | h_dim = self.params['hidden_size'] 312 | reshped_last_h=tf.reshape(self.ops['final_node_representations'], [-1, h_dim]) 313 | mean=tf.matmul(reshped_last_h, self.weights['mean_weights']) + self.weights['mean_biases'] 314 | logvariance=tf.matmul(reshped_last_h, self.weights['variance_weights']) + self.weights['variance_biases'] 315 | return mean, logvariance 316 | 317 | def sample_with_mean_and_logvariance(self): 318 | v = self.placeholders['num_vertices'] 319 | h_dim = self.params['hidden_size'] 320 | # Sample from normal distribution 321 | z_prior = tf.reshape(self.placeholders['z_prior'], [-1, h_dim]) 322 | # Train: sample from u, Sigma. Generation: sample from 0,1 323 | z_sampled = tf.cond(self.placeholders['is_generative'], lambda: z_prior, # standard normal 324 | lambda: tf.add(self.ops['mean'], tf.multiply(tf.sqrt(tf.exp(self.ops['logvariance'])), z_prior))) # non-standard normal 325 | # filter 326 | z_sampled = tf.reshape(z_sampled, [-1, v, h_dim]) * self.ops['graph_state_mask'] 327 | return z_sampled 328 | 329 | def fully_connected(self, input, hidden_weight, hidden_bias, output_weight): 330 | output=tf.nn.relu(tf.matmul(input, hidden_weight) + hidden_bias) 331 | output=tf.matmul(output, output_weight) 332 | return output 333 | 334 | def generate_cross_entropy(self, idx, cross_entropy_losses, edge_predictions, edge_type_predictions): 335 | v = self.placeholders['num_vertices'] 336 | h_dim = self.params['hidden_size'] 337 | num_symbols = self.params['num_symbols'] 338 | batch_size = tf.shape(self.placeholders['initial_node_representation'])[0] 339 | # Use latent representation as decoder GNN'input 340 | filtered_z_sampled = self.ops["initial_repre_for_decoder"] # [b, v, h+h] 341 | # data needed in this iteration 342 | incre_adj_mat = self.placeholders['incre_adj_mat'][:,idx,:,:, :] # [b, e, v, v] 343 | distance_to_others = self.placeholders['distance_to_others'][:, idx, :] # [b,v] 344 | overlapped_edge_features = self.placeholders['overlapped_edge_features'][:, idx, :] # [b,v] 345 | node_sequence = self.placeholders['node_sequence'][:, idx, :] # [b, v] 346 | node_sequence = tf.expand_dims(node_sequence, axis=2) # [b,v,1] 347 | edge_type_masks = self.placeholders['edge_type_masks'][:, idx, :, :] # [b, e, v] 348 | # make invalid locations to be very small before using softmax function 349 | edge_type_masks = edge_type_masks * LARGE_NUMBER - LARGE_NUMBER 350 | edge_type_labels = self.placeholders['edge_type_labels'][:, idx, :, :] # [b, e, v] 351 | edge_masks=self.placeholders['edge_masks'][:, idx, :] # [b, v] 352 | # make invalid locations to be very small before using softmax function 353 | edge_masks = edge_masks * LARGE_NUMBER - LARGE_NUMBER 354 | edge_labels = self.placeholders['edge_labels'][:, idx, :] # [b, v] 355 | local_stop = self.placeholders['local_stop'][:, idx] # [b] 356 | # concat the hidden states with the node in focus 357 | filtered_z_sampled = tf.concat([filtered_z_sampled, node_sequence], axis=2) # [b, v, h + h + 1] 358 | # Decoder GNN 359 | if self.params["use_graph"]: 360 | if self.params["residual_connection_on"]: 361 | new_filtered_z_sampled = self.compute_final_node_representations_with_residual(filtered_z_sampled, 362 | tf.transpose(incre_adj_mat, [1, 0, 2, 3]), 363 | "_decoder") # [b, v, h + h] 364 | else: 365 | new_filtered_z_sampled = self.compute_final_node_representations_without_residual(filtered_z_sampled, 366 | tf.transpose(incre_adj_mat, [1, 0, 2, 3]), 367 | self.weights['edge_weights_decoder'], 368 | self.weights['edge_biases_decoder'], 369 | self.weights['node_gru_decoder'], "gru_scope_decoder") # [b, v, h + h] 370 | else: 371 | new_filtered_z_sampled = filtered_z_sampled 372 | # Filter nonexist nodes 373 | new_filtered_z_sampled=new_filtered_z_sampled * self.ops['graph_state_mask'] 374 | # Take out the node in focus 375 | node_in_focus = tf.reduce_sum(node_sequence * new_filtered_z_sampled, axis=1)# [b, h + h] 376 | # edge pair representation 377 | edge_repr=tf.concat(\ 378 | [tf.tile(tf.expand_dims(node_in_focus, 1), [1,v,1]), new_filtered_z_sampled], axis=2) # [b, v, 2*(h+h)] 379 | #combine edge repre with local and global repr 380 | local_graph_repr_before_expansion = tf.reduce_sum(new_filtered_z_sampled, axis=1) / \ 381 | tf.reduce_sum(self.placeholders['node_mask'], axis=1, keep_dims=True) # [b, h + h] 382 | local_graph_repr = tf.expand_dims(local_graph_repr_before_expansion, 1) 383 | local_graph_repr = tf.tile(local_graph_repr, [1,v,1]) # [b, v, h+h] 384 | global_graph_repr_before_expansion = tf.reduce_sum(filtered_z_sampled, axis=1) / \ 385 | tf.reduce_sum(self.placeholders['node_mask'], axis=1, keep_dims=True) 386 | global_graph_repr = tf.expand_dims(global_graph_repr_before_expansion, 1) 387 | global_graph_repr = tf.tile(global_graph_repr, [1,v,1]) # [b, v, h+h] 388 | # distance representation 389 | distance_repr = tf.nn.embedding_lookup(self.weights['distance_embedding'], distance_to_others) # [b, v, h+h] 390 | # overlapped edge feature representation 391 | overlapped_edge_repr = tf.nn.embedding_lookup(self.weights['overlapped_edge_weight'], overlapped_edge_features) # [b, v, h+h] 392 | # concat and reshape. 393 | combined_edge_repr = tf.concat([edge_repr, local_graph_repr, 394 | global_graph_repr, distance_repr, overlapped_edge_repr], axis=2) 395 | 396 | combined_edge_repr = tf.reshape(combined_edge_repr, [-1, self.params["feature_dimension"]*(h_dim + h_dim + 1)]) 397 | # Calculate edge logits 398 | edge_logits=self.fully_connected(combined_edge_repr, self.weights['edge_iteration'], 399 | self.weights['edge_iteration_biases'], self.weights['edge_iteration_output']) 400 | edge_logits=tf.reshape(edge_logits, [-1, v]) # [b, v] 401 | # filter invalid terms 402 | edge_logits=edge_logits + edge_masks 403 | # Calculate whether it will stop at this step 404 | # prepare the data 405 | expanded_stop_node = tf.tile(self.weights['stop_node'], [batch_size, 1]) # [b, h + h] 406 | distance_to_stop_node = tf.nn.embedding_lookup(self.weights['distance_embedding'], tf.tile([0], [batch_size])) # [b, h + h] 407 | overlap_edge_stop_node = tf.nn.embedding_lookup(self.weights['overlapped_edge_weight'], tf.tile([0], [batch_size])) # [b, h + h] 408 | 409 | combined_stop_node_repr = tf.concat([node_in_focus, expanded_stop_node, local_graph_repr_before_expansion, 410 | global_graph_repr_before_expansion, distance_to_stop_node, overlap_edge_stop_node], axis=1) # [b, 6 * (h + h)] 411 | # logits for stop node 412 | stop_logits = self.fully_connected(combined_stop_node_repr, 413 | self.weights['edge_iteration'], self.weights['edge_iteration_biases'], 414 | self.weights['edge_iteration_output']) #[b, 1] 415 | edge_logits = tf.concat([edge_logits, stop_logits], axis=1) # [b, v + 1] 416 | 417 | # Calculate edge type logits 418 | edge_type_logits = [] 419 | for i in range(self.num_edge_types): 420 | edge_type_logit = self.fully_connected(combined_edge_repr, 421 | self.weights['edge_type_%d' % i], self.weights['edge_type_biases_%d' % i], 422 | self.weights['edge_type_output_%d' % i]) #[b * v, 1] 423 | edge_type_logits.append(tf.reshape(edge_type_logit, [-1, 1, v])) # [b, 1, v] 424 | 425 | edge_type_logits = tf.concat(edge_type_logits, axis=1) # [b, e, v] 426 | # filter invalid items 427 | edge_type_logits = edge_type_logits + edge_type_masks # [b, e, v] 428 | # softmax over edge type axis 429 | edge_type_probs = tf.nn.softmax(edge_type_logits, 1) # [b, e, v] 430 | 431 | # edge labels 432 | edge_labels = tf.concat([edge_labels,tf.expand_dims(local_stop, 1)], axis=1) # [b, v + 1] 433 | # softmax for edge 434 | edge_loss =- tf.reduce_sum(tf.log(tf.nn.softmax(edge_logits) + SMALL_NUMBER) * edge_labels, axis=1) 435 | # softmax for edge type 436 | edge_type_loss =- edge_type_labels * tf.log(edge_type_probs + SMALL_NUMBER) # [b, e, v] 437 | edge_type_loss = tf.reduce_sum(edge_type_loss, axis=[1, 2]) # [b] 438 | # total loss 439 | iteration_loss = edge_loss + edge_type_loss 440 | cross_entropy_losses = cross_entropy_losses.write(idx, iteration_loss) 441 | edge_predictions = edge_predictions.write(idx, tf.nn.softmax(edge_logits)) 442 | edge_type_predictions = edge_type_predictions.write(idx, edge_type_probs) 443 | return (idx+1, cross_entropy_losses, edge_predictions, edge_type_predictions) 444 | 445 | def construct_logit_matrices(self): 446 | v = self.placeholders['num_vertices'] 447 | batch_size=tf.shape(self.placeholders['initial_node_representation'])[0] 448 | h_dim = self.params['hidden_size'] 449 | 450 | # Initial state: embedding 451 | latent_node_state= self.get_node_embedding_state(self.placeholders["latent_node_symbols"]) 452 | # concat z_sampled with node symbols 453 | filtered_z_sampled = tf.concat([self.ops['z_sampled'], 454 | latent_node_state], axis=2) # [b, v, h + h] 455 | self.ops["initial_repre_for_decoder"] = filtered_z_sampled 456 | # The tensor array used to collect the cross entropy losses at each step 457 | cross_entropy_losses = tf.TensorArray(dtype=tf.float32, size=self.placeholders['max_iteration_num']) 458 | edge_predictions= tf.TensorArray(dtype=tf.float32, size=self.placeholders['max_iteration_num']) 459 | edge_type_predictions = tf.TensorArray(dtype=tf.float32, size=self.placeholders['max_iteration_num']) 460 | idx_final, cross_entropy_losses_final, edge_predictions_final,edge_type_predictions_final=\ 461 | tf.while_loop(lambda idx, cross_entropy_losses,edge_predictions,edge_type_predictions: idx < self.placeholders['max_iteration_num'], 462 | self.generate_cross_entropy, 463 | (tf.constant(0), cross_entropy_losses,edge_predictions,edge_type_predictions,)) 464 | 465 | # record the predictions for generation 466 | self.ops['edge_predictions'] = edge_predictions_final.read(0) 467 | self.ops['edge_type_predictions'] = edge_type_predictions_final.read(0) 468 | 469 | # final cross entropy losses 470 | cross_entropy_losses_final = cross_entropy_losses_final.stack() 471 | self.ops['cross_entropy_losses'] = tf.transpose(cross_entropy_losses_final, [1,0]) # [b, es] 472 | 473 | # Logits for node symbols 474 | self.ops['node_symbol_logits']=tf.reshape(tf.matmul(tf.reshape(self.ops['z_sampled'],[-1, h_dim]), self.weights['node_symbol_weights']) + 475 | self.weights['node_symbol_biases'], [-1, v, self.params['num_symbols']]) 476 | 477 | def construct_loss(self): 478 | v = self.placeholders['num_vertices'] 479 | h_dim = self.params['hidden_size'] 480 | kl_trade_off_lambda =self.placeholders['kl_trade_off_lambda'] 481 | # Edge loss 482 | self.ops["edge_loss"] = tf.reduce_sum(self.ops['cross_entropy_losses'] * self.placeholders['iteration_mask'], axis=1) 483 | # KL loss 484 | kl_loss = 1 + self.ops['logvariance'] - tf.square(self.ops['mean']) - tf.exp(self.ops['logvariance']) 485 | kl_loss = tf.reshape(kl_loss, [-1, v, h_dim]) * self.ops['graph_state_mask'] 486 | self.ops['kl_loss'] = -0.5 * tf.reduce_sum(kl_loss, [1,2]) 487 | # Node symbol loss 488 | self.ops['node_symbol_prob'] = tf.nn.softmax(self.ops['node_symbol_logits']) 489 | self.ops['node_symbol_loss'] = -tf.reduce_sum(tf.log(self.ops['node_symbol_prob'] + SMALL_NUMBER) * 490 | self.placeholders['node_symbols'], axis=[1,2]) 491 | # Add in the loss for calculating QED 492 | for (internal_id, task_id) in enumerate(self.params['task_ids']): 493 | with tf.variable_scope("out_layer_task%i" % task_id): 494 | with tf.variable_scope("regression_gate"): 495 | self.weights['regression_gate_task%i' % task_id] = MLP(self.params['hidden_size'], 1, [], 496 | self.placeholders['out_layer_dropout_keep_prob']) 497 | with tf.variable_scope("regression"): 498 | self.weights['regression_transform_task%i' % task_id] = MLP(self.params['hidden_size'], 1, [], 499 | self.placeholders['out_layer_dropout_keep_prob']) 500 | normalized_z_sampled=tf.nn.l2_normalize(self.ops['z_sampled'], 2) 501 | self.ops['qed_computed_values']=computed_values = self.gated_regression(normalized_z_sampled, 502 | self.weights['regression_gate_task%i' % task_id], 503 | self.weights['regression_transform_task%i' % task_id], self.params["hidden_size"], 504 | self.weights['qed_weights'], self.weights['qed_biases'], 505 | self.placeholders['num_vertices'], self.placeholders['node_mask']) 506 | diff = computed_values - self.placeholders['target_values'][internal_id,:] # [b] 507 | task_target_mask = self.placeholders['target_mask'][internal_id,:] 508 | task_target_num = tf.reduce_sum(task_target_mask) + SMALL_NUMBER 509 | diff = diff * task_target_mask # Mask out unused values [b] 510 | self.ops['accuracy_task%i' % task_id] = tf.reduce_sum(tf.abs(diff)) / task_target_num 511 | task_loss = tf.reduce_sum(0.5 * tf.square(diff)) / task_target_num # number 512 | # Normalise loss to account for fewer task-specific examples in batch: 513 | task_loss = task_loss * (1.0 / (self.params['task_sample_ratios'].get(task_id) or 1.0)) 514 | self.ops['qed_loss'].append(task_loss) 515 | if task_id ==0: # Assume it is the QED score 516 | z_sampled_shape=tf.shape(self.ops['z_sampled']) 517 | flattened_z_sampled=tf.reshape(self.ops['z_sampled'], [z_sampled_shape[0], -1]) 518 | self.ops['l2_loss'] = 0.01* tf.reduce_sum(flattened_z_sampled * flattened_z_sampled, axis=1) /2 519 | # Calculate the derivative with respect to QED + l2 loss 520 | self.ops['derivative_z_sampled'] = tf.gradients(self.ops['qed_computed_values'] - 521 | self.ops['l2_loss'],self.ops['z_sampled']) 522 | self.ops['total_qed_loss'] = tf.reduce_sum(self.ops['qed_loss']) # number 523 | self.ops['mean_edge_loss'] = tf.reduce_mean(self.ops["edge_loss"]) # record the mean edge loss 524 | self.ops['mean_node_symbol_loss'] = tf.reduce_mean(self.ops["node_symbol_loss"]) 525 | self.ops['mean_kl_loss'] = tf.reduce_mean(kl_trade_off_lambda *self.ops['kl_loss']) 526 | self.ops['mean_total_qed_loss'] = self.params["qed_trade_off_lambda"]*self.ops['total_qed_loss'] 527 | return tf.reduce_mean(self.ops["edge_loss"] + self.ops['node_symbol_loss'] + \ 528 | kl_trade_off_lambda *self.ops['kl_loss'])\ 529 | + self.params["qed_trade_off_lambda"]*self.ops['total_qed_loss'] 530 | 531 | def gated_regression(self, last_h, regression_gate, regression_transform, hidden_size, projection_weight, projection_bias, v, mask): 532 | # last_h: [b x v x h] 533 | last_h = tf.reshape(last_h, [-1, hidden_size]) # [b*v, h] 534 | # linear projection on last_h 535 | last_h = tf.nn.relu(tf.matmul(last_h, projection_weight)+projection_bias) # [b*v, h] 536 | # same as last_h 537 | gate_input = last_h 538 | # linear projection and combine 539 | gated_outputs = tf.nn.sigmoid(regression_gate(gate_input)) * tf.nn.tanh(regression_transform(last_h)) # [b*v, 1] 540 | gated_outputs = tf.reshape(gated_outputs, [-1, v]) # [b, v] 541 | masked_gated_outputs = gated_outputs * mask # [b x v] 542 | output = tf.reduce_sum(masked_gated_outputs, axis = 1) # [b] 543 | output=tf.sigmoid(output) 544 | return output 545 | 546 | def calculate_incremental_results(self, raw_data, bucket_sizes, file_name): 547 | incremental_results=[] 548 | # copy the raw_data if more than 1 BFS path is added 549 | new_raw_data=[] 550 | for idx, d in enumerate(raw_data): 551 | # Use canonical order or random order here. canonical order starts from index 0. random order starts from random nodes 552 | if not self.params["path_random_order"]: 553 | # Use several different starting index if using multi BFS path 554 | if self.params["multi_bfs_path"]: 555 | list_of_starting_idx= list(range(self.params["bfs_path_count"])) 556 | else: 557 | list_of_starting_idx=[0] # the index 0 558 | else: 559 | # get the node length for this molecule 560 | node_length=len(d["node_features"]) 561 | if self.params["multi_bfs_path"]: 562 | list_of_starting_idx= np.random.choice(node_length, self.params["bfs_path_count"], replace=True) #randomly choose several 563 | else: 564 | list_of_starting_idx= [random.choice(list(range(node_length)))] # randomly choose one 565 | for list_idx, starting_idx in enumerate(list_of_starting_idx): 566 | # choose a bucket 567 | chosen_bucket_idx = np.argmax(bucket_sizes > max([v for e in d['graph'] 568 | for v in [e[0], e[2]]])) 569 | chosen_bucket_size = bucket_sizes[chosen_bucket_idx] 570 | 571 | # Calculate incremental results without master node 572 | nodes_no_master, edges_no_master = to_graph(d['smiles'], self.params["dataset"]) 573 | incremental_adj_mat,distance_to_others,node_sequence,edge_type_masks,edge_type_labels,local_stop, edge_masks, edge_labels, overlapped_edge_features=\ 574 | construct_incremental_graph(dataset, edges_no_master, chosen_bucket_size, 575 | len(nodes_no_master), nodes_no_master, self.params, initial_idx=starting_idx) 576 | if self.params["sample_transition"] and list_idx > 0: 577 | incremental_results[-1]=[x+y for x, y in zip(incremental_results[-1], [incremental_adj_mat,distance_to_others, 578 | node_sequence,edge_type_masks,edge_type_labels,local_stop, edge_masks, edge_labels, overlapped_edge_features])] 579 | else: 580 | incremental_results.append([incremental_adj_mat, distance_to_others, node_sequence, edge_type_masks, 581 | edge_type_labels, local_stop, edge_masks, edge_labels, overlapped_edge_features]) 582 | # copy the raw_data here 583 | new_raw_data.append(d) 584 | if idx % 50 == 0: 585 | print('finish calculating %d incremental matrices' % idx, end="\r") 586 | return incremental_results, new_raw_data 587 | 588 | # ----- Data preprocessing and chunking into minibatches: 589 | def process_raw_graphs(self, raw_data, is_training_data, file_name, bucket_sizes=None): 590 | if bucket_sizes is None: 591 | bucket_sizes = dataset_info(self.params["dataset"])["bucket_sizes"] 592 | incremental_results, raw_data=self.calculate_incremental_results(raw_data, bucket_sizes, file_name) 593 | bucketed = defaultdict(list) 594 | x_dim = len(raw_data[0]["node_features"][0]) 595 | 596 | for d, (incremental_adj_mat,distance_to_others,node_sequence,edge_type_masks,edge_type_labels,local_stop, edge_masks, edge_labels, overlapped_edge_features)\ 597 | in zip(raw_data, incremental_results): 598 | # choose a bucket 599 | chosen_bucket_idx = np.argmax(bucket_sizes > max([v for e in d['graph'] 600 | for v in [e[0], e[2]]])) 601 | chosen_bucket_size = bucket_sizes[chosen_bucket_idx] 602 | # total number of nodes in this data point 603 | n_active_nodes = len(d["node_features"]) 604 | bucketed[chosen_bucket_idx].append({ 605 | 'adj_mat': graph_to_adj_mat(d['graph'], chosen_bucket_size, self.num_edge_types, self.params['tie_fwd_bkwd']), 606 | 'incre_adj_mat': incremental_adj_mat, 607 | 'distance_to_others': distance_to_others, 608 | 'overlapped_edge_features': overlapped_edge_features, 609 | 'node_sequence': node_sequence, 610 | 'edge_type_masks': edge_type_masks, 611 | 'edge_type_labels': edge_type_labels, 612 | 'edge_masks': edge_masks, 613 | 'edge_labels': edge_labels, 614 | 'local_stop': local_stop, 615 | 'number_iteration': len(local_stop), 616 | 'init': d["node_features"] + [[0 for _ in range(x_dim)] for __ in 617 | range(chosen_bucket_size - n_active_nodes)], 618 | 'labels': [d["targets"][task_id][0] for task_id in self.params['task_ids']], 619 | 'mask': [1. for _ in range(n_active_nodes) ] + [0. for _ in range(chosen_bucket_size - n_active_nodes)] 620 | }) 621 | 622 | if is_training_data: 623 | for (bucket_idx, bucket) in bucketed.items(): 624 | np.random.shuffle(bucket) 625 | for task_id in self.params['task_ids']: 626 | task_sample_ratio = self.params['task_sample_ratios'].get(str(task_id)) 627 | if task_sample_ratio is not None: 628 | ex_to_sample = int(len(bucket) * task_sample_ratio) 629 | for ex_id in range(ex_to_sample, len(bucket)): 630 | bucket[ex_id]['labels'][task_id] = None 631 | 632 | bucket_at_step = [[bucket_idx for _ in range(len(bucket_data) // self.params['batch_size'])] 633 | for bucket_idx, bucket_data in bucketed.items()] 634 | bucket_at_step = [x for y in bucket_at_step for x in y] 635 | 636 | return (bucketed, bucket_sizes, bucket_at_step) 637 | 638 | def pad_annotations(self, annotations): 639 | return np.pad(annotations, 640 | pad_width=[[0, 0], [0, 0], [0, self.params['hidden_size'] - self.params["num_symbols"]]], 641 | mode='constant') 642 | 643 | def make_batch(self, elements, maximum_vertice_num): 644 | # get maximum number of iterations in this batch. used to control while_loop 645 | max_iteration_num=-1 646 | for d in elements: 647 | max_iteration_num=max(d['number_iteration'], max_iteration_num) 648 | batch_data = {'adj_mat': [], 'init': [], 'labels': [], 'edge_type_masks':[], 'edge_type_labels':[], 'edge_masks':[], 649 | 'edge_labels':[],'node_mask': [], 'task_masks': [], 'node_sequence':[], 650 | 'iteration_mask': [], 'local_stop': [], 'incre_adj_mat': [], 'distance_to_others': [], 651 | 'max_iteration_num': max_iteration_num, 'overlapped_edge_features': []} 652 | for d in elements: 653 | # sparse to dense for saving memory 654 | incre_adj_mat = incre_adj_mat_to_dense(d['incre_adj_mat'], self.num_edge_types, maximum_vertice_num) 655 | distance_to_others = distance_to_others_dense(d['distance_to_others'], maximum_vertice_num) 656 | overlapped_edge_features = overlapped_edge_features_to_dense(d['overlapped_edge_features'], maximum_vertice_num) 657 | node_sequence = node_sequence_to_dense(d['node_sequence'],maximum_vertice_num) 658 | edge_type_masks = edge_type_masks_to_dense(d['edge_type_masks'], maximum_vertice_num,self.num_edge_types) 659 | edge_type_labels = edge_type_labels_to_dense(d['edge_type_labels'], maximum_vertice_num,self.num_edge_types) 660 | edge_masks = edge_masks_to_dense(d['edge_masks'], maximum_vertice_num) 661 | edge_labels = edge_labels_to_dense(d['edge_labels'], maximum_vertice_num) 662 | 663 | batch_data['adj_mat'].append(d['adj_mat']) 664 | batch_data['init'].append(d['init']) 665 | batch_data['node_mask'].append(d['mask']) 666 | 667 | batch_data['incre_adj_mat'].append(incre_adj_mat + 668 | [np.zeros((self.num_edge_types, maximum_vertice_num,maximum_vertice_num)) 669 | for _ in range(max_iteration_num-d['number_iteration'])]) 670 | batch_data['distance_to_others'].append(distance_to_others + 671 | [np.zeros((maximum_vertice_num)) 672 | for _ in range(max_iteration_num-d['number_iteration'])]) 673 | batch_data['overlapped_edge_features'].append(overlapped_edge_features + 674 | [np.zeros((maximum_vertice_num)) 675 | for _ in range(max_iteration_num-d['number_iteration'])]) 676 | batch_data['node_sequence'].append(node_sequence + 677 | [np.zeros((maximum_vertice_num)) 678 | for _ in range(max_iteration_num-d['number_iteration'])]) 679 | batch_data['edge_type_masks'].append(edge_type_masks + 680 | [np.zeros((self.num_edge_types, maximum_vertice_num)) 681 | for _ in range(max_iteration_num-d['number_iteration'])]) 682 | batch_data['edge_masks'].append(edge_masks + 683 | [np.zeros((maximum_vertice_num)) 684 | for _ in range(max_iteration_num-d['number_iteration'])]) 685 | batch_data['edge_type_labels'].append(edge_type_labels + 686 | [np.zeros((self.num_edge_types, maximum_vertice_num)) 687 | for _ in range(max_iteration_num-d['number_iteration'])]) 688 | batch_data['edge_labels'].append(edge_labels + 689 | [np.zeros((maximum_vertice_num)) 690 | for _ in range(max_iteration_num-d['number_iteration'])]) 691 | batch_data['iteration_mask'].append([1 for _ in range(d['number_iteration'])]+ 692 | [0 for _ in range(max_iteration_num-d['number_iteration'])]) 693 | batch_data['local_stop'].append([int(s) for s in d["local_stop"]]+ 694 | [0 for _ in range(max_iteration_num-d['number_iteration'])]) 695 | 696 | target_task_values = [] 697 | target_task_mask = [] 698 | for target_val in d['labels']: 699 | if target_val is None: # This is one of the examples we didn't sample... 700 | target_task_values.append(0.) 701 | target_task_mask.append(0.) 702 | else: 703 | target_task_values.append(target_val) 704 | target_task_mask.append(1.) 705 | batch_data['labels'].append(target_task_values) 706 | batch_data['task_masks'].append(target_task_mask) 707 | 708 | return batch_data 709 | 710 | def get_dynamic_feed_dict(self, elements, latent_node_symbol, incre_adj_mat, num_vertices, 711 | distance_to_others, overlapped_edge_dense, node_sequence, edge_type_masks, edge_masks, random_normal_states): 712 | if incre_adj_mat is None: 713 | incre_adj_mat=np.zeros((1, 1, self.num_edge_types, 1, 1)) 714 | distance_to_others=np.zeros((1,1,1)) 715 | overlapped_edge_dense=np.zeros((1,1,1)) 716 | node_sequence=np.zeros((1,1,1)) 717 | edge_type_masks=np.zeros((1,1,self.num_edge_types,1)) 718 | edge_masks=np.zeros((1,1,1)) 719 | latent_node_symbol=np.zeros((1,1,self.params["num_symbols"])) 720 | return { 721 | self.placeholders['z_prior']: random_normal_states, # [1, v, h] 722 | self.placeholders['incre_adj_mat']: incre_adj_mat, # [1, 1, e, v, v] 723 | self.placeholders['num_vertices']: num_vertices, # v 724 | 725 | self.placeholders['initial_node_representation']: \ 726 | self.pad_annotations([elements['init']]), 727 | self.placeholders['node_symbols']: [elements['init']], 728 | self.placeholders['latent_node_symbols']: self.pad_annotations(latent_node_symbol), 729 | self.placeholders['adjacency_matrix']: [elements['adj_mat']], 730 | self.placeholders['node_mask']: [elements['mask']], 731 | 732 | self.placeholders['graph_state_keep_prob']: 1, 733 | self.placeholders['edge_weight_dropout_keep_prob']: 1, 734 | self.placeholders['iteration_mask']: [[1]], 735 | self.placeholders['is_generative']: True, 736 | self.placeholders['out_layer_dropout_keep_prob'] : 1.0, 737 | self.placeholders['distance_to_others'] : distance_to_others, # [1, 1,v] 738 | self.placeholders['overlapped_edge_features']: overlapped_edge_dense, 739 | self.placeholders['max_iteration_num']: 1, 740 | self.placeholders['node_sequence']: node_sequence, #[1, 1, v] 741 | self.placeholders['edge_type_masks']: edge_type_masks, #[1, 1, e, v] 742 | self.placeholders['edge_masks']: edge_masks, # [1, 1, v] 743 | } 744 | 745 | def get_node_symbol(self, batch_feed_dict): 746 | fetch_list = [self.ops['node_symbol_prob']] 747 | result = self.sess.run(fetch_list, feed_dict=batch_feed_dict) 748 | return result[0] 749 | 750 | def node_symbol_one_hot(self, sampled_node_symbol, real_n_vertices, max_n_vertices): 751 | one_hot_representations=[] 752 | for idx in range(max_n_vertices): 753 | representation = [0] * self.params["num_symbols"] 754 | if idx < real_n_vertices: 755 | atom_type=sampled_node_symbol[idx] 756 | representation[atom_type]=1 757 | one_hot_representations.append(representation) 758 | return one_hot_representations 759 | 760 | def search_and_generate_molecule(self, initial_idx, valences, 761 | sampled_node_symbol, real_n_vertices, random_normal_states, 762 | elements, max_n_vertices): 763 | # New molecule 764 | new_mol = Chem.MolFromSmiles('') 765 | new_mol = Chem.rdchem.RWMol(new_mol) 766 | # Add atoms 767 | add_atoms(new_mol, sampled_node_symbol, self.params["dataset"]) 768 | # Breadth first search over the molecule 769 | queue=deque([initial_idx]) 770 | # color 0: have not found 1: in the queue 2: searched already 771 | color = [0] * max_n_vertices 772 | color[initial_idx] = 1 773 | # Empty adj list at the beginning 774 | incre_adj_list=defaultdict(list) 775 | # record the log probabilities at each step 776 | total_log_prob=0 777 | while len(queue) > 0: 778 | node_in_focus = queue.popleft() 779 | # iterate until the stop node is selected 780 | while True: 781 | # Prepare data for one iteration based on the graph state 782 | edge_type_mask_sparse, edge_mask_sparse = generate_mask(valences, incre_adj_list, color, real_n_vertices, node_in_focus, self.params["check_overlap_edge"], new_mol) 783 | edge_type_mask = edge_type_masks_to_dense([edge_type_mask_sparse], max_n_vertices, self.num_edge_types) # [1, e, v] 784 | edge_mask = edge_masks_to_dense([edge_mask_sparse],max_n_vertices) # [1, v] 785 | node_sequence = node_sequence_to_dense([node_in_focus], max_n_vertices) # [1, v] 786 | distance_to_others_sparse = bfs_distance(node_in_focus, incre_adj_list) 787 | distance_to_others = distance_to_others_dense([distance_to_others_sparse],max_n_vertices) # [1, v] 788 | overlapped_edge_sparse = get_overlapped_edge_feature(edge_mask_sparse, color, new_mol) 789 | 790 | overlapped_edge_dense = overlapped_edge_features_to_dense([overlapped_edge_sparse],max_n_vertices) # [1, v] 791 | incre_adj_mat = incre_adj_mat_to_dense([incre_adj_list], 792 | self.num_edge_types, max_n_vertices) # [1, e, v, v] 793 | sampled_node_symbol_one_hot = self.node_symbol_one_hot(sampled_node_symbol, real_n_vertices, max_n_vertices) 794 | 795 | # get feed_dict 796 | feed_dict=self.get_dynamic_feed_dict(elements, [sampled_node_symbol_one_hot], 797 | [incre_adj_mat], max_n_vertices, [distance_to_others], [overlapped_edge_dense], 798 | [node_sequence], [edge_type_mask], [edge_mask], random_normal_states) 799 | 800 | # fetch nn predictions 801 | fetch_list = [self.ops['edge_predictions'], self.ops['edge_type_predictions']] 802 | edge_probs, edge_type_probs = self.sess.run(fetch_list, feed_dict=feed_dict) 803 | # select an edge 804 | if not self.params["use_argmax_generation"]: 805 | neighbor=np.random.choice(np.arange(max_n_vertices+1), p=edge_probs[0]) 806 | else: 807 | neighbor=np.argmax(edge_probs[0]) 808 | # update log prob 809 | total_log_prob+=np.log(edge_probs[0][neighbor]+SMALL_NUMBER) 810 | # stop it if stop node is picked 811 | if neighbor == max_n_vertices: 812 | break 813 | # or choose an edge type 814 | if not self.params["use_argmax_generation"]: 815 | bond=np.random.choice(np.arange(self.num_edge_types),p=edge_type_probs[0, :, neighbor]) 816 | else: 817 | bond=np.argmax(edge_type_probs[0, :, neighbor]) 818 | # update log prob 819 | total_log_prob+=np.log(edge_type_probs[0, :, neighbor][bond]+SMALL_NUMBER) 820 | #update valences 821 | valences[node_in_focus] -= (bond+1) 822 | valences[neighbor] -= (bond+1) 823 | #add the bond 824 | new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[bond]) 825 | # add the edge to increment adj list 826 | incre_adj_list[node_in_focus].append((neighbor, bond)) 827 | incre_adj_list[neighbor].append((node_in_focus, bond)) 828 | # Explore neighbor nodes 829 | if color[neighbor]==0: 830 | queue.append(neighbor) 831 | color[neighbor]=1 832 | color[node_in_focus]=2 # explored 833 | # Remove unconnected node 834 | remove_extra_nodes(new_mol) 835 | new_mol=Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) 836 | return new_mol, total_log_prob 837 | 838 | def gradient_ascent(self, random_normal_states, derivative_z_sampled): 839 | return random_normal_states + self.params['prior_learning_rate'] * derivative_z_sampled 840 | 841 | # optimization in latent space. generate one molecule for each optimization step 842 | def optimization_over_prior(self, random_normal_states, num_vertices, generated_all_similes, elements, count): 843 | # record how many optimization steps are taken 844 | step=0 845 | # generate a new molecule 846 | self.generate_graph_with_state(random_normal_states, num_vertices, generated_all_similes, elements, step, count) 847 | fetch_list = [self.ops['derivative_z_sampled'], self.ops['qed_computed_values'], self.ops['l2_loss']] 848 | for _ in range(self.params['optimization_step']): 849 | # get current qed and derivative 850 | batch_feed_dict=self.get_dynamic_feed_dict(elements, None, None, num_vertices, None, 851 | None, None, None, None, 852 | random_normal_states) 853 | derivative_z_sampled, qed_computed_values, l2_loss= self.sess.run(fetch_list, feed_dict=batch_feed_dict) 854 | # update the states 855 | random_normal_states=self.gradient_ascent(random_normal_states, 856 | derivative_z_sampled[0]) 857 | # generate a new molecule 858 | step+=1 859 | self.generate_graph_with_state(random_normal_states, num_vertices, 860 | generated_all_similes, elements, step, count) 861 | return random_normal_states 862 | 863 | 864 | def generate_graph_with_state(self, random_normal_states, num_vertices, 865 | generated_all_similes, elements, step, count): 866 | # Get back node symbol predictions 867 | # Prepare dict 868 | node_symbol_batch_feed_dict=self.get_dynamic_feed_dict(elements, None, None, 869 | num_vertices, None, None, None, None, None, random_normal_states) 870 | # Get predicted node probs 871 | predicted_node_symbol_prob=self.get_node_symbol(node_symbol_batch_feed_dict) 872 | # Node numbers for each graph 873 | real_length=get_graph_length([elements['mask']])[0] # [valid_node_number] 874 | # Sample node symbols 875 | sampled_node_symbol=sample_node_symbol(predicted_node_symbol_prob, [real_length], self.params["dataset"])[0] # [v] 876 | # Maximum valences for each node 877 | valences=get_initial_valence(sampled_node_symbol, self.params["dataset"]) # [v] 878 | # randomly pick the starting point or use zero 879 | if not self.params["path_random_order"]: 880 | # Try different starting points 881 | if self.params["try_different_starting"]: 882 | #starting_point=list(range(self.params["num_different_starting"])) 883 | starting_point=random.sample(range(real_length), 884 | min(self.params["num_different_starting"], real_length)) 885 | else: 886 | starting_point=[0] 887 | else: 888 | if self.params["try_different_starting"]: 889 | starting_point=random.sample(range(real_length), 890 | min(self.params["num_different_starting"], real_length)) 891 | else: 892 | starting_point=[random.choice(list(range(real_length)))] # randomly choose one 893 | # record all molecules from different starting points 894 | all_mol=[] 895 | for idx in starting_point: 896 | # generate a new molecule 897 | new_mol, total_log_prob=self.search_and_generate_molecule(idx, np.copy(valences), 898 | sampled_node_symbol, real_length, 899 | random_normal_states, elements, num_vertices) 900 | # record the molecule with largest number of shapes 901 | if dataset=='qm9' and new_mol is not None: 902 | all_mol.append((np.sum(shape_count(self.params["dataset"], True, 903 | [Chem.MolToSmiles(new_mol)])[1]), total_log_prob, new_mol)) 904 | # record the molecule with largest number of pentagon and hexagonal for zinc and cep 905 | elif dataset=='zinc' and new_mol is not None: 906 | counts=shape_count(self.params["dataset"], True,[Chem.MolToSmiles(new_mol)]) 907 | all_mol.append((0.5 * counts[1][2]+ counts[1][3], total_log_prob, new_mol)) 908 | elif dataset=='cep' and new_mol is not None: 909 | all_mol.append((np.sum(shape_count(self.params["dataset"], True, 910 | [Chem.MolToSmiles(new_mol)])[1][2:]), total_log_prob, new_mol)) 911 | # select one out 912 | best_mol = select_best(all_mol) 913 | # nothing generated 914 | if best_mol is None: 915 | return 916 | # visualize it 917 | make_dir('visualization_%s' % dataset) 918 | visualize_mol('visualization_%s/%d_%d.png' % (dataset, count, step), best_mol) 919 | # record the best molecule 920 | generated_all_similes.append(Chem.MolToSmiles(best_mol)) 921 | dump('generated_smiles_%s' % (dataset), generated_all_similes) 922 | print("Real QED value") 923 | print(QED.qed(best_mol)) 924 | if len(generated_all_similes) >= self.params['number_of_generation']: 925 | print("generation done") 926 | exit(0) 927 | 928 | def compensate_node_length(self, elements, bucket_size): 929 | maximum_length=bucket_size+self.params["compensate_num"] 930 | real_length=get_graph_length([elements['mask']])[0]+self.params["compensate_num"] 931 | elements['mask']=[1]*real_length + [0]*(maximum_length-real_length) 932 | elements['init']=np.zeros((maximum_length, self.params["num_symbols"])) 933 | elements['adj_mat']=np.zeros((self.num_edge_types, maximum_length, maximum_length)) 934 | return maximum_length 935 | 936 | def generate_new_graphs(self, data): 937 | # bucketed: data organized by bucket 938 | (bucketed, bucket_sizes, bucket_at_step) = data 939 | bucket_counters = defaultdict(int) 940 | # all generated similes 941 | generated_all_similes=[] 942 | # counter 943 | count = 0 944 | # shuffle the lengths 945 | np.random.shuffle(bucket_at_step) 946 | for step in range(len(bucket_at_step)): 947 | bucket = bucket_at_step[step] # bucket number 948 | # data index 949 | start_idx = bucket_counters[bucket] * self.params['batch_size'] 950 | end_idx = (bucket_counters[bucket] + 1) * self.params['batch_size'] 951 | # batch data 952 | elements_batch = bucketed[bucket][start_idx:end_idx] 953 | for elements in elements_batch: 954 | # compensate for the length during generation 955 | # (this is a result that BFS may not make use of all candidate nodes during generation) 956 | maximum_length=self.compensate_node_length(elements, bucket_sizes[bucket]) 957 | # initial state 958 | random_normal_states=generate_std_normal(1, maximum_length,\ 959 | self.params['hidden_size']) # [1, v, h] 960 | random_normal_states = self.optimization_over_prior(random_normal_states, 961 | maximum_length, generated_all_similes,elements, count) 962 | count+=1 963 | bucket_counters[bucket] += 1 964 | 965 | def make_minibatch_iterator(self, data, is_training: bool): 966 | (bucketed, bucket_sizes, bucket_at_step) = data 967 | if is_training: 968 | np.random.shuffle(bucket_at_step) 969 | for _, bucketed_data in bucketed.items(): 970 | np.random.shuffle(bucketed_data) 971 | bucket_counters = defaultdict(int) 972 | dropout_keep_prob = self.params['graph_state_dropout_keep_prob'] if is_training else 1. 973 | edge_dropout_keep_prob = self.params['edge_weight_dropout_keep_prob'] if is_training else 1. 974 | for step in range(len(bucket_at_step)): 975 | bucket = bucket_at_step[step] 976 | start_idx = bucket_counters[bucket] * self.params['batch_size'] 977 | end_idx = (bucket_counters[bucket] + 1) * self.params['batch_size'] 978 | elements = bucketed[bucket][start_idx:end_idx] 979 | batch_data = self.make_batch(elements, bucket_sizes[bucket]) 980 | 981 | num_graphs = len(batch_data['init']) 982 | initial_representations = batch_data['init'] 983 | initial_representations = self.pad_annotations(initial_representations) 984 | batch_feed_dict = { 985 | self.placeholders['initial_node_representation']: initial_representations, 986 | self.placeholders['node_symbols']: batch_data['init'], 987 | self.placeholders['latent_node_symbols']: initial_representations, 988 | self.placeholders['target_values']: np.transpose(batch_data['labels'], axes=[1,0]), 989 | self.placeholders['target_mask']: np.transpose(batch_data['task_masks'], axes=[1, 0]), 990 | self.placeholders['num_graphs']: num_graphs, 991 | self.placeholders['num_vertices']: bucket_sizes[bucket], 992 | self.placeholders['adjacency_matrix']: batch_data['adj_mat'], 993 | self.placeholders['node_mask']: batch_data['node_mask'], 994 | self.placeholders['graph_state_keep_prob']: dropout_keep_prob, 995 | self.placeholders['edge_weight_dropout_keep_prob']: edge_dropout_keep_prob, 996 | self.placeholders['iteration_mask']: batch_data['iteration_mask'], 997 | self.placeholders['incre_adj_mat']: batch_data['incre_adj_mat'], 998 | self.placeholders['distance_to_others']: batch_data['distance_to_others'], 999 | self.placeholders['node_sequence']: batch_data['node_sequence'], 1000 | self.placeholders['edge_type_masks']: batch_data['edge_type_masks'], 1001 | self.placeholders['edge_type_labels']: batch_data['edge_type_labels'], 1002 | self.placeholders['edge_masks']: batch_data['edge_masks'], 1003 | self.placeholders['edge_labels']: batch_data['edge_labels'], 1004 | self.placeholders['local_stop']: batch_data['local_stop'], 1005 | self.placeholders['max_iteration_num']: batch_data['max_iteration_num'], 1006 | self.placeholders['kl_trade_off_lambda']: self.params['kl_trade_off_lambda'], 1007 | self.placeholders['overlapped_edge_features']: batch_data['overlapped_edge_features'] 1008 | } 1009 | bucket_counters[bucket] += 1 1010 | yield batch_feed_dict 1011 | 1012 | if __name__ == "__main__": 1013 | args = docopt(__doc__) 1014 | dataset=args.get('--dataset') 1015 | try: 1016 | model = DenseGGNNChemModel(args) 1017 | evaluation = False 1018 | if evaluation: 1019 | model.example_evaluation() 1020 | else: 1021 | model.train() 1022 | except: 1023 | typ, value, tb = sys.exc_info() 1024 | traceback.print_exc() 1025 | pdb.post_mortem(tb) -------------------------------------------------------------------------------- /GGNN_core.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python 2 | 3 | from typing import Tuple, List, Any, Sequence 4 | import tensorflow as tf 5 | import time 6 | import os 7 | import json 8 | import numpy as np 9 | import pickle 10 | import random 11 | import utils 12 | from utils import MLP, dataset_info, ThreadedIterator, graph_to_adj_mat, SMALL_NUMBER, LARGE_NUMBER, graph_to_adj_mat 13 | 14 | class ChemModel(object): 15 | @classmethod 16 | def default_params(cls): 17 | return { 18 | 19 | } 20 | 21 | def __init__(self, args): 22 | self.args = args 23 | 24 | # Collect argument things: 25 | data_dir = '' 26 | if '--data_dir' in args and args['--data_dir'] is not None: 27 | data_dir = args['--data_dir'] 28 | self.data_dir = data_dir 29 | 30 | # Collect parameters: 31 | params = self.default_params() 32 | config_file = args.get('--config-file') 33 | if config_file is not None: 34 | with open(config_file, 'r') as f: 35 | params.update(json.load(f)) 36 | config = args.get('--config') 37 | if config is not None: 38 | params.update(json.loads(config)) 39 | self.params = params 40 | 41 | # Get which dataset in use 42 | self.params['dataset']=dataset=args.get('--dataset') 43 | # Number of atom types of this dataset 44 | self.params['num_symbols']=len(dataset_info(dataset)["atom_types"]) 45 | 46 | self.run_id = "_".join([time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())]) 47 | log_dir = args.get('--log_dir') or '.' 48 | self.log_file = os.path.join(log_dir, "%s_log_%s.json" % (self.run_id, dataset)) 49 | self.best_model_file = os.path.join(log_dir, "%s_model.pickle" % self.run_id) 50 | 51 | with open(os.path.join(log_dir, "%s_params_%s.json" % (self.run_id,dataset)), "w") as f: 52 | json.dump(params, f) 53 | print("Run %s starting with following parameters:\n%s" % (self.run_id, json.dumps(self.params))) 54 | random.seed(params['random_seed']) 55 | np.random.seed(params['random_seed']) 56 | 57 | # Load data: 58 | self.max_num_vertices = 0 59 | self.num_edge_types = 0 60 | self.annotation_size = 0 61 | self.train_data = self.load_data(params['train_file'], is_training_data=True) 62 | self.valid_data = self.load_data(params['valid_file'], is_training_data=False) 63 | 64 | # Build the actual model 65 | config = tf.ConfigProto() 66 | config.gpu_options.allow_growth = True 67 | self.graph = tf.Graph() 68 | self.sess = tf.Session(graph=self.graph, config=config) 69 | with self.graph.as_default(): 70 | tf.set_random_seed(params['random_seed']) 71 | self.placeholders = {} 72 | self.weights = {} 73 | self.ops = {} 74 | self.make_model() 75 | self.make_train_step() 76 | 77 | # Restore/initialize variables: 78 | restore_file = args.get('--restore') 79 | if restore_file is not None: 80 | self.restore_model(restore_file) 81 | else: 82 | self.initialize_model() 83 | 84 | def load_data(self, file_name, is_training_data: bool): 85 | full_path = os.path.join(self.data_dir, file_name) 86 | 87 | print("Loading data from %s" % full_path) 88 | with open(full_path, 'r') as f: 89 | data = json.load(f) 90 | 91 | restrict = self.args.get("--restrict_data") 92 | if restrict is not None and restrict > 0: 93 | data = data[:restrict] 94 | 95 | # Get some common data out: 96 | num_fwd_edge_types = len(utils.bond_dict) - 1 97 | for g in data: 98 | self.max_num_vertices = max(self.max_num_vertices, max([v for e in g['graph'] for v in [e[0], e[2]]])) 99 | 100 | self.num_edge_types = max(self.num_edge_types, num_fwd_edge_types * (1 if self.params['tie_fwd_bkwd'] else 2)) 101 | self.annotation_size = max(self.annotation_size, len(data[0]["node_features"][0])) 102 | 103 | return self.process_raw_graphs(data, is_training_data, file_name) 104 | 105 | @staticmethod 106 | def graph_string_to_array(graph_string: str) -> List[List[int]]: 107 | return [[int(v) for v in s.split(' ')] 108 | for s in graph_string.split('\n')] 109 | 110 | def process_raw_graphs(self, raw_data, is_training_data, file_name, bucket_sizes=None): 111 | raise Exception("Models have to implement process_raw_graphs!") 112 | 113 | def make_model(self): 114 | self.placeholders['target_values'] = tf.placeholder(tf.float32, [len(self.params['task_ids']), None], 115 | name='target_values') 116 | self.placeholders['target_mask'] = tf.placeholder(tf.float32, [len(self.params['task_ids']), None], 117 | name='target_mask') 118 | self.placeholders['num_graphs'] = tf.placeholder(tf.int64, [], name='num_graphs') 119 | self.placeholders['out_layer_dropout_keep_prob'] = tf.placeholder(tf.float32, [], name='out_layer_dropout_keep_prob') 120 | # whether this session is for generating new graphs or not 121 | self.placeholders['is_generative'] = tf.placeholder(tf.bool, [], name='is_generative') 122 | 123 | with tf.variable_scope("graph_model"): 124 | self.prepare_specific_graph_model() 125 | 126 | # Initial state: embedding 127 | initial_state= self.get_node_embedding_state(self.placeholders['initial_node_representation']) 128 | 129 | # This does the actual graph work: 130 | if self.params['use_graph']: 131 | if self.params["residual_connection_on"]: 132 | self.ops['final_node_representations'] = self.compute_final_node_representations_with_residual( 133 | initial_state, 134 | tf.transpose(self.placeholders['adjacency_matrix'], [1, 0, 2, 3]), 135 | "_encoder") 136 | else: 137 | self.ops['final_node_representations'] = self.compute_final_node_representations_without_residual( 138 | initial_state, 139 | tf.transpose(self.placeholders['adjacency_matrix'], [1, 0, 2, 3]), self.weights['edge_weights_encoder'], 140 | self.weights['edge_biases_encoder'], self.weights['node_gru_encoder'], "gru_scope_encoder") 141 | else: 142 | self.ops['final_node_representations'] = initial_state 143 | 144 | # Calculate p(z|x)'s mean and log variance 145 | self.ops['mean'], self.ops['logvariance'] = self.compute_mean_and_logvariance() 146 | # Sample from a gaussian distribution according to the mean and log variance 147 | self.ops['z_sampled'] = self.sample_with_mean_and_logvariance() 148 | # Construct logit matrices for both edges and edge types 149 | self.construct_logit_matrices() 150 | 151 | # Obtain losses for edges and edge types 152 | self.ops['qed_loss'] = [] 153 | self.ops['loss']=self.construct_loss() 154 | 155 | def make_train_step(self): 156 | trainable_vars = self.sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 157 | if self.args.get('--freeze-graph-model'): 158 | graph_vars = set(self.sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="graph_model")) 159 | filtered_vars = [] 160 | for var in trainable_vars: 161 | if var not in graph_vars: 162 | filtered_vars.append(var) 163 | else: 164 | print("Freezing weights of variable %s." % var.name) 165 | trainable_vars = filtered_vars 166 | 167 | optimizer = tf.train.AdamOptimizer(self.params['learning_rate']) 168 | grads_and_vars = optimizer.compute_gradients(self.ops['loss'], var_list=trainable_vars) 169 | clipped_grads = [] 170 | for grad, var in grads_and_vars: 171 | if grad is not None: 172 | clipped_grads.append((tf.clip_by_norm(grad, self.params['clamp_gradient_norm']), var)) 173 | else: 174 | clipped_grads.append((grad, var)) 175 | grads_for_display=[] 176 | for grad, var in grads_and_vars: 177 | if grad is not None: 178 | grads_for_display.append((tf.clip_by_norm(grad, self.params['clamp_gradient_norm']), var)) 179 | self.ops['grads']= grads_for_display 180 | self.ops['train_step'] = optimizer.apply_gradients(clipped_grads) 181 | # Initialize newly-introduced variables: 182 | self.sess.run(tf.local_variables_initializer()) 183 | 184 | def gated_regression(self, last_h, regression_gate, regression_transform): 185 | raise Exception("Models have to implement gated_regression!") 186 | 187 | def prepare_specific_graph_model(self) -> None: 188 | raise Exception("Models have to implement prepare_specific_graph_model!") 189 | 190 | def compute_mean_and_logvariance(self): 191 | raise Exception("Models have to implement compute_mean_and_logvariance!") 192 | 193 | def sample_with_mean_and_logvariance(self): 194 | raise Exception("Models have to implement sample_with_mean_and_logvariance!") 195 | 196 | def construct_logit_matrices(self): 197 | raise Exception("Models have to implement construct_logit_matrices!") 198 | 199 | def construct_loss(self): 200 | raise Exception("Models have to implement construct_loss!") 201 | 202 | def make_minibatch_iterator(self, data: Any, is_training: bool): 203 | raise Exception("Models have to implement make_minibatch_iterator!") 204 | """ 205 | def save_intermediate_results(self, adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance): 206 | with open('intermediate_results_%s' % self.params["dataset"], 'wb') as out_file: 207 | pickle.dump([adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance], out_file, pickle.HIGHEST_PROTOCOL) 208 | """ 209 | 210 | def save_probs(self, all_results): 211 | with open('epoch_prob_matices_%s' % self.params["dataset"], 'wb') as out_file: 212 | pickle.dump([all_results], out_file, pickle.HIGHEST_PROTOCOL) 213 | 214 | def run_epoch(self, epoch_name: str, epoch_num, data, is_training: bool): 215 | loss = 0 216 | start_time = time.time() 217 | processed_graphs = 0 218 | batch_iterator = ThreadedIterator(self.make_minibatch_iterator(data, is_training), max_queue_size=5) 219 | 220 | for step, batch_data in enumerate(batch_iterator): 221 | num_graphs = batch_data[self.placeholders['num_graphs']] 222 | processed_graphs += num_graphs 223 | batch_data[self.placeholders['is_generative']] = False 224 | # Randomly sample from normal distribution 225 | batch_data[self.placeholders['z_prior']] = utils.generate_std_normal(\ 226 | self.params['batch_size'], batch_data[self.placeholders['num_vertices']],self.params['hidden_size']) 227 | if is_training: 228 | batch_data[self.placeholders['out_layer_dropout_keep_prob']] = self.params['out_layer_dropout_keep_prob'] 229 | fetch_list = [self.ops['loss'], self.ops['train_step'], 230 | self.ops["edge_loss"], self.ops['kl_loss'], 231 | self.ops['node_symbol_prob'], self.placeholders['node_symbols'], 232 | self.ops['qed_computed_values'], self.placeholders['target_values'], self.ops['total_qed_loss'], 233 | self.ops['mean'], self.ops['logvariance'], 234 | self.ops['grads'], self.ops['mean_edge_loss'], self.ops['mean_node_symbol_loss'], 235 | self.ops['mean_kl_loss'], self.ops['mean_total_qed_loss']] 236 | else: 237 | batch_data[self.placeholders['out_layer_dropout_keep_prob']] = 1.0 238 | fetch_list = [self.ops['mean_edge_loss'], self.ops['accuracy_task0']] 239 | result = self.sess.run(fetch_list, feed_dict=batch_data) 240 | 241 | """try: 242 | if is_training: 243 | self.save_intermediate_results(batch_data[self.placeholders['adjacency_matrix']], 244 | result[11], result[12], result[4], result[5], result[9], result[10], result[6], result[7], result[13], result[14]) 245 | except IndexError: 246 | pass""" 247 | 248 | batch_loss = result[0] 249 | loss += batch_loss * num_graphs 250 | 251 | print("Running %s, batch %i (has %i graphs). Loss so far: %.4f" % (epoch_name, 252 | step, 253 | num_graphs, 254 | loss / processed_graphs), end='\r') 255 | loss = loss / processed_graphs 256 | instance_per_sec = processed_graphs / (time.time() - start_time) 257 | return loss, instance_per_sec 258 | 259 | def generate_new_graphs(self, data): 260 | raise Exception("Models have to implement generate_new_graphs!") 261 | 262 | def train(self): 263 | log_to_save = [] 264 | total_time_start = time.time() 265 | with self.graph.as_default(): 266 | for epoch in range(1, self.params['num_epochs'] + 1): 267 | if not self.params['generation']: 268 | print("== Epoch %i" % epoch) 269 | 270 | train_loss, train_speed = self.run_epoch("epoch %i (training)" % epoch, epoch, 271 | self.train_data, True) 272 | print("\r\x1b[K Train: loss: %.5f| instances/sec: %.2f" % (train_loss, train_speed)) 273 | 274 | valid_loss,valid_speed = self.run_epoch("epoch %i (validation)" % epoch, epoch, 275 | self.valid_data, False) 276 | 277 | print("\r\x1b[K Valid: loss: %.5f | instances/sec: %.2f" % (valid_loss,valid_speed)) 278 | 279 | 280 | epoch_time = time.time() - total_time_start 281 | 282 | log_entry = { 283 | 'epoch': epoch, 284 | 'time': epoch_time, 285 | 'train_results': (train_loss, train_speed), 286 | } 287 | log_to_save.append(log_entry) 288 | with open(self.log_file, 'w') as f: 289 | json.dump(log_to_save, f, indent=4) 290 | self.save_model(str(epoch)+("_%s.pickle" % (self.params["dataset"]))) 291 | # Run epoches for graph generation 292 | if epoch >= self.params['epoch_to_generate']: 293 | self.generate_new_graphs(self.train_data) 294 | 295 | def save_model(self, path: str) -> None: 296 | weights_to_save = {} 297 | for variable in self.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): 298 | assert variable.name not in weights_to_save 299 | weights_to_save[variable.name] = self.sess.run(variable) 300 | 301 | data_to_save = { 302 | "params": self.params, 303 | "weights": weights_to_save 304 | } 305 | 306 | with open(path, 'wb') as out_file: 307 | pickle.dump(data_to_save, out_file, pickle.HIGHEST_PROTOCOL) 308 | 309 | def initialize_model(self) -> None: 310 | init_op = tf.group(tf.global_variables_initializer(), 311 | tf.local_variables_initializer()) 312 | self.sess.run(init_op) 313 | 314 | def restore_model(self, path: str) -> None: 315 | print("Restoring weights from file %s." % path) 316 | with open(path, 'rb') as in_file: 317 | data_to_load = pickle.load(in_file) 318 | 319 | variables_to_initialize = [] 320 | with tf.name_scope("restore"): 321 | restore_ops = [] 322 | used_vars = set() 323 | for variable in self.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): 324 | used_vars.add(variable.name) 325 | if variable.name in data_to_load['weights']: 326 | restore_ops.append(variable.assign(data_to_load['weights'][variable.name])) 327 | else: 328 | print('Freshly initializing %s since no saved value was found.' % variable.name) 329 | variables_to_initialize.append(variable) 330 | for var_name in data_to_load['weights']: 331 | if var_name not in used_vars: 332 | print('Saved weights for %s not used by model.' % var_name) 333 | restore_ops.append(tf.variables_initializer(variables_to_initialize)) 334 | self.sess.run(restore_ops) 335 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Constrained Graph Variational Autoencoders for Molecule Design 2 | 3 | This repository contains our implementation of [Constrained Graph Variational Autoencoders for Molecule Design](https://arxiv.org/abs/1805.09076) (CGVAE). 4 | 5 | ``` 6 | @article{liu2018constrained, 7 | title={Constrained Graph Variational Autoencoders for Molecule Design}, 8 | author={Liu, Qi and Allamanis, Miltiadis and Brockschmidt, Marc and Gaunt, Alexander L.}, 9 | journal={The Thirty-second Conference on Neural Information Processing Systems}, 10 | year={2018} 11 | } 12 | ``` 13 | 14 | # Requirements 15 | 16 | This code was tested in Python 3.5 with Tensorflow 1.3. `conda`, `docopt` and `rdkit` are also necessary. A Bash script is provided to install all these requirements. 17 | 18 | ``` 19 | source ./install.sh 20 | ``` 21 | 22 | To evaluate SAS scores, use `get_sascorer.sh` to download the SAS implementation from [rdkit](https://github.com/rdkit/rdkit/tree/master/Contrib/SA_Score) 23 | 24 | # Data Extraction 25 | 26 | Three datasets (QM9, ZINC and CEPDB) are in use. For downloading CEPDB, please refer to [CEPDB](http://cleanenergy.molecularspace.org/). 27 | 28 | For downloading QM9 and ZINC, please go to `data` directory and run `get_qm9.py` and `get_zinc.py`, respectively. 29 | 30 | ``` 31 | python get_qm9.py 32 | 33 | python get_zinc.py 34 | ``` 35 | 36 | # Running CGVAE 37 | 38 | We provide two settings of CGVAE. The first setting samples one breadth first search path for each molecule. The second setting samples transitions from multiple breadth first search paths for each molecule. 39 | 40 | To train and generate molecules using the first setting, use 41 | 42 | ``` 43 | python CGVAE.py --dataset qm9|zinc|cep 44 | ``` 45 | 46 | To avoid training and generate molecules with a pretrained model, use 47 | 48 | ``` 49 | python CGVAE.py --dataset qm9|zinc|cep --restore pretrained_model --config '{"generation": true}' 50 | ``` 51 | 52 | To train and generate molecules using the second setting, use 53 | 54 | ``` 55 | python CGVAE.py --dataset qm9|zinc|cep --config '{"sample_transition": true, "multi_bfs_path": true, "path_random_order": true}' 56 | ``` 57 | 58 | To use optimization in the latent space, set `optimization_step` to a positive number 59 | 60 | ``` 61 | python CGVAE.py --dataset qm9|zinc|cep --restore pretrained_model --config '{"generation": true, "optimization_step": 50}' 62 | ``` 63 | 64 | More configurations can be found at function `default_params` in `CGVAE.py` 65 | 66 | # Evaluation 67 | 68 | To evaluate the generated molecules, use 69 | 70 | ``` 71 | python evaluate.py --dataset qm9|zinc|cep 72 | ``` 73 | 74 | # Pretrained Models and Generated Molecules 75 | 98 | 99 | Generated molecules can be obtained upon request. 100 | 101 | A program in folder `molecules` is provided to read and visualize the molecules 102 | 103 | ``` 104 | python visualize.py molecule_file output_file 105 | ``` 106 | 107 | # Questions/Bugs 108 | 109 | Please submit a Github issue or contact [qiliu@u.nus.edu](mailto:qiliu@u.nus.edu). 110 | 111 | # Contributing 112 | 113 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 114 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 115 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 116 | 117 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 118 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 119 | provided by the bot. You will only need to do this once across all repos using our CLA. 120 | 121 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 122 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 123 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 124 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /data/get_qm9.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python 2 | """ 3 | Usage: 4 | get_qm9.py 5 | 6 | Options: 7 | -h --help Show this screen. 8 | """ 9 | 10 | import sys, os 11 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) 12 | from rdkit import Chem 13 | from rdkit.Chem import rdmolops 14 | from rdkit.Chem import QED 15 | import glob 16 | import json 17 | import numpy as np 18 | from utils import bond_dict, dataset_info, need_kekulize, to_graph, graph_to_adj_mat 19 | import utils 20 | import pickle 21 | import random 22 | from docopt import docopt 23 | 24 | dataset = 'qm9' 25 | 26 | def get_validation_file_names(unzip_path): 27 | print('loading train/validation split') 28 | with open('valid_idx_qm9.json', 'r') as f: 29 | valid_idx = json.load(f)['valid_idxs'] 30 | valid_files = [os.path.join(unzip_path, 'dsgdb9nsd_%s.xyz' % i) for i in valid_idx] 31 | return valid_files 32 | 33 | def read_xyz(file_path): 34 | with open(file_path, 'r') as f: 35 | lines = f.readlines() 36 | smiles = lines[-2].split('\t')[0] 37 | mu = QED.qed(Chem.MolFromSmiles(smiles)) 38 | return {'smiles': smiles, 'QED': mu} 39 | 40 | def train_valid_split(unzip_path): 41 | print('reading data...') 42 | raw_data = {'train': [], 'valid': []} # save the train, valid dataset. 43 | all_files = glob.glob(os.path.join(unzip_path, '*.xyz')) 44 | valid_files = get_validation_file_names(unzip_path) 45 | 46 | file_count = 0 47 | for file_idx, file_path in enumerate(all_files): 48 | if file_path not in valid_files: 49 | raw_data['train'].append(read_xyz(file_path)) 50 | else: 51 | raw_data['valid'].append(read_xyz(file_path)) 52 | file_count += 1 53 | if file_count % 2000 == 0: 54 | print('finished reading: %d' % file_count, end='\r') 55 | return raw_data 56 | 57 | def preprocess(raw_data, dataset): 58 | print('parsing smiles as graphs...') 59 | processed_data = {'train': [], 'valid': []} 60 | 61 | file_count = 0 62 | for section in ['train', 'valid']: 63 | all_smiles = [] # record all smiles in training dataset 64 | for i,(smiles, QED) in enumerate([(mol['smiles'], mol['QED']) 65 | for mol in raw_data[section]]): 66 | nodes, edges = to_graph(smiles, dataset) 67 | if len(edges) <= 0: 68 | continue 69 | processed_data[section].append({ 70 | 'targets': [[(QED)]], 71 | 'graph': edges, 72 | 'node_features': nodes, 73 | 'smiles': smiles 74 | }) 75 | all_smiles.append(smiles) 76 | if file_count % 2000 == 0: 77 | print('finished processing: %d' % file_count, end='\r') 78 | file_count += 1 79 | print('%s: 100 %% ' % (section)) 80 | # save the dataset 81 | with open('molecules_%s_%s.json' % (section, dataset), 'w') as f: 82 | json.dump(processed_data[section], f) 83 | # save all molecules in the training dataset 84 | if section == 'train': 85 | utils.dump('smiles_%s.pkl' % dataset, all_smiles) 86 | 87 | if __name__ == "__main__": 88 | # download 89 | download_path = 'dsgdb9nsd.xyz.tar.bz2' 90 | if not os.path.exists(download_path): 91 | print('downloading data to %s ...' % download_path) 92 | source = 'https://ndownloader.figshare.com/files/3195389' 93 | os.system('wget -O %s %s' % (download_path, source)) 94 | print('finished downloading') 95 | 96 | # unzip 97 | unzip_path = 'qm9_raw' 98 | if not os.path.exists(unzip_path): 99 | print('extracting data to %s ...' % unzip_path) 100 | os.mkdir(unzip_path) 101 | os.system('tar xvjf %s -C %s' % (download_path, unzip_path)) 102 | print('finished extracting') 103 | 104 | raw_data = train_valid_split(unzip_path) 105 | preprocess(raw_data, dataset) 106 | -------------------------------------------------------------------------------- /data/get_zinc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python 2 | """ 3 | Usage: 4 | get_data.py --dataset zinc|qm9|cep 5 | 6 | Options: 7 | -h --help Show this screen. 8 | --dataset NAME Dataset name: zinc, qm9, cep 9 | """ 10 | 11 | import sys, os 12 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) 13 | from rdkit import Chem 14 | from rdkit.Chem import rdmolops 15 | from rdkit.Chem import QED 16 | import glob 17 | import csv, json 18 | import numpy as np 19 | from utils import bond_dict, dataset_info, need_kekulize, to_graph,graph_to_adj_mat 20 | import utils 21 | import pickle 22 | import random 23 | from docopt import docopt 24 | from get_qm9 import preprocess 25 | 26 | dataset = "zinc" 27 | 28 | def train_valid_split(download_path): 29 | # load validation dataset 30 | with open("valid_idx_zinc.json", 'r') as f: 31 | valid_idx = json.load(f) 32 | 33 | print('reading data...') 34 | raw_data = {'train': [], 'valid': []} # save the train, valid dataset. 35 | with open(download_path, 'r') as f: 36 | all_data = list(csv.DictReader(f)) 37 | 38 | file_count=0 39 | for i, data_item in enumerate(all_data): 40 | smiles = data_item['smiles'].strip() 41 | QED = float(data_item['qed']) 42 | if i not in valid_idx: 43 | raw_data['train'].append({'smiles': smiles, 'QED': QED}) 44 | else: 45 | raw_data['valid'].append({'smiles': smiles, 'QED': QED}) 46 | file_count += 1 47 | if file_count % 2000 ==0: 48 | print('finished reading: %d' % file_count, end='\r') 49 | return raw_data 50 | 51 | if __name__ == "__main__": 52 | download_path = '250k_rndm_zinc_drugs_clean_3.csv' 53 | if not os.path.exists(download_path): 54 | print('downloading data to %s ...' % download_path) 55 | source = 'https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv' 56 | os.system('wget -O %s %s' % (download_path, source)) 57 | print('finished downloading') 58 | 59 | raw_data = train_valid_split(download_path) 60 | preprocess(raw_data, dataset) 61 | -------------------------------------------------------------------------------- /data_augmentation.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from copy import deepcopy 3 | import random 4 | 5 | # Generate the mask based on the valences and adjacent matrix so far 6 | # For a (node_in_focus, neighbor, edge_type) to be valid, neighbor's color < 2 and 7 | # there is no edge so far between node_in_focus and neighbor and it satisfy the valence constraint 8 | # and node_in_focus != neighbor 9 | def generate_mask(valences, adj_mat, color, real_n_vertices, node_in_focus, check_overlap_edge, new_mol): 10 | edge_type_mask=[] 11 | edge_mask=[] 12 | for neighbor in range(real_n_vertices): 13 | if neighbor != node_in_focus and color[neighbor] < 2 and \ 14 | not check_adjacent_sparse(adj_mat, node_in_focus, neighbor)[0]: 15 | min_valence = min(valences[node_in_focus], valences[neighbor], 3) 16 | # Check whether two cycles have more than two overlap edges here 17 | # the neighbor color = 1 and there are left valences and 18 | # adding that edge will not cause overlap edges. 19 | if check_overlap_edge and min_valence > 0 and color[neighbor] == 1: 20 | # attempt to add the edge 21 | new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[0]) 22 | # Check whether there are two cycles having more than two overlap edges 23 | ssr = Chem.GetSymmSSSR(new_mol) 24 | overlap_flag = False 25 | for idx1 in range(len(ssr)): 26 | for idx2 in range(idx1+1, len(ssr)): 27 | if len(set(ssr[idx1]) & set(ssr[idx2])) > 2: 28 | overlap_flag=True 29 | # remove that edge 30 | new_mol.RemoveBond(int(node_in_focus), int(neighbor)) 31 | if overlap_flag: 32 | continue 33 | for v in range(min_valence): 34 | assert v < 3 35 | edge_type_mask.append((node_in_focus, neighbor, v)) 36 | # there might be an edge between node in focus and neighbor 37 | if min_valence > 0: 38 | edge_mask.append((node_in_focus, neighbor)) 39 | return edge_type_mask, edge_mask 40 | 41 | # when a new edge is about to be added, we generate labels based on ground truth 42 | # if an edge is in ground truth and has not been added to incremental adj yet, we label it as positive 43 | def generate_label(ground_truth_graph, incremental_adj, node_in_focus, real_neighbor, real_n_vertices, params): 44 | edge_type_label=[] 45 | edge_label=[] 46 | for neighbor in range(real_n_vertices): 47 | adjacent, edge_type = check_adjacent_sparse(ground_truth_graph, node_in_focus, neighbor) 48 | incre_adjacent, incre_edge_type = check_adjacent_sparse(incremental_adj, node_in_focus, neighbor) 49 | if not params["label_one_hot"] and adjacent and not incre_adjacent: 50 | assert edge_type < 3 51 | edge_type_label.append((node_in_focus, neighbor, edge_type)) 52 | edge_label.append((node_in_focus, neighbor)) 53 | elif params["label_one_hot"] and adjacent and not incre_adjacent and neighbor==real_neighbor: 54 | edge_type_label.append((node_in_focus, neighbor, edge_type)) 55 | edge_label.append((node_in_focus, neighbor)) 56 | return edge_type_label, edge_label 57 | 58 | # add a incremental adj with one new edge 59 | def genereate_incremental_adj(last_adj, node_in_focus, neighbor, edge_type): 60 | # copy last incremental adj matrix 61 | new_adj= deepcopy(last_adj) 62 | # Add a new edge into it 63 | new_adj[node_in_focus].append((neighbor, edge_type)) 64 | new_adj[neighbor].append((node_in_focus, edge_type)) 65 | return new_adj 66 | 67 | def update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type, edge_type_masks, valences, incremental_adj_mat, 68 | color, real_n_vertices, graph, edge_type_labels, local_stop, edge_masks, edge_labels, local_stop_label, params, 69 | check_overlap_edge, new_mol, up_to_date_adj_mat,keep_prob): 70 | # check whether to keep this transition or not 71 | if params["sample_transition"] and random.random()> keep_prob: 72 | return 73 | # record the current node in focus 74 | node_sequence.append(node_in_focus) 75 | # generate mask based on current situation 76 | edge_type_mask, edge_mask=generate_mask(valences, up_to_date_adj_mat, 77 | color,real_n_vertices, node_in_focus, check_overlap_edge, new_mol) 78 | edge_type_masks.append(edge_type_mask) 79 | edge_masks.append(edge_mask) 80 | if not local_stop_label: 81 | # generate the label based on ground truth graph 82 | edge_type_label, edge_label=generate_label(graph, up_to_date_adj_mat, node_in_focus, neighbor,real_n_vertices, params) 83 | edge_type_labels.append(edge_type_label) 84 | edge_labels.append(edge_label) 85 | else: 86 | edge_type_labels.append([]) 87 | edge_labels.append([]) 88 | # update local stop 89 | local_stop.append(local_stop_label) 90 | # Calculate distance using bfs from the current node to all other node 91 | distances = bfs_distance(node_in_focus, up_to_date_adj_mat) 92 | distances = [(start, node, params["truncate_distance"]) if d > params["truncate_distance"] else (start, node, d) for start, node, d in distances] 93 | distance_to_others.append(distances) 94 | # Calculate the overlapped edge mask 95 | overlapped_edge_features.append(get_overlapped_edge_feature(edge_mask, color, new_mol)) 96 | # update the incremental adj mat at this step 97 | incremental_adj_mat.append(deepcopy(up_to_date_adj_mat)) 98 | 99 | def construct_incremental_graph(dataset, edges, max_n_vertices, real_n_vertices, node_symbol, params, initial_idx=0): 100 | # avoid calculating this if it is just for generating new molecules for speeding up 101 | if params["generation"]: 102 | return [], [], [], [], [], [], [], [], [] 103 | # avoid the initial index is larger than real_n_vertices: 104 | if initial_idx >= real_n_vertices: 105 | initial_idx=0 106 | # Maximum valences for each node 107 | valences=get_initial_valence([np.argmax(symbol) for symbol in node_symbol], dataset) 108 | # Add backward edges 109 | edges_bw=[(dst, edge_type, src) for src, edge_type, dst in edges] 110 | edges=edges+edges_bw 111 | # Construct a graph object using the edges 112 | graph=defaultdict(list) 113 | for src, edge_type, dst in edges: 114 | graph[src].append((dst, edge_type)) 115 | # Breadth first search over the molecule 116 | # color 0: have not found 1: in the queue 2: searched already 117 | color = [0] * max_n_vertices 118 | color[initial_idx] = 1 119 | queue=deque([initial_idx]) 120 | # create a adj matrix without any edges 121 | up_to_date_adj_mat=defaultdict(list) 122 | # record incremental adj mat 123 | incremental_adj_mat=[] 124 | # record the distance to other nodes at the moment 125 | distance_to_others=[] 126 | # soft constraint on overlapped edges 127 | overlapped_edge_features=[] 128 | # the exploration order of the nodes 129 | node_sequence=[] 130 | # edge type masks for nn predictions at each step 131 | edge_type_masks=[] 132 | # edge type labels for nn predictions at each step 133 | edge_type_labels=[] 134 | # edge masks for nn predictions at each step 135 | edge_masks=[] 136 | # edge labels for nn predictions at each step 137 | edge_labels=[] 138 | # local stop labels 139 | local_stop=[] 140 | # record the incremental molecule 141 | new_mol = Chem.MolFromSmiles('') 142 | new_mol = Chem.rdchem.RWMol(new_mol) 143 | # Add atoms 144 | add_atoms(new_mol, sample_node_symbol([node_symbol], [len(node_symbol)], dataset)[0], dataset) 145 | # calculate keep probability 146 | sample_transition_count= real_n_vertices + len(edges)/2 147 | keep_prob= float(sample_transition_count)/((real_n_vertices + len(edges)/2) * params["bfs_path_count"]) # to form a binomial distribution 148 | while len(queue) > 0: 149 | node_in_focus=queue.popleft() 150 | current_adj_list=graph[node_in_focus] 151 | # sort (canonical order) it or shuffle (random order) it 152 | if not params["path_random_order"]: 153 | current_adj_list=sorted(current_adj_list) 154 | else: 155 | random.shuffle(current_adj_list) 156 | for neighbor, edge_type in current_adj_list: 157 | # Add this edge if the color of neighbor node is not 2 158 | if color[neighbor]<2: 159 | update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type, 160 | edge_type_masks, valences, incremental_adj_mat, color, real_n_vertices, graph, 161 | edge_type_labels, local_stop, edge_masks, edge_labels, False, params, params["check_overlap_edge"], new_mol, 162 | up_to_date_adj_mat,keep_prob) 163 | # Add the edge and obtain a new adj mat 164 | up_to_date_adj_mat=genereate_incremental_adj( 165 | up_to_date_adj_mat, node_in_focus, neighbor, edge_type) 166 | # suppose the edge is selected and update valences after adding the 167 | valences[node_in_focus]-=(edge_type + 1) 168 | valences[neighbor]-=(edge_type + 1) 169 | # update the incremental mol 170 | new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[edge_type]) 171 | # Explore neighbor nodes 172 | if color[neighbor]==0: 173 | queue.append(neighbor) 174 | color[neighbor]=1 175 | # local stop here. We move on to another node for exploration or stop completely 176 | update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, None, None, edge_type_masks, 177 | valences, incremental_adj_mat, color, real_n_vertices, graph, 178 | edge_type_labels, local_stop, edge_masks, edge_labels, True, params, params["check_overlap_edge"], new_mol, up_to_date_adj_mat,keep_prob) 179 | color[node_in_focus]=2 180 | 181 | return incremental_adj_mat,distance_to_others,node_sequence,edge_type_masks,edge_type_labels,local_stop, edge_masks, edge_labels, overlapped_edge_features 182 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python 2 | """ 3 | Usage: 4 | evaluate.py --dataset zinc|qm9|cep 5 | 6 | Options: 7 | -h --help Show this screen. 8 | --dataset NAME Dataset name: zinc, qm9, cep 9 | """ 10 | 11 | import utils 12 | from utils import dataset_info 13 | import numpy as np 14 | from docopt import docopt 15 | 16 | if __name__ == '__main__': 17 | args = docopt(__doc__) 18 | dataset=args.get('--dataset') 19 | logpscorer, logp_score_per_molecule=utils.check_logp(dataset) 20 | qedscorer, qed_score_per_molecule=utils.check_qed(dataset) 21 | novelty=utils.novelty_metric(dataset) 22 | total, nonplanar=utils.check_planar(dataset) 23 | total, atom_counter, atom_per_molecule =utils.count_atoms(dataset) 24 | total, edge_type_counter, edge_type_per_molecule=utils.count_edge_type(dataset) 25 | total, shape_count, shape_count_per_molecule=utils.shape_count(dataset) 26 | total, tree_count=utils.check_cyclic(dataset) 27 | sascorer, sa_score_per_molecule=utils.check_sascorer(dataset) 28 | total, validity=utils.check_validity(dataset) 29 | 30 | print("------------------------------------------") 31 | print("Metrics") 32 | print("------------------------------------------") 33 | print("total molecule") 34 | print(total) 35 | print("------------------------------------------") 36 | print("percentage of nonplanar:") 37 | print(nonplanar/total) 38 | print("------------------------------------------") 39 | print("avg atom:") 40 | for atom_type, c in atom_counter.items(): 41 | print(dataset_info(dataset)['atom_types'][atom_type]) 42 | print(c/total) 43 | print("standard deviation") 44 | print(np.std(atom_per_molecule, axis=0)) 45 | print("------------------------------------------") 46 | print("avg edge_type:") 47 | for edge_type, c in edge_type_counter.items(): 48 | print(edge_type+1) 49 | print(c/total) 50 | print("standard deviation") 51 | print(np.std(edge_type_per_molecule, axis=0)) 52 | print("------------------------------------------") 53 | print("avg shape:") 54 | for shape, c in zip(utils.geometry_numbers, shape_count): 55 | print(shape) 56 | print(c/total) 57 | print("standard deviation") 58 | print(np.std(shape_count_per_molecule, axis=0)) 59 | print("------------------------------------------") 60 | print("percentage of tree:") 61 | print(tree_count/total) 62 | print("------------------------------------------") 63 | print("percentage of validity:") 64 | print(validity/total) 65 | print("------------------------------------------") 66 | print("avg sa_score:") 67 | print(sascorer) 68 | print("standard deviation") 69 | print(np.std(sa_score_per_molecule)) 70 | print("------------------------------------------") 71 | print("avg logp_score:") 72 | print(logpscorer) 73 | print("standard deviation") 74 | print(np.std(logp_score_per_molecule)) 75 | print("------------------------------------------") 76 | print("percentage of novelty:") 77 | print(novelty) 78 | print("------------------------------------------") 79 | print("avg qed_score:") 80 | print(qedscorer) 81 | print("standard deviation") 82 | print(np.std(qed_score_per_molecule)) 83 | print("------------------------------------------") 84 | print("uniqueness") 85 | print(utils.check_uniqueness(dataset)) 86 | print("------------------------------------------") 87 | print("percentage of SSSR") 88 | print(utils.sssr_metric(dataset)) 89 | -------------------------------------------------------------------------------- /get_sascorer.sh: -------------------------------------------------------------------------------- 1 | wget https://raw.githubusercontent.com/rdkit/rdkit/master/Contrib/SA_Score/sascorer.py 2 | wget https://raw.githubusercontent.com/rdkit/rdkit/master/Contrib/SA_Score/fpscores.pkl.gz -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # download and install miniconda 4 | # please update the link below according to the platform you are using (https://conda.io/miniconda.html) 5 | # e.g. for Mac, change to https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh 6 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh 7 | bash ./Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda 8 | export PATH="$HOME/miniconda/bin:$PATH" 9 | 10 | # create a new environment named cgvae 11 | conda create --name cgvae python=3.5 pip 12 | source activate cgvae 13 | 14 | # install cython 15 | pip install Cython --install-option="--no-cython-compile" 16 | 17 | # install rdkit 18 | conda install -c rdkit rdkit 19 | 20 | # install tensorflow 1.3 21 | pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp35-cp35m-linux_x86_64.whl 22 | 23 | # install other requirements 24 | pip install -r requirements.txt 25 | 26 | # remove conda bash 27 | rm ./Miniconda3-latest-Linux-x86_64.sh 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | docopt==0.6.2 2 | typing 3 | planarity 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python 2 | import numpy as np 3 | import tensorflow as tf 4 | import queue 5 | import threading 6 | import pickle 7 | from rdkit.Chem import AllChem 8 | from rdkit.Chem import Draw 9 | from rdkit import Chem 10 | from rdkit.Chem import rdmolops 11 | from collections import defaultdict, deque 12 | import os 13 | import heapq 14 | import planarity 15 | import sascorer 16 | from rdkit.Chem import Crippen 17 | from rdkit.Chem import QED 18 | 19 | SMALL_NUMBER = 1e-7 20 | LARGE_NUMBER= 1e10 21 | 22 | geometry_numbers=[3, 4, 5, 6] # triangle, square, pentagen, hexagon 23 | 24 | # bond mapping 25 | bond_dict = {'SINGLE': 0, 'DOUBLE': 1, 'TRIPLE': 2, "AROMATIC": 3} 26 | number_to_bond= {0: Chem.rdchem.BondType.SINGLE, 1:Chem.rdchem.BondType.DOUBLE, 27 | 2: Chem.rdchem.BondType.TRIPLE, 3:Chem.rdchem.BondType.AROMATIC} 28 | 29 | def dataset_info(dataset): #qm9, zinc, cep 30 | if dataset=='qm9': 31 | return { 'atom_types': ["H", "C", "N", "O", "F"], 32 | 'maximum_valence': {0: 1, 1: 4, 2: 3, 3: 2, 4: 1}, 33 | 'number_to_atom': {0: "H", 1: "C", 2: "N", 3: "O", 4: "F"}, 34 | 'bucket_sizes': np.array(list(range(4, 28, 2)) + [29]) 35 | } 36 | elif dataset=='zinc': 37 | return { 'atom_types': ['Br1(0)', 'C4(0)', 'Cl1(0)', 'F1(0)', 'H1(0)', 'I1(0)', 38 | 'N2(-1)', 'N3(0)', 'N4(1)', 'O1(-1)', 'O2(0)', 'S2(0)','S4(0)', 'S6(0)'], 39 | 'maximum_valence': {0: 1, 1: 4, 2: 1, 3: 1, 4: 1, 5:1, 6:2, 7:3, 8:4, 9:1, 10:2, 11:2, 12:4, 13:6, 14:3}, 40 | 'number_to_atom': {0: 'Br', 1: 'C', 2: 'Cl', 3: 'F', 4: 'H', 5:'I', 6:'N', 7:'N', 8:'N', 9:'O', 10:'O', 11:'S', 12:'S', 13:'S'}, 41 | 'bucket_sizes': np.array([28,31,33,35,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,53,55,58,84]) 42 | } 43 | 44 | elif dataset=="cep": 45 | return { 'atom_types': ["C", "S", "N", "O", "Se", "Si"], 46 | 'maximum_valence': {0: 4, 1: 2, 2: 3, 3: 2, 4: 2, 5: 4}, 47 | 'number_to_atom': {0: "C", 1: "S", 2: "N", 3: "O", 4: "Se", 5: "Si"}, 48 | 'bucket_sizes': np.array([25,28,29,30, 32, 33,34,35,36,37,38,39,43,46]) 49 | } 50 | else: 51 | print("the datasets in use are qm9|zinc|cep") 52 | exit(1) 53 | 54 | # add one edge to adj matrix 55 | def add_edge_mat(amat, src, dest, e, considering_edge_type=True): 56 | if considering_edge_type: 57 | amat[e, dest, src] = 1 58 | amat[e, src, dest] = 1 59 | else: 60 | amat[src, dest] = 1 61 | amat[dest, src] = 1 62 | 63 | def graph_to_adj_mat(graph, max_n_vertices, num_edge_types, tie_fwd_bkwd=True, considering_edge_type=True): 64 | if considering_edge_type: 65 | amat = np.zeros((num_edge_types, max_n_vertices, max_n_vertices)) 66 | for src, e, dest in graph: 67 | add_edge_mat(amat, src, dest, e) 68 | else: 69 | amat = np.zeros((max_n_vertices, max_n_vertices)) 70 | for src, e, dest in graph: 71 | add_edge_mat(amat, src, dest, e, considering_edge_type=False) 72 | return amat 73 | 74 | def check_edge_prob(dataset): 75 | with open('intermediate_results_%s' % dataset, 'rb') as f: 76 | adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels,mean, logvariance=pickle.load(f) 77 | for ep, epl in zip(edge_prob, edge_prob_label): 78 | print("prediction") 79 | print(ep) 80 | print("label") 81 | print(epl) 82 | 83 | # check whether a graph is planar or not 84 | def is_planar(location, adj_list, is_dense=False): 85 | if is_dense: 86 | new_adj_list=defaultdict(list) 87 | for x in range(len(adj_list)): 88 | for y in range(len(adj_list)): 89 | if adj_list[x][y]==1: 90 | new_adj_list[x].append((y,1)) 91 | adj_list=new_adj_list 92 | edges=[] 93 | seen=set() 94 | for src, l in adj_list.items(): 95 | for dst, e in l: 96 | if (dst, src) not in seen: 97 | edges.append((src,dst)) 98 | seen.add((src,dst)) 99 | edges+=[location, (location[1], location[0])] 100 | return planarity.is_planar(edges) 101 | 102 | def check_edge_type_prob(filter=None): 103 | with open('intermediate_results_%s' % dataset, 'rb') as f: 104 | adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels,mean, logvariance=pickle.load(f) 105 | for ep, epl in zip(edge_type_prob, edge_type_label): 106 | print("prediction") 107 | print(ep) 108 | print("label") 109 | print(epl) 110 | 111 | def check_mean(dataset, filter=None): 112 | with open('intermediate_results_%s' % dataset, 'rb') as f: 113 | adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels,mean, logvariance=pickle.load(f) 114 | print(mean.tolist()[:40]) 115 | 116 | def check_variance(dataset, filter=None): 117 | with open('intermediate_results_%s' % dataset, 'rb') as f: 118 | adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels,mean, logvariance=pickle.load(f) 119 | print(np.exp(logvariance).tolist()[:40]) 120 | 121 | def check_node_prob(filter=None): 122 | print(dataset) 123 | with open('intermediate_results_%s' % dataset, 'rb') as f: 124 | adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels,mean, logvariance=pickle.load(f) 125 | print(node_symbol_prob[0]) 126 | print(node_symbol[0]) 127 | print(node_symbol_prob.shape) 128 | 129 | def check_qed(filter=None): 130 | with open('intermediate_results_%s' % dataset, 'rb') as f: 131 | adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels,mean, logvariance=pickle.load(f) 132 | print(qed_prediction) 133 | print(qed_labels[0]) 134 | print(np.mean(np.abs(qed_prediction-qed_labels[0]))) 135 | 136 | def onehot(idx, len): 137 | z = [0 for _ in range(len)] 138 | z[idx] = 1 139 | return z 140 | 141 | def generate_empty_adj_matrix(maximum_vertice_num): 142 | return np.zeros((1, 3, maximum_vertice_num, maximum_vertice_num)) 143 | 144 | # standard normal with shape [a1, a2, a3] 145 | def generate_std_normal(a1, a2, a3): 146 | return np.random.normal(0, 1, [a1, a2, a3]) 147 | 148 | def check_validity(dataset): 149 | with open('generated_smiles_%s' % dataset, 'rb') as f: 150 | all_smiles=set(pickle.load(f)) 151 | count=0 152 | for smiles in all_smiles: 153 | mol = Chem.MolFromSmiles(smiles) 154 | if mol is not None: 155 | count+=1 156 | return len(all_smiles), count 157 | 158 | # Get length for each graph based on node masks 159 | def get_graph_length(all_node_mask): 160 | all_lengths=[] 161 | for graph in all_node_mask: 162 | if 0 in graph: 163 | length=np.argmin(graph) 164 | else: 165 | length=len(graph) 166 | all_lengths.append(length) 167 | return all_lengths 168 | 169 | def make_dir(path): 170 | if not os.path.exists(path): 171 | os.mkdir(path) 172 | print('made directory %s' % path) 173 | 174 | # sample node symbols based on node predictions 175 | def sample_node_symbol(all_node_symbol_prob, all_lengths, dataset): 176 | all_node_symbol=[] 177 | for graph_idx, graph_prob in enumerate(all_node_symbol_prob): 178 | node_symbol=[] 179 | for node_idx in range(all_lengths[graph_idx]): 180 | symbol=np.random.choice(np.arange(len(dataset_info(dataset)['atom_types'])), p=graph_prob[node_idx]) 181 | node_symbol.append(symbol) 182 | all_node_symbol.append(node_symbol) 183 | return all_node_symbol 184 | 185 | def dump(file_name, content): 186 | with open(file_name, 'wb') as out_file: 187 | pickle.dump(content, out_file, pickle.HIGHEST_PROTOCOL) 188 | 189 | def load(file_name): 190 | with open(file_name, 'rb') as f: 191 | return pickle.load(f) 192 | 193 | # generate a new feature on whether adding the edges will generate more than two overlapped edges for rings 194 | def get_overlapped_edge_feature(edge_mask, color, new_mol): 195 | overlapped_edge_feature=[] 196 | for node_in_focus, neighbor in edge_mask: 197 | if color[neighbor] == 1: 198 | # attempt to add the edge 199 | new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[0]) 200 | # Check whether there are two cycles having more than two overlap edges 201 | try: 202 | ssr = Chem.GetSymmSSSR(new_mol) 203 | except: 204 | ssr = [] 205 | overlap_flag = False 206 | for idx1 in range(len(ssr)): 207 | for idx2 in range(idx1+1, len(ssr)): 208 | if len(set(ssr[idx1]) & set(ssr[idx2])) > 2: 209 | overlap_flag=True 210 | # remove that edge 211 | new_mol.RemoveBond(int(node_in_focus), int(neighbor)) 212 | if overlap_flag: 213 | overlapped_edge_feature.append((node_in_focus, neighbor)) 214 | return overlapped_edge_feature 215 | 216 | # adj_list [3, v, v] or defaultdict. bfs distance on a graph 217 | def bfs_distance(start, adj_list, is_dense=False): 218 | distances={} 219 | visited=set() 220 | queue=deque([(start, 0)]) 221 | visited.add(start) 222 | while len(queue) != 0: 223 | current, d=queue.popleft() 224 | for neighbor, edge_type in adj_list[current]: 225 | if neighbor not in visited: 226 | distances[neighbor]=d+1 227 | visited.add(neighbor) 228 | queue.append((neighbor, d+1)) 229 | return [(start, node, d) for node, d in distances.items()] 230 | 231 | def get_initial_valence(node_symbol, dataset): 232 | return [dataset_info(dataset)['maximum_valence'][s] for s in node_symbol] 233 | 234 | def add_atoms(new_mol, node_symbol, dataset): 235 | for number in node_symbol: 236 | if dataset=='qm9' or dataset=='cep': 237 | idx=new_mol.AddAtom(Chem.Atom(dataset_info(dataset)['number_to_atom'][number])) 238 | elif dataset=='zinc': 239 | new_atom = Chem.Atom(dataset_info(dataset)['number_to_atom'][number]) 240 | charge_num=int(dataset_info(dataset)['atom_types'][number].split('(')[1].strip(')')) 241 | new_atom.SetFormalCharge(charge_num) 242 | new_mol.AddAtom(new_atom) 243 | 244 | def visualize_mol(path, new_mol): 245 | AllChem.Compute2DCoords(new_mol) 246 | print(path) 247 | Draw.MolToFile(new_mol,path) 248 | 249 | def get_idx_of_largest_frag(frags): 250 | return np.argmax([len(frag) for frag in frags]) 251 | 252 | def remove_extra_nodes(new_mol): 253 | frags=Chem.rdmolops.GetMolFrags(new_mol) 254 | while len(frags) > 1: 255 | # Get the idx of the frag with largest length 256 | largest_idx = get_idx_of_largest_frag(frags) 257 | for idx in range(len(frags)): 258 | if idx != largest_idx: 259 | # Remove one atom that is not in the largest frag 260 | new_mol.RemoveAtom(frags[idx][0]) 261 | break 262 | frags=Chem.rdmolops.GetMolFrags(new_mol) 263 | 264 | def novelty_metric(dataset): 265 | with open('all_smiles_%s.pkl' % dataset, 'rb') as f: 266 | all_smiles=set(pickle.load(f)) 267 | with open('generated_smiles_%s' % dataset, 'rb') as f: 268 | generated_all_smiles=set(pickle.load(f)) 269 | total_new_molecules=0 270 | for generated_smiles in generated_all_smiles: 271 | if generated_smiles not in all_smiles: 272 | total_new_molecules+=1 273 | 274 | return float(total_new_molecules)/len(generated_all_smiles) 275 | 276 | def count_edge_type(dataset, generated=True): 277 | if generated: 278 | filename='generated_smiles_%s' % dataset 279 | else: 280 | filename='all_smiles_%s.pkl' % dataset 281 | with open(filename, 'rb') as f: 282 | all_smiles=set(pickle.load(f)) 283 | 284 | counter=defaultdict(int) 285 | edge_type_per_molecule=[] 286 | for smiles in all_smiles: 287 | nodes, edges=to_graph(smiles, dataset) 288 | edge_type_this_molecule=[0]* len(bond_dict) 289 | for edge in edges: 290 | edge_type=edge[1] 291 | edge_type_this_molecule[edge_type]+=1 292 | counter[edge_type]+=1 293 | edge_type_per_molecule.append(edge_type_this_molecule) 294 | total_sum=0 295 | return len(all_smiles), counter, edge_type_per_molecule 296 | 297 | def need_kekulize(mol): 298 | for bond in mol.GetBonds(): 299 | if bond_dict[str(bond.GetBondType())] >= 3: 300 | return True 301 | return False 302 | 303 | def check_planar(dataset): 304 | with open("generated_smiles_%s" % dataset, 'rb') as f: 305 | all_smiles=set(pickle.load(f)) 306 | total_non_planar=0 307 | for smiles in all_smiles: 308 | try: 309 | nodes, edges=to_graph(smiles, dataset) 310 | except: 311 | continue 312 | edges=[(src, dst) for src, e, dst in edges] 313 | if edges==[]: 314 | continue 315 | 316 | if not planarity.is_planar(edges): 317 | total_non_planar+=1 318 | return len(all_smiles), total_non_planar 319 | 320 | def count_atoms(dataset): 321 | with open("generated_smiles_%s" % dataset, 'rb') as f: 322 | all_smiles=set(pickle.load(f)) 323 | counter=defaultdict(int) 324 | atom_count_per_molecule=[] # record the counts for each molecule 325 | for smiles in all_smiles: 326 | try: 327 | nodes, edges=to_graph(smiles, dataset) 328 | except: 329 | continue 330 | atom_count_this_molecule=[0]*len(dataset_info(dataset)['atom_types']) 331 | for node in nodes: 332 | atom_type=np.argmax(node) 333 | atom_count_this_molecule[atom_type]+=1 334 | counter[atom_type]+=1 335 | atom_count_per_molecule.append(atom_count_this_molecule) 336 | total_sum=0 337 | 338 | return len(all_smiles), counter, atom_count_per_molecule 339 | 340 | 341 | def to_graph(smiles, dataset): 342 | mol = Chem.MolFromSmiles(smiles) 343 | if mol is None: 344 | return [], [] 345 | # Kekulize it 346 | if need_kekulize(mol): 347 | rdmolops.Kekulize(mol) 348 | if mol is None: 349 | return None, None 350 | # remove stereo information, such as inward and outward edges 351 | Chem.RemoveStereochemistry(mol) 352 | 353 | edges = [] 354 | nodes = [] 355 | for bond in mol.GetBonds(): 356 | edges.append((bond.GetBeginAtomIdx(), bond_dict[str(bond.GetBondType())], bond.GetEndAtomIdx())) 357 | assert bond_dict[str(bond.GetBondType())] != 3 358 | for atom in mol.GetAtoms(): 359 | if dataset=='qm9' or dataset=="cep": 360 | nodes.append(onehot(dataset_info(dataset)['atom_types'].index(atom.GetSymbol()), len(dataset_info(dataset)['atom_types']))) 361 | elif dataset=='zinc': # transform using "()" notation 362 | symbol = atom.GetSymbol() 363 | valence = atom.GetTotalValence() 364 | charge = atom.GetFormalCharge() 365 | atom_str = "%s%i(%i)" % (symbol, valence, charge) 366 | 367 | if atom_str not in dataset_info(dataset)['atom_types']: 368 | print('unrecognized atom type %s' % atom_str) 369 | return [], [] 370 | 371 | nodes.append(onehot(dataset_info(dataset)['atom_types'].index(atom_str), len(dataset_info(dataset)['atom_types']))) 372 | 373 | return nodes, edges 374 | 375 | def check_uniqueness(dataset): 376 | with open('generated_smiles_%s' % dataset, 'rb') as f: 377 | all_smiles=pickle.load(f) 378 | original_num = len(all_smiles) 379 | all_smiles=set(all_smiles) 380 | new_num = len(all_smiles) 381 | return new_num/original_num 382 | 383 | def shape_count(dataset, remove_print=False, all_smiles=None): 384 | if all_smiles==None: 385 | with open('generated_smiles_%s' % dataset, 'rb') as f: 386 | all_smiles=set(pickle.load(f)) 387 | 388 | geometry_counts=[0]*len(geometry_numbers) 389 | geometry_counts_per_molecule=[] # record the geometry counts for each molecule 390 | for smiles in all_smiles: 391 | nodes, edges = to_graph(smiles, dataset) 392 | if len(edges)<=0: 393 | continue 394 | new_mol=Chem.MolFromSmiles(smiles) 395 | 396 | ssr = Chem.GetSymmSSSR(new_mol) 397 | counts_for_molecule=[0] * len(geometry_numbers) 398 | for idx in range(len(ssr)): 399 | ring_len=len(list(ssr[idx])) 400 | if ring_len in geometry_numbers: 401 | geometry_counts[geometry_numbers.index(ring_len)]+=1 402 | counts_for_molecule[geometry_numbers.index(ring_len)]+=1 403 | geometry_counts_per_molecule.append(counts_for_molecule) 404 | 405 | return len(all_smiles), geometry_counts, geometry_counts_per_molecule 406 | 407 | def check_adjacent_sparse(adj_list, node, neighbor_in_doubt): 408 | for neighbor, edge_type in adj_list[node]: 409 | if neighbor == neighbor_in_doubt: 410 | return True, edge_type 411 | return False, None 412 | 413 | def glorot_init(shape): 414 | initialization_range = np.sqrt(6.0 / (shape[-2] + shape[-1])) 415 | return np.random.uniform(low=-initialization_range, high=initialization_range, size=shape).astype(np.float32) 416 | 417 | class ThreadedIterator: 418 | """An iterator object that computes its elements in a parallel thread to be ready to be consumed. 419 | The iterator should *not* return None""" 420 | 421 | def __init__(self, original_iterator, max_queue_size: int=2): 422 | self.__queue = queue.Queue(maxsize=max_queue_size) 423 | self.__thread = threading.Thread(target=lambda: self.worker(original_iterator)) 424 | self.__thread.start() 425 | 426 | def worker(self, original_iterator): 427 | for element in original_iterator: 428 | assert element is not None, 'By convention, iterator elements much not be None' 429 | self.__queue.put(element, block=True) 430 | self.__queue.put(None, block=True) 431 | 432 | def __iter__(self): 433 | next_element = self.__queue.get(block=True) 434 | while next_element is not None: 435 | yield next_element 436 | next_element = self.__queue.get(block=True) 437 | self.__thread.join() 438 | 439 | # Implements multilayer perceptron 440 | class MLP(object): 441 | def __init__(self, in_size, out_size, hid_sizes, dropout_keep_prob): 442 | self.in_size = in_size 443 | self.out_size = out_size 444 | self.hid_sizes = hid_sizes 445 | self.dropout_keep_prob = dropout_keep_prob 446 | self.params = self.make_network_params() 447 | 448 | def make_network_params(self): 449 | dims = [self.in_size] + self.hid_sizes + [self.out_size] 450 | weight_sizes = list(zip(dims[:-1], dims[1:])) 451 | weights = [tf.Variable(self.init_weights(s), name='MLP_W_layer%i' % i) 452 | for (i, s) in enumerate(weight_sizes)] 453 | biases = [tf.Variable(np.zeros(s[-1]).astype(np.float32), name='MLP_b_layer%i' % i) 454 | for (i, s) in enumerate(weight_sizes)] 455 | 456 | network_params = { 457 | "weights": weights, 458 | "biases": biases, 459 | } 460 | 461 | return network_params 462 | 463 | def init_weights(self, shape): 464 | return np.sqrt(6.0 / (shape[-2] + shape[-1])) * (2 * np.random.rand(*shape).astype(np.float32) - 1) 465 | 466 | def __call__(self, inputs): 467 | acts = inputs 468 | for W, b in zip(self.params["weights"], self.params["biases"]): 469 | hid = tf.matmul(acts, tf.nn.dropout(W, self.dropout_keep_prob)) + b 470 | acts = tf.nn.relu(hid) 471 | last_hidden = hid 472 | return last_hidden 473 | 474 | class Graph(): 475 | 476 | def __init__(self, V, g): 477 | self.V = V 478 | self.graph = g 479 | 480 | def addEdge(self, v, w): 481 | # Add w to v ist. 482 | self.graph[v].append(w) 483 | # Add v to w list. 484 | self.graph[w].append(v) 485 | 486 | # A recursive function that uses visited[] 487 | # and parent to detect cycle in subgraph 488 | # reachable from vertex v. 489 | def isCyclicUtil(self, v, visited, parent): 490 | 491 | # Mark current node as visited 492 | visited[v] = True 493 | 494 | # Recur for all the vertices adjacent 495 | # for this vertex 496 | for i in self.graph[v]: 497 | # If an adjacent is not visited, 498 | # then recur for that adjacent 499 | if visited[i] == False: 500 | if self.isCyclicUtil(i, visited, v) == True: 501 | return True 502 | 503 | # If an adjacent is visited and not 504 | # parent of current vertex, then there 505 | # is a cycle. 506 | elif i != parent: 507 | return True 508 | 509 | return False 510 | 511 | # Returns true if the graph is a tree, 512 | # else false. 513 | def isTree(self): 514 | # Mark all the vertices as not visited 515 | # and not part of recursion stack 516 | visited = [False] * self.V 517 | 518 | # The call to isCyclicUtil serves multiple 519 | # purposes. It returns true if graph reachable 520 | # from vertex 0 is cyclcic. It also marks 521 | # all vertices reachable from 0. 522 | if self.isCyclicUtil(0, visited, -1) == True: 523 | return False 524 | 525 | # If we find a vertex which is not reachable 526 | # from 0 (not marked by isCyclicUtil(), 527 | # then we return false 528 | for i in range(self.V): 529 | if visited[i] == False: 530 | return False 531 | 532 | return True 533 | 534 | # whether whether the graphs has no cycle or not 535 | def check_cyclic(dataset, generated=True): 536 | if generated: 537 | with open("generated_smiles_%s" % dataset, 'rb') as f: 538 | all_smiles=set(pickle.load(f)) 539 | else: 540 | with open("all_smiles_%s.pkl" % dataset, 'rb') as f: 541 | all_smiles=set(pickle.load(f)) 542 | 543 | tree_count=0 544 | for smiles in all_smiles: 545 | nodes, edges=to_graph(smiles, dataset) 546 | edges=[(src, dst) for src, e, dst in edges] 547 | if edges==[]: 548 | continue 549 | new_adj_list=defaultdict(list) 550 | 551 | for src, dst in edges: 552 | new_adj_list[src].append(dst) 553 | new_adj_list[dst].append(src) 554 | graph=Graph(len(nodes), new_adj_list) 555 | if graph.isTree(): 556 | tree_count+=1 557 | return len(all_smiles), tree_count 558 | 559 | def check_sascorer(dataset): 560 | with open('generated_smiles_%s' % dataset, 'rb') as f: 561 | all_smiles=set(pickle.load(f)) 562 | sa_sum=0 563 | total=0 564 | sa_score_per_molecule=[] 565 | for smiles in all_smiles: 566 | new_mol=Chem.MolFromSmiles(smiles) 567 | try: 568 | val = sascorer.calculateScore(new_mol) 569 | except: 570 | continue 571 | sa_sum+=val 572 | sa_score_per_molecule.append(val) 573 | total+=1 574 | return sa_sum/total, sa_score_per_molecule 575 | 576 | def check_logp(dataset): 577 | with open('generated_smiles_%s' % dataset, 'rb') as f: 578 | all_smiles=set(pickle.load(f)) 579 | logp_sum=0 580 | total=0 581 | logp_score_per_molecule=[] 582 | for smiles in all_smiles: 583 | new_mol=Chem.MolFromSmiles(smiles) 584 | try: 585 | val = Crippen.MolLogP(new_mol) 586 | except: 587 | continue 588 | logp_sum+=val 589 | logp_score_per_molecule.append(val) 590 | total+=1 591 | return logp_sum/total, logp_score_per_molecule 592 | 593 | def check_qed(dataset): 594 | with open('generated_smiles_%s' % dataset, 'rb') as f: 595 | all_smiles=set(pickle.load(f)) 596 | qed_sum=0 597 | total=0 598 | qed_score_per_molecule=[] 599 | for smiles in all_smiles: 600 | new_mol=Chem.MolFromSmiles(smiles) 601 | try: 602 | val = QED.qed(new_mol) 603 | except: 604 | continue 605 | qed_sum+=val 606 | qed_score_per_molecule.append(val) 607 | total+=1 608 | return qed_sum/total, qed_score_per_molecule 609 | 610 | def sssr_metric(dataset): 611 | with open('generated_smiles_%s' % dataset, 'rb') as f: 612 | all_smiles=set(pickle.load(f)) 613 | overlapped_molecule=0 614 | for smiles in all_smiles: 615 | new_mol=Chem.MolFromSmiles(smiles) 616 | ssr = Chem.GetSymmSSSR(new_mol) 617 | overlap_flag=False 618 | for idx1 in range(len(ssr)): 619 | for idx2 in range(idx1+1, len(ssr)): 620 | if len(set(ssr[idx1]) & set(ssr[idx2])) > 2: 621 | overlap_flag=True 622 | if overlap_flag: 623 | overlapped_molecule+=1 624 | return overlapped_molecule/len(all_smiles) 625 | 626 | # select the best based on shapes and probs 627 | def select_best(all_mol): 628 | # sort by shape 629 | all_mol=sorted(all_mol) 630 | best_shape=all_mol[-1][0] 631 | all_mol=[(p, m) for s, p, m in all_mol if s==best_shape] 632 | # sort by probs 633 | all_mol=sorted(all_mol) 634 | return all_mol[-1][1] 635 | 636 | 637 | # a series util function converting sparse matrix representation to dense 638 | 639 | def incre_adj_mat_to_dense(incre_adj_mat, num_edge_types, maximum_vertice_num): 640 | new_incre_adj_mat=[] 641 | for sparse_incre_adj_mat in incre_adj_mat: 642 | dense_incre_adj_mat=np.zeros((num_edge_types, maximum_vertice_num,maximum_vertice_num)) 643 | for current, adj_list in sparse_incre_adj_mat.items(): 644 | for neighbor, edge_type in adj_list: 645 | dense_incre_adj_mat[edge_type][current][neighbor]=1 646 | new_incre_adj_mat.append(dense_incre_adj_mat) 647 | return new_incre_adj_mat # [number_iteration,num_edge_types,maximum_vertice_num, maximum_vertice_num] 648 | 649 | def distance_to_others_dense(distance_to_others, maximum_vertice_num): 650 | new_all_distance=[] 651 | for sparse_distances in distance_to_others: 652 | dense_distances=np.zeros((maximum_vertice_num), dtype=int) 653 | for x, y, d in sparse_distances: 654 | dense_distances[y]=d 655 | new_all_distance.append(dense_distances) 656 | return new_all_distance # [number_iteration, maximum_vertice_num] 657 | 658 | def overlapped_edge_features_to_dense(overlapped_edge_features, maximum_vertice_num): 659 | new_overlapped_edge_features=[] 660 | for sparse_overlapped_edge_features in overlapped_edge_features: 661 | dense_overlapped_edge_features=np.zeros((maximum_vertice_num), dtype=int) 662 | for node_in_focus, neighbor in sparse_overlapped_edge_features: 663 | dense_overlapped_edge_features[neighbor]=1 664 | new_overlapped_edge_features.append(dense_overlapped_edge_features) 665 | return new_overlapped_edge_features # [number_iteration, maximum_vertice_num] 666 | 667 | def node_sequence_to_dense(node_sequence,maximum_vertice_num): 668 | new_node_sequence=[] 669 | for node in node_sequence: 670 | s=[0]*maximum_vertice_num 671 | s[node]=1 672 | new_node_sequence.append(s) 673 | return new_node_sequence # [number_iteration, maximum_vertice_num] 674 | 675 | def edge_type_masks_to_dense(edge_type_masks, maximum_vertice_num, num_edge_types): 676 | new_edge_type_masks=[] 677 | for mask_sparse in edge_type_masks: 678 | mask_dense=np.zeros([num_edge_types, maximum_vertice_num]) 679 | for node_in_focus, neighbor, bond in mask_sparse: 680 | mask_dense[bond][neighbor]=1 681 | new_edge_type_masks.append(mask_dense) 682 | return new_edge_type_masks #[number_iteration, 3, maximum_vertice_num] 683 | 684 | def edge_type_labels_to_dense(edge_type_labels, maximum_vertice_num,num_edge_types): 685 | new_edge_type_labels=[] 686 | for labels_sparse in edge_type_labels: 687 | labels_dense=np.zeros([num_edge_types, maximum_vertice_num]) 688 | for node_in_focus, neighbor, bond in labels_sparse: 689 | labels_dense[bond][neighbor]= 1/float(len(labels_sparse)) # fix the probability bug here. 690 | new_edge_type_labels.append(labels_dense) 691 | return new_edge_type_labels #[number_iteration, 3, maximum_vertice_num] 692 | 693 | def edge_masks_to_dense(edge_masks, maximum_vertice_num): 694 | new_edge_masks=[] 695 | for mask_sparse in edge_masks: 696 | mask_dense=[0] * maximum_vertice_num 697 | for node_in_focus, neighbor in mask_sparse: 698 | mask_dense[neighbor]=1 699 | new_edge_masks.append(mask_dense) 700 | return new_edge_masks # [number_iteration, maximum_vertice_num] 701 | 702 | def edge_labels_to_dense(edge_labels, maximum_vertice_num): 703 | new_edge_labels=[] 704 | for label_sparse in edge_labels: 705 | label_dense=[0] * maximum_vertice_num 706 | for node_in_focus, neighbor in label_sparse: 707 | label_dense[neighbor]=1/float(len(label_sparse)) 708 | new_edge_labels.append(label_dense) 709 | return new_edge_labels # [number_iteration, maximum_vertice_num] --------------------------------------------------------------------------------