├── GIN ├── GIN_model.py ├── GIN_utils.py └── __init__.py ├── Logs └── regularGin_cv_0_s1337_.log ├── MoleculeTasks ├── Duvenaud-kary.py ├── RNN-DFS.py └── rp_duvenaud.py ├── README.md ├── Run_Gin_Experiment.py ├── Synthetic_Data ├── X_eye_list_Kary_Deterministic_Graphs.pkl ├── X_unity_list_Kary_Deterministic_Graphs.pkl ├── graphs_Kary_Deterministic_Graphs.pkl └── y_Kary_Deterministic_Graphs.pt └── training_utils.py /GIN/GIN_model.py: -------------------------------------------------------------------------------- 1 | ########################################## 2 | # 3 | # Ryan L Murphy and Balasubramaniam Srinivasan, 2019 4 | # 5 | # Implement Graph Isomorphism Network (GIN) https://arxiv.org/pdf/1810.00826.pdf 6 | # for the purposes of our project 7 | # 8 | ########################################## 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import sys 13 | import scipy.sparse as sps 14 | from training_utils import * 15 | from torch.nn import init 16 | from sklearn.utils import shuffle as sparse_shuffle 17 | 18 | class MLP(nn.Module): 19 | """Define a multilayer perceptron 20 | assume that all intermediate hidden layers have the same dimension (number of neurons) 21 | """ 22 | def __init__(self, in_dim, hidden_dim, out_dim, num_hidden_layers, act=F.relu, other_mlp_parameters={}): 23 | """ :param: other_mlp_parameters: dictionary with keys of dropout and/or batchnorm. values are dropout prob""" 24 | super(MLP, self).__init__() 25 | assert num_hidden_layers > 0, "MLP should have at least one hidden layer" 26 | assert isinstance(other_mlp_parameters, dict), 'other_mlp_parameters should be dict or none.' 27 | 28 | # Check that the other mlp parameters are valid 29 | for key_ in other_mlp_parameters.keys(): 30 | if key_ not in ['dropout', 'batchnorm']: 31 | raise ValueError("The key entered into other_mlp_parameters is invalid. Must be in ['dropout', 'batchnorm']. Entered: "+ str(key_)) 32 | 33 | if 'dropout' in other_mlp_parameters: 34 | assert isinstance(other_mlp_parameters['dropout'], float), "dropout prob should be a float" 35 | assert 0.0 <= other_mlp_parameters['dropout'] < 1.0, "dropout prob needs to be in half-open interval [0, 1)" 36 | self.dropout_prob = other_mlp_parameters['dropout'] 37 | self.do_dropout = True 38 | self.dropout_layer = nn.Dropout(p=self.dropout_prob) 39 | logging.info("Dropout will be used in MLP. Probability: {}".format(self.dropout_prob)) 40 | else: 41 | self.do_dropout = False 42 | logging.info("Dropout will NOT be used in MLP.") 43 | 44 | # Set batchnorm flags, add layers later 45 | # If batchnorm is a key and its value is true: 46 | if 'batchnorm' in other_mlp_parameters: 47 | assert isinstance(other_mlp_parameters['batchnorm'], bool), "batchnorm value must be bool" 48 | if other_mlp_parameters['batchnorm']: 49 | self.do_batchnorm = True 50 | logging.info("batchnorm WILL be applied in MLP") 51 | # If (batchnorm is not a key) OR (it has value False) 52 | else: 53 | self.do_batchnorm = False 54 | logging.info("batchnorm will NOT be applied in MLP.") 55 | 56 | if self.do_dropout and self.do_batchnorm: 57 | raise Warning("User selected both batchnorm and dropout in the MLP") 58 | 59 | self.act = act 60 | self.num_hidden_layers = num_hidden_layers 61 | self.layers = [] 62 | 63 | for ii in range(num_hidden_layers + 1): 64 | # Input to hidden 65 | if ii == 0: 66 | self.layers.append(nn.Linear(in_dim, hidden_dim)) 67 | # Hidden to output 68 | elif ii == num_hidden_layers: 69 | self.layers.append(nn.Linear(hidden_dim, out_dim)) 70 | # Hidden to hidden 71 | else: 72 | self.layers.append(nn.Linear(hidden_dim, hidden_dim)) 73 | 74 | # 75 | # Init weights with Xavier Glorot and set biases to zero 76 | # 77 | init.xavier_uniform_(self.layers[-1].weight) 78 | self.layers[-1].bias.data.fill_(0.0) 79 | 80 | self.add_module("layer_{}".format(ii), self.layers[-1]) 81 | # 82 | # Batchnorm 83 | # 84 | if self.do_batchnorm and ii < num_hidden_layers: 85 | # Get out_features in a robust way by calling getattr 86 | lin = getattr(self, "layer_{}".format(ii)) 87 | self.layers.append(nn.BatchNorm1d(lin.out_features)) 88 | self.add_module("batchnorm_{}".format(ii), self.layers[-1]) 89 | 90 | 91 | def forward(self, x): 92 | for jj in range(self.num_hidden_layers + 1): 93 | layer = getattr(self, "layer_{}".format(jj)) 94 | x = layer(x) 95 | if jj < self.num_hidden_layers: 96 | x = self.act(x) 97 | 98 | # Batchnorm and/or dropout 99 | # warning is raised if both are selected in constructor 100 | if self.do_batchnorm: 101 | bn = getattr(self, "batchnorm_{}".format(jj)) 102 | x = bn(x) 103 | 104 | if self.do_dropout: 105 | x = self.dropout_layer(x) 106 | return x 107 | # ================================================================ 108 | # 109 | # Parent class generates MLPs for the vertex embedding 110 | # for both the whole-graph and one-graph classes 111 | # (carry-over from previous implementations) 112 | # ================================================================ 113 | class GinParent(nn.Module): 114 | def __init__(self, input_data_dim, num_agg_steps, vertex_embed_dim, mlp_num_hidden, mlp_hidden_dim, vertices_are_onehot, other_mlp_parameters={}): 115 | """ 116 | :param input_data_dim: Dimension of the vertex attributes 117 | :param num_agg_steps: K, the number of WL iterations. The number of neighborhood aggregations 118 | :param vertex_embed_dim: Dimension of the `hidden' vertex attributes iteration to iteration 119 | :param mlp_num_hidden: Number of layers. 1 layer is sigmoid(Wx). 2 layers is Theta sigmoid(Wx) 120 | :param mlp_hidden_dim: Number of neurons in each layer 121 | :param vertices_are_onehot: Are the vertex features one-hot-encoded? Boolean 122 | :param vertex_embed_only: We are only interested in the vertex embeddings at layer K. 123 | ..not forming a graph-wide embedding 124 | ..note: this is helpful for debug 125 | """ 126 | assert num_agg_steps > 0, "Number of aggregation steps should be positive" 127 | assert isinstance(vertices_are_onehot, bool) 128 | 129 | super(GinParent, self).__init__() 130 | 131 | self.vertices_are_onehot = vertices_are_onehot 132 | self.input_data_dim = input_data_dim 133 | self.num_agg_steps = num_agg_steps 134 | self.vertex_embed_dim = vertex_embed_dim 135 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~ 136 | # Init layers for embedding 137 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~ 138 | self.gin_layers = [] 139 | # 140 | # If vertex attributes are one-hot, we don't need an MLP before summation in the first layer 141 | # 142 | if not vertices_are_onehot: 143 | logging.info("User indicated: Vertex attributes are NOT one hot") 144 | # We need an extra MLP for embedding the features 145 | self.gin_layers.append(MLP(in_dim=self.input_data_dim, 146 | hidden_dim=mlp_hidden_dim, 147 | out_dim=vertex_embed_dim, 148 | num_hidden_layers=mlp_num_hidden, 149 | other_mlp_parameters=other_mlp_parameters)) 150 | self.add_module("raw_embedding_layer", self.gin_layers[-1]) 151 | # 152 | # MLP after aggregation is different here, because of the input dimension 153 | # 154 | self.gin_layers.append(MLP(in_dim=vertex_embed_dim, 155 | hidden_dim=mlp_hidden_dim, 156 | out_dim=vertex_embed_dim, 157 | num_hidden_layers=mlp_num_hidden, 158 | other_mlp_parameters=other_mlp_parameters)) 159 | 160 | self.add_module("agg_0", self.gin_layers[-1]) 161 | else: 162 | logging.info("User indicated: Vertex attributes ARE one hot") 163 | 164 | for itr in range(num_agg_steps): 165 | if itr == 0 and vertices_are_onehot: 166 | self.gin_layers.append(MLP(in_dim=self.input_data_dim, 167 | hidden_dim=mlp_hidden_dim, 168 | out_dim=vertex_embed_dim, 169 | num_hidden_layers=mlp_num_hidden, 170 | other_mlp_parameters=other_mlp_parameters)) 171 | # Assume all 'hidden' vertex features are of the same dim 172 | else: 173 | self.gin_layers.append(MLP(in_dim=vertex_embed_dim, 174 | hidden_dim=mlp_hidden_dim, 175 | out_dim=vertex_embed_dim, 176 | num_hidden_layers=mlp_num_hidden, 177 | other_mlp_parameters=other_mlp_parameters)) 178 | 179 | self.add_module("agg_{}".format(itr), self.gin_layers[-1]) 180 | # 181 | # Compute graph embedding dim (note it won't be used if 182 | # we only want vertex embeds, but that's fine) 183 | self.graph_embed_dim = self.input_data_dim + vertex_embed_dim * num_agg_steps 184 | 185 | # ======================================== 186 | class GinMultiGraph(GinParent): 187 | """ 188 | Designed for graph classification 189 | """ 190 | def __init__(self, adjmat_list, input_data_dim, num_agg_steps, vertex_embed_dim, mlp_num_hidden, mlp_hidden_dim, vertices_are_onehot, target_dim, epsilon_tunable=False, dense_layer_dropout=0.0, other_mlp_parameters={}): 191 | """ 192 | Most parameters defined in the parent class 193 | 194 | :param adjmat_list: List of all adjmats to be considered 195 | Purpose: force input validation, but not saved to any variable. 196 | The user will enter the graphs in the dataset. In principle, the graphs passed to 197 | initialize could be different than those used in the forward method; it is up 198 | to the user to properly do input validation on all desired graphs 199 | 200 | This is NOT stored as a self object; rest easy we're not wasting memory 201 | 202 | :param target_dim: Dimension of the response variable (the target) 203 | 204 | :param epsilon_tunable: Do we make epsilon in equation 4.1 tunable 205 | :param dense_layer_dropout: Dropout to apply to the dense layer. 206 | In accordance with the GIN paper's experimental section 207 | """ 208 | # Make sure all entered matrices are coo 209 | def is_coo(mat): 210 | return isinstance(mat, sps.coo.coo_matrix) 211 | 212 | # Make sure there are ones on the diagonal. 213 | def diags_all_one(mat): 214 | return np.array_equal(mat.diagonal(), np.ones(mat.shape[0])) 215 | 216 | assert all(list(map(is_coo, adjmat_list))), "All adjacency matrices must be scipy sparse coo" 217 | assert all(list(map(diags_all_one, adjmat_list))), "All adjacency matrices must have ones on the diag" 218 | assert isinstance(dense_layer_dropout, float), "Dense layer dropout must be a float in 0 <= p < 1" 219 | assert 0 <= dense_layer_dropout < 1, "Dense layer dropout must be a float in 0 <= p < 1" 220 | 221 | super(GinMultiGraph, self).__init__(input_data_dim=input_data_dim, 222 | num_agg_steps=num_agg_steps, 223 | vertex_embed_dim=vertex_embed_dim, 224 | mlp_num_hidden=mlp_num_hidden, 225 | mlp_hidden_dim=mlp_hidden_dim, 226 | vertices_are_onehot=vertices_are_onehot, 227 | other_mlp_parameters=other_mlp_parameters 228 | ) 229 | 230 | self.target_dim = target_dim 231 | self.add_module("last_linear", nn.Linear(self.graph_embed_dim, target_dim)) 232 | 233 | self.epsilon_tunable = epsilon_tunable 234 | 235 | logging.info("Dense layer dropout: {}".format(dense_layer_dropout)) 236 | self.dense_layer_dropout = nn.Dropout(p=dense_layer_dropout) 237 | 238 | if epsilon_tunable: 239 | logging.info("User indicated: epsilon_tunable = True") 240 | logging.info("Epsilon_k WILL be LEARNED via backprop") 241 | logging.info("It is initialized to zero") 242 | 243 | self.epsilons = nn.ParameterList() 244 | for ll in range(num_agg_steps): 245 | epsilon_k = nn.Parameter(torch.zeros(1), requires_grad=True) 246 | self.epsilons.append(epsilon_k) 247 | else: 248 | logging.info("User indicated: epsilon_tunable = False") 249 | logging.info("Epsilon_k WILL NOT be learned via backprop (and set to zero implicitly)") 250 | 251 | 252 | def construct_sparse_operator_tensors(self, sparse_adjmats): 253 | """ Construct the matrices needed to perform 254 | hidden layer updates (pre-MLP) 255 | 256 | :param: Sparse adjmat in a BATCH (thus different from the list passed to the constructor) 257 | 258 | :return: Adjacency: A sparse block-diagonal torch tensor, where the blocks 259 | are adjmats 260 | 261 | :return Summation matrix: A matrix of ones and zeros such that 262 | matrix multiplication will effectively 263 | compute the row sums of chunks of a matrix B 264 | 265 | Because B will store the vertex embeddings 266 | for every vertex, for every graph. 267 | We want to compute the sums within a graph. 268 | 269 | Example: 270 | 271 | Suppose we had a two-node graph, a three-node graph, and another two-node 272 | our matrix would look like 273 | [1, 1, 0, 0, 0, 0, 0] 274 | S= [0, 0, 1, 1, 1, 0, 0] 275 | [0, 0, 0, 0, 0, 1, 1] 276 | 277 | Then we will do S @ B 278 | """ 279 | assert isinstance(sparse_adjmats, list) 280 | # 281 | # ADJMAT: 282 | # 283 | # > Make diagonal scipy sparse matrix of adjmats 284 | diag_mat = sps.block_diag(sparse_adjmats) 285 | # 286 | # turn it into a torch sparse tensor 287 | # 288 | rows, cols = sps.find(diag_mat)[0:2] # indices of nonzero rows and cols 289 | indx_tens = torch.stack([torch.LongTensor(rows), torch.LongTensor(cols)], dim=0) 290 | vals_tens = torch.ones(len(rows)) 291 | self.block_adj = torch.sparse.FloatTensor(indx_tens, vals_tens) # One may think this should be an int tensor, but we cannot multiply ints with floats in PyTorch 292 | # 293 | # MAKE THE SUMMATION MATRIX 294 | # >> non-zero indices since we will make a sparse matrix 295 | sum_mat_cols = list(range(self.block_adj.shape[0])) 296 | sum_mat_rows = [] 297 | for iii in range(len(sparse_adjmats)): 298 | num_nodes = sparse_adjmats[iii].shape[0] 299 | sum_mat_rows.extend([iii for jjj in range(num_nodes)]) 300 | 301 | sum_indx = torch.stack([torch.LongTensor(sum_mat_rows), torch.LongTensor(sum_mat_cols)], dim=0) 302 | sum_vals = torch.ones(sum_indx.shape[1]) 303 | self.sum_tensor = torch.sparse.FloatTensor(sum_indx, sum_vals) 304 | 305 | def forward(self, adjmat_list, X): 306 | """ 307 | Get a graph-level prediction for a list of graphs 308 | :param X: Vertex attributes for every vertex in every batch 309 | :param adjmat_list: List of adjacency matrices in batch 310 | """ 311 | # check that #vertices and X dimension coincide 312 | total_vertices = np.sum([mat.shape[0] for mat in adjmat_list]) 313 | assert total_vertices == X.shape[0], "Total vertices must match the number of rows in X" 314 | assert X.shape[1] == self.input_data_dim, "Number of columns in X must match self.input_data_dim" 315 | 316 | # Construct matrices that will allow vectorized operations of 317 | # "sum neighbors" and "sum all vertices within a graph" 318 | self.construct_sparse_operator_tensors(adjmat_list) 319 | 320 | # Get embedding from X 321 | self.graph_embedding = torch.mm(self.sum_tensor, X) 322 | 323 | if not self.vertices_are_onehot: 324 | embedding = getattr(self, "raw_embedding_layer") 325 | H = embedding(X) 326 | else: 327 | H = X.clone() 328 | 329 | for kk in range(self.num_agg_steps): 330 | # Sum self and neighbor 331 | if not self.epsilon_tunable: 332 | # Aggregation in matrix form: (A + I)H 333 | agg_pre_mlp = torch.mm(self.block_adj, H) 334 | # print(agg_pre_mlp) 335 | else: 336 | # 337 | # Add epsilon to h_v, as in equation 4.1 338 | # Note that the proper matrix multiplication is 339 | # (A + (1+epsilon)I)H = (A+I)H + epsilon H 340 | # 341 | # Our implementation avoids making epsilon interact with the 342 | # adjacency matrix, which would make PyTorch want to 343 | # track gradients through the adjmat by default 344 | # 345 | epsilon_k = self.epsilons[kk] 346 | agg_pre_mlp = torch.mm(self.block_adj, H) + epsilon_k*H 347 | 348 | 349 | mlp = getattr(self, "agg_{}".format(kk)) 350 | H = mlp(agg_pre_mlp) 351 | # 352 | layer_k_embed = torch.mm(self.sum_tensor, H) 353 | self.graph_embedding = torch.cat((self.graph_embedding, 354 | layer_k_embed), 355 | dim=1) 356 | # 357 | last_layer = getattr(self, "last_linear") 358 | final = last_layer(self.graph_embedding) 359 | 360 | # apply dropout and return (note dropout is 0.0 by default) 361 | return self.dense_layer_dropout(final) 362 | 363 | # ======================================== 364 | # 365 | # RP-GIN. Use GIN as \harrow{f} in 366 | # relational pooling model. 367 | # ======================================== 368 | class RpGin(GinMultiGraph): 369 | """ 370 | Wrap GIN in relational pooling. 371 | Here we randomly permute the adjacency matrix (while preserving isomorphic invariance) 372 | A_new = P^T @ A @ P 373 | where @ denotes matrix multiplication and P is a permutation matrix. 374 | 375 | We then forward the shuffled mats (see paper for theoretical explanation) 376 | """ 377 | def __init__(self, adjmat_list, input_data_dim, num_agg_steps, vertex_embed_dim, mlp_num_hidden, mlp_hidden_dim, target_dim, featureless_case, vertices_are_onehot=False, epsilon_tunable = False, dense_layer_dropout = 0.0, other_mlp_parameters = {}): 378 | """ Parameters are defined in parent class 379 | :param: featureless_case: bool, is the input featureless? 380 | if featureless, input_data_dim is used as the largest expected graph 381 | """ 382 | assert isinstance(featureless_case, bool) 383 | self.featureles_case = featureless_case 384 | if featureless_case: 385 | logging.info("User has indicated that the graphs are featureless") 386 | logging.info("The number of vertices in the largest expected graph is {}".format(input_data_dim)) 387 | self.featureles_case = True 388 | else: 389 | self.featureles_case = False 390 | raise NotImplementedError("Have only considered featureless case thus far. Set input dim to 0 for featureless case") 391 | 392 | super(RpGin, self).__init__(adjmat_list, input_data_dim, num_agg_steps, vertex_embed_dim, mlp_num_hidden, mlp_hidden_dim, vertices_are_onehot, target_dim, epsilon_tunable, dense_layer_dropout, other_mlp_parameters) 393 | 394 | def permute_adjmat(self, mat): 395 | """ 396 | :param mat: A scipy sparse matrix representing an adjacency matrix 397 | :return: A permuted matrix corresponding to an isomorphic graph, ie 398 | P^T @ mat @ P 399 | where @ denotes matrix multiplication and P is a permutation matrix. 400 | """ 401 | # form a permutation matrix by shuffling an identity matrix 402 | # the result of sparse_shuffle will be compressed row, need to coo it 403 | P = sps.coo_matrix(sparse_shuffle(sps.eye(mat.shape[0]))) 404 | return P.transpose() @ mat @ P 405 | 406 | def forward(self, sparse_adjmats, X): 407 | """ 408 | :param sparse_adjmats: List of adjacency matrices for the graphs in the batch 409 | :param X: Vertex features 410 | :return: 411 | """ 412 | # (1) Permute all the adjacency matrices in the list 413 | # (2) Forward to GIN 414 | if self.featureles_case: 415 | return super(RpGin, self).forward(adjmat_list=list(map(self.permute_adjmat, sparse_adjmats)), 416 | X=X) 417 | else: 418 | pass 419 | 420 | def inference(self, sparse_adjmats, X, num_inf_perms=5): 421 | """ 422 | To do proper inference, we sample multiple 423 | permutations and average 424 | :param num_inf_perms: Number of random permutations to do at inference time 425 | """ 426 | divisor = (1.0/num_inf_perms) 427 | preds = divisor * self.forward(sparse_adjmats, X) 428 | for iii in range(num_inf_perms-1): 429 | preds += divisor * self.forward(sparse_adjmats, X) 430 | 431 | return preds 432 | -------------------------------------------------------------------------------- /GIN/GIN_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def construct_onehot_ids(graph_size, onehot_dim): 3 | """Assign one hot identifiers of dimension onehot_dim 4 | :return: A matrix of dim (graph_size by onehot_dim) 5 | where each row is a one-hot identifier 6 | E.g. if onehot_dim = 3 7 | [1, 0, 0] 8 | [0, 1, 0] 9 | [0, 0, 1] 10 | . . . 11 | . . . 12 | . . . 13 | """ 14 | ID = torch.zeros(graph_size, onehot_dim) 15 | col = 0 16 | for row in range(graph_size): 17 | ID[row, col] = 1 18 | col = (col + 1) % onehot_dim 19 | 20 | return ID -------------------------------------------------------------------------------- /GIN/__init__.py: -------------------------------------------------------------------------------- 1 | ########################################## 2 | # Implement Graph Isomorphism Network (GIN) https://arxiv.org/pdf/1810.00826.pdf 3 | # for the purposes of our project 4 | ########################################## 5 | 6 | -------------------------------------------------------------------------------- /Logs/regularGin_cv_0_s1337_.log: -------------------------------------------------------------------------------- 1 | 2019-01-23 04:28:48,987:INFO: {'--cv-fold': 0, 2 | '--dense-dropout-prob': 0.0, 3 | '--learning-rate': 0.01, 4 | '--mlp-hidden-dim': 16, 5 | '--model-type': 'regularGin', 6 | '--num-epochs': 12, 7 | '--num-gnn-layers': 5, 8 | '--num-inf-perm': 5, 9 | '--num-mlp-hidden': 2, 10 | '--onehot-id-dim': 41, 11 | '--out-weight-dir': '/scratch-data/murph213/', 12 | '--seed-val': 1337, 13 | '--set-epsilon-zero': False, 14 | '--use-batchnorm': False, 15 | '--vertex-embed-dim': 16} 16 | 2019-01-23 04:28:48,988:INFO: ---Loading data...--- 17 | 2019-01-23 04:28:48,993:INFO: 140 Adjacency matrices were loaded 18 | 2019-01-23 04:28:49,002:INFO: ---splitting into training and validation folds--- 19 | 2019-01-23 04:28:49,002:INFO: The indices are shuffled, and the shuffle is consistent on many machines as long as python3 is used 20 | 2019-01-23 04:28:49,006:INFO: Building model... 21 | 2019-01-23 04:28:49,010:INFO: User indicated: Vertex attributes are NOT one hot 22 | 2019-01-23 04:28:49,010:INFO: Dropout will NOT be used in MLP. 23 | 2019-01-23 04:28:49,010:INFO: batchnorm will NOT be applied in MLP. 24 | 2019-01-23 04:28:49,011:INFO: Dropout will NOT be used in MLP. 25 | 2019-01-23 04:28:49,011:INFO: batchnorm will NOT be applied in MLP. 26 | 2019-01-23 04:28:49,012:INFO: Dropout will NOT be used in MLP. 27 | 2019-01-23 04:28:49,012:INFO: batchnorm will NOT be applied in MLP. 28 | 2019-01-23 04:28:49,012:INFO: Dropout will NOT be used in MLP. 29 | 2019-01-23 04:28:49,012:INFO: batchnorm will NOT be applied in MLP. 30 | 2019-01-23 04:28:49,013:INFO: Dropout will NOT be used in MLP. 31 | 2019-01-23 04:28:49,013:INFO: batchnorm will NOT be applied in MLP. 32 | 2019-01-23 04:28:49,013:INFO: Dropout will NOT be used in MLP. 33 | 2019-01-23 04:28:49,014:INFO: batchnorm will NOT be applied in MLP. 34 | 2019-01-23 04:28:49,014:INFO: Dropout will NOT be used in MLP. 35 | 2019-01-23 04:28:49,014:INFO: batchnorm will NOT be applied in MLP. 36 | 2019-01-23 04:28:49,015:INFO: Dense layer dropout: 0.0 37 | 2019-01-23 04:28:49,015:INFO: User indicated: epsilon_tunable = True 38 | 2019-01-23 04:28:49,015:INFO: Epsilon_k WILL be LEARNED via backprop 39 | 2019-01-23 04:28:49,015:INFO: It is initialized to zero 40 | 2019-01-23 04:28:49,015:INFO: GinMultiGraph( 41 | (raw_embedding_layer): MLP( 42 | (layer_0): Linear(in_features=1, out_features=16, bias=True) 43 | (layer_1): Linear(in_features=16, out_features=16, bias=True) 44 | (layer_2): Linear(in_features=16, out_features=16, bias=True) 45 | ) 46 | (agg_0): MLP( 47 | (layer_0): Linear(in_features=16, out_features=16, bias=True) 48 | (layer_1): Linear(in_features=16, out_features=16, bias=True) 49 | (layer_2): Linear(in_features=16, out_features=16, bias=True) 50 | ) 51 | (agg_1): MLP( 52 | (layer_0): Linear(in_features=16, out_features=16, bias=True) 53 | (layer_1): Linear(in_features=16, out_features=16, bias=True) 54 | (layer_2): Linear(in_features=16, out_features=16, bias=True) 55 | ) 56 | (agg_2): MLP( 57 | (layer_0): Linear(in_features=16, out_features=16, bias=True) 58 | (layer_1): Linear(in_features=16, out_features=16, bias=True) 59 | (layer_2): Linear(in_features=16, out_features=16, bias=True) 60 | ) 61 | (agg_3): MLP( 62 | (layer_0): Linear(in_features=16, out_features=16, bias=True) 63 | (layer_1): Linear(in_features=16, out_features=16, bias=True) 64 | (layer_2): Linear(in_features=16, out_features=16, bias=True) 65 | ) 66 | (agg_4): MLP( 67 | (layer_0): Linear(in_features=16, out_features=16, bias=True) 68 | (layer_1): Linear(in_features=16, out_features=16, bias=True) 69 | (layer_2): Linear(in_features=16, out_features=16, bias=True) 70 | ) 71 | (last_linear): Linear(in_features=81, out_features=10, bias=True) 72 | (dense_layer_dropout): Dropout(p=0.0) 73 | (epsilons): ParameterList( 74 | (0): Parameter containing: [torch.FloatTensor of size 1] 75 | (1): Parameter containing: [torch.FloatTensor of size 1] 76 | (2): Parameter containing: [torch.FloatTensor of size 1] 77 | (3): Parameter containing: [torch.FloatTensor of size 1] 78 | (4): Parameter containing: [torch.FloatTensor of size 1] 79 | ) 80 | ) 81 | 2019-01-23 04:28:49,016:INFO: ------Training Model--------- 82 | 2019-01-23 04:28:49,016:INFO: Train X has shape torch.Size([4592, 1]) 83 | 2019-01-23 04:28:49,016:INFO: Val X has shape torch.Size([1148, 1]) 84 | 2019-01-23 04:28:49,016:INFO: Train y has shape torch.Size([112]) 85 | 2019-01-23 04:28:49,016:INFO: Validation y has shape torch.Size([28]) 86 | 2019-01-23 04:28:51,107:INFO: ~~~~~ 87 | 2019-01-23 04:28:51,107:INFO: Epoch: 0 | Train Loss: 287.92242 | Val Loss: 19.08630 | Train Accuracy : 0.10714 | Val Accuracy : 0.10714 88 | 2019-01-23 04:28:53,154:INFO: ~~~~~ 89 | 2019-01-23 04:28:53,155:INFO: Epoch: 1 | Train Loss: 17.59324 | Val Loss: 21.50815 | Train Accuracy : 0.09821 | Val Accuracy : 0.10714 90 | 2019-01-23 04:28:55,386:INFO: ~~~~~ 91 | 2019-01-23 04:28:55,387:INFO: Epoch: 2 | Train Loss: 21.36885 | Val Loss: 11.38049 | Train Accuracy : 0.09821 | Val Accuracy : 0.14286 92 | 2019-01-23 04:28:57,384:INFO: ~~~~~ 93 | 2019-01-23 04:28:57,385:INFO: Epoch: 3 | Train Loss: 11.11135 | Val Loss: 15.29234 | Train Accuracy : 0.08929 | Val Accuracy : 0.10714 94 | 2019-01-23 04:28:59,158:INFO: ~~~~~ 95 | 2019-01-23 04:28:59,159:INFO: Epoch: 4 | Train Loss: 14.30012 | Val Loss: 14.56442 | Train Accuracy : 0.09821 | Val Accuracy : 0.10714 96 | 2019-01-23 04:29:00,638:INFO: ~~~~~ 97 | 2019-01-23 04:29:00,639:INFO: Epoch: 5 | Train Loss: 15.00910 | Val Loss: 10.52087 | Train Accuracy : 0.09821 | Val Accuracy : 0.10714 98 | 2019-01-23 04:29:02,622:INFO: ~~~~~ 99 | 2019-01-23 04:29:02,623:INFO: Epoch: 6 | Train Loss: 11.12828 | Val Loss: 8.97679 | Train Accuracy : 0.09821 | Val Accuracy : 0.10714 100 | 2019-01-23 04:29:04,575:INFO: ~~~~~ 101 | 2019-01-23 04:29:04,575:INFO: Epoch: 7 | Train Loss: 9.14753 | Val Loss: 10.08619 | Train Accuracy : 0.09821 | Val Accuracy : 0.07143 102 | 2019-01-23 04:29:07,378:INFO: ~~~~~ 103 | 2019-01-23 04:29:07,379:INFO: Epoch: 8 | Train Loss: 10.37136 | Val Loss: 7.75332 | Train Accuracy : 0.10714 | Val Accuracy : 0.03571 104 | 2019-01-23 04:29:09,634:INFO: ~~~~~ 105 | 2019-01-23 04:29:09,635:INFO: Epoch: 9 | Train Loss: 7.82652 | Val Loss: 7.68935 | Train Accuracy : 0.11607 | Val Accuracy : 0.10714 106 | 2019-01-23 04:29:11,682:INFO: ~~~~~ 107 | 2019-01-23 04:29:11,683:INFO: Epoch: 10 | Train Loss: 8.33960 | Val Loss: 5.62804 | Train Accuracy : 0.09821 | Val Accuracy : 0.14286 108 | 2019-01-23 04:29:14,166:INFO: ~~~~~ 109 | 2019-01-23 04:29:14,167:INFO: Epoch: 11 | Train Loss: 6.56234 | Val Loss: 5.54927 | Train Accuracy : 0.08929 | Val Accuracy : 0.10714 110 | 2019-01-23 04:29:14,167:INFO: Saving model to file 111 | 2019-01-23 04:29:14,167:INFO: /scratch-data/murph213/regularGin_cv_0_s1337_.pth 112 | 2019-01-23 04:29:14,274:INFO: ...done saving 113 | 2019-01-23 04:29:14,275:INFO: Saving metrics 114 | 2019-01-23 04:29:14,275:INFO: /scratch-data/murph213/regularGin_cv_0_s1337_.pkl 115 | 2019-01-23 04:29:14,333:INFO: ... done saving 116 | -------------------------------------------------------------------------------- /MoleculeTasks/Duvenaud-kary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Balasubramanian Srinivasan and Ryan L Murphy 3 | This code implements so-called k-ary RP approaches 4 | """ 5 | from __future__ import division 6 | from __future__ import print_function 7 | from __future__ import unicode_literals 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import deepchem as dc 12 | import sys 13 | 14 | from deepchem.models.tensorgraph.models.graph_models import GraphConvTensorGraph 15 | from random import shuffle 16 | from deepchem.models.tensorgraph.layers import Feature 17 | from deepchem.models.tensorgraph.layers import Dense, GraphConv, BatchNorm 18 | from deepchem.models.tensorgraph.layers import GraphPool, GraphGather 19 | from deepchem.models.tensorgraph.layers import Dense, SoftMax, SoftMaxCrossEntropy, WeightedError, Stack 20 | from deepchem.models.tensorgraph.layers import Label, Weights 21 | from deepchem.metrics import to_one_hot 22 | from deepchem.feat.mol_graphs import ConvMol 23 | from deepchem.models.tensorgraph.tensor_graph import TensorGraph 24 | tg = TensorGraph(use_queue=False) 25 | 26 | 27 | TASK = sys.argv[1] # 'tox_21', 'hiv', 'muv 28 | K = int(sys.argv[2]) 29 | technique = 'dfs' 30 | batch_size = 96 31 | NUM_EPOCHS = 100 32 | 33 | 34 | def randomize_perm(a): 35 | ordering = list(range(a)) 36 | shuffle(ordering) 37 | return ordering 38 | 39 | 40 | def depth_first_search(neighbour_list, root_node): 41 | """ DFS can be used as a poly-canonical ordering to reduce computational cost in the RP sum """ 42 | visited_nodes = set() 43 | order = [] 44 | stack = [root_node] 45 | while stack: 46 | node = stack.pop() 47 | if node not in visited_nodes: 48 | visited_nodes.add(node) 49 | order.append(node) 50 | stack.extend(set(neighbour_list[node]) - visited_nodes) 51 | return order 52 | 53 | 54 | def breadth_first_search(neighbour_list, root_node): 55 | """ BFS can be used as a poly-canonical ordering to reduce computational cost in the RP sum """ 56 | visited_nodes = set() 57 | order = [] 58 | queue = [root_node] 59 | while queue: 60 | node = queue.pop(0) 61 | if node not in visited_nodes: 62 | visited_nodes.add(node) 63 | order.append(node) 64 | queue.extend(set(neighbour_list[node]) - visited_nodes) 65 | return order 66 | 67 | 68 | def generate_new_X(dataset, K, technique): 69 | """ Reduce to k-ary and run poly-canonical ordering""" 70 | count = 0 71 | new_array = [] 72 | size = dataset.shape[0] 73 | for i in range(size): 74 | mol = dataset[i] 75 | min_degree, max_degree = 1000, 0 76 | atom_feats = mol.get_atom_features() 77 | adjacent_list = mol.get_adjacency_list() 78 | num_atoms = mol.get_num_atoms() 79 | if num_atoms > K: 80 | #Reduce to k-ary 81 | count+=1 82 | ordering = randomize_perm(num_atoms) 83 | if technique == 'dfs': 84 | order = depth_first_search(adjacent_list,ordering[0]) 85 | elif technique == 'bfs': 86 | order = breadth_first_search(adjacent_list,ordering[0]) 87 | else : 88 | order = ordering 89 | if (len(order) < K): 90 | order = ordering 91 | order = order[:K] 92 | atom_feats = atom_feats[order] 93 | new_atom_feats = atom_feats 94 | create_adjacency = [] 95 | for i in order: 96 | edges = [] 97 | for neighbor in adjacent_list[i]: 98 | if neighbor in order: 99 | get_new_index = int(order.index(neighbor)) 100 | edges.append(get_new_index) 101 | create_adjacency.append(edges) 102 | new_mol = dc.feat.mol_graphs.ConvMol(new_atom_feats, create_adjacency) 103 | else : 104 | new_mol = dc.feat.mol_graphs.ConvMol(atom_feats, adjacent_list) 105 | new_array.append(new_mol) 106 | print(count) 107 | return np.array(new_array) 108 | 109 | 110 | def data_generator(dataset, epochs=1, predict=False, pad_batches=True): 111 | for epoch in range(epochs): 112 | if not predict: 113 | print('Starting epoch %i' % epoch) 114 | for ind, (X_b, y_b, w_b, ids_b) in enumerate( 115 | dataset.iterbatches(batch_size, pad_batches=pad_batches, deterministic=True)): 116 | d = {} 117 | for index, label in enumerate(labels): 118 | d[label] = to_one_hot(y_b[:, index]) 119 | d[weights] = w_b 120 | multiConvMol = ConvMol.agglomerate_mols(X_b) 121 | d[atom_features] = multiConvMol.get_atom_features() 122 | d[degree_slice] = multiConvMol.deg_slice 123 | d[membership] = multiConvMol.membership 124 | for i in range(1, len(multiConvMol.get_deg_adjacency_lists())): 125 | d[deg_adjs[i - 1]] = multiConvMol.get_deg_adjacency_lists()[i] 126 | yield d 127 | 128 | 129 | def reshape_y_pred(y_true, y_pred): 130 | """ 131 | TensorGraph.Predict returns a list of arrays, one for each output 132 | We also have to remove the padding on the last batch 133 | Metrics taks results of shape (samples, n_task, prob_of_class) 134 | """ 135 | n_samples = len(y_true) 136 | retval = np.stack(y_pred, axis=1) 137 | return retval[:n_samples] 138 | 139 | 140 | if TASK == 'tox_21': 141 | from deepchem.molnet import load_tox21 as dataloader 142 | NUM_TASKS = 12 143 | elif TASK == 'hiv': 144 | from deepchem.molnet import load_hiv as dataloader 145 | NUM_TASKS = 1 146 | elif TASK == 'muv': 147 | from deepchem.molnet import load_muv as dataloader 148 | NUM_TASKS = 17 149 | 150 | # ------------------------------------------------- 151 | # Load datasets, tasks, and transformers 152 | # The number of tasks in each dataset can be found in Table 1 of MoleculeNet: A Benchmark for Molecular Machine Learning 153 | # by Wu et. al. 154 | current_tasks, current_datasets, transformers = dataloader(featurizer='GraphConv',reload=True,split='random') 155 | train_dataset, valid_dataset, test_dataset = current_datasets 156 | # 157 | # Build up model object 158 | # Follow: https://deepchem.io/docs/notebooks/graph_convolutional_networks_for_tox21.html 159 | # 160 | atom_features = Feature(shape=(None, 75)) 161 | degree_slice = Feature(shape=(None, 2), dtype=tf.int32) 162 | membership = Feature(shape=(None,), dtype=tf.int32) 163 | 164 | deg_adjs = [] 165 | for i in range(0, 10 + 1): 166 | deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32) 167 | deg_adjs.append(deg_adj) 168 | 169 | 170 | gc1 = GraphConv( 171 | 64, 172 | activation_fn=tf.nn.relu, 173 | in_layers=[atom_features, degree_slice, membership] + deg_adjs) 174 | batch_norm1 = BatchNorm(in_layers=[gc1]) 175 | gp1 = GraphPool(in_layers=[batch_norm1, degree_slice, membership] + deg_adjs) 176 | gc2 = GraphConv( 177 | 64, 178 | activation_fn=tf.nn.relu, 179 | in_layers=[gp1, degree_slice, membership] + deg_adjs) 180 | batch_norm2 = BatchNorm(in_layers=[gc2]) 181 | gp2 = GraphPool(in_layers=[batch_norm2, degree_slice, membership] + deg_adjs) 182 | dense = Dense(out_channels=128, activation_fn=tf.nn.relu, in_layers=[gp2]) 183 | batch_norm3 = BatchNorm(in_layers=[dense]) 184 | readout = GraphGather( 185 | batch_size=batch_size, 186 | activation_fn=tf.nn.tanh, 187 | in_layers=[batch_norm3, degree_slice, membership] + deg_adjs) 188 | 189 | costs = [] 190 | labels = [] 191 | for task in range(len(current_tasks)): 192 | classification = Dense( 193 | out_channels=2, activation_fn=None, in_layers=[readout]) 194 | 195 | softmax = SoftMax(in_layers=[classification]) 196 | tg.add_output(softmax) 197 | 198 | label = Label(shape=(None, 2)) 199 | labels.append(label) 200 | cost = SoftMaxCrossEntropy(in_layers=[label, classification]) 201 | costs.append(cost) 202 | 203 | all_cost = Stack(in_layers=costs, axis=1) 204 | weights = Weights(shape=(None, len(current_tasks))) 205 | loss = WeightedError(in_layers=[all_cost, weights]) 206 | tg.set_loss(loss) 207 | # Data splits 208 | # Tox21 is treated differently: we manually (randomly) split into test, train, and valid directly from train_dataset.X 209 | # (rather than letting deepchem provide the data directly) 210 | # Reason: In the early stages of developing the code, the valid_dataset and test_dataset were empty for tox and 211 | # we observed a comment in the deepchem source code leading us to believe this was intended. 212 | # Thus, when we access valid_dataset.X and test_dataset.X, we don't do it for tox21. We only later 213 | # found that we could access tox21 validation and test. But we do this for all models, so the treatment is fair 214 | # 215 | # 216 | # This treatment is done for all models, so the comparison is fair. 217 | # 218 | if TASK != 'tox_21': 219 | new_train_data = generate_new_X(train_dataset.X, K, technique) 220 | new_train_dataset = dc.data.datasets.DiskDataset.from_numpy(new_train_data, train_dataset.y, train_dataset.w ,train_dataset.ids, data_dir=None) 221 | print("Train Data - added RP") 222 | new_valid_data = generate_new_X(valid_dataset.X, K, technique) 223 | new_valid_dataset = dc.data.datasets.DiskDataset.from_numpy(new_valid_data, valid_dataset.y, valid_dataset.w ,valid_dataset.ids, data_dir=None) 224 | print("Valid Data - added RP") 225 | new_test_data = generate_new_X(test_dataset.X, K, technique) 226 | new_test_dataset = dc.data.datasets.DiskDataset.from_numpy(new_test_data, test_dataset.y, test_dataset.w ,test_dataset.ids, data_dir=None) 227 | print("Test Data - added RP") 228 | else: 229 | new_train_data = generate_new_X(train_dataset.X[:3800], K, technique) 230 | new_train_dataset = dc.data.datasets.DiskDataset.from_numpy(new_train_data, train_dataset.y[:3800], train_dataset.w[:3800] ,train_dataset.ids[:3800], data_dir=None) 231 | print("Train Data - added RP - tox21") 232 | new_valid_data = generate_new_X(train_dataset.X[3800:5000], K, technique) 233 | new_valid_dataset = dc.data.datasets.DiskDataset.from_numpy(new_valid_data, train_dataset.y[3800:5000], train_dataset.w[3800:5000] ,train_dataset.ids[3800:5000], data_dir=None) 234 | print("Valid Data - added RP - tox21") 235 | new_test_data = generate_new_X(train_dataset.X[5000:], K, technique) 236 | new_test_dataset = dc.data.datasets.DiskDataset.from_numpy(new_test_data, train_dataset.y[5000:], train_dataset.w[5000:] ,train_dataset.ids[5000:], data_dir=None) 237 | print("Test Data - added RP - tox21") 238 | 239 | 240 | tg.fit_generator(data_generator(new_train_dataset, epochs=NUM_EPOCHS)) 241 | 242 | metric = dc.metrics.Metric( 243 | dc.metrics.roc_auc_score, np.mean, mode="classification") 244 | 245 | 246 | print("Evaluating model") 247 | train_predictions = tg.predict_on_generator(data_generator(new_train_dataset, predict=True)) 248 | train_predictions = reshape_y_pred(new_train_dataset.y, train_predictions) 249 | train_scores = metric.compute_metric(new_train_dataset.y, train_predictions, new_train_dataset.w) 250 | print("Training ROC-AUC Score: %f" % train_scores) 251 | 252 | valid_predictions = tg.predict_on_generator(data_generator(new_valid_dataset, predict=True)) 253 | valid_predictions = reshape_y_pred(new_valid_dataset.y, valid_predictions) 254 | valid_scores = metric.compute_metric(new_valid_dataset.y, valid_predictions, new_valid_dataset.w) 255 | print("Valid ROC-AUC Score: %f" % valid_scores) 256 | 257 | test_predictions = tg.predict_on_generator(data_generator(new_test_dataset, predict=True)) 258 | test_predictions = reshape_y_pred(new_test_dataset.y, test_predictions) 259 | test_scores = metric.compute_metric(new_test_dataset.y, test_predictions, new_test_dataset.w) 260 | print("test ROC-AUC Score: %f" % test_scores) 261 | 262 | 263 | -------------------------------------------------------------------------------- /MoleculeTasks/RNN-DFS.py: -------------------------------------------------------------------------------- 1 | """ 2 | Balasubramanian Srinivasan and Ryan L Murphy 3 | This code implements neural-network-based methods as \harrow{f} in the RP framework 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import time 10 | import pickle 11 | import sys 12 | from itertools import permutations, product 13 | from scipy import sparse 14 | from torch.nn import init 15 | from random import shuffle 16 | from sklearn.metrics import roc_auc_score 17 | 18 | TASK = sys.argv[1] 19 | batch_size = 96 20 | num_epochs = 50 21 | inference_permutations = 20 22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | 24 | #Since the individual features and the pairwise features aren't aligned across the conv and the weave models from deepchem, we align them using the atom features 25 | def align_adjacency(a, b): 26 | data0 = a.get_atom_features() 27 | data1 = b.get_atom_features() 28 | index = list(range(data0.shape[0])) 29 | remain = list(range(data1.shape[0])) 30 | mapping_dict = dict((k, k) for k in index) 31 | for i in index: 32 | for j in remain: 33 | if np.array_equal(data0[i], data1[j]): 34 | mapping_dict[i] = j 35 | remain.remove(j) 36 | break 37 | return mapping_dict 38 | 39 | 40 | #Returns a random ordering of the atoms in the molecule 41 | def randomize_perm(a): 42 | ordering = list(range(a)) 43 | shuffle(ordering) 44 | return ordering 45 | 46 | 47 | #Permuting the adjacency tensor in accordance with joint erxchangeability 48 | def permute_array(a, ordering, mapping_dict): 49 | pair_features = a.get_pair_features() 50 | new_array = np.zeros(pair_features.shape) 51 | mod_factor = pair_features.shape[0] 52 | m, n = 0, 0 53 | for i in ordering: 54 | for j in ordering: 55 | new_array[m][n] = pair_features[mapping_dict[i]][mapping_dict[j]] 56 | n += 1 57 | n = n % mod_factor 58 | m += 1 59 | m = m % mod_factor 60 | return new_array 61 | 62 | #Perform depth first search based on an input root node 63 | def depth_first_search(neighbour_list, root_node): 64 | visited_nodes = set() 65 | order = [] 66 | stack = [root_node] 67 | while stack: 68 | node = stack.pop() 69 | if node not in visited_nodes: 70 | visited_nodes.add(node) 71 | order.append(node) 72 | stack.extend(set(neighbour_list[node]) - visited_nodes) 73 | return order 74 | 75 | #Construct the complete tensors using the pairwise and individual tensors adhering to joint exchangeability 76 | def construct_tensor(dataset_conv, dataset_weave, y): 77 | size = dataset_weave.shape[0] 78 | arr_pair = [] 79 | arr_indv = [] 80 | y_true = [] 81 | for i in range(size): 82 | a = dataset_conv[i] 83 | b = dataset_weave[i] 84 | ordering = randomize_perm(a.get_num_atoms()) 85 | mapping_dict = align_adjacency(a, b) 86 | order = depth_first_search(a.get_adjacency_list(), ordering[0]) 87 | pair_array = permute_array(b, order, mapping_dict) 88 | indv_array = a.get_atom_features()[order] 89 | if len(pair_array) == len(indv_array): 90 | arr_pair.append(torch.Tensor(pair_array).to(device)) 91 | arr_indv.append(torch.Tensor(indv_array).to(device)) 92 | y_true.append(y[i]) 93 | return (arr_indv, arr_pair, y_true) 94 | 95 | 96 | def unison_shuffled(a, b, c): 97 | p = np.random.permutation(len(a)) 98 | return a[p], b[p], c[p] 99 | 100 | #The RNN model as described in the appendix of the paper 101 | class RNNModel(nn.Module): 102 | def __init__(self): 103 | super(RNNModel, self).__init__() 104 | self.rnn_unit_1 = nn.LSTM(14, 100, batch_first=True) 105 | self.indv_linear_1 = nn.Linear(75, 100) 106 | init.xavier_uniform_(self.indv_linear_1.weight) 107 | self.indv_act_1 = nn.ReLU() 108 | self.rnn_unit_2 = nn.LSTM(200, 100, batch_first=True) 109 | self.rho_lin_1 = nn.Linear(100, 100) 110 | init.xavier_uniform_(self.rho_lin_1.weight) 111 | self.rho_act_1 = nn.ReLU() 112 | self.final_lin = nn.Linear(100, NUM_TASKS) 113 | init.xavier_uniform_(self.final_lin.weight) 114 | self.loss_func = nn.BCEWithLogitsLoss() 115 | 116 | def forward(self, pair_inp, indv_inp): 117 | rho_input = torch.zeros((1, 100)).to(device) 118 | for i in range(len(pair_inp)): 119 | out_rnn_1, (h_n, c_n) = self.rnn_unit_1(pair_inp[i]) 120 | out_indv = self.indv_linear_1(indv_inp[i]) 121 | out_indv = self.indv_act_1(out_indv).unsqueeze(0) 122 | inp_rnn_2 = torch.cat((c_n, out_indv), 2) 123 | out_rnn_2, (h_n, c_n) = self.rnn_unit_2(inp_rnn_2) 124 | c_n = c_n.squeeze(0) 125 | rho_input = torch.cat((rho_input, c_n), 0) 126 | rho_input = rho_input[1:] 127 | rho_out = self.rho_lin_1(rho_input) 128 | rho_out = self.rho_act_1(rho_out) 129 | final_out = self.final_lin(rho_out) 130 | return final_out 131 | 132 | def compute_loss(self, pair_inp, indv_inp, y_true): 133 | pred = self.forward(pair_inp, indv_inp) 134 | return self.loss_func(pred, y_true) 135 | 136 | def compute_proba(self, pair_inp, indv_inp): 137 | return torch.sigmoid(self.forward(pair_inp, indv_inp)) 138 | 139 | 140 | if TASK == 'tox_21': 141 | from deepchem.molnet import load_tox21 as dataloader 142 | NUM_TASKS = 12 143 | elif TASK == 'hiv': 144 | from deepchem.molnet import load_hiv as dataloader 145 | NUM_TASKS = 1 146 | elif TASK == 'muv': 147 | from deepchem.molnet import load_muv as dataloader 148 | NUM_TASKS = 17 149 | 150 | 151 | current_tasks_weave, current_datasets_weave, transformers_weave = dataloader(featurizer='Weave') 152 | current_tasks_conv, current_datasets_conv, transformers_conv = dataloader(featurizer='GraphConv') 153 | 154 | train_dataset_weave, valid_dataset_weave, test_dataset_weave = current_datasets_weave 155 | train_dataset_conv, valid_dataset_conv, test_dataset_conv = current_datasets_conv 156 | 157 | train_shuffled_conv = train_dataset_conv.X 158 | train_shuffled_weave = train_dataset_weave.X 159 | train_shuffled_y = train_dataset_conv.y 160 | 161 | # Data splits 162 | # Tox21 is treated differently: we manually (randomly) split into test, train, and valid directly from train_dataset.X 163 | # (rather than letting deepchem provide the data directly) 164 | # Reason: In the early stages of developing the code, the valid_dataset and test_dataset were empty for tox and 165 | # we observed a comment in the deepchem source code leading us to believe this was intended. 166 | # Thus, when we access valid_dataset.X and test_dataset.X, we don't do it for tox21. We only later 167 | # found that we could access tox21 validation and test. But we do this for all models, so the treatment is fair 168 | # 169 | # 170 | # This treatment is done for all models, so the comparison is fair. 171 | # 172 | if TASK != 'tox_21': 173 | train_pair = train_dataset_weave.X 174 | train_indv = train_dataset_conv.X 175 | train_y = train_dataset_conv.y 176 | valid_pair = valid_dataset_weave.X 177 | valid_indv = valid_dataset_conv.X 178 | valid_y = valid_dataset_conv.y 179 | test_pair = test_dataset_weave.X 180 | test_indv = test_dataset_conv.X 181 | test_y = test_dataset_conv.y 182 | else : 183 | train_shuffled_conv, train_shuffled_weave, train_shuffled_y = unison_shuffled(train_shuffled_conv, train_shuffled_weave, train_shuffled_y) 184 | train_indv = train_shuffled_conv[:3800] 185 | train_pair = train_shuffled_weave[:3800] 186 | train_y = train_shuffled_y[:3800] 187 | valid_indv = train_shuffled_conv[3800:5000] 188 | valid_pair = train_shuffled_weave[3800:5000] 189 | valid_y = train_shuffled_y[3800:5000] 190 | test_indv = train_shuffled_conv[5000:] 191 | test_pair = train_shuffled_weave[5000:] 192 | test_y = train_shuffled_y[5000:] 193 | 194 | # train_shuffled_conv, train_shuffled_weave, train_shuffled_y = unison_shuffled(train_shuffled_conv, train_shuffled_weave, train_shuffled_y) 195 | 196 | #Construct Valid and Test 197 | indv, pair, y_true = construct_tensor(train_indv,train_pair, train_y) 198 | valid_indv, valid_pair, valid_y_true = construct_tensor(valid_indv, valid_pair, valid_y) 199 | #test_indv, test_pair, test_y_true = construct_tensor(test_indv, test_pair, test_y) 200 | 201 | # Train over multiple epochs 202 | val_score_tracker, train_loss_tracker = [], [] 203 | NUM_TRAINING_EXAMPLES = len(indv) 204 | start_time = time.time() 205 | num_batches = int(NUM_TRAINING_EXAMPLES / batch_size) 206 | val_loss_tracker = [] 207 | num_steps_tracker = [] 208 | checkpoint_file_name = "rnn_dfs_{}.model".format(TASK) 209 | checker = RNNModel().to(device) 210 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, checker.parameters()), lr=0.003) 211 | count = 0 212 | best_roc_auc = 0.0 213 | 214 | for epoch in range(num_epochs): 215 | print("Epoch Num: ", epoch) 216 | # Do seed and random shuffle of the input 217 | print("Performing Random Shuffle") 218 | train_indv, train_pair, train_y = unison_shuffled(train_indv,train_pair, train_y) 219 | indv, pair, y_true = construct_tensor(train_indv,train_pair, train_y) 220 | y_true_tensor = torch.FloatTensor(y_true).to(device) 221 | print("Random Shuffle Done") 222 | for batch in range(num_batches): 223 | optimizer.zero_grad() 224 | batch_pair = pair[batch_size * batch:batch_size * batch + batch_size] 225 | batch_indv = indv[batch_size * batch:batch_size * batch + batch_size] 226 | batch_y = y_true_tensor[batch_size * batch:batch_size * batch + batch_size] 227 | loss = checker.compute_loss(batch_pair, batch_indv, batch_y) 228 | loss.backward() 229 | optimizer.step() 230 | count += 1 231 | 232 | if count % 100 == 0: 233 | with torch.no_grad(): 234 | val_loss = checker.compute_loss(valid_pair, valid_indv, torch.FloatTensor(valid_y_true).to(device)) 235 | val_loss_tracker.append(val_loss.item()) 236 | pickle.dump(val_loss_tracker, open("val_loss_dfs_rnn_{}.p".format(TASK), "wb")) 237 | print("Val Loss at Step ", count, " : ", val_loss.item()) 238 | num_steps_tracker.append(count) 239 | 240 | val_out = checker.compute_proba(valid_pair, valid_indv) 241 | val_y_pred = np.round(val_out.detach().cpu().numpy()) 242 | val_score = roc_auc_score(np.array(valid_y_true), val_y_pred) 243 | val_score_tracker.append(val_score) 244 | pickle.dump(val_score_tracker, open("val_score_dfs_rnn_{}.p".format(TASK), "wb" )) 245 | if val_score > best_roc_auc: 246 | print("Best Val ROC AUC Score till now: ", val_score) 247 | best_roc_auc = val_score 248 | torch.save(checker.state_dict(),checkpoint_file_name) 249 | 250 | with torch.no_grad(): 251 | loss = checker.compute_loss(pair, indv, y_true_tensor) 252 | print("Epoch Training Loss: ", loss.item()) 253 | train_loss_tracker.append(loss.item()) 254 | pickle.dump(train_loss_tracker, open("train_loss_dfs_rnn_{}.p".format(TASK), "wb" )) 255 | 256 | PATH = "train_loss_dfs_rnn_{}.p".format(TASK) 257 | end_time = time.time() 258 | total_training_time = end_time - start_time 259 | print("Total Time: ", total_training_time) 260 | 261 | # 262 | # Run test-set prediction (TODO: paste separate script that uses trained model here) 263 | # 264 | 265 | # Load best model which was saved 266 | best_model = RNNModel().to(device) 267 | best_model.load_state_dict(torch.load(PATH)) 268 | 269 | # Use the test set 270 | with torch.no_grad(): 271 | for num_perm in range(inference_permutations): 272 | # Do a random shuffle and then compute probabilities 273 | test_indv, test_pair, test_y = unison_shuffled(test_indv,test_pair, test_y) 274 | indv, pair, y_true = construct_tensor(test_indv,test_pair, test_y) 275 | y_true_tensor = torch.FloatTensor(y_true).to(device) 276 | if num_perm == 0 : 277 | test_out = best_model.compute_proba(pair, indv) 278 | else : 279 | test_out += best_model.compute_proba(pair, indv) 280 | test_out = test_out/inference_permutations 281 | test_y_pred = np.round(test_out.detach().cpu().numpy()) 282 | test_score = roc_auc_score(np.array(y_true), test_y_pred) 283 | 284 | print("Test ROC AUC score: ", test_score) 285 | -------------------------------------------------------------------------------- /MoleculeTasks/rp_duvenaud.py: -------------------------------------------------------------------------------- 1 | """ 2 | Balasubramanian Srinivasan and Ryan L Murphy 3 | This code implements so-called RP-Duvenaud 4 | (1) \harrow{f} is defined as the Graph Conv model based on Duvenaud and implemented by the deepchem team 5 | (2) We assign unique one-hot identifiers to increase representational power, rendering the model perm-sensitive, 6 | but wrapping it in our pooling makes it permutation invariant again 7 | """ 8 | 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import numpy as np 14 | import sys 15 | import pickle 16 | import tensorflow as tf 17 | import deepchem as dc 18 | from random import shuffle 19 | from deepchem.models.tensorgraph.models.graph_models import GraphConvTensorGraph 20 | from deepchem.metrics import to_one_hot 21 | from deepchem.feat.mol_graphs import ConvMol 22 | from deepchem.models.tensorgraph.layers import Feature 23 | from deepchem.models.tensorgraph.layers import Dense, GraphConv, BatchNorm 24 | from deepchem.models.tensorgraph.layers import GraphPool, GraphGather 25 | from deepchem.models.tensorgraph.layers import Dense, SoftMax, SoftMaxCrossEntropy, WeightedError, Stack 26 | from deepchem.models.tensorgraph.layers import Label, Weights 27 | from deepchem.models.tensorgraph.tensor_graph import TensorGraph 28 | tg = TensorGraph(use_queue=False) 29 | 30 | TASK = sys.argv[1] # 'tox_21', 'hiv', 'muv 31 | METHOD = sys.argv[2] # Either 'unique_ids' or 'unique_local' 32 | RUN_NUM = sys.argv[3] 33 | 34 | batch_size = 96 35 | NUM_EPOCHS = 100 36 | INFERENCE_TIME_PERMUTATIONS = 10 37 | 38 | 39 | def generate_rp_vertex_feats_unique_ids(vertex_feats): 40 | """ 41 | Use the "unique-ids" scheme for assigning one-hot unique IDs to atoms. 42 | Unique scheme gives each and every atom in the molecule it's own ID, in contrast with the "local" method below 43 | 44 | Implicitly, max_atoms comes from the global env 45 | 46 | :param vertex_feats: Matrix of endowed vertex (atom) attributes Xv that come with the data 47 | :return: concat(Xv, Ids) where Ids is a matrix of one-hot identifiers for the atom 48 | """ 49 | K = vertex_feats.shape[0] 50 | appender = np.zeros((K, max_atoms)) 51 | atom_permute = list(range(K)) 52 | shuffle(atom_permute) 53 | for i in range(K): 54 | atom_number = atom_permute[0] 55 | atom_permute.pop(0) 56 | appender[i][atom_number] = 1 57 | vertex_feats_appended = np.concatenate((vertex_feats, appender),axis=1) 58 | return vertex_feats_appended 59 | 60 | 61 | def generate_rp_vertex_feats_unique_local(vertex_feats): 62 | """ 63 | Use the "unique-local" scheme for assigning one-hot unique IDs to atoms. 64 | Here, atoms of the same type get unique one-hot IDS but atoms of a different type might have the same ID 65 | 66 | For example, given two carbons and two hydrogens, 67 | (C, (1 ,0)) 68 | (C, (0 ,1)) 69 | (H, (1 ,0)) 70 | (H, (0 ,1)) 71 | 72 | Implicitly, max_atoms comes from the global env 73 | 74 | :param vertex_feats: Matrix of endowed vertex (atom) attributes Xv that come with the data 75 | :return: concat(Xv, Ids) where Ids is a matrix of one-hot identifiers for the atom 76 | """ 77 | K = vertex_feats.shape[0] 78 | appender = np.zeros((K, max_atoms)) 79 | # 80 | # Find the number of atoms of each type 81 | # If there are m atoms, allocate a list [0, 1, ..., m-1] of identifiers 82 | # that will get mapped to one-hot encodings 83 | # 84 | count_tracker = {} 85 | for i in range(K): 86 | atom = np.array_str(vertex_feats[i],max_line_width=300) 87 | if atom in count_tracker: 88 | count_tracker[atom].append(len(count_tracker[atom])) 89 | else : 90 | count_tracker[atom]=[0] 91 | # 92 | # Shuffle the lists: pi sgd 93 | # 94 | for item in count_tracker: 95 | shuffle(count_tracker[item]) 96 | # 97 | # Map IDs to one-hot encodings 98 | # 99 | for i in range(K): 100 | atom = np.array_str(vertex_feats[i],max_line_width=300) 101 | unique_local_id = count_tracker[atom][0] 102 | count_tracker[atom].pop(0) 103 | appender[i][unique_local_id] = 1 104 | 105 | vertex_feats_appended = np.concatenate((vertex_feats, appender),axis=1) 106 | return vertex_feats_appended 107 | 108 | 109 | def generate_new_X(dataset): 110 | new_array = [] 111 | size = dataset.shape[0] 112 | for i in range(size): 113 | mol = dataset[i] 114 | atom_feats = mol.get_atom_features() 115 | if METHOD == 'unique_local': 116 | new_atom_feats = generate_rp_vertex_feats_unique_local(atom_feats) 117 | elif METHOD == 'unique_ids': 118 | new_atom_feats = generate_rp_vertex_feats_unique_ids(atom_feats) 119 | adjacent_list = mol.get_adjacency_list() 120 | new_mol = dc.feat.mol_graphs.ConvMol(new_atom_feats, adjacent_list) 121 | new_array.append(new_mol) 122 | return np.array(new_array) 123 | 124 | 125 | def data_generator(dataset, epochs=1, predict=False, pad_batches=True): 126 | for epoch in range(epochs): 127 | for ind, (X_b, y_b, w_b, ids_b) in enumerate( 128 | dataset.iterbatches(batch_size, pad_batches=pad_batches, deterministic=True)): 129 | d = {} 130 | for index, label in enumerate(labels): 131 | d[label] = to_one_hot(y_b[:, index]) 132 | d[weights] = w_b 133 | multiConvMol = ConvMol.agglomerate_mols(X_b) 134 | d[atom_features] = multiConvMol.get_atom_features() 135 | d[degree_slice] = multiConvMol.deg_slice 136 | d[membership] = multiConvMol.membership 137 | for i in range(1, len(multiConvMol.get_deg_adjacency_lists())): 138 | d[deg_adjs[i - 1]] = multiConvMol.get_deg_adjacency_lists()[i] 139 | yield d 140 | 141 | 142 | def reshape_y_pred(y_true, y_pred): 143 | """ 144 | TensorGraph.Predict returns a list of arrays, one for each output 145 | We also have to remove the padding on the last batch 146 | Metrics taks results of shape (samples, n_task, prob_of_class) 147 | """ 148 | if TASK != 'hiv': 149 | n_samples = len(y_true) 150 | retval = np.stack(y_pred, axis=1) 151 | return retval[:n_samples] 152 | else : 153 | n_samples = len(y_true) 154 | retval = y_pred 155 | return retval[:n_samples] 156 | 157 | 158 | def proper_inference(): 159 | for iteration in range(INFERENCE_TIME_PERMUTATIONS): 160 | if TASK == 'tox_21': 161 | new_test_data = generate_new_X(train_dataset.X[5000:]) 162 | rp_dataset = dc.data.datasets.DiskDataset.from_numpy(new_test_data, train_dataset.y[5000:], train_dataset.w[5000:] ,train_dataset.ids[5000:], data_dir=None) 163 | else: 164 | new_test_data = generate_new_X(test_dataset.X) 165 | rp_dataset = dc.data.datasets.DiskDataset.from_numpy(new_test_data, test_dataset.y, test_dataset.w ,test_dataset.ids, data_dir=None) 166 | preds = tg.predict_on_generator(data_generator(rp_dataset, predict=True)) 167 | preds = reshape_y_pred(rp_dataset.y, preds) 168 | if iteration == 0: 169 | sum_out = np.zeros((preds.shape)) 170 | sum_out += preds 171 | 172 | one_random_perm = preds 173 | sum_out = sum_out/float(INFERENCE_TIME_PERMUTATIONS) 174 | return (one_random_perm, sum_out) 175 | 176 | # 177 | # Set up logging information 178 | # 179 | default_stdout = sys.stdout 180 | logger_file = str(TASK) + "_" + str(METHOD) + "_" + str(RUN_NUM) + ".log" 181 | lfile = open(logger_file, 'w') 182 | sys.stdout = lfile 183 | val_roc_tracker = [] 184 | test_roc_tracker = [] 185 | val_pfile = str(TASK) + "_" + str(METHOD) + "_" + str(RUN_NUM) + "_val.pickle" 186 | test_pfile = str(TASK) + "_" + str(METHOD) + "_" + str(RUN_NUM) + "_test.pickle" 187 | 188 | print("Dataset is ", TASK) 189 | print("RP Technique used is ", METHOD) 190 | print("Run number is ", RUN_NUM) 191 | sys.stdout.flush() 192 | 193 | if TASK == 'tox_21': 194 | from deepchem.molnet import load_tox21 as dataloader 195 | NUM_TASKS = 12 196 | elif TASK == 'hiv': 197 | from deepchem.molnet import load_hiv as dataloader 198 | NUM_TASKS = 1 199 | elif TASK == 'muv': 200 | from deepchem.molnet import load_muv as dataloader 201 | NUM_TASKS = 17 202 | 203 | 204 | # ------------------------------------------------- 205 | # Load datasets, tasks, and transformers 206 | # The number of tasks in each dataset can be found in Table 1 of MoleculeNet: A Benchmark for Molecular Machine Learning 207 | # by Wu et. al. 208 | current_tasks, current_datasets, transformers = dataloader(featurizer='GraphConv', reload=True, split='random') 209 | train_dataset, valid_dataset, test_dataset = current_datasets 210 | 211 | # 212 | # Determine the max number of atoms 213 | # 214 | max_atoms = 0 215 | for mol in train_dataset.X: 216 | num_atom = mol.get_num_atoms() 217 | if num_atom > max_atoms : 218 | max_atoms = num_atom 219 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 220 | # Look ahead to largets molecule in test-valid 221 | # 222 | # We assume for our experiments that we can, in general, look ahead to the test set to find the 223 | # largest molecule. This is needed to allocated a fixed-sized feature vector (padding with zeros as needed) 224 | # Further discussion is found in our appendix. 225 | # 226 | # Tox21 is treated differently: we manually (randomly) split into test, train, and valid directly from train_dataset.X 227 | # (rather than letting deepchem provide the data directly) 228 | # Reason: In the early stages of developing the code, the valid_dataset and test_dataset were empty for tox and 229 | # we observed a comment in the deepchem source code leading us to believe this was intended. 230 | # Thus, when we access valid_dataset.X and test_dataset.X, we don't do it for tox21. We only later 231 | # found that we could access tox21 validation and test. But we do this for all models, so the treatment is fair 232 | # 233 | # 234 | # This treatment is done for all models, so the comparison is fair. 235 | # 236 | if TASK!= 'tox_21': 237 | for mol in valid_dataset.X: 238 | num_atom = mol.get_num_atoms() 239 | if num_atom > max_atoms : 240 | max_atoms = num_atom 241 | for mol in test_dataset.X: 242 | num_atom = mol.get_num_atoms() 243 | if num_atom > max_atoms : 244 | max_atoms = num_atom 245 | # 246 | # Build up model object 247 | # Follow: https://deepchem.io/docs/notebooks/graph_convolutional_networks_for_tox21.html 248 | # 249 | atom_features = Feature(shape=(None, 75+max_atoms)) 250 | degree_slice = Feature(shape=(None, 2), dtype=tf.int32) 251 | membership = Feature(shape=(None,), dtype=tf.int32) 252 | 253 | deg_adjs = [] 254 | for i in range(0, 10 + 1): 255 | deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32) 256 | deg_adjs.append(deg_adj) 257 | 258 | gc1 = GraphConv( 259 | 64, 260 | activation_fn=tf.nn.relu, 261 | in_layers=[atom_features, degree_slice, membership] + deg_adjs) 262 | batch_norm1 = BatchNorm(in_layers=[gc1]) 263 | gp1 = GraphPool(in_layers=[batch_norm1, degree_slice, membership] + deg_adjs) 264 | gc2 = GraphConv( 265 | 64, 266 | activation_fn=tf.nn.relu, 267 | in_layers=[gp1, degree_slice, membership] + deg_adjs) 268 | batch_norm2 = BatchNorm(in_layers=[gc2]) 269 | gp2 = GraphPool(in_layers=[batch_norm2, degree_slice, membership] + deg_adjs) 270 | dense = Dense(out_channels=128, activation_fn=tf.nn.relu, in_layers=[gp2]) 271 | batch_norm3 = BatchNorm(in_layers=[dense]) 272 | readout = GraphGather( 273 | batch_size=batch_size, 274 | activation_fn=tf.nn.tanh, 275 | in_layers=[batch_norm3, degree_slice, membership] + deg_adjs) 276 | 277 | costs = [] 278 | labels = [] 279 | for task in range(len(current_tasks)): 280 | classification = Dense( 281 | out_channels=2, activation_fn=None, in_layers=[readout]) 282 | 283 | softmax = SoftMax(in_layers=[classification]) 284 | tg.add_output(softmax) 285 | 286 | label = Label(shape=(None, 2)) 287 | labels.append(label) 288 | cost = SoftMaxCrossEntropy(in_layers=[label, classification]) 289 | costs.append(cost) 290 | 291 | all_cost = Stack(in_layers=costs, axis=1) 292 | weights = Weights(shape=(None, len(current_tasks))) 293 | loss = WeightedError(in_layers=[all_cost, weights]) 294 | tg.set_loss(loss) 295 | 296 | if TASK != 'tox_21': 297 | new_train_data = generate_new_X(train_dataset.X) 298 | new_train_dataset = dc.data.datasets.DiskDataset.from_numpy(new_train_data, train_dataset.y, train_dataset.w ,train_dataset.ids, data_dir=None) 299 | print("Train Data - added RP") 300 | new_valid_data = generate_new_X(valid_dataset.X) 301 | new_valid_dataset = dc.data.datasets.DiskDataset.from_numpy(new_valid_data, valid_dataset.y, valid_dataset.w ,valid_dataset.ids, data_dir=None) 302 | print("Valid Data - added RP") 303 | new_test_data = generate_new_X(test_dataset.X) 304 | new_test_dataset = dc.data.datasets.DiskDataset.from_numpy(new_test_data, test_dataset.y, test_dataset.w ,test_dataset.ids, data_dir=None) 305 | print("Test Data - added RP") 306 | else : 307 | new_train_data = generate_new_X(train_dataset.X[:3800]) 308 | new_train_dataset = dc.data.datasets.DiskDataset.from_numpy(new_train_data, train_dataset.y[:3800], train_dataset.w[:3800] ,train_dataset.ids[:3800], data_dir=None) 309 | print("Train Data - added RP - tox21") 310 | new_valid_data = generate_new_X(train_dataset.X[3800:5000]) 311 | new_valid_dataset = dc.data.datasets.DiskDataset.from_numpy(new_valid_data, train_dataset.y[3800:5000], train_dataset.w[3800:5000] ,train_dataset.ids[3800:5000], data_dir=None) 312 | print("Valid Data - added RP - tox21") 313 | new_test_data = generate_new_X(train_dataset.X[5000:]) 314 | new_test_dataset = dc.data.datasets.DiskDataset.from_numpy(new_test_data, train_dataset.y[5000:], train_dataset.w[5000:] ,train_dataset.ids[5000:], data_dir=None) 315 | print("Test Data - added RP - tox21") 316 | 317 | metric = dc.metrics.Metric( 318 | dc.metrics.roc_auc_score, np.mean, mode="classification") 319 | 320 | best_auc_score = 0.0 321 | for i in range(NUM_EPOCHS): 322 | print("Epoch Num: ", i) 323 | sys.stdout.flush() 324 | tg.fit_generator(data_generator(new_train_dataset, epochs=1)) 325 | if TASK != 'tox_21': 326 | new_train_data = generate_new_X(train_dataset.X) 327 | new_train_dataset = dc.data.datasets.DiskDataset.from_numpy(new_train_data, train_dataset.y, train_dataset.w ,train_dataset.ids, data_dir=None) 328 | else : 329 | new_train_data = generate_new_X(train_dataset.X[:3800]) 330 | new_train_dataset = dc.data.datasets.DiskDataset.from_numpy(new_train_data, train_dataset.y[:3800], train_dataset.w[:3800] ,train_dataset.ids[:3800], data_dir=None) 331 | print("Validation Loss") 332 | valid_predictions = tg.predict_on_generator(data_generator(new_valid_dataset, predict=True)) 333 | valid_predictions = reshape_y_pred(new_valid_dataset.y, valid_predictions) 334 | valid_scores = metric.compute_metric(new_valid_dataset.y, valid_predictions, new_valid_dataset.w) 335 | print("Valid ROC-AUC Score: %f" % valid_scores) 336 | val_roc_tracker.append(valid_scores) 337 | if valid_scores > best_auc_score: 338 | best_auc_score = valid_scores 339 | one_random_perm, sum_out = proper_inference() 340 | 341 | test_scores = metric.compute_metric(new_test_dataset.y, one_random_perm, new_test_dataset.w) 342 | print("test ROC-AUC 1 inference Score: %f" % test_scores) 343 | test_scores = metric.compute_metric(new_test_dataset.y, sum_out, new_test_dataset.w) 344 | print("test ROC-AUC proper inference Score: %f" % test_scores) 345 | test_roc_tracker.append(test_scores) 346 | with open(val_pfile, 'wb') as file: 347 | pickle.dump(val_roc_tracker, file) 348 | with open(test_pfile, 'wb') as file: 349 | pickle.dump(test_roc_tracker, file) 350 | 351 | 352 | print("Evaluating model") 353 | train_predictions = tg.predict_on_generator(data_generator(new_train_dataset, predict=True)) 354 | train_predictions = reshape_y_pred(new_train_dataset.y, train_predictions) 355 | train_scores = metric.compute_metric(new_train_dataset.y, train_predictions, new_train_dataset.w) 356 | print("Training ROC-AUC Score: %f" % train_scores) 357 | 358 | valid_predictions = tg.predict_on_generator(data_generator(new_valid_dataset, predict=True)) 359 | valid_predictions = reshape_y_pred(new_valid_dataset.y, valid_predictions) 360 | valid_scores = metric.compute_metric(new_valid_dataset.y, valid_predictions, new_valid_dataset.w) 361 | print("Valid ROC-AUC Score: %f" % valid_scores) 362 | 363 | 364 | one_random_perm, sum_out = proper_inference() 365 | test_predictions = tg.predict_on_generator(data_generator(new_test_dataset, predict=True)) 366 | test_predictions = reshape_y_pred(new_test_dataset.y, test_predictions) 367 | test_scores = metric.compute_metric(new_test_dataset.y, one_random_perm, new_test_dataset.w) 368 | print("test ROC-AUC 1 inference Score: %f" % test_scores) 369 | test_scores = metric.compute_metric(new_test_dataset.y, sum_out, new_test_dataset.w) 370 | print("test ROC-AUC 20 inference Score: %f" % test_scores) 371 | 372 | 373 | sys.stdout = default_stdout 374 | lfile.close() 375 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Relational Pooling for Graph Representations 2 | 3 | ## Overview 4 | This is the code associated with the paper [Relational Pooling for Graph Representations](https://arxiv.org/abs/1903.02541). 5 | Accepted at ICML, 2019. 6 | 7 | Our first task evaluates RP-GIN, a powerful model we propose to make Graph Isomorphism Network (GIN) [Xu et. al. 2019](https://arxiv.org/abs/1810.00826) more powerful than its corresponding WL[1] test. 8 | Our second set of tasks uses molecule datasets to evaluate different instantiations of RP. 9 | 10 | The models are described in plain English in the appendix of our paper, but feel free to contact us with any questions (see below). 11 | 12 | ## Requirements 13 | * [PyTorch](https://www.pytorch.org) 14 | * Python 3 15 | 16 | For the first set of tasks, you will need 17 | * SciPy 18 | * scikit-learn 19 | * docopt and schema for parsing arguments from command line 20 | 21 | For the molecular tasks, you will need 22 | * [DeepChem](https://github.com/deepchem/deepchem) and its associated dependencies 23 | 24 | ## Examples: How to Run 25 | * An example call for the synthetic tasks follows. We trained these models on CPUs. Please see the docstring for further details 26 | ``` 27 | python Run_Gin_Experiment.py --cv-fold 0 --model-type rpGin --out-weight-dir /some/path --out-log-dir /another/path/ --onehot-id-dim 10 28 | ``` 29 | * Now we show examples for the molecular tasks. The Tox 21 dataset is smaller so we demonstrate with that. 30 | For the molecular k-ary tasks: 31 | ``` 32 | python Duvenaud-kary.py 'tox_21' 20 33 | ``` 34 | * For the molecular RP-Duvenaud tasks: 35 | ``` 36 | python rp_duvenaud.py 'tox_21' 'unique_local' 0 37 | ``` 38 | * For the molecular RNN task: 39 | ``` 40 | python RNN-DFS.py 'tox_21' 41 | ``` 42 | 43 | ## Data 44 | * The datasets for the first set of tasks are available in the Synthetic_Data directory. 45 | * The datasets for the molecular tasks are all available in the DeepChem package. 46 | 47 | ## Questions and Contact 48 | Please feel free to reach out to Ryan Murphy (murph213@purdue.edu) if you have any questions. 49 | 50 | ## Citation 51 | If you use this code, please consider citing our paper. Here is the Bibtex entry: 52 | ``` 53 | @InProceedings{murphy19a, 54 | title = {Relational Pooling for Graph Representations}, 55 | author = {Murphy, Ryan and Srinivasan, Balasubramaniam and Rao, Vinayak and Ribeiro, Bruno}, 56 | booktitle = {Proceedings of the 36th International Conference on Machine Learning}, 57 | pages = {4663--4673}, 58 | year = {2019}, 59 | editor = {Chaudhuri, Kamalika and Salakhutdinov, Ruslan}, 60 | volume = {97}, 61 | series = {Proceedings of Machine Learning Research}, 62 | address = {Long Beach, California, USA}, 63 | month = {09--15 Jun}, 64 | publisher = {PMLR}, 65 | pdf = {http://proceedings.mlr.press/v97/murphy19a/murphy19a.pdf}, 66 | url = {http://proceedings.mlr.press/v97/murphy19a.html} 67 | } 68 | ``` -------------------------------------------------------------------------------- /Run_Gin_Experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | (anon). Synthetic experiments with GIN and RP-GIN 3 | 4 | Usage: 5 | Run_Gin_Experiment.py (--cv-fold ) (--out-weight-dir ) (--out-log-dir ) [--use-batchnorm] [--dense-dropout-prob ] 6 | [--num-mlp-hidden ] [--num-gnn-layers ] 7 | [--model-type ] 8 | [--set-epsilon-zero] [--vertex-embed-dim ] 9 | [--mlp-hidden-dim ] [--learning-rate ] 10 | [--num-epochs ] [--num-inf-perm ] 11 | [--onehot-id-dim ] [--seed-val ] 12 | 13 | Options: 14 | --cv-fold Which fold in cross-validation: 0 thru 5 15 | --out-weight-dir Output directory where trained weights (and any other objects) will be stored 16 | --out-log-dir Output directory where logfiles will be saved 17 | --use-batchnorm Boolean flag, should batch normalization be implemented? 18 | --dense-dropout-prob Dropout probability for the dense layer [default: 0.0] 19 | --num-mlp-hidden Number of hidden layers in the MLP [default: 2] 20 | --num-gnn-layers Number of iterations of WL-like aggregation [default: 5] 21 | --model-type Either 'regularGin' or 'dataAugGin' or 'rpGin. Note: the model choice influences how the data is loaded/used [default: regularGin] 22 | --set-epsilon-zero Boolean flag, should epsilon be set to zero? By default, we train epsilon via backprop 23 | --vertex-embed-dim Dimension of each vertex's embedding [default: 16] 24 | --mlp-hidden-dim Number of hidden units in the aggregator's multilayer perceptron [default: 16] 25 | --learning-rate Learning rate for Adam Optimizer [default: 0.01] 26 | --num-epochs Number of epochs for training [default: 200] 27 | --num-inf-perm Number of inference-time permutations [default: 5] 28 | --onehot-id-dim For use with rpGin. Dimension of the one-hot ID. [default: 41] 29 | --seed-val Seed value, to get different random inits and variability [default: 1337] 30 | """ 31 | # python Run_Gin_Experiment.py --cv-fold 0 --model-type regularGin --num-epochs 100 --out-weight-dir some/folder --out-log-dir some/other/folder 32 | # 33 | import docopt 34 | import os 35 | import pickle 36 | import random 37 | from GIN.GIN_model import * 38 | from GIN.GIN_utils import construct_onehot_ids 39 | from training_utils import * 40 | from schema import Schema, Use, And, Or 41 | from operator import itemgetter 42 | 43 | def get_filename_prefix(args): 44 | """ Create a string to name weights file, log file, etc.""" 45 | prefix = args['--model-type'] + "_cv_" + str(args['--cv-fold']) 46 | if args['--use-batchnorm']: 47 | prefix += "batchnorm" 48 | 49 | if args['--dense-dropout-prob'] > 0.0: 50 | prefix += "_dropout_{}".format(args['--dense-dropout-prob']) 51 | 52 | if args['--set-epsilon-zero']: 53 | prefix += "_no_epsilon" 54 | 55 | if args['--num-gnn-layers'] != 5: 56 | prefix += "_gnn_layers_{}".format(args['--num-gnn-layers']) 57 | 58 | if args['--num-mlp-hidden'] != 2: 59 | prefix += "_mlp_hidden_{}".format(args['--num-mlp-hidden']) 60 | 61 | if args['--num-inf-perm'] != 5: 62 | prefix += "_num_inf_perm_{}".format(args['--num-inf-perm']) 63 | 64 | if args['--onehot-id-dim'] != 41: 65 | prefix += "_onehot_id_dim_{}".format(args['--onehot-id-dim']) 66 | 67 | prefix += "_s" + str(args['--seed-val']) + "_epochs_" + str(args['--num-epochs']) + "_" 68 | return prefix 69 | 70 | 71 | def get_train_val_idx(num_graphs, cv_fold): 72 | """ Return a tuple of the train and val indices, 73 | depending on the cv_fold 74 | This method shuffles the index (with a seed) 75 | The shuffle is consistent across machines with python3""" 76 | # 77 | # Extract indices of train and val in terms of the shuffled list 78 | # Balanced across test and train 79 | # Assumes 10-class 80 | # 81 | random.seed(1) 82 | num_classes = 10 83 | num_per_class = int(num_graphs/num_classes) 84 | val_size = int(0.2 * num_per_class) 85 | idx_to_classes = {} 86 | val_idx = [] 87 | train_idx = [] 88 | for cc in range(num_classes): 89 | idx_to_classes[cc] = list(range(cc*num_per_class, (cc+1)*num_per_class)) 90 | random.shuffle(idx_to_classes[cc]) 91 | # These indices correspond to the validation for this class. 92 | class_val_idx = slice(cv_fold * val_size, cv_fold * val_size + val_size, 1) 93 | # Extract validation. 94 | vals = idx_to_classes[cc][class_val_idx] 95 | val_idx.extend(vals) 96 | train_idx.extend(list(set(idx_to_classes[cc]) - set(vals))) 97 | # 98 | return tuple(train_idx), tuple(val_idx) 99 | 100 | def accuracy(yhat, y, print_scores=False): 101 | """ Compute accuracy """ 102 | scores = torch.argmax(yhat, dim=1) 103 | if print_scores: 104 | logging.info(scores) 105 | num_correct = torch.sum(scores == y).item() 106 | return num_correct/float(len(y)) 107 | 108 | if __name__ == '__main__': 109 | requirements = { 110 | '--use-batchnorm': Use(bool), 111 | '--dense-dropout-prob': And(Use(float), lambda fff: 0.0 <= fff < 1.0), 112 | '--num-mlp-hidden': Use(int), 113 | '--num-gnn-layers': Use(int), 114 | '--cv-fold': And(Use(int), lambda nnn: 0 <= nnn < 5), 115 | '--out-weight-dir': Use(str), 116 | '--out-log-dir': Use(str), 117 | '--model-type': And(Use(str), lambda sss: sss in ['regularGin', 'dataAugGin', 'rpGin']), 118 | '--set-epsilon-zero': Use(bool), 119 | '--vertex-embed-dim': And(Use(int), lambda mmm: mmm > 0), 120 | '--mlp-hidden-dim': And(Use(int), lambda lll: lll > 0), 121 | '--learning-rate': And(Use(float), lambda flo: flo > 0.0), 122 | '--num-epochs': And(Use(int), lambda epo: epo > 9), 123 | '--num-inf-perm': Use(int), 124 | '--onehot-id-dim': And(Use(int), lambda idd: idd > 0), 125 | '--seed-val': Use(int) 126 | } 127 | args = docopt.docopt(__doc__) 128 | args = Schema(requirements).validate(args) 129 | assert os.path.isdir(args['--out-weight-dir']), "Must enter a valid output weights directory" 130 | assert os.path.isdir(args['--out-log-dir']), "Must enter a valid output logs directory" 131 | # 132 | # Set up paths for logging and weight saving. 133 | # 134 | base_dir = os.getcwd() 135 | filename_pre = get_filename_prefix(args) 136 | log_file = os.path.join(args['--out-log-dir'], 137 | filename_pre + '.log') 138 | 139 | weights_file = os.path.join(args['--out-weight-dir'], 140 | filename_pre + '.pth') 141 | training_metrics_file = os.path.join(args['--out-weight-dir'], 142 | filename_pre + '.pkl') 143 | 144 | set_logger(log_file) 145 | logging.info(args) 146 | # 147 | # Load graphs, y 148 | # 149 | logging.info("---Loading data...---") 150 | sparse_adjmats = pickle.load(open(os.path.join(base_dir, 'Synthetic_Data', 'graphs_Kary_Deterministic_Graphs.pkl'), 'rb')) 151 | y = torch.load(os.path.join(base_dir, 'Synthetic_Data', 'y_Kary_Deterministic_Graphs.pt')) 152 | 153 | num_graphs = len(sparse_adjmats) 154 | logging.info("{} Adjacency matrices were loaded".format(num_graphs)) 155 | # 156 | # Load X 157 | # Standard WL-approach: featureless implies use a constant vertex attribute, for every vertex 158 | # (such data could be generated here rather than loaded, but this coding structure easily 159 | # lends itself to future extensions) 160 | # 161 | if args['--model-type'] == 'regularGin': 162 | # X_all = torch.load(os.path.join(base_dir, 'Synthetic_Data', 'X_unity_Kary_Deterministic_Graphs.pt')) 163 | X_list = pickle.load(open(os.path.join(base_dir, 'Synthetic_Data', 'X_unity_list_Kary_Deterministic_Graphs.pkl'), 'rb')) 164 | elif args['--model-type'] == 'dataAugGin': 165 | X_list = pickle.load(open(os.path.join(base_dir, 'Synthetic_Data', 'X_eye_list_Kary_Deterministic_Graphs.pkl'), 'rb')) 166 | elif args['--model-type'] == 'rpGin': 167 | # 168 | # Set the dimension of the one hot id 169 | # (redefine it if the user makes it too big) 170 | largest_adjmat = np.max([adjmat.shape[0] for adjmat in sparse_adjmats]) 171 | if args['--onehot-id-dim'] > largest_adjmat: 172 | logging.info("Your selected value of onehot-id-dim, {}, is larger than the largest graph".format(args['--onehot-id-dim'])) 173 | logging.info("I am resetting onehot-id-dim = {}, the largest adjmat".format(largest_adjmat)) 174 | onehot_id_dim = largest_adjmat 175 | else: 176 | onehot_id_dim = args['--onehot-id-dim'] 177 | # 178 | # Construct one hot ids 179 | # 180 | X_list = [] 181 | for mat in sparse_adjmats: 182 | X_list.append(construct_onehot_ids(mat.shape[0], onehot_id_dim)) 183 | # 184 | # split according to cv fold 185 | # 186 | logging.info("---splitting into training and validation folds---") 187 | logging.info(" The indices are shuffled, and the shuffle is consistent on many machines as long as python3 is used") 188 | train_idx, val_idx = get_train_val_idx(num_graphs, args['--cv-fold']) 189 | 190 | train_adjmats = list(itemgetter(*train_idx)(sparse_adjmats)) 191 | val_adjmats = list(itemgetter(*val_idx)(sparse_adjmats)) 192 | y_train = y[torch.tensor(train_idx)] 193 | y_val = y[torch.tensor(val_idx)] 194 | # 195 | # Print class distribution 196 | # 197 | logging.info("------Class distributions---------") 198 | logging.info("train:") 199 | logging.info(np.unique(y_train.numpy(), return_counts=True)) 200 | logging.info("test:") 201 | logging.info(np.unique(y_val.numpy(), return_counts=True)) 202 | 203 | X_train = torch.cat(itemgetter(*train_idx)(X_list), dim=0) 204 | X_val = torch.cat(itemgetter(*val_idx)(X_list), dim=0) 205 | # 206 | # Define model 207 | # 208 | torch.manual_seed(args['--seed-val']) 209 | np.random.seed(args['--seed-val']) # Used with rpGin, since random permutations are generated with scipy sparse (which uses np seed) 210 | logging.info("Building model...") 211 | 212 | if args['--use-batchnorm']: 213 | other_mlp_params = {'batchnorm': True} 214 | else: 215 | other_mlp_params = {} 216 | 217 | if args['--set-epsilon-zero']: 218 | eps_tunable = False 219 | else: 220 | eps_tunable = True 221 | 222 | if args['--model-type'] == 'regularGin': 223 | model = GinMultiGraph(adjmat_list=train_adjmats, 224 | input_data_dim=X_train.shape[1], 225 | num_agg_steps=args['--num-gnn-layers'], 226 | vertex_embed_dim=args['--vertex-embed-dim'], 227 | mlp_num_hidden=args['--num-mlp-hidden'], 228 | mlp_hidden_dim=args['--mlp-hidden-dim'], 229 | vertices_are_onehot=False, 230 | target_dim=10, 231 | epsilon_tunable=eps_tunable, 232 | dense_layer_dropout=args['--dense-dropout-prob'], 233 | other_mlp_parameters=other_mlp_params) 234 | 235 | elif args['--model-type'] == 'dataAugGin': 236 | model = GinMultiGraph(adjmat_list=train_adjmats, 237 | input_data_dim=X_train.shape[1], 238 | num_agg_steps=args['--num-gnn-layers'], 239 | vertex_embed_dim=args['--vertex-embed-dim'], 240 | mlp_num_hidden=args['--num-mlp-hidden'], 241 | mlp_hidden_dim=args['--mlp-hidden-dim'], 242 | vertices_are_onehot=True, 243 | target_dim=10, 244 | epsilon_tunable=eps_tunable, 245 | dense_layer_dropout=args['--dense-dropout-prob'], 246 | other_mlp_parameters=other_mlp_params) 247 | 248 | elif args['--model-type'] == 'rpGin': 249 | model = RpGin(adjmat_list=train_adjmats, 250 | input_data_dim=X_train.shape[1], 251 | num_agg_steps=args['--num-gnn-layers'], 252 | vertex_embed_dim=args['--vertex-embed-dim'], 253 | mlp_num_hidden=args['--num-mlp-hidden'], 254 | mlp_hidden_dim=args['--mlp-hidden-dim'], 255 | target_dim=10, 256 | featureless_case=True, 257 | vertices_are_onehot=True, 258 | epsilon_tunable=eps_tunable, 259 | dense_layer_dropout=args['--dense-dropout-prob'], 260 | other_mlp_parameters=other_mlp_params) 261 | 262 | logging.info(model) 263 | # 264 | # Train 265 | # 266 | metrics = {'acc_train': [], 'acc_val': [], 'loss_train': [], 'loss_val': []} 267 | 268 | logging.info("------Training Model---------") 269 | learning_rate = args['--learning-rate'] 270 | num_epochs = args['--num-epochs'] 271 | 272 | logging.info("Train X has shape {}".format(X_train.shape)) 273 | logging.info("Val X has shape {}".format(X_val.shape)) 274 | logging.info("Train y has shape {}".format(y_train.shape)) 275 | logging.info("Validation y has shape {}".format(y_val.shape)) 276 | 277 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 278 | loss_func = nn.CrossEntropyLoss() 279 | for epoch in range(num_epochs): 280 | model.train() 281 | optimizer.zero_grad() 282 | pred = model(train_adjmats, X_train) 283 | loss_train = loss_func(pred, y_train) 284 | loss_train.backward() 285 | optimizer.step() 286 | # 287 | # Evaluate model. 288 | # > loss and accuracy over validation 289 | # > accuracy over train 290 | model.eval() 291 | with torch.no_grad(): 292 | pred_val = model(val_adjmats, X_val) 293 | loss_val = loss_func(pred_val, y_val) 294 | 295 | # get accuracy and print predictions if it's regular GIN 296 | acc_train = accuracy(pred, y_train, print_scores=(epoch % 10 == 0)) 297 | acc_val = accuracy(pred_val, y_val, print_scores=(epoch % 10 == 0)) 298 | 299 | logging.info("~"*5) 300 | logging.info( 301 | "Epoch: %3d | Train Loss: %.5f | Val Loss: %.5f | Train Accuracy : %.5f | Val Accuracy : %.5f" % (epoch, loss_train, loss_val, acc_train, acc_val)) 302 | 303 | metrics['acc_train'].append(acc_train) 304 | metrics['acc_val'].append(acc_val) 305 | metrics['loss_val'].append(loss_val.item()) 306 | metrics['loss_train'].append(loss_train.item()) 307 | 308 | if args['--model-type'] == 'rpGin': 309 | with torch.no_grad(): 310 | pred_inf = model.inference(val_adjmats, X_val, args['--num-inf-perm']) 311 | final_accuracy = accuracy(pred_inf, y_val) 312 | logging.info("="*10) 313 | logging.info("Final accuracy: {}".format(final_accuracy)) 314 | logging.info("="*10) 315 | metrics['final_accuracy'] = final_accuracy 316 | # 317 | # Save model 318 | # 319 | logging.info("Saving model to file") 320 | logging.info(weights_file) 321 | torch.save(model.state_dict(), weights_file) 322 | logging.info("...done saving") 323 | # 324 | # Save metrics 325 | # 326 | logging.info("Saving metrics") 327 | logging.info(training_metrics_file) 328 | pickle.dump(metrics, open(training_metrics_file, 'wb')) 329 | logging.info("... done saving") 330 | -------------------------------------------------------------------------------- /Synthetic_Data/X_eye_list_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PurdueMINDS/RelationalPooling/a20f707deb437f171d4b2c0b28ca8aa3efbfac3d/Synthetic_Data/X_eye_list_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /Synthetic_Data/X_unity_list_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PurdueMINDS/RelationalPooling/a20f707deb437f171d4b2c0b28ca8aa3efbfac3d/Synthetic_Data/X_unity_list_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /Synthetic_Data/graphs_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PurdueMINDS/RelationalPooling/a20f707deb437f171d4b2c0b28ca8aa3efbfac3d/Synthetic_Data/graphs_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /Synthetic_Data/y_Kary_Deterministic_Graphs.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PurdueMINDS/RelationalPooling/a20f707deb437f171d4b2c0b28ca8aa3efbfac3d/Synthetic_Data/y_Kary_Deterministic_Graphs.pt -------------------------------------------------------------------------------- /training_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for training: hyperparm tuning and logging. 3 | 4 | Taken & modified from https://raw.githubusercontent.com/cs230-stanford/cs230-code-examples/master/pytorch/nlp/utils.py 5 | and referencing https://cs230-stanford.github.io/logging-hyperparams.html 6 | """ 7 | 8 | import json 9 | import logging 10 | import numpy as np 11 | 12 | 13 | class Params(): 14 | """Class that loads hyperparameters from a json file. 15 | 16 | Example: 17 | ``` 18 | params = Params(json_path) 19 | print(params.learning_rate) 20 | params.learning_rate = 0.5 # change the value of learning_rate in params 21 | ``` 22 | """ 23 | 24 | def __init__(self, json_path): 25 | with open(json_path) as f: 26 | params = json.load(f) 27 | self.__dict__.update(params) 28 | 29 | def save(self, json_path): 30 | with open(json_path, 'w') as f: 31 | json.dump(self.__dict__, f, indent=4) 32 | 33 | def update(self, json_path): 34 | """Loads parameters from json file""" 35 | with open(json_path) as f: 36 | params = json.load(f) 37 | self.__dict__.update(params) 38 | 39 | @property 40 | def dict(self): 41 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 42 | return self.__dict__ 43 | 44 | 45 | class RunningAverage(): 46 | """A simple class that maintains the running average of a quantity 47 | 48 | Example: 49 | ``` 50 | loss_avg = RunningAverage() 51 | loss_avg.update(2) 52 | loss_avg.update(4) 53 | loss_avg() = 3 54 | ``` 55 | """ 56 | 57 | def __init__(self): 58 | self.steps = 0 59 | self.total = 0 60 | 61 | def update(self, val): 62 | self.total += val 63 | self.steps += 1 64 | 65 | def __call__(self): 66 | return self.total / float(self.steps) 67 | 68 | 69 | def set_logger(log_path): 70 | """Set the logger to log info in terminal and file `log_path`. 71 | 72 | In general, it is useful to have a logger so that every output to the terminal is saved 73 | in a permanent file. Here we save it to `model_dir/train.log`. 74 | 75 | Example: 76 | ``` 77 | logging.info("Starting training...") 78 | ``` 79 | 80 | Args: 81 | log_path: (string) where to log 82 | """ 83 | logger = logging.getLogger() 84 | logger.setLevel(logging.INFO) 85 | 86 | if not logger.handlers: 87 | # Logging to a file 88 | file_handler = logging.FileHandler(log_path) 89 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 90 | logger.addHandler(file_handler) 91 | 92 | # Logging to console 93 | stream_handler = logging.StreamHandler() 94 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 95 | logger.addHandler(stream_handler) 96 | 97 | 98 | def save_dict_to_json(d, json_path): 99 | """Saves dict of floats in json file 100 | 101 | Args: 102 | d: (dict) of float-castable values (np.float, int, float, etc.) 103 | json_path: (string) path to json file 104 | """ 105 | with open(json_path, 'w') as f: 106 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) 107 | d = {k: float(v) for k, v in d.items()} 108 | json.dump(d, f, indent=4) 109 | 110 | 111 | def get_n_params(model): 112 | """ 113 | Count number of parameters in the model 114 | Courtesy: 115 | https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/5 116 | """ 117 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 118 | params = sum([np.prod(p.size()) for p in model_parameters]) 119 | return params 120 | --------------------------------------------------------------------------------