├── README.md ├── config └── params_universal_graph_embedding_model.py ├── dataloader ├── data.py ├── graph_dataloader.py ├── graph_datasets.py ├── read_graph_datasets.py └── read_qm8_dataset.py ├── hyperopt └── run_train_hyperopt.py ├── torch_dgl ├── layers │ ├── graph_capsule_layer.py │ └── ugd_input_layer.py └── models │ ├── model_universal_graph_embedding.py │ └── universal_graph_encoder.py ├── train └── train_ugd.py └── utils ├── compute_fgsd_features.py ├── compute_wl_kernel.py ├── fast_fgsd_features.py ├── fast_wl_kernel.py ├── read_sdf_file.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/190910086/graph-classification-on-collab)](https://paperswithcode.com/sota/graph-classification-on-collab?p=190910086) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/190910086/graph-classification-on-imdb-m)](https://paperswithcode.com/sota/graph-classification-on-imdb-m?p=190910086) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/190910086/graph-classification-on-proteins)](https://paperswithcode.com/sota/graph-classification-on-proteins?p=190910086) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/190910086/graph-classification-on-dd)](https://paperswithcode.com/sota/graph-classification-on-dd?p=190910086) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/190910086/graph-classification-on-imdb-b)](https://paperswithcode.com/sota/graph-classification-on-imdb-b?p=190910086) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/190910086/graph-classification-on-ptc)](https://paperswithcode.com/sota/graph-classification-on-ptc?p=190910086) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/190910086/graph-classification-on-enzymes)](https://paperswithcode.com/sota/graph-classification-on-enzymes?p=190910086) 8 | 9 | 10 | # Paper Code: Learning Universal Graph Neural Network Embeddings With Aid Of Transfer Learning 11 | Universal-Graph-Embedding-Neural-Network Paper: https://arxiv.org/abs/1909.10086 12 | 13 | ## Package Requirements: 14 | 15 | pytorch 16 | 17 | sklearn 18 | 19 | scipy 20 | 21 | shutil 22 | 23 | comet_ml (optional, required for reproducible experiments) 24 | 25 | matlab.engine (optional, required for computing FGSD Graph Kernel) 26 | 27 | ## Running main script: 28 | 29 | python train/train_ugd.py 30 | 31 | 32 | -------------------------------------------------------------------------------- /config/params_universal_graph_embedding_model.py: -------------------------------------------------------------------------------- 1 | from tsalib import dim_vars 2 | from sklearn.model_selection import ParameterSampler 3 | 4 | 5 | def split(a, n): 6 | k, m = divmod(len(a), n) 7 | return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)) 8 | 9 | 10 | B, N, E = dim_vars('Batch_Size Graph_Nodes Edge_List') 11 | D, H, K, C = dim_vars('Node_Input_Features Node_Hidden_Features Graph_Kernel_Dim Num_Classes') 12 | 13 | hyperparams_grid = { 14 | 'model_name': ['universal-graph-embedding'], 15 | 'dataset_name': ['DD'], 16 | 'save_steps': [200], 17 | 'run_on_comet': [True], 18 | 'gpu_device': [0], 19 | 20 | 'hidden_dim': [16, 32, 64], 21 | 'num_gconv_layers': [5, 7], 22 | 'num_gfc_layers': [2, 4], 23 | 'batch_size': [128, 64, 32], 24 | 'drop_prob': [0, 0.2], 25 | 'num_epochs': [3000] 26 | } 27 | 28 | 29 | gen_params_set = 1 30 | for key, val in hyperparams_grid.items(): 31 | gen_params_set = gen_params_set * len(val) 32 | 33 | params_list = list(ParameterSampler(hyperparams_grid, n_iter=gen_params_set)) 34 | -------------------------------------------------------------------------------- /dataloader/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Data(object): 4 | def __init__(self, 5 | x=None, 6 | edge_index=None, 7 | edge_attr=None, 8 | y=None, 9 | pos=None): 10 | self.x = x 11 | self.edge_index = edge_index 12 | self.edge_attr = edge_attr 13 | self.y = y 14 | self.pos = pos 15 | 16 | @staticmethod 17 | def from_dict(dictionary): 18 | data = Data() 19 | for key, item in dictionary.items(): 20 | data[key] = item 21 | return data 22 | 23 | def __getitem__(self, key): 24 | return getattr(self, key) 25 | 26 | def __setitem__(self, key, item): 27 | setattr(self, key, item) 28 | 29 | @property 30 | def keys(self): 31 | return [key for key in self.__dict__.keys() if self[key] is not None] 32 | 33 | def __len__(self): 34 | return len(self.keys) 35 | 36 | def __contains__(self, key): 37 | return key in self.keys 38 | 39 | def __iter__(self): 40 | for key in sorted(self.keys): 41 | yield key, self[key] 42 | 43 | def __call__(self, *keys): 44 | for key in sorted(self.keys) if not keys else keys: 45 | if self[key] is not None: 46 | yield key, self[key] 47 | 48 | def cat_dim(self, key): 49 | return -1 if self[key].dtype == torch.long else 0 50 | 51 | @property 52 | def num_nodes(self): 53 | for key, item in self('x', 'pos'): 54 | return item.size(self.cat_dim(key)) 55 | if self.edge_index is not None: 56 | return maybe_num_nodes(self.edge_index) 57 | return None 58 | 59 | @property 60 | def num_edges(self): 61 | for key, item in self('edge_index', 'edge_attr'): 62 | return item.size(self.cat_dim(key)) 63 | return None 64 | 65 | @property 66 | def num_features(self): 67 | return 1 if self.x.dim() == 1 else self.x.size(1) 68 | 69 | @property 70 | def num_classes(self): 71 | return self.y.max().item() + 1 if self.y.dim() == 1 else self.y.size(1) 72 | 73 | def is_coalesced(self): 74 | row, col = self.edge_index 75 | index = self.num_nodes * row + col 76 | return self.row.size(0) == torch.unique(index).size(0) 77 | 78 | def apply(self, func, *keys): 79 | for key, item in self(*keys): 80 | self[key] = func(item) 81 | return self 82 | 83 | def contiguous(self, *keys): 84 | return self.apply(lambda x: x.contiguous(), *keys) 85 | 86 | def to(self, device, *keys): 87 | return self.apply(lambda x: x.to(device), *keys) 88 | 89 | def __repr__(self): 90 | info = ['{}={}'.format(key, list(item.size())) for key, item in self] 91 | return '{}({})'.format(self.__class__.__name__, ', '.join(info)) 92 | -------------------------------------------------------------------------------- /dataloader/graph_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import networkx as nx 3 | import time 4 | import numpy as np 5 | from scipy.linalg import block_diag 6 | from scipy.sparse import csr_matrix, coo_matrix 7 | 8 | 9 | class DataLoader(torch.utils.data.DataLoader): 10 | 11 | def __init__(self, dataset, dataset_name=None, batch_size=1, shuffle=True, batch_prep_func=None, **kwargs): 12 | 13 | def batch_collate(batch_graph_list): 14 | t = time.time() 15 | batch = dict() 16 | num_graphs = len(batch_graph_list) 17 | 18 | batch_sample_idx = [] 19 | batch_node_feature_matrix = [] 20 | batch_graph_labels = [] 21 | batch_edge_matrix = [] 22 | # batch_adj_mask_matrix = [] 23 | prev_num_nodes = 0 24 | 25 | for i, G in enumerate(batch_graph_list): 26 | num_nodes = G['num_nodes'] 27 | edge_matrix = G['edge_matrix'] 28 | label = G['graph_label'] 29 | node_feature_matrix = G['node_feature_matrix'] 30 | # node_feature_matrix = torch.ones(num_nodes, 1) 31 | 32 | curr_num_nodes = prev_num_nodes 33 | edge_matrix = edge_matrix + int(curr_num_nodes) 34 | prev_num_nodes = curr_num_nodes + num_nodes 35 | 36 | # mask_matrix = np.ones((num_nodes, num_nodes)) 37 | # np.fill_diagonal(mask_matrix, 0) 38 | # adj_mask_matrix = torch.LongTensor(coo_matrix(mask_matrix).nonzero()) + int(curr_num_nodes) 39 | # batch_adj_mask_matrix.append(adj_mask_matrix) 40 | 41 | batch_edge_matrix.append(edge_matrix) 42 | batch_sample_idx.append(i * torch.ones(num_nodes, dtype=torch.int64)) 43 | batch_node_feature_matrix.append(node_feature_matrix) 44 | batch_graph_labels.append(label) 45 | 46 | total_num_nodes = prev_num_nodes 47 | 48 | edge_matrix = torch.cat(batch_edge_matrix, dim=1) 49 | total_num_edges = edge_matrix.shape[-1] 50 | val = torch.ones(total_num_edges) 51 | A = torch.sparse.FloatTensor(edge_matrix, val, torch.Size([total_num_nodes, total_num_nodes])) 52 | 53 | # adj_mask_matrix = torch.cat(batch_adj_mask_matrix, dim=1) 54 | # val = torch.ones(adj_mask_matrix.shape[-1]) 55 | # A_mask = torch.sparse.FloatTensor(adj_mask_matrix, val, torch.Size([total_num_nodes, total_num_nodes])) 56 | 57 | batch_sample_idx = torch.cat(batch_sample_idx) 58 | sparse_idx = torch.stack((batch_sample_idx, torch.arange(0, int(total_num_nodes), dtype=torch.long))) 59 | val = torch.ones(total_num_nodes) 60 | batch_sample_matrix = torch.sparse.FloatTensor(sparse_idx, val, torch.Size([num_graphs, total_num_nodes])) 61 | 62 | batch['edge_matrix'] = edge_matrix 63 | batch['adjacency_matrix'] = A 64 | batch['node_feature_matrix'] = torch.cat(batch_node_feature_matrix) 65 | batch['graph_labels'] = torch.LongTensor(batch_graph_labels) 66 | batch['batch_sample_matrix'] = batch_sample_matrix 67 | batch['num_graphs'] = num_graphs 68 | batch['adjacency_mask'] = None 69 | 70 | if self.batch_prep_func is not None: 71 | batch = self.batch_prep_func(batch, batch_graph_list, dataset_name=dataset_name) 72 | 73 | batch['prep_time'] = time.time() - t 74 | return batch 75 | 76 | super(DataLoader, self).__init__(dataset, batch_size, shuffle, collate_fn=batch_collate, **kwargs) 77 | self.batch_prep_func = batch_prep_func 78 | 79 | 80 | -------------------------------------------------------------------------------- /dataloader/graph_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import collections 4 | import tarfile 5 | import zipfile 6 | import torch.utils.data 7 | from dataloader.read_graph_datasets import read_graph_data 8 | import random 9 | import numpy as np 10 | from six.moves import urllib 11 | import errno 12 | from utils.compute_wl_kernel import compute_reduce_wl_kernel, compute_full_wl_kernel 13 | # from utils.compute_fgsd_features import compute_reduce_fgsd_features 14 | import time 15 | import logging 16 | 17 | 18 | def extract_tar(path, folder, mode='r:gz', log=True): 19 | maybe_log(path, log) 20 | with tarfile.open(path, mode) as f: 21 | f.extractall(folder) 22 | 23 | 24 | def extract_zip(path, folder, log=True): 25 | maybe_log(path, log) 26 | with zipfile.ZipFile(path, 'r') as f: 27 | f.extractall(folder) 28 | 29 | 30 | def makedirs(path): 31 | try: 32 | os.makedirs(osp.expanduser(osp.normpath(path))) 33 | except OSError as e: 34 | if e.errno != errno.EEXIST and osp.isdir(path): 35 | raise e 36 | 37 | 38 | def download_url(url, folder, log=True): 39 | if log: 40 | print('Downloading', url) 41 | 42 | makedirs(folder) 43 | 44 | data = urllib.request.urlopen(url) 45 | filename = url.rpartition('/')[2] 46 | path = osp.join(folder, filename) 47 | 48 | with open(path, 'wb') as f: 49 | f.write(data.read()) 50 | 51 | return path 52 | 53 | 54 | def to_list(x): 55 | if not isinstance(x, collections.Iterable) or isinstance(x, str): 56 | x = [x] 57 | return x 58 | 59 | 60 | def files_exist(files): 61 | return all([osp.exists(f) for f in files]) 62 | 63 | 64 | class TUDataset(torch.utils.data.Dataset): 65 | url = 'https://ls11-www.cs.uni-dortmund.de/people/morris/graphkerneldatasets' 66 | 67 | def __init__(self, data_path, name, shuffle=True, compute_graph_kernel_features=False, wl_num_iter=5, wl_node_labels='degree', compute_fgsd_features=False): 68 | 69 | self.name = name 70 | self.compute_graph_kernel_features = compute_graph_kernel_features 71 | self.wl_num_iter = wl_num_iter 72 | self.wl_node_labels = wl_node_labels 73 | 74 | self.root = osp.expanduser(osp.normpath(data_path)) 75 | self.raw_dir = osp.join(self.root, 'raw') 76 | self.processed_dir = osp.join(self.root, 'processed') 77 | if self.name != 'ALL' and self.name != 'QM8' and self.name != 'NCI_FULL': 78 | self._download() 79 | self._process() 80 | logging.info('Loading save graph list from: ' + os.path.abspath(self.processed_paths[0])) 81 | self.graph_list = torch.load(self.processed_paths[0]) 82 | if shuffle: 83 | logging.info('Shuffling the dataset...') 84 | random.shuffle(self.graph_list) 85 | if self.name == 'QM8': 86 | self.num_graph_labels = len(self.graph_list[0]['graph_label']) 87 | elif self.name != 'ALL': 88 | self.num_graph_labels = max(G['graph_label'].item() for G in self.graph_list) + 1 89 | else: 90 | self.num_graph_labels = None 91 | 92 | def __getitem__(self, index): 93 | # return self.graph_list[index] 94 | return [self.graph_list[i] for i in index] 95 | 96 | def __len__(self): 97 | return len(self.graph_list) 98 | 99 | def raw_file_names(self): 100 | names = ['A', 'graph_indicator'] 101 | return ['{}_{}.txt'.format(self.name, name) for name in names] 102 | 103 | @property 104 | def raw_paths(self): 105 | files = to_list(self.raw_file_names()) 106 | return [osp.join(self.raw_dir, f) for f in files] 107 | 108 | @property 109 | def processed_paths(self): 110 | files = to_list(self.processed_file_names) 111 | return [osp.join(self.processed_dir, f) for f in files] 112 | 113 | @property 114 | def processed_file_names(self): 115 | return 'data.pt' 116 | 117 | def _download(self): 118 | if files_exist(self.raw_paths): 119 | return 120 | makedirs(self.raw_dir) 121 | self.download() 122 | 123 | @property 124 | def num_features(self): 125 | return self.graph_list[0]['node_feature_matrix'].shape[-1] 126 | 127 | @property 128 | def wl_kernel_feature_dim(self): 129 | return self.graph_list[0]['WL_kernel_features'].shape[-1] 130 | 131 | @property 132 | def num_classes(self): 133 | return self.num_graph_labels 134 | 135 | def _process(self): 136 | if files_exist(self.processed_paths): 137 | return 138 | print('Processing...') 139 | makedirs(self.processed_dir) 140 | self.process() 141 | print('Done!') 142 | 143 | def download(self): 144 | path = download_url('{}/{}.zip'.format(self.url, self.name), self.root) 145 | extract_zip(path, self.root) 146 | os.unlink(path) 147 | os.rename(osp.join(self.root, self.name), self.raw_dir) 148 | 149 | def process(self): 150 | if self.name == 'ALL': 151 | if os.path.exists('data/ALL/processed/graph_list.pt'): 152 | self.graph_list = torch.load('data/ALL/processed/graph_list.pt') 153 | else: 154 | self.graph_list = [] 155 | for i in range(len(graph_dataset_names)): 156 | logging.info('Reading graph dataset: ' + graph_dataset_names[i]) 157 | dataset_graph_list, _ = read_graph_data('data/' + graph_dataset_names[i], graph_dataset_names[i]) 158 | self.graph_list = self.graph_list + dataset_graph_list 159 | logging.info('Done') 160 | self.num_graph_labels = None 161 | torch.save(self.graph_list, 'data/ALL/processed/graph_list.pt') 162 | elif self.name == 'QM8': 163 | from dataloader.read_qm8_dataset import read_qm8_data 164 | self.graph_list, self.num_graph_labels = read_qm8_data(self.raw_dir, self.name) 165 | elif self.name == 'NCI_FULL': 166 | from dataloader.read_nci_full_dataset import read_nci_full_data 167 | self.graph_list, self.num_graph_labels = read_nci_full_data(self.raw_dir, self.name) 168 | else: 169 | logging.info('Reading graph dataset from: ' + os.path.abspath(self.raw_dir)) 170 | self.graph_list, self.num_graph_labels = read_graph_data(self.raw_dir, self.name) 171 | 172 | if self.compute_graph_kernel_features: 173 | feature_matrix = compute_full_wl_kernel(self.graph_list, num_iter=self.wl_num_iter, type_node_labels=self.wl_node_labels) 174 | for i in range(len(self.graph_list)): 175 | if (i + 1) % 1000 == 0: 176 | logging.info('wl features loaded in num graphs so far: ' + str(i + 1)) 177 | self.graph_list[i]['WL_kernel_features'] = feature_matrix[i] 178 | 179 | logging.info('Saving the process data at: ' + os.path.abspath(self.processed_paths[0])) 180 | t = time.time() 181 | torch.save(self.graph_list, self.processed_paths[0]) 182 | logging.info('Time Taken: ' + str(time.time() - t)) 183 | 184 | def __repr__(self): 185 | return '{}({})'.format(self.name, len(self)) 186 | -------------------------------------------------------------------------------- /dataloader/read_graph_datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import torch 4 | import numpy as np 5 | from dataloader.data import Data 6 | from sklearn import preprocessing 7 | import networkx as nx 8 | import os 9 | import scipy.io 10 | from scipy import sparse 11 | 12 | 13 | names = [ 14 | 'A', 'graph_indicator', 'node_labels', 'node_attributes' 15 | 'edge_labels', 'edge_attributes', 'graph_labels', 'graph_attributes' 16 | ] 17 | 18 | 19 | def read_graph_data(folder, prefix): 20 | 21 | edge_list_file = folder + '/' + prefix + '_A.txt' 22 | graph_ind_file = folder + '/' + prefix + '_graph_indicator.txt' 23 | graph_labels_file = folder + '/' + prefix + '_graph_labels.txt' 24 | node_label_file = folder + '/' + prefix + '_node_labels.txt' 25 | fgsd_mat_file = os.path.dirname(folder) + '/' + prefix + '_X_fgsd.mat' 26 | 27 | df_edge_list = pd.read_csv(edge_list_file, header=None) 28 | df_graph_ind = pd.read_csv(graph_ind_file, header=None) 29 | df_graph_labels = pd.read_csv(graph_labels_file, header=None) 30 | 31 | edge_index_array = np.array(df_edge_list, dtype=np.int64) 32 | graph_labels = np.array(df_graph_labels, dtype=np.int64).reshape(-1) 33 | 34 | graph_ind = np.array(df_graph_ind, dtype=np.int64).reshape(-1) 35 | graph_ind_encoder = preprocessing.LabelEncoder() 36 | graph_ind = graph_ind_encoder.fit_transform(graph_ind) 37 | graph_ind_vals, graph_ind_start_idx = np.unique(graph_ind, return_index=True) 38 | num_graphs = len(graph_ind_vals) 39 | big_num_edges = edge_index_array.shape[0] 40 | 41 | if os.path.exists(node_label_file): 42 | df_node_labels = pd.read_csv(node_label_file, header=None) 43 | node_labels = np.array(df_node_labels, dtype=np.int64).reshape(-1) 44 | node_label_encoder = preprocessing.LabelEncoder() 45 | node_labels = node_label_encoder.fit_transform(node_labels) 46 | else: 47 | node_labels = np.zeros(len(graph_ind), dtype=np.int64) 48 | 49 | graph_label_encoder = preprocessing.LabelEncoder() 50 | graph_labels = graph_label_encoder.fit_transform(graph_labels) 51 | num_classes = graph_labels.max() + 1 52 | 53 | load_fgsd_features = False 54 | if os.path.exists(fgsd_mat_file): 55 | mat = scipy.io.loadmat(fgsd_mat_file) 56 | fgsd_feature_matrix = sparse.csr_matrix(mat['X'], dtype=np.float64) 57 | preprocessing.normalize(fgsd_feature_matrix, norm='l2', axis=1, copy=False, return_norm=False) 58 | assert (fgsd_feature_matrix.shape[0] == num_graphs) 59 | load_fgsd_features = True 60 | 61 | assert (graph_ind[-1] + 1 == num_graphs) 62 | assert (len(graph_labels) == num_graphs) 63 | assert (len(graph_ind) == len(node_labels)) 64 | assert (node_labels.min() == 0) 65 | assert (node_labels.max() + 1 == len(np.unique(node_labels))) 66 | 67 | onehot_encoder = preprocessing.OneHotEncoder(sparse=False) 68 | big_node_feature_matrix = onehot_encoder.fit_transform(node_labels.reshape(-1, 1)) 69 | big_node_feature_matrix = torch.FloatTensor(big_node_feature_matrix) 70 | 71 | graph_ind_start_idx = np.concatenate((graph_ind_start_idx, [len(graph_ind)])) 72 | graph_list = [] 73 | prev_node_start_id = 1 74 | prev_edge_end_idx = 0 75 | for i in range(num_graphs): 76 | if (i+1) % 1000 == 0: 77 | print('num graphs processed so far: ', i) 78 | 79 | curr_num_nodes = graph_ind_start_idx[i+1] - graph_ind_start_idx[i] 80 | node_feature_matrix = big_node_feature_matrix[graph_ind_start_idx[i]: graph_ind_start_idx[i+1]] 81 | graph_node_labels = node_labels[graph_ind_start_idx[i]: graph_ind_start_idx[i+1]] 82 | 83 | if i == num_graphs-1: 84 | curr_edge_end_idx = big_num_edges 85 | else: 86 | curr_edge_end_idx = np.argmax(edge_index_array[:, 0] > graph_ind_start_idx[i + 1]) 87 | curr_edge_list = edge_index_array[prev_edge_end_idx: curr_edge_end_idx] 88 | if curr_edge_list.size == 0: 89 | prev_node_start_id = prev_node_start_id + curr_num_nodes 90 | continue 91 | curr_edge_list = curr_edge_list - prev_node_start_id 92 | 93 | assert(curr_edge_list.min() >= 0) 94 | assert(curr_edge_list.max() <= curr_num_nodes - 1) 95 | 96 | prev_edge_end_idx = curr_edge_end_idx 97 | prev_node_start_id = graph_ind_start_idx[i+1] + 1 98 | 99 | G_nx = nx.Graph() 100 | G_nx.add_edges_from(curr_edge_list) 101 | G_nx.remove_edges_from(G_nx.selfloop_edges()) 102 | num_edges = G_nx.number_of_edges() 103 | 104 | G = dict() 105 | edge_matrix = np.array(list(G_nx.edges())) 106 | row = torch.LongTensor(np.concatenate((edge_matrix[:, 0], edge_matrix[:, 1]))) 107 | col = torch.LongTensor(np.concatenate((edge_matrix[:, 1], edge_matrix[:, 0]))) 108 | row_col_idx = torch.zeros((2, 2*num_edges), dtype=torch.int64) 109 | row_col_idx[0] = row 110 | row_col_idx[1] = col 111 | 112 | node_degree = [] 113 | for node_id in range(curr_num_nodes): 114 | if G_nx.has_node(node_id): 115 | node_degree.append(G_nx.degree[node_id]) 116 | else: 117 | node_degree.append(0) 118 | 119 | G['graph_id'] = i 120 | G['edge_matrix'] = row_col_idx 121 | G['node_feature_matrix'] = node_feature_matrix 122 | G['node_labels'] = graph_node_labels 123 | G['num_nodes'] = curr_num_nodes 124 | G['graph_label'] = graph_labels[i] 125 | G['node_degree_vec'] = np.array(node_degree) 126 | G['adj_matrix'] = nx.adjacency_matrix(G_nx, nodelist=range(0, G['num_nodes'])) 127 | if load_fgsd_features: 128 | G['fgsd_features'] = fgsd_feature_matrix[i] 129 | 130 | assert(G['node_degree_vec'].shape[0] == G['num_nodes']) 131 | assert(G['adj_matrix'].shape[0] == G['num_nodes']) 132 | assert(G['node_feature_matrix'].shape[0] == G['num_nodes']) 133 | assert(G['node_feature_matrix'].shape[1] == node_labels.max() + 1) 134 | assert(G_nx.number_of_selfloops() == 0) 135 | 136 | graph_list.append(G) 137 | 138 | return graph_list, num_classes 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /dataloader/read_qm8_dataset.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from scipy import sparse 3 | import pandas as pd 4 | import os 5 | import scipy.io 6 | import torch 7 | import numpy as np 8 | from sklearn import preprocessing 9 | from utils.read_sdf_dataset import read_from_sdf 10 | from rdkit import Chem 11 | from operator import itemgetter 12 | 13 | 14 | atom_list = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown'] 15 | 16 | 17 | def one_of_k_encoding(x, allowable_set): 18 | if x not in allowable_set: 19 | raise Exception("input {0} not in allowable set{1}:".format( 20 | x, allowable_set)) 21 | return list(map(lambda s: x == s, allowable_set)) 22 | 23 | 24 | def one_of_k_encoding_unk(x, allowable_set): 25 | """Maps inputs not in the allowable set to the last element.""" 26 | if x not in allowable_set: 27 | x = allowable_set[-1] 28 | return list(map(lambda s: x == s, allowable_set)) 29 | 30 | 31 | def atom_features(atom): 32 | 33 | results = one_of_k_encoding_unk(atom.GetSymbol(), atom_list)\ 34 | + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) \ 35 | + one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) \ 36 | + [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] \ 37 | + one_of_k_encoding_unk(atom.GetHybridization(), [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]) \ 38 | + [atom.GetIsAromatic()] 39 | 40 | return np.array(results) 41 | 42 | 43 | def read_qm8_data(folder, prefix): 44 | 45 | sdf_file = folder + '/' + prefix + '.sdf' 46 | csv_file = folder + '/' + prefix + '.csv' 47 | graph_list = read_from_sdf(sdf_file) 48 | num_graphs = len(graph_list) 49 | df = pd.read_csv(csv_file) 50 | Y = np.array(df) 51 | graph_labels = Y[:, 1:] 52 | assert (Y.shape[0] == num_graphs) 53 | num_targets = graph_labels.shape[-1] 54 | 55 | X = [G['node_labels'] for G in graph_list] 56 | unique_node_labels = list(set(x for l in X for x in l)) 57 | num_unique_node_labels = len(unique_node_labels) 58 | label_encoder = preprocessing.LabelEncoder() 59 | label_encoder.fit(unique_node_labels) 60 | onehot_encoder = preprocessing.OneHotEncoder(n_values=num_unique_node_labels, sparse=False) 61 | onehot_vec = np.array(range(0, num_unique_node_labels)) 62 | onehot_encoder.fit(onehot_vec.reshape(-1, 1)) 63 | 64 | process_graph_list = [] 65 | for i in range(num_graphs): 66 | if (i + 1) % 1000 == 0: 67 | print('num graphs processed so far: ', i) 68 | 69 | curr_edge_list = graph_list[i]['edge_list'] 70 | node_labels = graph_list[i]['node_labels'] 71 | curr_num_nodes = len(node_labels) 72 | assert(curr_num_nodes == graph_list[i]['num_nodes']) 73 | assert(len(curr_edge_list) == graph_list[i]['num_edges']) 74 | 75 | node_labels = label_encoder.transform(node_labels) 76 | node_feature_matrix = onehot_encoder.transform(node_labels.reshape(-1, 1)) 77 | 78 | G_nx = nx.Graph() 79 | G_nx.add_nodes_from(np.array(range(0, curr_num_nodes))) 80 | G_nx.add_edges_from(curr_edge_list) 81 | G_nx.remove_edges_from(G_nx.selfloop_edges()) 82 | num_edges = G_nx.number_of_edges() 83 | nx_curr_num_nodes = G_nx.number_of_nodes() 84 | assert(curr_num_nodes == nx_curr_num_nodes) 85 | 86 | G = dict() 87 | edge_matrix = np.array(list(G_nx.edges())) 88 | row = torch.LongTensor(np.concatenate((edge_matrix[:, 0], edge_matrix[:, 1]))) 89 | col = torch.LongTensor(np.concatenate((edge_matrix[:, 1], edge_matrix[:, 0]))) 90 | row_col_idx = torch.zeros((2, 2 * num_edges), dtype=torch.int64) 91 | row_col_idx[0] = row 92 | row_col_idx[1] = col 93 | 94 | node_degree = [] 95 | for node_id in range(curr_num_nodes): 96 | if G_nx.has_node(node_id): 97 | node_degree.append(G_nx.degree[node_id]) 98 | else: 99 | node_degree.append(0) 100 | 101 | G['graph_id'] = i 102 | G['edge_matrix'] = row_col_idx 103 | G['node_feature_matrix'] = torch.FloatTensor(node_feature_matrix) 104 | G['num_nodes'] = curr_num_nodes 105 | G['graph_label'] = graph_labels[i] 106 | G['node_degree_vec'] = np.array(node_degree) 107 | G['adj_matrix'] = nx.adjacency_matrix(G_nx, nodelist=range(0, G['num_nodes'])) 108 | G['node_labels'] = np.array(node_labels) 109 | 110 | assert (G['node_degree_vec'].shape[0] == G['num_nodes']) 111 | assert (G['adj_matrix'].shape[0] == G['num_nodes']) 112 | assert (G['node_feature_matrix'].shape[0] == G['num_nodes']) 113 | assert (G['node_feature_matrix'].shape[1] == num_unique_node_labels) 114 | assert (G_nx.number_of_selfloops() == 0) 115 | 116 | process_graph_list.append(G) 117 | 118 | return process_graph_list, num_targets 119 | -------------------------------------------------------------------------------- /hyperopt/run_train_hyperopt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | from itertools import islice 5 | import shutil 6 | import errno 7 | sys.path.extend('..') 8 | from config.params_universal_graph_embedding_model import params_list, hyperparams_grid 9 | import uuid 10 | import argparse 11 | 12 | 13 | def contruct_py_cmd_args(params_dict): 14 | cmd_args = '' 15 | for key, value in params_dict.items(): 16 | cmd_args = cmd_args + '--' + str(key) + ' ' + str(value) + ' ' 17 | return cmd_args 18 | 19 | 20 | if __name__ == '__main__': 21 | 22 | experiment_id = str(uuid.uuid4()) 23 | dest_path = os.path.join('./hyperopt_runs/', experiment_id) 24 | if not os.path.exists(dest_path): 25 | os.makedirs(dest_path) 26 | 27 | print('Copying files...') 28 | src_list = ['./train', './utils', './torch_dgl', './dataloader', './config'] 29 | dest_list = [os.path.join(dest_path, 'train'), 30 | os.path.join(dest_path, 'utils'), 31 | os.path.join(dest_path, 'torch_dgl'), 32 | os.path.join(dest_path, 'dataloader'), 33 | os.path.join(dest_path, 'config')] 34 | 35 | for dataset in hyperparams_grid['dataset_name']: 36 | src_list = src_list + ['./data/' + dataset] 37 | dest_list = dest_list + [dest_path + '/data/' + dataset] 38 | src_list = src_list + ['./data/prime_numbers_list_v2.npy'] 39 | dest_list = dest_list + [dest_path + '/data/prime_numbers_list_v2.npy'] 40 | 41 | for src, dest in zip(src_list, dest_list): 42 | if os.path.isdir(dest): 43 | shutil.rmtree(dest) 44 | try: 45 | shutil.copytree(src, dest) 46 | except NotADirectoryError: 47 | shutil.copy(src, dest) 48 | 49 | os.chdir(dest_path) 50 | print('Done...') 51 | 52 | parser = argparse.ArgumentParser(description='Model Arguments') 53 | parser.add_argument('--train_filename', type=str, default='train_universal_graph_embedding.py') 54 | parser.add_argument('--max_workers', type=int, default=2) 55 | args, unknown = parser.parse_known_args() 56 | 57 | for arg, value in sorted(vars(args).items()): 58 | print("Hyperparameter: %s: %r", arg, value) 59 | 60 | train_filename = args.train_filename 61 | commands = [] 62 | for i in range(len(params_list)): 63 | py_cmd_args = contruct_py_cmd_args(params_list[i]) 64 | run_cmd = 'python' + ' ' + os.path.join('train/', train_filename) + ' ' + py_cmd_args # + ' &' 65 | commands.append(run_cmd) 66 | 67 | max_workers = args.max_workers 68 | processes = (subprocess.Popen(cmd, shell=True) for cmd in commands) 69 | running_processes = list(islice(processes, max_workers)) # start new processes 70 | while running_processes: 71 | for i, process in enumerate(running_processes): 72 | if process.poll() is not None: # the process has finished 73 | running_processes[i] = next(processes, None) # start new process 74 | if running_processes[i] is None: # no new processes 75 | del running_processes[i] 76 | break 77 | -------------------------------------------------------------------------------- /torch_dgl/layers/graph_capsule_layer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class GraphCapsuleConv(nn.Module): 8 | def __init__(self, input_dim, hidden_dim, num_gfc_layers=2, num_stats_in=1, num_stats_out=1): 9 | super(GINConv, self).__init__() 10 | 11 | self.input_dim = input_dim 12 | self.hidden_dim = hidden_dim 13 | self.num_stats_in = num_stats_in 14 | self.num_stats_out = num_stats_out 15 | self.num_gfc_layers = num_gfc_layers 16 | 17 | self.stat_layers = nn.ModuleList() 18 | for _ in range(self.num_stats_out): 19 | gfc_layers = nn.ModuleList() 20 | curr_input_dim = input_dim * num_stats_in 21 | for _ in range(self.num_gfc_layers): 22 | gfc_layers.append(nn.Linear(curr_input_dim, hidden_dim)) 23 | curr_input_dim = hidden_dim 24 | self.stat_layers.append(gfc_layers) 25 | 26 | def forward(self, x_in, A): 27 | 28 | x = x_in 29 | output = [] 30 | for i in range(self.num_stats_out): 31 | out = torch.spmm(A, x) + x 32 | for j in range(self.num_gfc_layers): 33 | out = self.stat_layers[i][j](out) 34 | out = F.selu(out) 35 | output.append(out) 36 | x = torch.mul(x, x_in) 37 | 38 | output = torch.cat(output, 1) 39 | return output 40 | -------------------------------------------------------------------------------- /torch_dgl/layers/ugd_input_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Sequential, Linear, ReLU 4 | import torch.nn.functional as F 5 | from config.params_universal_graph_embedding_model import B, N, E, D, H, C, K 6 | 7 | 8 | class InputTransformer(nn.Module): 9 | def __init__(self, input_dim, hidden_dim, num_input_layers=1, drop_prob=0.0): 10 | super(InputTransformer, self).__init__() 11 | 12 | self.input_dim = input_dim 13 | self.drop_prob = drop_prob 14 | self.hidden_dim = hidden_dim 15 | self.num_input_layers = num_input_layers 16 | self.output_dim = hidden_dim 17 | 18 | self.input_layers = nn.ModuleList() 19 | curr_input_dim = input_dim 20 | for i in range(self.num_input_layers): 21 | self.input_layers.append(nn.Linear(curr_input_dim, hidden_dim)) 22 | curr_input_dim = hidden_dim 23 | 24 | def forward(self, x: (N, D), A: (N, N)): 25 | 26 | out = torch.spmm(A, x) 27 | out = out + x 28 | for i in range(self.num_input_layers): 29 | out = self.input_layers[i](out) 30 | out = F.selu(out) 31 | return out 32 | -------------------------------------------------------------------------------- /torch_dgl/models/model_universal_graph_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Sequential, Linear, ReLU 4 | import torch.nn.functional as F 5 | from torch_dgl.models.universal_graph_encoder import GraphEncoder 6 | from config.params_universal_graph_embedding_model import B, N, E, D, H, C, K 7 | from torch_dgl.layers.ugd_input_layer import InputTransformer 8 | 9 | 10 | class UniversalGraphEmbedder(nn.Module): 11 | def __init__(self, input_dim, num_classes, hidden_dim, num_gconv_layers, num_gfc_layers, drop_prob, use_fgsd_features=False, use_spk_features=False): 12 | super(UniversalGraphEmbedder, self).__init__() 13 | 14 | self.num_classes = num_classes 15 | self.drop_prob = drop_prob 16 | self.num_encoder_layers = num_gconv_layers 17 | self.hidden_dim = hidden_dim 18 | self.num_gfc_layers = num_gfc_layers 19 | self.use_fgsd_features = use_fgsd_features 20 | self.use_spk_features = use_spk_features 21 | self.drop_prob_kernel = 0 22 | 23 | self.input_transform = InputTransformer(input_dim, hidden_dim, num_input_layers=1) 24 | self.graph_encoder = GraphEncoder(hidden_dim, hidden_dim, num_gconv_layers, num_gfc_layers, drop_prob) 25 | self.output_dim = hidden_dim + num_gconv_layers * hidden_dim 26 | 27 | self.num_fc_layers = 1 28 | self.fc_layers = nn.ModuleList() 29 | curr_input_dim = self.output_dim 30 | for i in range(self.num_fc_layers): 31 | self.fc_layers.append(Linear(curr_input_dim, hidden_dim)) 32 | curr_input_dim = hidden_dim 33 | 34 | self.class_fc_layer = Linear(hidden_dim, num_classes) 35 | 36 | self.adj_linear_layer = Linear(self.output_dim, self.output_dim) 37 | 38 | self.kernel_linear_layer = Linear(self.output_dim, self.output_dim) 39 | if use_fgsd_features: 40 | self.fgsd_linear_layer = Linear(self.output_dim, self.output_dim) 41 | 42 | if use_spk_features: 43 | self.spk_linear_layer = Linear(self.output_dim, self.output_dim) 44 | 45 | def forward(self, x: (N, D), A: (N, N), A_mask: (N, N), batch_sample_matrix: (B, N)): 46 | 47 | x: (N, H) = self.input_transform(x, A) 48 | x: (N, H) = self.graph_encoder(x, A) 49 | x_emb: (B, H) = torch.spmm(batch_sample_matrix, x) 50 | 51 | # A_pred: (N, N) = torch.matmul(self.adj_linear_layer(x), x.t()) 52 | # A_pred: (N, N) = torch.matmul(x, x.t()) 53 | # A_pred: (N, N) = torch.mul(A_pred, A_mask.to_dense()) 54 | A_pred = None 55 | 56 | x_curr = x_emb 57 | for i in range(self.num_fc_layers): 58 | x_curr: (B, H) = F.selu(self.fc_layers[i](x_curr)) 59 | x_curr: (B, H) = F.dropout(x_curr, p=self.drop_prob, training=self.training) 60 | 61 | x: (B, C) = self.class_fc_layer(x_curr) 62 | class_logits: (B, C) = F.log_softmax(x, dim=-1) 63 | 64 | # kernel_matrix_pred = torch.matmul(self.kernel_linear_layer(x_emb), x_emb.t()) 65 | # kernel_matrix_pred = F.dropout(kernel_matrix_pred, p=self.drop_prob_kernel, training=self.training) 66 | # kernel_matrix_pred = torch.sigmoid(kernel_matrix_pred) 67 | 68 | kernel_matrix_pred = torch.matmul(x_emb, x_emb.t()) 69 | kernel_matrix_pred = F.dropout(kernel_matrix_pred, p=self.drop_prob_kernel, training=self.training) 70 | kernel_matrix_pred = torch.sigmoid(kernel_matrix_pred) 71 | 72 | # x: (B, H) = F.normalize(x_emb, p=2, dim=1) 73 | # kernel_matrix_pred: (B, B) = torch.matmul(x, x.t()) 74 | 75 | fgsd_kernel_matrix_pred = None 76 | if self.use_fgsd_features: 77 | fgsd_kernel_matrix_pred = torch.matmul(self.fgsd_linear_layer(x_emb), x_emb.t()) 78 | fgsd_kernel_matrix_pred = F.dropout(fgsd_kernel_matrix_pred, p=self.drop_prob_kernel, training=self.training) 79 | fgsd_kernel_matrix_pred = torch.sigmoid(fgsd_kernel_matrix_pred) 80 | 81 | spk_kernel_matrix_pred = None 82 | if self.use_spk_features: 83 | spk_kernel_matrix_pred = torch.matmul(self.spk_linear_layer(x_emb), x_emb.t()) 84 | spk_kernel_matrix_pred = F.dropout(spk_kernel_matrix_pred, p=self.drop_prob_kernel, training=self.training) 85 | spk_kernel_matrix_pred = torch.sigmoid(spk_kernel_matrix_pred) 86 | 87 | return class_logits, kernel_matrix_pred, fgsd_kernel_matrix_pred, spk_kernel_matrix_pred, x_emb, A_pred 88 | -------------------------------------------------------------------------------- /torch_dgl/models/universal_graph_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Sequential, Linear, ReLU 4 | from torch_dgl.layers.graph_capsule_layer import GraphCapsuleConv 5 | import torch.nn.functional as F 6 | from config.params_universal_graph_embedding_model import B, N, E, D, H, C 7 | 8 | 9 | class GraphEncoder(nn.Module): 10 | def __init__(self, input_dim, hidden_dim, num_gconv_layers, num_gfc_layers, drop_prob): 11 | super(GraphEncoder, self).__init__() 12 | 13 | self.input_dim = input_dim 14 | self.hidden_dim = hidden_dim 15 | self.drop_prob = drop_prob 16 | self.num_gconv_layers = num_gconv_layers 17 | self.num_gfc_layers = num_gfc_layers 18 | 19 | self.gconv_layers = nn.ModuleList() 20 | self.batchnorm_layers = nn.ModuleList() 21 | curr_input_dim = input_dim 22 | for i in range(self.num_gconv_layers): 23 | self.gconv_layers.append(GraphCapsuleConv(curr_input_dim, hidden_dim, num_gfc_layers=num_gfc_layers)) 24 | self.batchnorm_layers.append(torch.nn.BatchNorm1d(hidden_dim)) 25 | curr_input_dim = curr_input_dim + hidden_dim 26 | 27 | def forward(self, x: (N, D), edge_index: (2, E)): 28 | 29 | x_prev = x 30 | for i in range(self.num_gconv_layers): 31 | x_curr: (N, H) = F.selu(self.gconv_layers[i](x_prev, edge_index)) 32 | x_curr: (N, H) = self.batchnorm_layers[i](x_curr) 33 | x_prev = torch.cat((x_prev, x_curr), dim=-1) 34 | return x_prev 35 | -------------------------------------------------------------------------------- /train/train_ugd.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | sys.path.extend('..') 5 | import numpy as np 6 | import time 7 | import random 8 | import argparse 9 | from comet_ml import Experiment 10 | # import matlab.engine 11 | import torch 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torch import nn 15 | from dataloader.graph_datasets import TUDataset 16 | from dataloader.graph_dataloader import DataLoader 17 | from utils.utils import save_model, load_model, NoamLR, load_partial_model_state 18 | from torch_dgl.models.model_universal_graph_embedding import UniversalGraphEmbedder 19 | from config.utils_comet import API_KEY 20 | from sklearn.model_selection import StratifiedKFold, cross_val_score 21 | from sklearn.metrics import classification_report, precision_recall_fscore_support, roc_auc_score, precision_score, recall_score, accuracy_score 22 | from scipy import sparse 23 | from sklearn import random_projection, svm 24 | from tensorboardX import SummaryWriter 25 | import shutil 26 | from sklearn.preprocessing import normalize 27 | 28 | 29 | def batch_prep_input(batch, batch_graph_list, dataset_name=None): 30 | 31 | feature_matrix = [] 32 | idx_graphid = [] 33 | for i, G in enumerate(batch_graph_list): 34 | feature_matrix.append(G['WL_kernel_features']) 35 | idx_graphid.append(G['graph_id']) 36 | 37 | idx_graphid = np.array(idx_graphid, dtype=np.int64) 38 | batch_feature_matrix = sparse.vstack(feature_matrix) 39 | batch_kernel_matrix = batch_feature_matrix.dot(batch_feature_matrix.transpose()) 40 | batch_kernel_matrix = np.array(batch_kernel_matrix.todense()) 41 | batch['wl_kernel_matrix'] = torch.FloatTensor(batch_kernel_matrix) 42 | 43 | if args.use_fgsd_features: 44 | 45 | batch_fgsd_kernel_matrix = fgsd_kernel[idx_graphid] 46 | batch_fgsd_kernel_matrix = batch_fgsd_kernel_matrix[:, idx_graphid] 47 | batch['fgsd_kernel_matrix'] = torch.FloatTensor(batch_fgsd_kernel_matrix) 48 | 49 | if args.use_spk_features: 50 | batch_spk_kernel_matrix = spk_kernel[idx_graphid] 51 | batch_spk_kernel_matrix = batch_spk_kernel_matrix[:, idx_graphid] 52 | batch['spk_kernel_matrix'] = torch.FloatTensor(batch_spk_kernel_matrix) 53 | 54 | graph_labels = batch['graph_labels'].data.cpu().numpy() 55 | class_kernel_matrix = np.equal(graph_labels, graph_labels[:, np.newaxis]).astype(int) 56 | batch['class_kernel_matrix'] = torch.FloatTensor(class_kernel_matrix) 57 | 58 | return batch 59 | 60 | 61 | def eval_model(eval_loader): 62 | 63 | model.eval() 64 | num_samples = len(eval_loader.dataset) 65 | total_loss = 0 66 | total_class_loss = 0 67 | total_kernel_loss = 0 68 | total_fgsd_kernel_loss = 0 69 | total_spk_kernel_loss = 0 70 | total_adj_reconst_loss = 0 71 | correct = 0 72 | num_batches = 0 73 | X_eval = [] 74 | Y_eval = [] 75 | 76 | with torch.no_grad(): 77 | for idx_batch, batch in enumerate(eval_loader): 78 | 79 | X = batch['node_feature_matrix'].to(device) 80 | A = batch['adjacency_matrix'].to(device) 81 | # A_mask = batch['adjacency_mask'].to(device) 82 | A_mask = None 83 | batch_sample_matrix = batch['batch_sample_matrix'].to(device) 84 | graph_labels = batch['graph_labels'].to(device) 85 | kernel_true = batch['wl_kernel_matrix'].to(device) 86 | 87 | class_kernel_true = None 88 | spk_kernel_true = None 89 | fgsd_kernel_true = None 90 | if args.use_adaptive_kernel_loss: 91 | class_kernel_true = batch['class_kernel_matrix'].to(device) 92 | if args.use_spk_features: 93 | spk_kernel_true = batch['spk_kernel_matrix'].to(device) 94 | if args.use_fgsd_features: 95 | fgsd_kernel_true = batch['fgsd_kernel_matrix'].to(device) 96 | 97 | logits, kernel_pred, fgsd_kernel_pred, spk_kernel_pred, graph_emb, A_pred = model(X, A, A_mask, batch_sample_matrix) 98 | loss, loss_class, loss_kernel, loss_fgsd_kernel, loss_spk_kernel, loss_adj_reconst = compute_loss(logits, graph_labels, kernel_pred, kernel_true, fgsd_kernel_pred, fgsd_kernel_true, spk_kernel_pred, spk_kernel_true, class_kernel_true, A_pred, A) 99 | 100 | pred = logits.max(dim=1)[1] 101 | correct += pred.eq(graph_labels).sum().item() 102 | 103 | total_loss += loss.item() 104 | total_class_loss += loss_class.item() 105 | total_kernel_loss += loss_kernel.item() 106 | total_fgsd_kernel_loss += loss_fgsd_kernel.item() 107 | total_spk_kernel_loss += loss_spk_kernel.item() 108 | total_adj_reconst_loss += loss_adj_reconst.item() 109 | num_batches = num_batches + 1 110 | X_eval.append(graph_emb.data.cpu().numpy()) 111 | Y_eval.append(graph_labels.data.cpu().numpy()) 112 | 113 | acc_val = correct / num_samples 114 | loss_per_sample = total_loss / num_batches 115 | loss_class_per_sample = total_class_loss / num_batches 116 | loss_kernel_per_sample = total_kernel_loss / num_batches 117 | loss_fgsd_kernel_per_sample = total_fgsd_kernel_loss / num_batches 118 | loss_spk_kernel_per_sample = total_spk_kernel_loss / num_batches 119 | loss_adj_reconst_per_sample = total_adj_reconst_loss / num_batches 120 | 121 | return loss_per_sample, loss_class_per_sample, loss_kernel_per_sample, loss_fgsd_kernel_per_sample, loss_spk_kernel_per_sample, loss_adj_reconst_per_sample, acc_val, X_eval, Y_eval 122 | 123 | 124 | def compute_loss(logits, labels, kernel_pred, kernel_true, fgsd_kernel_pred=None, fgsd_kernel_true=None, spk_kernel_pred=None, spk_kernel_true=None, class_kernel_true=None, A_pred=None, A=None): 125 | 126 | loss_class = args.lambda_class_reg * F.nll_loss(logits, labels) 127 | loss_fgsd_kernel = torch.FloatTensor([0]).to(device) 128 | loss_spk_kernel = torch.FloatTensor([0]).to(device) 129 | 130 | if class_kernel_true is not None: 131 | kernel_max = torch.max(kernel_true, torch.max(fgsd_kernel_true, spk_kernel_true)) 132 | kernel_min = torch.min(kernel_true, torch.min(fgsd_kernel_true, spk_kernel_true)) 133 | kernel_combo = class_kernel_true * kernel_max + (1.0 - class_kernel_true) * kernel_min 134 | loss_kernel = args.lambda_kernel_reg * F.mse_loss(kernel_pred, kernel_combo) 135 | else: 136 | loss_kernel = args.lambda_kernel_reg * F.mse_loss(kernel_pred, kernel_true) 137 | if fgsd_kernel_true is not None: 138 | loss_fgsd_kernel = args.lambda_fgsd_kernel_reg * F.mse_loss(fgsd_kernel_pred, fgsd_kernel_true) 139 | if spk_kernel_true is not None: 140 | loss_spk_kernel = args.lambda_spk_kernel_reg * F.mse_loss(spk_kernel_pred, spk_kernel_true) 141 | 142 | loss_adj_reconst = torch.FloatTensor([0]).to(device) 143 | # if A_pred is not None: 144 | # num_nodes = A.shape[0] 145 | # num_edges = torch.sparse.sum(A) 146 | # edge_ratio = (num_nodes * num_nodes - num_edges) / num_edges 147 | # A_true = A.to_dense() + torch.eye(num_nodes).to(device) 148 | # # A_true = A.to_dense() 149 | # loss_adj_reconst = args.lambda_adj_reconst_reg * F.binary_cross_entropy_with_logits(A_pred, A_true, pos_weight=edge_ratio) 150 | # # loss_adj_reconst = args.lambda_adj_reconst_reg * F.mse_loss(torch.sigmoid(A_pred), A_true) 151 | 152 | total_loss = 0 * loss_kernel 153 | if args.use_class_loss: 154 | total_loss = total_loss + loss_class 155 | if args.use_wl_features: 156 | total_loss = total_loss + loss_kernel 157 | if args.use_fgsd_features: 158 | total_loss = total_loss + loss_fgsd_kernel 159 | if args.use_spk_features: 160 | total_loss = total_loss + loss_spk_kernel 161 | if args.use_adj_reconst_loss: 162 | total_loss = total_loss + loss_adj_reconst 163 | 164 | return total_loss, loss_class, loss_kernel, loss_fgsd_kernel, loss_spk_kernel, loss_adj_reconst 165 | 166 | 167 | def train_per_epoch(): 168 | global best_validation_loss, best_acc_validation, best_acc_validation_loss 169 | 170 | t = time.time() 171 | model.train() 172 | num_samples = len(train_loader.dataset) 173 | total_loss = 0 174 | total_class_loss = 0 175 | total_kernel_loss = 0 176 | total_fgsd_kernel_loss = 0 177 | total_spk_kernel_loss = 0 178 | total_adj_reconst_loss = 0 179 | correct = 0 180 | num_batches = 0 181 | total_batch_prep_time = 0 182 | X_train = [] 183 | Y_train = [] 184 | 185 | for idx_batch, batch in enumerate(train_loader): 186 | 187 | X = batch['node_feature_matrix'].to(device) 188 | A = batch['adjacency_matrix'].to(device) 189 | # A_mask = batch['adjacency_mask'].to(device) 190 | A_mask = None 191 | batch_sample_matrix = batch['batch_sample_matrix'].to(device) 192 | graph_labels = batch['graph_labels'].to(device) 193 | kernel_true = batch['wl_kernel_matrix'].to(device) 194 | class_kernel_true = None 195 | spk_kernel_true = None 196 | fgsd_kernel_true = None 197 | if args.use_adaptive_kernel_loss: 198 | class_kernel_true = batch['class_kernel_matrix'].to(device) 199 | if args.use_spk_features: 200 | spk_kernel_true = batch['spk_kernel_matrix'].to(device) 201 | if args.use_fgsd_features: 202 | fgsd_kernel_true = batch['fgsd_kernel_matrix'].to(device) 203 | 204 | optimizer.zero_grad() 205 | logits, kernel_pred, fgsd_kernel_pred, spk_kernel_pred, graph_emb, A_pred = model(X, A, A_mask, batch_sample_matrix) 206 | loss, loss_class, loss_kernel, loss_fgsd_kernel, loss_spk_kernel, loss_adj_reconst = compute_loss(logits, graph_labels, kernel_pred, kernel_true, fgsd_kernel_pred, fgsd_kernel_true, spk_kernel_pred, spk_kernel_true, class_kernel_true, A_pred, A) 207 | loss.backward() 208 | optimizer.step() 209 | optimizer_lr_scheduler.step() 210 | 211 | pred = logits.max(dim=1)[1] 212 | correct += pred.eq(graph_labels).sum().item() 213 | 214 | total_loss += loss.item() 215 | total_class_loss += loss_class.item() 216 | total_kernel_loss += loss_kernel.item() 217 | total_fgsd_kernel_loss += loss_fgsd_kernel.item() 218 | total_spk_kernel_loss += loss_spk_kernel.item() 219 | total_adj_reconst_loss += loss_adj_reconst.item() 220 | num_batches = num_batches + 1 221 | total_batch_prep_time = total_batch_prep_time + batch['prep_time'] 222 | X_train.append(graph_emb.data.cpu().numpy()) 223 | Y_train.append(graph_labels.data.cpu().numpy()) 224 | 225 | acc_train = correct / num_samples 226 | loss_train_per_sample = total_loss / num_batches 227 | loss_train_class_per_sample = total_class_loss / num_batches 228 | loss_kernel_train_per_sample = total_kernel_loss / num_batches 229 | loss_fgsd_kernel_train_per_sample = total_fgsd_kernel_loss / num_batches 230 | loss_spk_kernel_train_per_sample = total_spk_kernel_loss / num_batches 231 | loss_adj_reconst_train_per_sample = total_adj_reconst_loss / num_batches 232 | 233 | # loss_train_per_sample, loss_train_class_per_sample, loss_kernel_train_per_sample, loss_fgsd_kernel_train_per_sample, loss_spk_kernel_train_per_sample, loss_adj_reconst_train_per_sample, acc_train, X_train, Y_train = eval_model(train_loader) 234 | loss_validation_per_sample, loss_validation_class_per_sample, loss_kernel_validation_per_sample, loss_fgsd_kernel_validation_per_sample, loss_spk_kernel_validation_per_sample, loss_adj_reconst_validation_per_sample, acc_validation, X_validation, Y_validation = eval_model(validation_loader) 235 | 236 | X_train = np.concatenate(X_train) 237 | Y_train = np.concatenate(Y_train) 238 | X_validation = np.concatenate(X_validation) 239 | Y_validation = np.concatenate(Y_validation) 240 | 241 | if args.use_svm_classifier: 242 | clf = svm.SVC(C=1, gamma='scale', class_weight='balanced') 243 | clf.fit(X_train, Y_train) 244 | Y_pred = clf.predict(X_validation) 245 | acc_validation = accuracy_score(Y_validation, Y_pred) 246 | Y_pred = clf.predict(X_train) 247 | acc_train = accuracy_score(Y_train, Y_pred) 248 | 249 | if best_acc_validation < acc_validation: 250 | best_acc_validation = acc_validation 251 | if best_validation_loss > loss_validation_per_sample: 252 | best_validation_loss = loss_validation_per_sample 253 | best_acc_validation_loss = acc_validation 254 | 255 | writer.add_scalars('acc/best', {'best_acc_validation_loss': best_acc_validation_loss, 'best_acc_validation': best_acc_validation}, epoch + 1) 256 | writer.add_scalars('acc', {'acc_train': acc_train, 'acc_validation': acc_validation}, epoch + 1) 257 | writer.add_scalars('loss/loss_total', {'loss_train_per_sample': loss_train_per_sample, 'loss_validation_per_sample': loss_validation_per_sample}, epoch + 1) 258 | writer.add_scalars('loss/loss_wl', {'loss_kernel_train_per_sample': loss_kernel_train_per_sample, 'loss_kernel_validation_per_sample': loss_kernel_validation_per_sample}, epoch + 1) 259 | writer.add_scalars('loss/loss_fgsd', {'loss_fgsd_kernel_train_per_sample': loss_fgsd_kernel_train_per_sample, 'loss_fgsd_kernel_validation_per_sample': loss_fgsd_kernel_validation_per_sample}, epoch + 1) 260 | writer.add_scalars('loss/loss_spk', {'loss_spk_kernel_train_per_sample': loss_spk_kernel_train_per_sample, 'loss_spk_kernel_validation_per_sample': loss_spk_kernel_validation_per_sample}, epoch + 1) 261 | writer.add_scalars('loss/loss_adj', {'loss_adj_reconst_train_per_sample': loss_adj_reconst_train_per_sample, 'loss_adj_reconst_validation_per_sample': loss_adj_reconst_validation_per_sample}, epoch + 1) 262 | 263 | logging.info('Epoch: {:04d}'.format(epoch + 1) + 264 | ' acc_train: {:.4f}'.format(acc_train) + 265 | ' acc_validation: {:.4f}'.format(acc_validation) + 266 | ' best_acc_validation: {:.4f}'.format(best_acc_validation) + 267 | ' best_acc_validation_loss: {:.4f}'.format(best_acc_validation_loss) + 268 | ' loss_train: {:08.5f}'.format(loss_train_per_sample) + 269 | ' loss_class_train: {:08.5f}'.format(loss_train_class_per_sample) + 270 | ' loss_kernel_train: {:08.5f}'.format(loss_kernel_train_per_sample) + 271 | ' loss_fgsd_kernel_train: {:08.5f}'.format(loss_fgsd_kernel_train_per_sample) + 272 | ' loss_spk_kernel_train: {:08.5f}'.format(loss_spk_kernel_train_per_sample) + 273 | ' loss_adj_reconst_train: {:08.5f}'.format(loss_adj_reconst_train_per_sample) + 274 | ' loss_validation: {:08.5f}'.format(loss_validation_per_sample) + 275 | ' loss_class_validation: {:08.5f}'.format(loss_validation_class_per_sample) + 276 | ' loss_kernel_validation: {:08.5f}'.format(loss_kernel_validation_per_sample) + 277 | ' loss_fgsd_kernel_validation: {:08.5f}'.format(loss_fgsd_kernel_validation_per_sample) + 278 | ' loss_spk_kernel_validation: {:08.5f}'.format(loss_spk_kernel_validation_per_sample) + 279 | ' loss_adj_reconst_validation: {:08.5f}'.format(loss_adj_reconst_validation_per_sample) + 280 | ' lr: {:.2e}'.format(optimizer.param_groups[0]['lr']) + 281 | ' batch_prep_time: {:.4f}s'.format(total_batch_prep_time) + 282 | ' crossval_split: {:04d}'.format(args.crossval_split) + 283 | ' time: {:.4f}s'.format(time.time() - t)) 284 | 285 | with experiment.train(): 286 | experiment.log_metric("loss", loss_train_per_sample, step=epoch) 287 | experiment.log_metric("accuracy", float('{:.4f}'.format(acc_train)), step=epoch) 288 | experiment.log_metric("loss_train", float('{:.4f}'.format(loss_train_per_sample)), step=epoch) 289 | experiment.log_metric("loss_class_train", float('{:.4f}'.format(loss_train_class_per_sample)), step=epoch) 290 | experiment.log_metric("loss_kernel_train", float('{:.4f}'.format(loss_kernel_train_per_sample)), step=epoch) 291 | experiment.log_metric("loss_fgsd_kernel_train", float('{:.4f}'.format(loss_fgsd_kernel_train_per_sample)), step=epoch) 292 | experiment.log_metric("loss_spk_kernel_train", float('{:.4f}'.format(loss_spk_kernel_train_per_sample)), step=epoch) 293 | experiment.log_metric("loss_adj_reconst_train", float('{:.4f}'.format(loss_adj_reconst_train_per_sample)), step=epoch) 294 | 295 | with experiment.validation(): 296 | experiment.log_metric("loss", loss_validation_per_sample, step=epoch) 297 | experiment.log_metric("accuracy", float('{:.4f}'.format(acc_validation)), step=epoch) 298 | experiment.log_metric("best_acc", float('{:.4f}'.format(best_acc_validation)), step=epoch) 299 | experiment.log_metric("loss_best_acc", float('{:.4f}'.format(best_acc_validation_loss)), step=epoch) 300 | experiment.log_metric("epoch", float('{:04d}'.format(epoch + 1)), step=epoch) 301 | experiment.log_metric("loss_validation", float('{:.4f}'.format(loss_validation_per_sample)), step=epoch) 302 | experiment.log_metric("loss_class_validation", float('{:.4f}'.format(loss_validation_class_per_sample)), step=epoch) 303 | experiment.log_metric("loss_kernel_validation", float('{:.4f}'.format(loss_kernel_validation_per_sample)), step=epoch) 304 | experiment.log_metric("loss_fgsd_kernel_validation", float('{:.4f}'.format(loss_fgsd_kernel_validation_per_sample)), step=epoch) 305 | experiment.log_metric("loss_spk_kernel_validation", float('{:.4f}'.format(loss_spk_kernel_validation_per_sample)), step=epoch) 306 | experiment.log_metric("loss_adj_reconst_validation", float('{:.4f}'.format(loss_adj_reconst_validation_per_sample)), step=epoch) 307 | 308 | 309 | if __name__ == '__main__': 310 | 311 | seed = 42 312 | np.random.seed(seed) 313 | random.seed(seed) 314 | torch.manual_seed(seed) 315 | torch.cuda.manual_seed(seed) 316 | torch.cuda.manual_seed_all(seed) 317 | # torch.backends.cudnn.deterministic = True 318 | # torch.backends.cudnn.benchmark = False 319 | 320 | if 'pydevd' in sys.modules: 321 | DEBUGGING = True 322 | else: 323 | DEBUGGING = False 324 | 325 | parser = argparse.ArgumentParser(description='Model Arguments') 326 | parser.add_argument('--data_dir', type=str, default='data/') 327 | parser.add_argument('--log_dir', type=str, default='logs/') 328 | parser.add_argument('--pretrained_model_file', type=str, default=None) 329 | parser.add_argument('--gpu_device', type=int, default=1) 330 | parser.add_argument('--dataset_name', type=str, default='DD') 331 | parser.add_argument('--crossval_split', type=int, default=2) 332 | parser.add_argument('--save_steps', type=int, default=100) 333 | parser.add_argument('--run_on_comet', type=lambda x: (str(x).lower() == 'true'), default=not DEBUGGING) 334 | 335 | parser.add_argument('--batch_size', type=int, default=64) 336 | parser.add_argument('--hidden_dim', type=int, default=32) 337 | parser.add_argument('--drop_prob', type=float, default=0) 338 | parser.add_argument('--num_gconv_layers', type=int, default=5) 339 | parser.add_argument('--num_gfc_layers', type=int, default=2) 340 | parser.add_argument('--num_epochs', type=int, default=1000) 341 | 342 | parser.add_argument('--use_class_loss', type=lambda x: (str(x).lower() == 'true'), default=True) 343 | parser.add_argument('--use_wl_features', type=lambda x: (str(x).lower() == 'true'), default=True) 344 | parser.add_argument('--use_fgsd_features', type=lambda x: (str(x).lower() == 'true'), default=True) 345 | parser.add_argument('--use_spk_features', type=lambda x: (str(x).lower() == 'true'), default=True) 346 | parser.add_argument('--use_adj_reconst_loss', type=lambda x: (str(x).lower() == 'true'), default=False) 347 | parser.add_argument('--use_adaptive_kernel_loss', type=lambda x: (str(x).lower() == 'true'), default=True) 348 | parser.add_argument('--use_svm_classifier', type=lambda x: (str(x).lower() == 'true'), default=False) 349 | parser.add_argument('--lambda_class_reg', type=float, default=1.0) 350 | parser.add_argument('--lambda_kernel_reg', type=float, default=1.0) 351 | parser.add_argument('--lambda_fgsd_kernel_reg', type=float, default=1.0) 352 | parser.add_argument('--lambda_spk_kernel_reg', type=float, default=1.0) 353 | parser.add_argument('--lambda_adj_reconst_reg', type=float, default=1.0) 354 | 355 | parser.add_argument('--warmup_epochs', type=float, nargs='*', default=[2.0], help='Number of epochs during which learning rate increases linearly from init_lr to max_lr. Afterwards, learning rate decreases exponentially from max_lr to final_lr.') 356 | parser.add_argument('--init_lr', type=float, nargs='*', default=[1e-4], help='Initial learning rate') 357 | parser.add_argument('--max_lr', type=float, nargs='*', default=[1e-3], help='Maximum learning rate') 358 | parser.add_argument('--final_lr', type=float, nargs='*', default=[1e-4], help='Final learning rate') 359 | parser.add_argument('--lr_scaler', type=float, nargs='*', default=[1.0], help='Amount by which to scale init_lr, max_lr, and final_lr (for convenience)') 360 | parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='lr decay per epoch, for decay scheduler') 361 | 362 | args, unknown = parser.parse_known_args() 363 | 364 | experiment = Experiment(api_key=API_KEY, project_name="universal-graph-embedding", workspace="saurabh08", disabled=not args.run_on_comet) 365 | experiment_id = experiment.get_key() 366 | 367 | data_path = os.path.join(args.data_dir, args.dataset_name) 368 | log_path = os.path.join(args.log_dir, experiment_id) 369 | if not os.path.exists(log_path): 370 | os.makedirs(log_path) 371 | logging.basicConfig(format='%(message)s', level=logging.INFO, handlers=[logging.StreamHandler(), logging.FileHandler(os.path.join(log_path, 'console_output.txt'))]) 372 | 373 | run_filepath = os.path.abspath(__file__) 374 | shutil.copy(run_filepath, log_path) 375 | src_list = ['./train', './utils', './torch_dgl', './dataloader', './config'] 376 | dest_list = [os.path.join(log_path, 'train'), os.path.join(log_path, 'utils'), os.path.join(log_path, 'torch_dgl'), os.path.join(log_path, 'dataloader'), os.path.join(log_path, 'config')] 377 | for src, dest in zip(src_list, dest_list): 378 | shutil.copytree(src, dest) 379 | 380 | for arg, value in sorted(vars(args).items()): 381 | logging.info("Hyperparameter: %s: %r", arg, value) 382 | 383 | writer = SummaryWriter('tensorboard/') 384 | 385 | dataset = TUDataset(data_path, name=args.dataset_name, shuffle=False, compute_graph_kernel_features=True, wl_node_labels='node_label') 386 | cross_val_path = os.path.join(args.data_dir, args.dataset_name, 'crossval_10fold_idx/') 387 | idx_train = np.loadtxt(cross_val_path + 'idx_train_split_' + str(args.crossval_split) + '.txt', dtype=np.int64) 388 | idx_validation = np.loadtxt(cross_val_path + 'idx_validation_split_' + str(args.crossval_split) + '.txt', dtype=np.int64) 389 | train_dataset = dataset[idx_train] 390 | validation_dataset = dataset[idx_validation] 391 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, batch_prep_func=batch_prep_input) 392 | validation_loader = DataLoader(validation_dataset, batch_size=args.batch_size, shuffle=True, batch_prep_func=batch_prep_input) 393 | 394 | num_train_samples = len(train_dataset) 395 | num_features = dataset.num_features 396 | num_classes = dataset.num_classes 397 | device = torch.device('cuda', args.gpu_device) 398 | 399 | if args.use_spk_features: 400 | spk_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_name + '_K_shorvalidation_path.npy') 401 | spk_kernel = np.load(spk_path) 402 | if args.use_fgsd_features: 403 | fgsd_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_name + '_K_fgsd.npy') 404 | fgsd_kernel = np.load(fgsd_path) 405 | 406 | #################################### 407 | 408 | model = UniversalGraphEmbedder(num_features, num_classes, args.hidden_dim, args.num_gconv_layers, args.num_gfc_layers, args.drop_prob, args.use_fgsd_features, args.use_spk_features).to(device) 409 | loss_func = nn.MSELoss() 410 | 411 | # optimizer = optim.Adam(model.parameters(), lr=0.01) 412 | # lr_decay_factor = 0.1 413 | # lr_decay_at_epochs = 50 414 | # optimizer_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_at_epochs, gamma=lr_decay_factor) 415 | 416 | optimizer = optim.Adam(model.parameters(), lr=args.init_lr[0], weight_decay=0.1) 417 | optimizer_lr_scheduler = NoamLR(optimizer=optimizer, warmup_epochs=args.warmup_epochs, 418 | total_epochs=[args.num_epochs], 419 | steps_per_epoch=num_train_samples // args.batch_size, 420 | init_lr=args.init_lr, max_lr=args.max_lr, final_lr=args.final_lr) 421 | 422 | logging.info('Starting epoch: {:04d}'.format(1) + ' current optimizer lr: {:.2e}'.format(optimizer.param_groups[0]['lr'])) 423 | 424 | if args.pretrained_model_file is not None: 425 | logging.info('Loading previous trained model file from: ' + str(args.pretrained_model_file)) 426 | pretrain_model_state = load_model(args.pretrained_model_file) 427 | model_state_dict = load_partial_model_state(model.state_dict(), pretrain_model_state['model_state_dict']) 428 | model.load_state_dict(model_state_dict) 429 | # optimizer.load_state_dict(model_state['optimizer_state_dict']) 430 | else: 431 | logging.info('Training from scratch... ') 432 | 433 | best_validation_loss = 1e10 434 | best_acc_validation = 0.0 435 | best_acc_validation_loss = 0.0 436 | for epoch in range(args.num_epochs): 437 | 438 | train_per_epoch() 439 | # optimizer_lr_scheduler.step() 440 | 441 | if (epoch+1) % args.save_steps == 0: 442 | model_checkpoint_name = "model_state_epoch_" + str(epoch+1) + ".pt" 443 | save_path = os.path.join(log_path, model_checkpoint_name) 444 | save_model(model, optimizer, save_path, device) 445 | 446 | writer.close() 447 | 448 | -------------------------------------------------------------------------------- /utils/compute_fgsd_features.py: -------------------------------------------------------------------------------- 1 | from scipy import sparse 2 | import numpy as np 3 | from utils.fast_fgsd_features import fgsd_features 4 | import time 5 | from sklearn import random_projection 6 | 7 | 8 | def compute_reduce_fgsd_features(graph_list): 9 | 10 | print('Computing fgsd features...') 11 | t = time.time() 12 | feature_matrix = fgsd_features([graph['adj_matrix'] for graph in graph_list]) # TODO: quality testing 13 | print('Feature matrix shape: ', feature_matrix.shape) 14 | print('Time Taken: ', time.time() - t) 15 | 16 | print('Performing dimension reduction...') 17 | t = time.time() 18 | transformer = random_projection.SparseRandomProjection(n_components=1000) 19 | feature_matrix_reduce = transformer.fit_transform(feature_matrix) 20 | feature_matrix_reduce[feature_matrix_reduce < 0] = 0 21 | feature_matrix_reduce[feature_matrix_reduce > 0] = 1 22 | print('feature matrix density: ', feature_matrix_reduce.getnnz() / (feature_matrix_reduce.shape[0] * feature_matrix_reduce.shape[1])) 23 | feature_matrix_reduce = feature_matrix_reduce.todense() 24 | print('Time Taken: ', time.time() - t) 25 | 26 | return feature_matrix_reduce 27 | -------------------------------------------------------------------------------- /utils/compute_wl_kernel.py: -------------------------------------------------------------------------------- 1 | from scipy import sparse 2 | import numpy as np 3 | from utils.fast_wl_kernel import wl_kernel, wl_kernel_batch 4 | import time 5 | from sklearn.decomposition import TruncatedSVD, PCA, SparsePCA, NMF 6 | from sklearn import random_projection 7 | import logging 8 | 9 | 10 | def compute_reduce_wl_kernel(graph_list, num_iter=5, reduce_feature_dim=1000): 11 | 12 | logging.info('Preparing WL kernel input data...') 13 | t = time.time() 14 | node_labels = graph_list[0]['node_degree_vec'] 15 | graph_indicator = 0 * np.ones(graph_list[0]['num_nodes'], dtype=np.int64) 16 | for i in range(1, len(graph_list)): 17 | if (i + 1) % 1000 == 0: 18 | logging.info('num graphs processed so far: ' + str(i + 1)) 19 | node_labels = np.concatenate((node_labels, graph_list[i]['node_degree_vec'])) 20 | graph_indicator = np.concatenate((graph_indicator, i * np.ones(graph_list[i]['num_nodes'], dtype=np.int64))) 21 | logging.info('Time Taken: ' + str(time.time() - t)) 22 | 23 | logging.info('Computing WL Kernel...') 24 | t = time.time() 25 | feature_matrix = wl_kernel_batch([graph['adj_matrix'] for graph in graph_list], node_labels, graph_indicator, num_iter, compute_kernel_matrix=False) # TODO: quality testing 26 | logging.info('Feature matrix shape: ' + str(feature_matrix.shape)) 27 | logging.info('Time Taken: ' + str(time.time() - t)) 28 | 29 | feature_matrix = sparse.csr_matrix(feature_matrix) 30 | logging.info('Performing dimension reduction...') 31 | t = time.time() 32 | transformer = random_projection.SparseRandomProjection(n_components=reduce_feature_dim) 33 | feature_matrix_reduce = transformer.fit_transform(feature_matrix) 34 | feature_matrix_reduce[feature_matrix_reduce < 0] = 0 35 | feature_matrix_reduce[feature_matrix_reduce > 0] = 1 36 | logging.info('feature matrix density: ' + str(feature_matrix_reduce.getnnz() / (feature_matrix_reduce.shape[0] * feature_matrix_reduce.shape[1]))) 37 | feature_matrix_reduce = feature_matrix_reduce.todense() 38 | logging.info('Time Taken: ' + str(time.time() - t)) 39 | 40 | return feature_matrix_reduce 41 | 42 | 43 | def compute_full_wl_kernel(graph_list, num_iter=5, type_node_labels='degree'): 44 | 45 | logging.info('Preparing WL kernel input data...') 46 | t = time.time() 47 | if type_node_labels == 'degree': 48 | node_labels = graph_list[0]['node_degree_vec'] 49 | elif type_node_labels == 'node_label': 50 | node_labels = graph_list[0]['node_labels'] 51 | 52 | graph_indicator = 0 * np.ones(graph_list[0]['num_nodes'], dtype=np.int64) 53 | for i in range(1, len(graph_list)): 54 | if (i + 1) % 1000 == 0: 55 | logging.info('num graphs processed so far: ' + str(i + 1)) 56 | if type_node_labels == 'degree': 57 | curr_node_labels = graph_list[i]['node_degree_vec'] 58 | elif type_node_labels == 'node_label': 59 | curr_node_labels = graph_list[i]['node_labels'] 60 | node_labels = np.concatenate((node_labels, curr_node_labels)) 61 | graph_indicator = np.concatenate((graph_indicator, i * np.ones(graph_list[i]['num_nodes'], dtype=np.int64))) 62 | logging.info('Time Taken: ' + str(time.time() - t)) 63 | 64 | logging.info('Computing WL Kernel...') 65 | t = time.time() 66 | feature_matrix = wl_kernel_batch([graph['adj_matrix'] for graph in graph_list], node_labels, graph_indicator, num_iter, compute_kernel_matrix=False, normalize_feature_matrix=True) 67 | logging.info('Feature matrix shape: ' + str(feature_matrix.shape)) 68 | logging.info('Time Taken: ' + str(time.time() - t)) 69 | 70 | return feature_matrix 71 | -------------------------------------------------------------------------------- /utils/fast_fgsd_features.py: -------------------------------------------------------------------------------- 1 | from scipy import sparse 2 | import numpy as np 3 | import time 4 | from fast_histogram import histogram1d 5 | import matlab.engine 6 | 7 | eng = matlab.engine.start_matlab() 8 | 9 | 10 | def fgsd_features(graph_list): 11 | 12 | S_max = 0 13 | S_list = [] 14 | print('Computing pseudo inverse...') 15 | t = time.time() 16 | for i, A in enumerate(graph_list): 17 | if (i + 1) % 1000 == 0: 18 | print('num graphs processed so far: ', i + 1) 19 | A = np.array(A.todense(), dtype=np.float32) 20 | D = np.sum(A, axis=0) 21 | L = np.diag(D) - A 22 | 23 | ones_vector = np.ones(L.shape[0]) 24 | try: 25 | fL = np.linalg.pinv(L) 26 | except np.linalg.LinAlgError: 27 | fL = np.array(eng.fgsd_fast_pseudo_inverse(matlab.double(L.tolist()), nargout=1)) 28 | fL[np.isinf(fL)] = 0 29 | fL[np.isnan(fL)] = 0 30 | 31 | S = np.outer(np.diag(fL), ones_vector) + np.outer(ones_vector, np.diag(fL)) - 2 * fL 32 | if S.max() > S_max: 33 | S_max = S.max() 34 | S_list.append(S) 35 | 36 | print('S_max: ', S_max) 37 | print('Time Taken: ', time.time() - t) 38 | 39 | feature_matrix = [] 40 | nbins = 1000000 41 | range_hist = (0, S_max) 42 | print('Computing histogram...') 43 | t = time.time() 44 | for i, S in enumerate(S_list): 45 | if (i + 1) % 1000 == 0: 46 | print('num graphs processed so far: ', i + 1) 47 | # hist, _ = np.histogram(S.flatten(), bins=nbins, range=range_hist) 48 | hist = histogram1d(S.flatten(), bins=nbins, range=range_hist) 49 | hist = sparse.csr_matrix(hist) 50 | feature_matrix.append(hist) 51 | print('Time Taken: ', time.time() - t) 52 | 53 | feature_matrix = sparse.vstack(feature_matrix) 54 | return feature_matrix 55 | -------------------------------------------------------------------------------- /utils/fast_wl_kernel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sympy 3 | from scipy import sparse 4 | from itertools import product 5 | import networkx as nx 6 | import scipy.io 7 | from sklearn.preprocessing import normalize 8 | 9 | 10 | prime_numbers_list = np.load('data/prime_numbers_list_v2.npy') 11 | 12 | 13 | def uniquetol(ar, tol=1e-12, return_index=False, return_inverse=False, return_counts=False, axis=None): 14 | ar = np.asanyarray(ar) 15 | if axis is None: 16 | return unique1dtol(ar, tol, return_index, return_inverse, return_counts) 17 | if not (-ar.ndim <= axis < ar.ndim): 18 | raise ValueError('Invalid axis kwarg specified for unique') 19 | 20 | ar = np.swapaxes(ar, axis, 0) 21 | orig_shape, orig_dtype = ar.shape, ar.dtype 22 | ar = ar.reshape(orig_shape[0], -1) 23 | ar = np.ascontiguousarray(ar) 24 | 25 | dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])] 26 | 27 | try: 28 | consolidated = ar.view(dtype) 29 | except TypeError: 30 | msg = 'The axis argument to unique is not supported for dtype {dt}' 31 | raise TypeError(msg.format(dt=ar.dtype)) 32 | 33 | def reshape_uniq(uniq): 34 | uniq = uniq.view(orig_dtype) 35 | uniq = uniq.reshape(-1, *orig_shape[1:]) 36 | uniq = np.swapaxes(uniq, 0, axis) 37 | return uniq 38 | 39 | output = unique1dtol(consolidated, tol, return_index, return_inverse, return_counts) 40 | if not (return_index or return_inverse or return_counts): 41 | return reshape_uniq(output) 42 | else: 43 | uniq = reshape_uniq(output[0]) 44 | return (uniq,) + output[1:] 45 | 46 | 47 | def unique1dtol(ar, tol, return_index=False, return_inverse=False, return_counts=False): 48 | 49 | ar = np.asanyarray(ar).flatten() 50 | 51 | optional_indices = return_index or return_inverse 52 | optional_returns = optional_indices or return_counts 53 | 54 | if ar.size == 0: 55 | if not optional_returns: 56 | ret = ar 57 | else: 58 | ret = (ar,) 59 | if return_index: 60 | ret += (np.empty(0, np.intp),) 61 | if return_inverse: 62 | ret += (np.empty(0, np.intp),) 63 | if return_counts: 64 | ret += (np.empty(0, np.intp),) 65 | return ret 66 | 67 | if optional_indices: 68 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 69 | aux = ar[perm] 70 | else: 71 | ar.sort() 72 | aux = ar 73 | flag = np.concatenate(([True], np.absolute(aux[1:] - aux[:-1]) >= tol * np.max(np.absolute(aux[:])))) 74 | 75 | if not optional_returns: 76 | ret = aux[flag] 77 | else: 78 | ret = (aux[flag],) 79 | if return_index: 80 | ret += (perm[flag],) 81 | if return_inverse: 82 | iflag = np.cumsum(flag) - 1 83 | inv_idx = np.empty(ar.shape, dtype=np.intp) 84 | inv_idx[perm] = iflag 85 | ret += (inv_idx,) 86 | if return_counts: 87 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 88 | ret += (np.diff(idx),) 89 | return ret 90 | 91 | 92 | def wl_transformation(A, node_labels): 93 | 94 | num_labels = max(node_labels) + 1 95 | log_primes = np.log(prime_numbers_list[0:num_labels]) 96 | 97 | signatures = node_labels + A.dot(log_primes[node_labels]) 98 | _, new_labels = uniquetol(signatures.flatten('F'), return_inverse=True) 99 | return new_labels 100 | 101 | 102 | def wl_kernel(A, node_labels, graph_ind, num_iterations): 103 | 104 | num_graphs = max(graph_ind) + 1 105 | K = sparse.csr_matrix((num_graphs, num_graphs)) 106 | feature_matrix = None 107 | 108 | for i in range(num_iterations+1): 109 | 110 | num_nodes = len(graph_ind) 111 | num_node_labels = max(node_labels) + 1 112 | counts = sparse.coo_matrix((np.ones(num_nodes), (graph_ind, node_labels)), shape=(num_graphs, num_node_labels)) 113 | if feature_matrix is None: 114 | feature_matrix = counts 115 | else: 116 | feature_matrix = sparse.hstack([feature_matrix, counts]) 117 | 118 | # K_new = counts.dot(counts.transpose()) 119 | # K = K + K_new 120 | node_labels = wl_transformation(A, node_labels) 121 | 122 | return K, feature_matrix 123 | 124 | 125 | def wl_transformation_batch(A_batch, node_labels): 126 | 127 | num_labels = max(node_labels) + 1 128 | log_primes = np.log(prime_numbers_list[0:num_labels]) 129 | 130 | signatures = [] 131 | prev_total_num_nodes = 0 132 | for A in A_batch: 133 | curr_total_num_nodes = prev_total_num_nodes + A.shape[0] 134 | curr_node_labels = node_labels[prev_total_num_nodes: curr_total_num_nodes] 135 | out = curr_node_labels + A.dot(log_primes[curr_node_labels]) 136 | signatures.append(out) 137 | prev_total_num_nodes = curr_total_num_nodes 138 | 139 | signatures = np.concatenate(signatures, axis=0) 140 | _, new_labels = uniquetol(signatures.flatten('F'), return_inverse=True) 141 | return new_labels 142 | 143 | 144 | def wl_kernel_batch(A_batch, node_labels, graph_ind, num_iterations, compute_kernel_matrix=False, normalize_feature_matrix=False): 145 | 146 | num_graphs = max(graph_ind) + 1 147 | feature_matrix = None 148 | 149 | for i in range(num_iterations+1): 150 | num_nodes = len(graph_ind) 151 | num_node_labels = max(node_labels) + 1 152 | counts = sparse.coo_matrix((np.ones(num_nodes), (graph_ind, node_labels)), shape=(num_graphs, num_node_labels)) 153 | if feature_matrix is None: 154 | feature_matrix = counts 155 | else: 156 | feature_matrix = sparse.hstack([feature_matrix, counts]) 157 | node_labels = wl_transformation_batch(A_batch, node_labels) 158 | 159 | feature_matrix = sparse.csr_matrix(feature_matrix) 160 | if normalize_feature_matrix: 161 | normalize(feature_matrix, norm='l2', axis=1, copy=False, return_norm=False) 162 | if compute_kernel_matrix: 163 | K = feature_matrix.dot(feature_matrix.transpose()) 164 | return K, feature_matrix 165 | else: 166 | return feature_matrix 167 | -------------------------------------------------------------------------------- /utils/read_sdf_file.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def read_from_sdf(fname): 6 | 7 | file = open(fname, 'r') 8 | 9 | graph_list = [] 10 | while True: 11 | G = _read_sdf_molecule(file) 12 | if G == False: 13 | break 14 | graph_list.append(G) 15 | 16 | file.close() 17 | return graph_list 18 | 19 | 20 | # read a single molecule from file 21 | def _read_sdf_molecule(file): 22 | # read the header 3 lines 23 | line = file.readline().split('\t') 24 | if line[0] == '': 25 | return False 26 | gdb_id = int(line[0].split(' ')[1]) 27 | for i in range(2): 28 | file.readline() 29 | line = file.readline() 30 | 31 | # this does not work for 123456 which must be 123 and 456 32 | # (atoms, bonds) = [t(s) for t,s in zip((int,int),line.split())] 33 | num_atoms = int(line[:3]) 34 | num_bonds = int(line[3:6]) 35 | 36 | v = [] 37 | node_labels = [] 38 | for i in range(num_atoms): 39 | line = file.readline() 40 | atom_symbol = line.split()[3] 41 | v.append(i + 1) 42 | node_labels.append(atom_symbol) 43 | 44 | edge_list = [] 45 | for i in range(num_bonds): 46 | line = file.readline() 47 | u = int(line[:3]) - 1 48 | v = int(line[3:6]) - 1 49 | edge_list.append((u, v)) 50 | 51 | while line != '': 52 | line = file.readline() 53 | if line[:4] == "$$$$": 54 | break 55 | 56 | G = dict() 57 | G['node_labels'] = node_labels 58 | G['edge_list'] = edge_list 59 | G['gdb_id'] = gdb_id 60 | G['num_nodes'] = num_atoms 61 | G['num_edges'] = num_bonds 62 | 63 | return G 64 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | def load_model(path): 7 | model_filename = path 8 | model_state = torch.load(model_filename) 9 | return model_state 10 | 11 | 12 | def load_partial_model_state(model_state_dict, pretrain_model_state_dict): 13 | 14 | new_pretrain_model_state_dict = dict() 15 | for param_name, param in pretrain_model_state_dict.items(): 16 | if param_name in model_state_dict and param.size() == model_state_dict[param_name].size(): 17 | new_pretrain_model_state_dict[param_name] = param 18 | 19 | model_state_dict.update(new_pretrain_model_state_dict) 20 | return model_state_dict 21 | 22 | 23 | def save_model(model, optimizer, path, device): 24 | 25 | model.cpu() 26 | for state in optimizer.state.values(): 27 | for k, v in state.items(): 28 | if torch.is_tensor(v): 29 | state[k] = v.cpu() 30 | 31 | model_state = { 32 | 'model_state_dict': model.state_dict(), 33 | 'optimizer_state_dict': optimizer.state_dict(), 34 | } 35 | torch.save(model_state, path) 36 | 37 | model.to(device) 38 | for state in optimizer.state.values(): 39 | for k, v in state.items(): 40 | if torch.is_tensor(v): 41 | state[k] = v.to(device) 42 | 43 | 44 | class NoamLR(_LRScheduler): 45 | """ 46 | Noam learning rate scheduler with piecewise linear increase and exponential decay. 47 | 48 | The learning rate increases linearly from init_lr to max_lr over the course of 49 | the first warmup_steps (where warmup_steps = warmup_epochs * steps_per_epoch). 50 | Then the learning rate decreases exponentially from max_lr to final_lr over the 51 | course of the remaining total_steps - warmup_steps (where total_steps = 52 | total_epochs * steps_per_epoch). This is roughly based on the learning rate 53 | schedule from Attention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762). 54 | """ 55 | def __init__(self, optimizer, warmup_epochs, total_epochs, steps_per_epoch, init_lr, max_lr, final_lr): 56 | """ 57 | Initializes the learning rate scheduler. 58 | 59 | :param optimizer: A PyTorch optimizer. 60 | :param warmup_epochs: The number of epochs during which to linearly increase the learning rate. 61 | :param total_epochs: The total number of epochs. 62 | :param steps_per_epoch: The number of steps (batches) per epoch. 63 | :param init_lr: The initial learning rate. 64 | :param max_lr: The maximum learning rate (achieved after warmup_epochs). 65 | :param final_lr: The final learning rate (achieved after total_epochs). 66 | """ 67 | assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \ 68 | len(max_lr) == len(final_lr) 69 | 70 | self.num_lrs = len(optimizer.param_groups) 71 | 72 | self.optimizer = optimizer 73 | self.warmup_epochs = np.array(warmup_epochs) 74 | self.total_epochs = np.array(total_epochs) 75 | self.steps_per_epoch = steps_per_epoch 76 | self.init_lr = np.array(init_lr) 77 | self.max_lr = np.array(max_lr) 78 | self.final_lr = np.array(final_lr) 79 | 80 | self.current_step = 0 81 | self.lr = init_lr 82 | self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int) 83 | self.total_steps = self.total_epochs * self.steps_per_epoch 84 | self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps 85 | 86 | self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps)) 87 | 88 | super(NoamLR, self).__init__(optimizer) 89 | 90 | def get_lr(self): 91 | """Gets a list of the current learning rates.""" 92 | return list(self.lr) 93 | 94 | def step(self, current_step=None): 95 | """ 96 | Updates the learning rate by taking a step. 97 | 98 | :param current_step: Optionally specify what step to set the learning rate to. 99 | If None, current_step = self.current_step + 1. 100 | """ 101 | if current_step is not None: 102 | self.current_step = current_step 103 | else: 104 | self.current_step += 1 105 | 106 | for i in range(self.num_lrs): 107 | if self.current_step <= self.warmup_steps[i]: 108 | self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i] 109 | elif self.current_step <= self.total_steps[i]: 110 | self.lr[i] = self.max_lr[i] * (self.exponential_gamma[i] ** (self.current_step - self.warmup_steps[i])) 111 | else: # theoretically this case should never be reached since training should stop at total_steps 112 | self.lr[i] = self.final_lr[i] 113 | 114 | self.optimizer.param_groups[i]['lr'] = self.lr[i] 115 | --------------------------------------------------------------------------------