├── gcn_lib ├── dense │ ├── __init__.py │ ├── torch_nn.py │ ├── torch_edge.py │ └── torch_vertex.py └── sparse │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── torch_nn.cpython-37.pyc │ ├── torch_edge.cpython-37.pyc │ ├── torch_message.cpython-37.pyc │ └── torch_vertex.cpython-37.pyc │ ├── torch_nn.py │ ├── torch_edge.py │ ├── torch_message.py │ └── torch_vertex.py ├── utils ├── __pycache__ │ ├── loss.cpython-37.pyc │ ├── optim.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── metrics.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── ckpt_util.cpython-37.pyc │ └── data_util.cpython-37.pyc ├── __init__.py ├── metrics.py ├── loss.py ├── logger.py ├── tf_logger.py ├── ckpt_util.py ├── optim.py ├── pc_viz.py └── data_util.py ├── __init__.py ├── extract_fingerprint.py ├── README.md ├── random_forest.py ├── args.py ├── main.py ├── model.py ├── model_att.py └── finetune.py /gcn_lib/dense/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_nn import * 2 | from .torch_edge import * 3 | from .torch_vertex import * 4 | 5 | -------------------------------------------------------------------------------- /gcn_lib/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_nn import * 2 | from .torch_edge import * 3 | from .torch_vertex import * 4 | 5 | -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/utils/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/optim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/utils/__pycache__/optim.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ckpt_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/utils/__pycache__/ckpt_util.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/utils/__pycache__/data_util.cpython-37.pyc -------------------------------------------------------------------------------- /gcn_lib/sparse/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/gcn_lib/sparse/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /gcn_lib/sparse/__pycache__/torch_nn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/gcn_lib/sparse/__pycache__/torch_nn.cpython-37.pyc -------------------------------------------------------------------------------- /gcn_lib/sparse/__pycache__/torch_edge.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/gcn_lib/sparse/__pycache__/torch_edge.cpython-37.pyc -------------------------------------------------------------------------------- /gcn_lib/sparse/__pycache__/torch_message.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/gcn_lib/sparse/__pycache__/torch_message.cpython-37.pyc -------------------------------------------------------------------------------- /gcn_lib/sparse/__pycache__/torch_vertex.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhuoning/DeepAUC_OGB_Challenge/HEAD/gcn_lib/sparse/__pycache__/torch_vertex.cpython-37.pyc -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | sys.path.append(ROOT_DIR) 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .ckpt_util import * 2 | # from .data_util import * 3 | from .loss import * 4 | from .metrics import * 5 | from .optim import * 6 | # from .tf_logger import * 7 | 8 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from math import log10 2 | 3 | 4 | def PSNR(mse, peak=1.): 5 | return 10 * log10((peak ** 2) / mse) 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class SmoothCrossEntropy(torch.nn.Module): 6 | def __init__(self, smoothing=True, eps=0.2): 7 | super(SmoothCrossEntropy, self).__init__() 8 | self.smoothing = smoothing 9 | self.eps = eps 10 | 11 | def forward(self, pred, gt): 12 | gt = gt.contiguous().view(-1) 13 | 14 | if self.smoothing: 15 | n_class = pred.size(1) 16 | one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1) 17 | one_hot = one_hot * (1 - self.eps) + (1 - one_hot) * self.eps / (n_class - 1) 18 | log_prb = F.log_softmax(pred, dim=1) 19 | 20 | loss = -(one_hot * log_prb).sum(dim=1).mean() 21 | else: 22 | loss = F.cross_entropy(pred, gt, reduction='mean') 23 | 24 | return loss 25 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import csv 4 | 5 | 6 | def save_best_result(list_of_dict, file_name, dir_path='best_result'): 7 | if not os.path.exists(dir_path): 8 | os.mkdir(dir_path) 9 | # print("Directory ", dir_path, " is created.") 10 | csv_file_name = '{}/{}.csv'.format(dir_path, file_name) 11 | with open(csv_file_name, 'a+') as csv_file: 12 | csv_writer = csv.writer(csv_file) 13 | for _ in range(len(list_of_dict)): 14 | csv_writer.writerow(list_of_dict[_].values()) 15 | 16 | 17 | def create_exp_dir(path, scripts_to_save=None): 18 | if not os.path.exists(path): 19 | os.makedirs(path) 20 | # print('Experiment dir : {}'.format(path)) 21 | 22 | if scripts_to_save is not None: 23 | os.mkdir(os.path.join(path, 'scripts')) 24 | for script in scripts_to_save: 25 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 26 | shutil.copyfile(script, dst_file) 27 | -------------------------------------------------------------------------------- /extract_fingerprint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | from rdkit.Chem import AllChem 8 | from ogb.graphproppred import GraphPropPredDataset 9 | 10 | def getmorganfingerprint(mol): 11 | return list(AllChem.GetMorganFingerprintAsBitVect(mol, 2)) 12 | 13 | def getmaccsfingerprint(mol): 14 | fp = AllChem.GetMACCSKeysFingerprint(mol) 15 | return [int(b) for b in fp.ToBitString()] 16 | 17 | def main(dataset_name): 18 | dataset = GraphPropPredDataset(name=dataset_name) 19 | 20 | df_smi = pd.read_csv(f"dataset/{dataset_name}/mapping/mol.csv.gz".replace("-", "_")) 21 | smiles = df_smi["smiles"] 22 | 23 | mgf_feat_list = [] 24 | maccs_feat_list = [] 25 | for ii in tqdm(range(len(smiles))): 26 | rdkit_mol = AllChem.MolFromSmiles(smiles.iloc[ii]) 27 | 28 | mgf = getmorganfingerprint(rdkit_mol) 29 | mgf_feat_list.append(mgf) 30 | 31 | maccs = getmaccsfingerprint(rdkit_mol) 32 | maccs_feat_list.append(maccs) 33 | 34 | mgf_feat = np.array(mgf_feat_list, dtype="int64") 35 | maccs_feat = np.array(maccs_feat_list, dtype="int64") 36 | print("morgan feature shape: ", mgf_feat.shape) 37 | print("maccs feature shape: ", maccs_feat.shape) 38 | 39 | save_path = f"./dataset/{dataset_name}".replace("-", "_") 40 | print("saving feature in %s" % save_path) 41 | np.save(os.path.join(save_path, "mgf_feat.npy"), mgf_feat) 42 | np.save(os.path.join(save_path, "maccs_feat.npy"), maccs_feat) 43 | 44 | if __name__=="__main__": 45 | parser = argparse.ArgumentParser(description='gnn') 46 | parser.add_argument("--dataset_name", type=str, default="ogbg-molhiv") 47 | args = parser.parse_args() 48 | 49 | main(args.dataset_name) -------------------------------------------------------------------------------- /gcn_lib/dense/torch_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Sequential as Seq, Linear as Lin, Conv2d 4 | 5 | 6 | ############################## 7 | # Basic layers 8 | ############################## 9 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): 10 | # activation layer 11 | 12 | act = act.lower() 13 | if act == 'relu': 14 | layer = nn.ReLU(inplace) 15 | elif act == 'leakyrelu': 16 | layer = nn.LeakyReLU(neg_slope, inplace) 17 | elif act == 'prelu': 18 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 19 | else: 20 | raise NotImplementedError('activation layer [%s] is not found' % act) 21 | return layer 22 | 23 | 24 | def norm_layer(norm, nc): 25 | # normalization layer 2d 26 | norm = norm.lower() 27 | if norm == 'batch': 28 | layer = nn.BatchNorm2d(nc, affine=True) 29 | elif norm == 'instance': 30 | layer = nn.InstanceNorm2d(nc, affine=False) 31 | else: 32 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 33 | return layer 34 | 35 | 36 | class MLP(Seq): 37 | def __init__(self, channels, act='relu', norm=None, bias=True): 38 | m = [] 39 | for i in range(1, len(channels)): 40 | m.append(Lin(channels[i - 1], channels[i], bias)) 41 | if act is not None and act.lower() != 'none': 42 | m.append(act_layer(act)) 43 | if norm is not None and norm.lower() != 'none': 44 | m.append(norm_layer(norm, channels[-1])) 45 | super(MLP, self).__init__(*m) 46 | 47 | 48 | class BasicConv(Seq): 49 | def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): 50 | m = [] 51 | for i in range(1, len(channels)): 52 | m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias)) 53 | if act is not None and act.lower() != 'none': 54 | m.append(act_layer(act)) 55 | if norm is not None and norm.lower() != 'none': 56 | m.append(norm_layer(norm, channels[-1])) 57 | if drop > 0: 58 | m.append(nn.Dropout2d(drop)) 59 | 60 | super(BasicConv, self).__init__(*m) 61 | 62 | self.reset_parameters() 63 | 64 | def reset_parameters(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight) 68 | if m.bias is not None: 69 | nn.init.zeros_(m.bias) 70 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | 75 | def batched_index_select(x, idx): 76 | r"""fetches neighbors features from a given neighbor idx 77 | 78 | Args: 79 | x (Tensor): input feature Tensor 80 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. 81 | idx (Tensor): edge_idx 82 | :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. 83 | Returns: 84 | Tensor: output neighbors features 85 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. 86 | """ 87 | batch_size, num_dims, num_vertices = x.shape[:3] 88 | k = idx.shape[-1] 89 | idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices 90 | idx = idx + idx_base 91 | idx = idx.contiguous().view(-1) 92 | 93 | x = x.transpose(2, 1) 94 | feature = x.contiguous().view(batch_size * num_vertices, -1)[idx, :] 95 | feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() 96 | return feature 97 | -------------------------------------------------------------------------------- /gcn_lib/dense/torch_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_cluster import knn_graph 4 | 5 | 6 | class DenseDilated(nn.Module): 7 | """ 8 | Find dilated neighbor from neighbor list 9 | 10 | edge_index: (2, batch_size, num_points, k) 11 | """ 12 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 13 | super(DenseDilated, self).__init__() 14 | self.dilation = dilation 15 | self.stochastic = stochastic 16 | self.epsilon = epsilon 17 | self.k = k 18 | 19 | def forward(self, edge_index): 20 | if self.stochastic: 21 | if torch.rand(1) < self.epsilon and self.training: 22 | num = self.k * self.dilation 23 | randnum = torch.randperm(num)[:self.k] 24 | edge_index = edge_index[:, :, :, randnum] 25 | else: 26 | edge_index = edge_index[:, :, :, ::self.dilation] 27 | else: 28 | edge_index = edge_index[:, :, :, ::self.dilation] 29 | return edge_index 30 | 31 | 32 | def pairwise_distance(x): 33 | """ 34 | Compute pairwise distance of a point cloud. 35 | Args: 36 | x: tensor (batch_size, num_points, num_dims) 37 | Returns: 38 | pairwise distance: (batch_size, num_points, num_points) 39 | """ 40 | x_inner = -2*torch.matmul(x, x.transpose(2, 1)) 41 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 42 | return x_square + x_inner + x_square.transpose(2, 1) 43 | 44 | 45 | def dense_knn_matrix(x, k=16): 46 | """Get KNN based on the pairwise distance. 47 | Args: 48 | x: (batch_size, num_dims, num_points, 1) 49 | k: int 50 | Returns: 51 | nearest neighbors: (batch_size, num_points ,k) (batch_size, num_points, k) 52 | """ 53 | with torch.no_grad(): 54 | x = x.transpose(2, 1).squeeze(-1) 55 | batch_size, n_points, n_dims = x.shape 56 | _, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k) 57 | center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) 58 | return torch.stack((nn_idx, center_idx), dim=0) 59 | 60 | 61 | class DenseDilatedKnnGraph(nn.Module): 62 | """ 63 | Find the neighbors' indices based on dilated knn 64 | """ 65 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 66 | super(DenseDilatedKnnGraph, self).__init__() 67 | self.dilation = dilation 68 | self.stochastic = stochastic 69 | self.epsilon = epsilon 70 | self.k = k 71 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon) 72 | self.knn = dense_knn_matrix 73 | 74 | def forward(self, x): 75 | edge_index = self.knn(x, self.k * self.dilation) 76 | return self._dilated(edge_index) 77 | 78 | 79 | class DilatedKnnGraph(nn.Module): 80 | """ 81 | Find the neighbors' indices based on dilated knn 82 | """ 83 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 84 | super(DilatedKnnGraph, self).__init__() 85 | self.dilation = dilation 86 | self.stochastic = stochastic 87 | self.epsilon = epsilon 88 | self.k = k 89 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon) 90 | self.knn = knn_graph 91 | 92 | def forward(self, x): 93 | x = x.squeeze(-1) 94 | B, C, N = x.shape 95 | edge_index = [] 96 | for i in range(B): 97 | edgeindex = self.knn(x[i].contiguous().transpose(1, 0).contiguous(), self.k * self.dilation) 98 | edgeindex = edgeindex.view(2, N, self.k * self.dilation) 99 | edge_index.append(edgeindex) 100 | edge_index = torch.stack(edge_index, dim=1) 101 | return self._dilated(edge_index) 102 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_nn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import Sequential as Seq, Linear as Lin 3 | from utils.data_util import get_atom_feature_dims, get_bond_feature_dims 4 | 5 | 6 | ############################## 7 | # Basic layers 8 | ############################## 9 | def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): 10 | # activation layer 11 | act = act_type.lower() 12 | if act == 'relu': 13 | layer = nn.ReLU(inplace) 14 | elif act == 'leakyrelu': 15 | layer = nn.LeakyReLU(neg_slope, inplace) 16 | elif act == 'prelu': 17 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 18 | else: 19 | raise NotImplementedError('activation layer [%s] is not found' % act) 20 | return layer 21 | 22 | 23 | def norm_layer(norm_type, nc): 24 | # normalization layer 1d 25 | norm = norm_type.lower() 26 | if norm == 'batch': 27 | layer = nn.BatchNorm1d(nc, affine=True) 28 | elif norm == 'layer': 29 | layer = nn.LayerNorm(nc, elementwise_affine=True) 30 | elif norm == 'instance': 31 | layer = nn.InstanceNorm1d(nc, affine=False) 32 | else: 33 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 34 | return layer 35 | 36 | 37 | class MultiSeq(Seq): 38 | def __init__(self, *args): 39 | super(MultiSeq, self).__init__(*args) 40 | 41 | def forward(self, *inputs): 42 | for module in self._modules.values(): 43 | if type(inputs) == tuple: 44 | inputs = module(*inputs) 45 | else: 46 | inputs = module(inputs) 47 | return inputs 48 | 49 | 50 | class MLP(Seq): 51 | def __init__(self, channels, act='relu', 52 | norm=None, bias=True, 53 | drop=0., last_lin=False): 54 | m = [] 55 | 56 | for i in range(1, len(channels)): 57 | 58 | m.append(Lin(channels[i - 1], channels[i], bias)) 59 | 60 | if (i == len(channels) - 1) and last_lin: 61 | pass 62 | else: 63 | if norm is not None and norm.lower() != 'none': 64 | m.append(norm_layer(norm, channels[i])) 65 | if act is not None and act.lower() != 'none': 66 | m.append(act_layer(act)) 67 | if drop > 0: 68 | m.append(nn.Dropout2d(drop)) 69 | 70 | self.m = m 71 | super(MLP, self).__init__(*self.m) 72 | 73 | 74 | class AtomEncoder(nn.Module): 75 | 76 | def __init__(self, emb_dim): 77 | super(AtomEncoder, self).__init__() 78 | 79 | self.atom_embedding_list = nn.ModuleList() 80 | full_atom_feature_dims = get_atom_feature_dims() 81 | 82 | for i, dim in enumerate(full_atom_feature_dims): 83 | emb = nn.Embedding(dim, emb_dim) 84 | nn.init.xavier_uniform_(emb.weight.data) 85 | self.atom_embedding_list.append(emb) 86 | 87 | def forward(self, x): 88 | x_embedding = 0 89 | for i in range(x.shape[1]): 90 | x_embedding += self.atom_embedding_list[i](x[:, i]) 91 | 92 | return x_embedding 93 | 94 | 95 | class BondEncoder(nn.Module): 96 | 97 | def __init__(self, emb_dim): 98 | super(BondEncoder, self).__init__() 99 | 100 | self.bond_embedding_list = nn.ModuleList() 101 | full_bond_feature_dims = get_bond_feature_dims() 102 | 103 | for i, dim in enumerate(full_bond_feature_dims): 104 | emb = nn.Embedding(dim, emb_dim) 105 | nn.init.xavier_uniform_(emb.weight.data) 106 | self.bond_embedding_list.append(emb) 107 | 108 | def forward(self, edge_attr): 109 | bond_embedding = 0 110 | for i in range(edge_attr.shape[1]): 111 | bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) 112 | 113 | return bond_embedding 114 | 115 | 116 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_cluster import knn_graph 4 | 5 | 6 | class Dilated(nn.Module): 7 | """ 8 | Find dilated neighbor from neighbor list 9 | """ 10 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 11 | super(Dilated, self).__init__() 12 | self.dilation = dilation 13 | self.stochastic = stochastic 14 | self.epsilon = epsilon 15 | self.k = k 16 | 17 | def forward(self, edge_index, batch=None): 18 | if self.stochastic: 19 | if torch.rand(1) < self.epsilon and self.training: 20 | num = self.k * self.dilation 21 | randnum = torch.randperm(num)[:self.k] 22 | edge_index = edge_index.view(2, -1, num) 23 | edge_index = edge_index[:, :, randnum] 24 | return edge_index.view(2, -1) 25 | else: 26 | edge_index = edge_index[:, ::self.dilation] 27 | else: 28 | edge_index = edge_index[:, ::self.dilation] 29 | return edge_index 30 | 31 | 32 | class DilatedKnnGraph(nn.Module): 33 | """ 34 | Find the neighbors' indices based on dilated knn 35 | """ 36 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn='matrix'): 37 | super(DilatedKnnGraph, self).__init__() 38 | self.dilation = dilation 39 | self.stochastic = stochastic 40 | self.epsilon = epsilon 41 | self.k = k 42 | self._dilated = Dilated(k, dilation, stochastic, epsilon) 43 | if knn == 'matrix': 44 | self.knn = knn_graph_matrix 45 | else: 46 | self.knn = knn_graph 47 | 48 | def forward(self, x, batch): 49 | edge_index = self.knn(x, self.k * self.dilation, batch) 50 | return self._dilated(edge_index, batch) 51 | 52 | 53 | def pairwise_distance(x): 54 | """ 55 | Compute pairwise distance of a point cloud. 56 | Args: 57 | x: tensor (batch_size, num_points, num_dims) 58 | Returns: 59 | pairwise distance: (batch_size, num_points, num_points) 60 | """ 61 | x_inner = -2*torch.matmul(x, x.transpose(2, 1)) 62 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 63 | return x_square + x_inner + x_square.transpose(2, 1) 64 | 65 | 66 | def knn_matrix(x, k=16, batch=None): 67 | """Get KNN based on the pairwise distance. 68 | Args: 69 | pairwise distance: (num_points, num_points) 70 | k: int 71 | Returns: 72 | nearest neighbors: (num_points*k ,1) (num_points, k) 73 | """ 74 | with torch.no_grad(): 75 | if batch is None: 76 | batch_size = 1 77 | else: 78 | batch_size = batch[-1] + 1 79 | x = x.view(batch_size, -1, x.shape[-1]) 80 | 81 | neg_adj = -pairwise_distance(x.detach()) 82 | _, nn_idx = torch.topk(neg_adj, k=k) 83 | del neg_adj 84 | 85 | n_points = x.shape[1] 86 | start_idx = torch.arange(0, n_points*batch_size, n_points).long().view(batch_size, 1, 1) 87 | if x.is_cuda: 88 | start_idx = start_idx.cuda() 89 | nn_idx += start_idx 90 | del start_idx 91 | 92 | if x.is_cuda: 93 | torch.cuda.empty_cache() 94 | 95 | nn_idx = nn_idx.view(1, -1) 96 | center_idx = torch.arange(0, n_points*batch_size).repeat(k, 1).transpose(1, 0).contiguous().view(1, -1) 97 | if x.is_cuda: 98 | center_idx = center_idx.cuda() 99 | return nn_idx, center_idx 100 | 101 | 102 | def knn_graph_matrix(x, k=16, batch=None): 103 | """Construct edge feature for each point 104 | Args: 105 | x: (num_points, num_dims) 106 | batch: (num_points, ) 107 | k: int 108 | Returns: 109 | edge_index: (2, num_points*k) 110 | """ 111 | nn_idx, center_idx = knn_matrix(x, k, batch) 112 | return torch.cat((nn_idx, center_idx), dim=0) 113 | 114 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_message.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import MessagePassing 4 | from torch_scatter import scatter, scatter_softmax 5 | from torch_geometric.utils import degree 6 | 7 | 8 | class GenMessagePassing(MessagePassing): 9 | def __init__(self, aggr='softmax', 10 | t=1.0, learn_t=False, 11 | p=1.0, learn_p=False, 12 | y=0.0, learn_y=False): 13 | 14 | if aggr in ['softmax_sg', 'softmax', 'softmax_sum']: 15 | 16 | super(GenMessagePassing, self).__init__(aggr=None) 17 | self.aggr = aggr 18 | 19 | if learn_t and (aggr == 'softmax' or aggr == 'softmax_sum'): 20 | self.learn_t = True 21 | self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True) 22 | else: 23 | self.learn_t = False 24 | self.t = t 25 | 26 | if aggr == 'softmax_sum': 27 | self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) 28 | 29 | elif aggr in ['power', 'power_sum']: 30 | 31 | super(GenMessagePassing, self).__init__(aggr=None) 32 | self.aggr = aggr 33 | 34 | if learn_p: 35 | self.p = torch.nn.Parameter(torch.Tensor([p]), requires_grad=True) 36 | else: 37 | self.p = p 38 | 39 | if aggr == 'power_sum': 40 | self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) 41 | else: 42 | super(GenMessagePassing, self).__init__(aggr=aggr) 43 | 44 | def aggregate(self, inputs, index, ptr=None, dim_size=None): 45 | 46 | if self.aggr in ['add', 'mean', 'max', None]: 47 | return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size) 48 | 49 | elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']: 50 | 51 | if self.learn_t: 52 | out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) 53 | else: 54 | with torch.no_grad(): 55 | out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) 56 | 57 | out = scatter(inputs*out, index, dim=self.node_dim, 58 | dim_size=dim_size, reduce='sum') 59 | 60 | if self.aggr == 'softmax_sum': 61 | self.sigmoid_y = torch.sigmoid(self.y) 62 | degrees = degree(index, num_nodes=dim_size).unsqueeze(1) 63 | out = torch.pow(degrees, self.sigmoid_y) * out 64 | 65 | return out 66 | 67 | 68 | elif self.aggr in ['power', 'power_sum']: 69 | min_value, max_value = 1e-7, 1e1 70 | torch.clamp_(inputs, min_value, max_value) 71 | out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim, 72 | dim_size=dim_size, reduce='mean') 73 | torch.clamp_(out, min_value, max_value) 74 | out = torch.pow(out, 1/self.p) 75 | # torch.clamp(out, min_value, max_value) 76 | 77 | if self.aggr == 'power_sum': 78 | self.sigmoid_y = torch.sigmoid(self.y) 79 | degrees = degree(index, num_nodes=dim_size).unsqueeze(1) 80 | out = torch.pow(degrees, self.sigmoid_y) * out 81 | 82 | return out 83 | 84 | else: 85 | raise NotImplementedError('To be implemented') 86 | 87 | 88 | class MsgNorm(torch.nn.Module): 89 | def __init__(self, learn_msg_scale=False): 90 | super(MsgNorm, self).__init__() 91 | 92 | self.msg_scale = torch.nn.Parameter(torch.Tensor([1.0]), 93 | requires_grad=learn_msg_scale) 94 | 95 | def forward(self, x, msg, p=2): 96 | msg = F.normalize(msg, p=p, dim=1) 97 | x_norm = x.norm(p=p, dim=1, keepdim=True) 98 | msg = msg * x_norm * self.msg_scale 99 | return msg 100 | -------------------------------------------------------------------------------- /utils/tf_logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | try: 3 | import tensorflow as tf 4 | import tensorboard.plugins.mesh.summary as meshsummary 5 | except ImportError: 6 | print('tensorflow is not installed.') 7 | import numpy as np 8 | import scipy.misc 9 | 10 | 11 | try: 12 | from StringIO import StringIO # Python 2.7 13 | except ImportError: 14 | from io import BytesIO # Python 3.x 15 | 16 | 17 | class TfLogger(object): 18 | 19 | def __init__(self, log_dir): 20 | """Create a summary writer logging to log_dir.""" 21 | self.writer = tf.compat.v1.summary.FileWriter(log_dir) 22 | 23 | # Camera and scene configuration. 24 | self.config_dict = { 25 | 'camera': {'cls': 'PerspectiveCamera', 'fov': 75}, 26 | 'lights': [ 27 | { 28 | 'cls': 'AmbientLight', 29 | 'color': '#ffffff', 30 | 'intensity': 0.75, 31 | }, { 32 | 'cls': 'DirectionalLight', 33 | 'color': '#ffffff', 34 | 'intensity': 0.75, 35 | 'position': [0, -1, 2], 36 | }], 37 | 'material': { 38 | 'cls': 'MeshStandardMaterial', 39 | 'metalness': 0 40 | } 41 | } 42 | 43 | def scalar_summary(self, tag, value, step): 44 | """Log a scalar variable.""" 45 | summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) 46 | self.writer.add_summary(summary, step) 47 | 48 | def image_summary(self, tag, images, step): 49 | """Log a list of images.""" 50 | img_summaries = [] 51 | for i, img in enumerate(images): 52 | # Write the image to a string 53 | s = BytesIO() 54 | scipy.misc.toimage(img).save(s, format="png") 55 | 56 | # Create an Image object 57 | img_sum = tf.compat.v1.Summary.Image(encoded_image_string=s.getvalue(), 58 | height=img.shape[0], width=img.shape[1]) 59 | # Create a Summary value 60 | img_summaries.append(tf.compat.v1.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 61 | 62 | # Create and write Summary 63 | summary = tf.Summary(value=img_summaries) 64 | self.writer.add_summary(summary, step) 65 | 66 | def mesh_summary(self, tag, vertices, faces=None, colors=None, step=0): 67 | 68 | """Log a list of mesh images.""" 69 | if colors is None: 70 | colors = tf.constant(np.zeros_like(vertices)) 71 | vertices = tf.constant(vertices) 72 | if faces is not None: 73 | faces = tf.constant(faces) 74 | meshes_summares=[] 75 | for i in range(vertices.shape[0]): 76 | meshes_summares.append(meshsummary.op( 77 | tag, vertices=vertices, faces=faces, colors=colors, config_dict=self.config_dict)) 78 | 79 | sess = tf.Session() 80 | summaries = sess.run(meshes_summares) 81 | for summary in summaries: 82 | self.writer.add_summary(summary, step) 83 | 84 | def histo_summary(self, tag, values, step, bins=1000): 85 | """Log a histogram of the tensor of values.""" 86 | 87 | # Create a histogram using numpy 88 | counts, bin_edges = np.histogram(values, bins=bins) 89 | 90 | # Fill the fields of the histogram proto 91 | hist = tf.HistogramProto() 92 | hist.min = float(np.min(values)) 93 | hist.max = float(np.max(values)) 94 | hist.num = int(np.prod(values.shape)) 95 | hist.sum = float(np.sum(values)) 96 | hist.sum_squares = float(np.sum(values**2)) 97 | 98 | # Drop the start of the first bin 99 | bin_edges = bin_edges[1:] 100 | 101 | # Add bin edges and counts 102 | for edge in bin_edges: 103 | hist.bucket_limit.append(edge) 104 | for c in counts: 105 | hist.bucket.append(c) 106 | 107 | # Create and write Summary 108 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 109 | self.writer.add_summary(summary, step) 110 | self.writer.flush() 111 | 112 | -------------------------------------------------------------------------------- /gcn_lib/dense/torch_vertex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .torch_nn import BasicConv, batched_index_select 4 | from .torch_edge import DenseDilatedKnnGraph, DilatedKnnGraph 5 | import torch.nn.functional as F 6 | 7 | 8 | class MRConv2d(nn.Module): 9 | """ 10 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type 11 | """ 12 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 13 | super(MRConv2d, self).__init__() 14 | self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias) 15 | 16 | def forward(self, x, edge_index): 17 | x_i = batched_index_select(x, edge_index[1]) 18 | x_j = batched_index_select(x, edge_index[0]) 19 | x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) 20 | return self.nn(torch.cat([x, x_j], dim=1)) 21 | 22 | 23 | class EdgeConv2d(nn.Module): 24 | """ 25 | Edge convolution layer (with activation, batch normalization) for dense data type 26 | """ 27 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 28 | super(EdgeConv2d, self).__init__() 29 | self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias) 30 | 31 | def forward(self, x, edge_index): 32 | x_i = batched_index_select(x, edge_index[1]) 33 | x_j = batched_index_select(x, edge_index[0]) 34 | max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) 35 | return max_value 36 | 37 | 38 | class GraphConv2d(nn.Module): 39 | """ 40 | Static graph convolution layer 41 | """ 42 | def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True): 43 | super(GraphConv2d, self).__init__() 44 | if conv == 'edge': 45 | self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias) 46 | elif conv == 'mr': 47 | self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias) 48 | else: 49 | raise NotImplementedError('conv:{} is not supported'.format(conv)) 50 | 51 | def forward(self, x, edge_index): 52 | return self.gconv(x, edge_index) 53 | 54 | 55 | class DynConv2d(GraphConv2d): 56 | """ 57 | Dynamic graph convolution layer 58 | """ 59 | def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', 60 | norm=None, bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 61 | super(DynConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias) 62 | self.k = kernel_size 63 | self.d = dilation 64 | if knn == 'matrix': 65 | self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) 66 | else: 67 | self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) 68 | 69 | def forward(self, x): 70 | edge_index = self.dilated_knn_graph(x) 71 | return super(DynConv2d, self).forward(x, edge_index) 72 | 73 | 74 | class PlainDynBlock2d(nn.Module): 75 | """ 76 | Plain Dynamic graph convolution block 77 | """ 78 | def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 79 | bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 80 | super(PlainDynBlock2d, self).__init__() 81 | self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, 82 | act, norm, bias, stochastic, epsilon, knn) 83 | 84 | def forward(self, x): 85 | return self.body(x) 86 | 87 | 88 | class ResDynBlock2d(nn.Module): 89 | """ 90 | Residual Dynamic graph convolution block 91 | """ 92 | def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 93 | bias=True, stochastic=False, epsilon=0.0, knn='matrix', res_scale=1): 94 | super(ResDynBlock2d, self).__init__() 95 | self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, 96 | act, norm, bias, stochastic, epsilon, knn) 97 | self.res_scale = res_scale 98 | 99 | def forward(self, x): 100 | return self.body(x) + x*self.res_scale 101 | 102 | 103 | class DenseDynBlock2d(nn.Module): 104 | """ 105 | Dense Dynamic graph convolution block 106 | """ 107 | def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', 108 | act='relu', norm=None,bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 109 | super(DenseDynBlock2d, self).__init__() 110 | self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv, 111 | act, norm, bias, stochastic, epsilon, knn) 112 | 113 | def forward(self, x): 114 | dense = self.body(x) 115 | return torch.cat((x, dense), 1) 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep AUC Maximization on Graph Property Prediction 2 | This repo contains code submission for OGB challenge. Here, we focus on [**ogbg-molhiv**](https://ogb.stanford.edu/docs/leader_graphprop/), which is a binary classification task to predict target molecular property, e.g, whether a molecule inhibits HIV virus replication or not. The evaluation metric is **AUROC**. To our best knowledge, this is the first solution to directly optimize AUC score in this task. Our [**AUC-Margin loss**](https://arxiv.org/abs/2012.03173) improves baseline (DeepGCN) to **0.8159** and achieves SOTA performance **0.8352** when jointly training with Neural FingerPrints. Our approaches are implemented in **[LibAUC](https://github.com/Optimization-AI/LibAUC)**, which is a ML library for AUC optimization. 3 | 4 | ## Results on ogbg-molhiv 5 | **Our method ranks 1st place as of 10/11/2021 on the leaderboard!** We present our results on the ogbg-molhiv dataset with some strong baselines as below: 6 | 7 | | Method |Test AUROC |Validation AUROC | Parameters | Hardware | 8 | | ------------------ |------------------- | ----------------- | -------------- |----------| 9 | | DeepGCN | 0.7858±0.0117 | 0.8427±0.0063 | 531,976 | Tesla V100 (32GB) | 10 | | DeeperGCN+FLAG | 0.7942±0.0120 | 0.8425±0.0061 | 531,976 | Tesla V100 (32GB) | 11 | | Neural FingerPrints| 0.8232±0.0047 | 0.8331±0.0054 | 2,425,102 | Tesla V100 (32GB) | 12 | | Graphormer | 0.8051±0.0053 | 0.8310±0.0089 | 47,183,040 | Tesla V100 (16GB) | 13 | | **DeepAUC (Ours)** | **0.8159±0.0059** | 0.8054±0.0080 | 1,019,407 | Tesla V100 (32GB) | 14 | | **DeepAUC+FPs (Ours)** | **0.8352±0.0054** | 0.8238±0.0061 | 1,019,407** | Tesla V100 (32GB) | 15 | 16 | - Note that this number** doesn't count the parameters of Random Forest model. 17 | 18 | 19 | 20 | ## Requirements 21 | 1. Install base packages: 22 | ```bash 23 | Python>=3.7 24 | Pytorch>=1.9.0 25 | tensorflow>=2.0.0 26 | pytorch_geometric>=1.6.0 27 | ogb>=1.3.2 28 | dgl>=0.5.3 29 | numpy==1.20.3 30 | pandas==1.2.5 31 | scikit-learn==0.24.2 32 | deep_gcns_torch 33 | ``` 34 | 2. Install [**LibAUC**](https://github.com/Optimization-AI/LibAUC) (using **AUC-Margin** loss and **PESG** optimizer): 35 | ```bash 36 | pip install LibAUC 37 | ``` 38 | 39 | ## Training 40 | The training process has two steps: 1) we train a DeepGCN model using our **[AUC-margin loss](https://arxiv.org/abs/2012.03173)** from scratch. 2) we jointly finetuning the pretrained model from (1) with FingerPrints models. 41 | ### Training from scratch using AUC-margin loss: 42 | - Train [DeepGCN](https://github.com/lightaime/deep_gcns_torch) model with AUC-Margin loss and PESG optimizer by default parameters 43 | ``` 44 | python main.py --use_gpu --conv_encode_edge --num_layers 14 --block res+ --gcn_aggr softmax --t 1.0 --learn_t --dropout 0.2 \ 45 | --dataset ogbg-molhiv \ 46 | --loss auroc \ 47 | --optimizer pesg \ 48 | --batch_size 512 \ 49 | --lr 0.1 \ 50 | --gamma 500 \ 51 | --margin 1.0 \ 52 | --weight_decay 1e-5 \ 53 | --random_seed 0 \ 54 | --epochs 300 55 | ``` 56 | 57 | ### Jointly traininig with FingerPrints Model 58 | - Extract fingerprints and train Random Forest by following [PaddleHelix](https://github.com/PaddlePaddle/PaddleHelix/tree/dev/competition/ogbg_molhiv) 59 | ``` 60 | python extract_fingerprint.py 61 | python random_forest.py 62 | ``` 63 | - Finetuning pretrained model with FingerPrints model using **[AUC-margin loss](https://arxiv.org/abs/2012.03173)** by default parameters 64 | ``` 65 | python finetune.py --use_gpu --conv_encode_edge --num_layers 14 --block res+ --gcn_aggr softmax --t 1.0 --learn_t --dropout 0.2 \ 66 | --dataset ogbg-molhiv \ 67 | --loss auroc \ 68 | --optimizer pesg \ 69 | --batch_size 512 \ 70 | --lr 0.01 \ 71 | --gamma 300 \ 72 | --margin 1.0 \ 73 | --weight_decay 1e-5 \ 74 | --random_seed 0 \ 75 | --epochs 100 76 | ``` 77 | 78 | ## Results 79 | The results (1) improves the original baseline (DeepGCN) to **0.8159**, which is ~**3%** improvement. The result (2) achieves a higher SOTA performance **0.8352**, which is ~**1%** improvement over previous baselines. For each stage, we train model by 10 times using different random seeds, e.g., 0 to 9. 80 | 81 | 82 | Citation 83 | --------- 84 | If you have any questions, please open an new issue in this repo or contact us @ [Zhuoning Yuan](https://homepage.divms.uiowa.edu/~zhuoning/) [yzhuoning@gmail.com]. If you find this work useful, please cite the following paper for our method and library: 85 | ``` 86 | @inproceedings{yuan2021robust, 87 | title={Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification}, 88 | author={Yuan, Zhuoning and Yan, Yan and Sonka, Milan and Yang, Tianbao}, 89 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 90 | year={2021} 91 | } 92 | ``` 93 | 94 | Reference 95 | --------- 96 | - https://libauc.org/ 97 | - https://github.com/Optimization-AI/LibAUC 98 | - https://github.com/PaddlePaddle/PaddleHelix/tree/dev/competition/ogbg_molhiv 99 | - https://github.com/lightaime/deep_gcns_torch/ 100 | - https://ogb.stanford.edu/ 101 | 102 | -------------------------------------------------------------------------------- /utils/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | from collections import OrderedDict 5 | import logging 6 | import numpy as np 7 | 8 | 9 | def save_ckpt(model, optimizer, loss, epoch, save_path, name_pre, name_post='best'): 10 | model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} 11 | state = { 12 | 'epoch': epoch, 13 | 'model_state_dict': model_cpu, 14 | 'optimizer_state_dict': optimizer.state_dict(), 15 | 'loss': loss 16 | } 17 | 18 | if not os.path.exists(save_path): 19 | os.mkdir(save_path) 20 | # print("Directory ", save_path, " is created.") 21 | 22 | filename = '{}/{}_{}.pth'.format(save_path, name_pre, name_post) 23 | torch.save(state, filename) 24 | # print('model has been saved as {}'.format(filename)) 25 | 26 | 27 | def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax means max best 28 | if ismax: 29 | best_value = -np.inf 30 | else: 31 | best_value = np.inf 32 | epoch = -1 33 | 34 | if pretrained_model: 35 | if os.path.isfile(pretrained_model): 36 | logging.info("===> Loading checkpoint '{}'".format(pretrained_model)) 37 | checkpoint = torch.load(pretrained_model) 38 | try: 39 | best_value = checkpoint['best_value'] 40 | if best_value == -np.inf or best_value == np.inf: 41 | show_best_value = False 42 | else: 43 | show_best_value = True 44 | except: 45 | best_value = best_value 46 | show_best_value = False 47 | 48 | model_dict = model.state_dict() 49 | ckpt_model_state_dict = checkpoint['state_dict'] 50 | 51 | # rename ckpt (avoid name is not same because of multi-gpus) 52 | is_model_multi_gpus = True if list(model_dict)[0][0][0] == 'm' else False 53 | is_ckpt_multi_gpus = True if list(ckpt_model_state_dict)[0][0] == 'm' else False 54 | 55 | if not (is_model_multi_gpus == is_ckpt_multi_gpus): 56 | temp_dict = OrderedDict() 57 | for k, v in ckpt_model_state_dict.items(): 58 | if is_ckpt_multi_gpus: 59 | name = k[7:] # remove 'module.' 60 | else: 61 | name = 'module.'+k # add 'module' 62 | temp_dict[name] = v 63 | # load params 64 | ckpt_model_state_dict = temp_dict 65 | 66 | model_dict.update(ckpt_model_state_dict) 67 | model.load_state_dict(ckpt_model_state_dict) 68 | 69 | if show_best_value: 70 | logging.info("The pretrained_model is at checkpoint {}. \t " 71 | "Best value: {}".format(checkpoint['epoch'], best_value)) 72 | else: 73 | logging.info("The pretrained_model is at checkpoint {}.".format(checkpoint['epoch'])) 74 | 75 | if phase == 'train': 76 | epoch = checkpoint['epoch'] 77 | else: 78 | epoch = -1 79 | else: 80 | raise ImportError("===> No checkpoint found at '{}'".format(pretrained_model)) 81 | else: 82 | logging.info('===> No pre-trained model') 83 | return model, best_value, epoch 84 | 85 | 86 | def load_pretrained_optimizer(pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True): 87 | if pretrained_model: 88 | if os.path.isfile(pretrained_model): 89 | checkpoint = torch.load(pretrained_model) 90 | if 'optimizer_state_dict' in checkpoint.keys(): 91 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 92 | for state in optimizer.state.values(): 93 | for k, v in state.items(): 94 | if torch.is_tensor(v): 95 | state[k] = v.cuda() 96 | if 'scheduler_state_dict' in checkpoint.keys(): 97 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 98 | if use_ckpt_lr: 99 | try: 100 | lr = scheduler.get_lr()[0] 101 | except: 102 | lr = lr 103 | 104 | return optimizer, scheduler, lr 105 | 106 | 107 | def save_checkpoint(state, is_best, save_path, postname): 108 | filename = '{}/{}_{}.pth'.format(save_path, postname, int(state['epoch'])) 109 | torch.save(state, filename) 110 | if is_best: 111 | shutil.copyfile(filename, '{}/{}_best.pth'.format(save_path, postname)) 112 | 113 | 114 | def change_ckpt_dict(model, optimizer, scheduler, opt): 115 | 116 | for _ in range(opt.epoch): 117 | scheduler.step() 118 | is_best = (opt.test_value < opt.best_value) 119 | opt.best_value = min(opt.test_value, opt.best_value) 120 | 121 | model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} 122 | # optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()} 123 | save_checkpoint({ 124 | 'epoch': opt.epoch, 125 | 'state_dict': model_cpu, 126 | 'optimizer_state_dict': optimizer.state_dict(), 127 | 'scheduler_state_dict': scheduler.state_dict(), 128 | 'best_value': opt.best_value, 129 | }, is_best, opt.save_path, opt.post) 130 | 131 | -------------------------------------------------------------------------------- /random_forest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import random 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | from ogb.graphproppred import GraphPropPredDataset 10 | from sklearn.ensemble import RandomForestClassifier 11 | from sklearn.metrics import average_precision_score, roc_auc_score 12 | 13 | def seed(seed=0): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | 17 | def main(args): 18 | all_probs = {} 19 | all_ap = {} 20 | all_rocs = {} 21 | train_label_props = {} 22 | 23 | n_estimators = 1000 24 | max_tasks = None 25 | run_times = 10 26 | 27 | eval_scores = [] 28 | test_scores = [] 29 | 30 | mgf_file = "./dataset/%s/mgf_feat.npy" % (args.dataset_name.replace("-", "_")) 31 | maccs_file = "./dataset/%s/maccs_feat.npy" % (args.dataset_name.replace("-", "_")) 32 | mgf_feat = np.load(mgf_file) 33 | maccs_feat = np.load(maccs_file) 34 | 35 | dataset = GraphPropPredDataset(name=args.dataset_name, root="./dataset/") 36 | smiles_file = "./dataset/%s/mapping/mol.csv.gz" % (args.dataset_name.replace("-", "_")) 37 | df_smi = pd.read_csv(smiles_file) 38 | smiles = df_smi["smiles"] 39 | outcomes = df_smi.set_index("smiles").drop(["mol_id"], axis=1) 40 | 41 | feat = np.concatenate([mgf_feat, maccs_feat], axis=1) 42 | print("features size:", feat.shape[1]) 43 | 44 | X = pd.DataFrame(feat, 45 | index=smiles, 46 | columns=[i for i in range(feat.shape[1])]) 47 | 48 | # Split into train/val/test 49 | split_idx = dataset.get_idx_split() 50 | train_idx, val_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 51 | 52 | X_train, X_val, X_test = X.iloc[train_idx], X.iloc[val_idx], X.iloc[test_idx] 53 | 54 | for rep in range(run_times): 55 | for oo in tqdm(outcomes.columns[:max_tasks]): 56 | # Get probabilities 57 | val_key = args.dataset_name, oo, rep, "val" 58 | test_key = args.dataset_name, oo, rep, "test" 59 | 60 | # If re-running, skip finished runs 61 | if val_key in all_probs: 62 | print("Skipping", val_key[:-1]) 63 | continue 64 | 65 | # Split outcome in to train/val/test 66 | Y = outcomes[oo] 67 | y_train, y_val, y_test = Y.loc[X_train.index], Y.loc[X_val.index], Y.loc[X_test.index] 68 | 69 | # Skip outcomes with no positive training examples 70 | if y_train.sum() == 0: 71 | continue 72 | 73 | # Remove missing labels in validation 74 | y_val, y_test = y_val.dropna(), y_test.dropna() 75 | X_v, X_t = X_val.loc[y_val.index], X_test.loc[y_test.index] 76 | 77 | # Remove missing values in the training labels, and downsample imbalance to cut runtime 78 | y_tr = y_train.dropna() 79 | train_label_props[args.dataset_name, oo, rep] = y_tr.mean() 80 | print(f"Sampled label balance:\n{y_tr.value_counts()}") 81 | 82 | # Fit model 83 | print("Fitting model...") 84 | rf = RandomForestClassifier(min_samples_leaf=2, 85 | n_estimators=n_estimators, 86 | n_jobs=-1, 87 | criterion='entropy', 88 | class_weight={0:1, 1:10}, 89 | random_state=rep 90 | ) 91 | rf.fit(X_train.loc[y_tr.index], y_tr) 92 | 93 | # Calculate probabilities 94 | all_probs[val_key] = pd.Series(rf.predict_proba(X_v)[:, 1], index=X_v.index) 95 | all_probs[test_key] = pd.Series(rf.predict_proba(X_t)[:, 1], index=X_t.index) 96 | 97 | if y_val.sum() > 0: 98 | all_ap[val_key] = average_precision_score(y_val, all_probs[val_key]) 99 | all_rocs[val_key] = roc_auc_score(y_val, all_probs[val_key]) 100 | 101 | if y_test.sum() > 0: 102 | all_ap[test_key] = average_precision_score(y_test, all_probs[test_key]) 103 | all_rocs[test_key] = roc_auc_score(y_test, all_probs[test_key]) 104 | 105 | print(f'{oo}, rep {rep}, AP (val, test): {all_ap.get(val_key, np.nan):.3f}, {all_ap.get(test_key, np.nan):.3f}') 106 | print(f'\tROC (val, test): {all_rocs.get(val_key, np.nan):.3f}, {all_rocs.get(test_key, np.nan):.3f}') 107 | eval_scores.append(all_rocs.get(val_key, np.nan)) 108 | test_scores.append(all_rocs.get(test_key, np.nan)) 109 | 110 | # save pred 111 | all_prob = np.array(rf.predict_proba(X)) 112 | print(all_prob.shape) 113 | os.makedirs('rf_preds', exist_ok=True) 114 | np.save("./rf_preds/rf_pred_auc_{:0.4f}_{:0.4f}_RS_{:d}.npy".format(all_rocs[val_key], all_rocs[test_key], rep), all_prob) 115 | 116 | # final_index = np.argmax(eval_scores) 117 | # best_val_score = eval_scores[final_index] 118 | # final_test_score = test_scores[final_index] 119 | # shutil.copy2("./rf_preds/rf_pred_auc_{:0.4f}_{:0.4f}.npy".format(best_val_score, final_test_score), './rf_preds/rf_final_pred.npy') 120 | # print("Best preds saved in ./rf_preds/rf_final_pred.npy") 121 | 122 | if __name__=="__main__": 123 | parser = argparse.ArgumentParser(description='gnn') 124 | parser.add_argument("--dataset_name", type=str, default="ogbg-molhiv") 125 | args = parser.parse_args() 126 | main(args) -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import uuid 3 | import logging 4 | import time 5 | import os 6 | import sys 7 | from utils.logger import create_exp_dir 8 | import glob 9 | 10 | def boolean_string(s): 11 | if s not in {'False', 'True'}: 12 | raise ValueError('Not a valid boolean string') 13 | return s == 'True' 14 | 15 | class ArgsInit(object): 16 | def __init__(self): 17 | parser = argparse.ArgumentParser(description='DeeperGCN') 18 | # dataset 19 | parser.add_argument('--dataset', type=str, default="ogbg-molhiv", 20 | help='dataset name (default: ogbg-molhiv)') 21 | parser.add_argument('--num_workers', type=int, default=8, 22 | help='number of workers (default: 0)') 23 | parser.add_argument('--batch_size', type=int, default=32, 24 | help='input batch size for training (default: 32)') 25 | parser.add_argument('--feature', type=str, default='full', 26 | help='two options: full or simple') 27 | parser.add_argument('--add_virtual_node', action='store_true') 28 | # training & eval settings 29 | parser.add_argument('--use_gpu', action='store_true') 30 | parser.add_argument('--device', type=int, default=0, 31 | help='which gpu to use if any (default: 0)') 32 | parser.add_argument('--epochs', type=int, default=300, 33 | help='number of epochs to train (default: 300)') 34 | parser.add_argument('--lr', type=float, default=0.01, 35 | help='learning rate set for optimizer.') 36 | parser.add_argument('--dropout', type=float, default=0.5) 37 | parser.add_argument('--grad_clip', type=float, default=0., 38 | help='gradient clipping') 39 | 40 | # model 41 | parser.add_argument('--num_layers', type=int, default=3, 42 | help='the number of layers of the networks') 43 | parser.add_argument('--mlp_layers', type=int, default=1, 44 | help='the number of layers of mlp in conv') 45 | parser.add_argument('--hidden_channels', type=int, default=256, 46 | help='the dimension of embeddings of nodes and edges') 47 | parser.add_argument('--block', default='res+', type=str, 48 | help='graph backbone block type {res+, res, dense, plain}') 49 | parser.add_argument('--conv', type=str, default='gen', 50 | help='the type of GCNs') 51 | parser.add_argument('--gcn_aggr', type=str, default='max', 52 | help='the aggregator of GENConv [mean, max, add, softmax, softmax_sg, power]') 53 | parser.add_argument('--norm', type=str, default='batch', 54 | help='the type of normalization layer') 55 | parser.add_argument('--num_tasks', type=int, default=1, 56 | help='the number of prediction tasks') 57 | # learnable parameters 58 | parser.add_argument('--t', type=float, default=1.0, 59 | help='the temperature of SoftMax') 60 | parser.add_argument('--p', type=float, default=1.0, 61 | help='the power of PowerMean') 62 | parser.add_argument('--learn_t', action='store_true') 63 | parser.add_argument('--learn_p', action='store_true') 64 | parser.add_argument('--y', type=float, default=0.0, 65 | help='the power of softmax_sum and powermean_sum') 66 | parser.add_argument('--learn_y', action='store_true') 67 | 68 | 69 | ''' 70 | Args for Deep AUC Maximization 71 | ''' 72 | parser.add_argument('--configs', type=str, default='', help='') 73 | parser.add_argument('--model_name', type=str, default='deepgcn', help='') 74 | parser.add_argument('--loss', type=str, default='auroc', help='') 75 | parser.add_argument('--optimizer', type=str, default='pesg', help='') 76 | parser.add_argument('--weight_decay', type=float, default=1e-4) 77 | parser.add_argument('--random_seed', type=int, default=0) # try different seeds 78 | parser.add_argument('--pretrained', type=boolean_string, default=False) 79 | parser.add_argument('--activations', type=str, default='relu') 80 | 81 | # AUC-margin loss 82 | parser.add_argument('--gamma', type=float, default=500) 83 | parser.add_argument('--margin', type=float, default=1.0) 84 | parser.add_argument('--imratio', type=float, default=0.01) 85 | 86 | 87 | # message norm 88 | parser.add_argument('--msg_norm', action='store_true') 89 | parser.add_argument('--learn_msg_scale', action='store_true') 90 | # encode edge in conv 91 | parser.add_argument('--conv_encode_edge', action='store_true') 92 | # graph pooling type 93 | parser.add_argument('--graph_pooling', type=str, default='mean', 94 | help='graph pooling method') 95 | # save model 96 | parser.add_argument('--model_save_path', type=str, default='model_ckpt', 97 | help='the directory used to save models') 98 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 99 | parser.add_argument('--save_dir', type=str, default='', help='experiment name') 100 | # load pre-trained model 101 | parser.add_argument('--model_load_path', type=str, default='ogbg_molhiv_pretrained_model.pth', 102 | help='the path of pre-trained model') 103 | 104 | # for colab 105 | parser.add_argument('-f', type=str, default='kernel') 106 | self.args = parser.parse_args() 107 | 108 | def save_exp(self): 109 | self.args.save = '{}-B_{}-C_{}-L_{}-F_{}-DP_{}' \ 110 | '-GA_{}-T_{}-LT_{}-P_{}-LP_{}-Y_{}-LY_{}' \ 111 | '-MN_{}-LS_{}-RS_{}'.format(self.args.save, self.args.block, self.args.conv, 112 | self.args.num_layers, self.args.hidden_channels, 113 | self.args.dropout, self.args.gcn_aggr, 114 | self.args.t, self.args.learn_t, self.args.p, self.args.learn_p, 115 | self.args.y, self.args.learn_y, 116 | self.args.msg_norm, self.args.learn_msg_scale, 117 | self.args.random_seed) 118 | 119 | self.args.save_dir = './saved_models/{}'.format(self.args.save) #, time.strftime("%Y%m%d-%H%M%S")) #, str(uuid.uuid4())) 120 | self.args.model_save_path = os.path.join(self.args.save_dir, self.args.model_save_path) 121 | 122 | 123 | create_exp_dir(self.args.save_dir) #, scripts_to_save=glob.glob('*.py')) 124 | 125 | log_format = '%(asctime)s %(message)s' 126 | logging.basicConfig(stream=sys.stdout, 127 | level=logging.INFO, 128 | format=log_format, 129 | datefmt='%m/%d %I:%M:%S %p') 130 | fh = logging.FileHandler(os.path.join(self.args.save_dir, 'log.txt')) 131 | fh.setFormatter(logging.Formatter(log_format)) 132 | logging.getLogger().addHandler(fh) 133 | 134 | return self.args 135 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | from model import DeeperGCN 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | from args import ArgsInit 9 | from utils.ckpt_util import save_ckpt 10 | import logging 11 | import time, os 12 | import statistics 13 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 14 | import torch.nn.functional as F 15 | from utils.logger import create_exp_dir 16 | 17 | # for AUC margin loss 18 | from libauc.losses import AUCMLoss 19 | from libauc.optimizers import PESG 20 | 21 | def set_all_seeds(SEED): 22 | # REPRODUCIBILITY 23 | torch.manual_seed(SEED) 24 | np.random.seed(SEED) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | args = ArgsInit().save_exp() 29 | 30 | 31 | def train(model, device, loader, optimizer, task_type, grad_clip=0.): 32 | loss_list = [] 33 | model.train() 34 | 35 | for step, batch in enumerate(loader): 36 | batch = batch.to(device) 37 | 38 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 39 | pass 40 | else: 41 | optimizer.zero_grad() 42 | pred = model(batch) 43 | pred = torch.sigmoid(pred) 44 | is_labeled = batch.y == batch.y 45 | loss = aucm_criterion(pred.to(torch.float32)[is_labeled].reshape(-1, 1), batch.y.to(torch.float32)[is_labeled].reshape(-1, 1)) 46 | loss.backward() 47 | 48 | if grad_clip > 0: 49 | torch.nn.utils.clip_grad_value_( 50 | model.parameters(), 51 | grad_clip) 52 | 53 | optimizer.step() 54 | 55 | loss_list.append(loss.item()) 56 | return statistics.mean(loss_list) 57 | 58 | 59 | @torch.no_grad() 60 | def eval(model, device, loader, evaluator): 61 | model.eval() 62 | y_true = [] 63 | y_pred = [] 64 | 65 | for step, batch in enumerate(loader): 66 | batch = batch.cuda().to(device) 67 | 68 | if batch.x.shape[0] == 1: 69 | pass 70 | else: 71 | pred = model(batch) 72 | pred = torch.sigmoid(pred) 73 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 74 | y_pred.append(pred.detach().cpu()) 75 | 76 | y_true = torch.cat(y_true, dim=0).numpy() 77 | y_pred = torch.cat(y_pred, dim=0).numpy() 78 | 79 | input_dict = {"y_true": y_true, 80 | "y_pred": y_pred} 81 | 82 | return evaluator.eval(input_dict) 83 | 84 | 85 | def main(): 86 | 87 | #args = ArgsInit().save_exp() 88 | if args.use_gpu: 89 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 90 | else: 91 | device = torch.device('cpu') 92 | 93 | sub_dir = 'BS_{}-NF_{}'.format(args.batch_size, args.feature) 94 | set_all_seeds(args.random_seed) 95 | dataset = PygGraphPropPredDataset(name=args.dataset) 96 | args.num_tasks = dataset.num_tasks 97 | #logging.info('%s' % args) 98 | 99 | if args.feature == 'full': 100 | pass 101 | elif args.feature == 'simple': 102 | print('using simple feature') 103 | # only retain the top two node/edge features 104 | dataset.data.x = dataset.data.x[:, :2] 105 | dataset.data.edge_attr = dataset.data.edge_attr[:, :2] 106 | 107 | evaluator = Evaluator(args.dataset) 108 | split_idx = dataset.get_idx_split() 109 | 110 | set_all_seeds(args.random_seed) 111 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, 112 | num_workers=args.num_workers) 113 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, 114 | num_workers=args.num_workers) 115 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, 116 | num_workers=args.num_workers) 117 | 118 | 119 | set_all_seeds(args.random_seed) 120 | model = DeeperGCN(args).to(device) 121 | 122 | num_params = sum(p.numel() for p in model.parameters()) 123 | print(f'#Params: {num_params}') 124 | 125 | optimizer = PESG(model, 126 | a=aucm_criterion.a, 127 | b=aucm_criterion.b, 128 | alpha=aucm_criterion.alpha, 129 | lr=args.lr, 130 | gamma=args.gamma, 131 | margin=args.margin, 132 | weight_decay=args.weight_decay) 133 | 134 | 135 | # get imbalance ratio from train set 136 | args.imratio = float((train_loader.dataset.data.y.sum()/train_loader.dataset.data.y.shape[0]).numpy()) 137 | aucm_criterion.p = args.imratio 138 | print (aucm_criterion.p) 139 | 140 | # save 141 | datetime_now = '2021-10-09' 142 | pretrained_prefix = 'pre_' if args.pretrained else '' 143 | virtual_node_prefilx = '-vt' if args.add_virtual_node else '' 144 | args.configs = '[%s]Train_%s_im_%.4f_rd_%s_%s%s_%s_%s_wd_%s_lr_%s_B_%s_E_%s_%s_%s_g_%s_m_%s'%(datetime_now, args.dataset, args.imratio, args.random_seed, pretrained_prefix, args.model_name, virtual_node_prefilx, args.activations, args.weight_decay, args.lr, args.batch_size, args.epochs, args.loss, args.optimizer, args.gamma, args.margin) 145 | logging.info(args.save) 146 | logging.info(args.configs) 147 | 148 | results = {'highest_valid': 0, 149 | 'final_train': 0, 150 | 'final_test': 0, 151 | 'highest_train': 0} 152 | 153 | start_time = time.time() 154 | start_time_local = time.time() 155 | for epoch in range(1, args.epochs + 1): 156 | 157 | if epoch in [int(args.epochs*0.33), int(args.epochs*0.66)] and args.loss!= 'ce': 158 | optimizer.update_regularizer(decay_factor=10) 159 | 160 | epoch_loss = train(model, device, train_loader, optimizer, dataset.task_type, grad_clip=args.grad_clip) 161 | 162 | #logging.info('Evaluating...') 163 | train_result = eval(model, device, train_loader, evaluator)[dataset.eval_metric] 164 | valid_result = eval(model, device, valid_loader, evaluator)[dataset.eval_metric] 165 | test_result = eval(model, device, test_loader, evaluator)[dataset.eval_metric] 166 | 167 | logging.info("Epoch:%s, train_auc:%.4f, valid_auc:%.4f, test_auc:%.4f, lr:%.4f, time:%.4f"%(epoch, train_result, valid_result, test_result, optimizer.lr, time.time()-start_time_local)) 168 | start_time_local = time.time() 169 | # model.print_params(epoch=epoch) 170 | 171 | if train_result > results['highest_train']: 172 | results['highest_train'] = train_result 173 | 174 | if valid_result > results['highest_valid'] and epoch > 200: 175 | results['highest_valid'] = valid_result 176 | results['final_train'] = train_result 177 | results['final_test'] = test_result 178 | 179 | save_ckpt(model, optimizer, 180 | round(epoch_loss, 4), epoch, 181 | args.model_save_path, 182 | sub_dir, name_post='valid_best_AUC_E_%s_R%s'%(epoch, args.random_seed)) 183 | 184 | logging.info("%s" % results) 185 | 186 | end_time = time.time() 187 | total_time = end_time - start_time 188 | logging.info('Total time: {}'.format(time.strftime('%H:%M:%S', time.gmtime(total_time)))) 189 | 190 | 191 | if __name__ == "__main__": 192 | cls_criterion = torch.nn.BCEWithLogitsLoss() 193 | reg_criterion = torch.nn.MSELoss() 194 | # https://github.com/Optimization-AI/LibAUC 195 | aucm_criterion = AUCMLoss() 196 | main() 197 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import __init__ 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool 5 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder 6 | from gcn_lib.sparse.torch_vertex import GENConv 7 | from gcn_lib.sparse.torch_nn import norm_layer, MLP 8 | import logging 9 | 10 | 11 | class DeeperGCN(torch.nn.Module): 12 | def __init__(self, args): 13 | super(DeeperGCN, self).__init__() 14 | 15 | self.num_layers = args.num_layers 16 | self.dropout = args.dropout 17 | self.block = args.block 18 | self.conv_encode_edge = args.conv_encode_edge 19 | self.add_virtual_node = args.add_virtual_node 20 | 21 | hidden_channels = args.hidden_channels 22 | num_tasks = args.num_tasks 23 | conv = args.conv 24 | aggr = args.gcn_aggr 25 | t = args.t 26 | self.learn_t = args.learn_t 27 | p = args.p 28 | self.learn_p = args.learn_p 29 | y = args.y 30 | self.learn_y = args.learn_y 31 | 32 | self.msg_norm = args.msg_norm 33 | learn_msg_scale = args.learn_msg_scale 34 | self.activation_func = F.relu if args.activations=='relu' else F.elu 35 | 36 | norm = args.norm 37 | mlp_layers = args.mlp_layers 38 | 39 | graph_pooling = args.graph_pooling 40 | 41 | print('The number of layers {}'.format(self.num_layers), 42 | 'Aggr aggregation method {}'.format(aggr), 43 | 'block: {}'.format(self.block)) 44 | if self.block == 'res+': 45 | print('LN/BN->ReLU->GraphConv->Res') 46 | elif self.block == 'res': 47 | print('GraphConv->LN/BN->ReLU->Res') 48 | elif self.block == 'dense': 49 | raise NotImplementedError('To be implemented') 50 | elif self.block == "plain": 51 | print('GraphConv->LN/BN->ReLU') 52 | else: 53 | raise Exception('Unknown block Type') 54 | 55 | self.gcns = torch.nn.ModuleList() 56 | self.norms = torch.nn.ModuleList() 57 | 58 | if self.add_virtual_node: 59 | self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels) 60 | torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) 61 | 62 | self.mlp_virtualnode_list = torch.nn.ModuleList() 63 | 64 | for layer in range(self.num_layers - 1): 65 | self.mlp_virtualnode_list.append(MLP([hidden_channels]*3, 66 | norm=norm)) 67 | 68 | for layer in range(self.num_layers): 69 | if conv == 'gen': 70 | gcn = GENConv(hidden_channels, hidden_channels, 71 | aggr=aggr, 72 | t=t, learn_t=self.learn_t, 73 | p=p, learn_p=self.learn_p, 74 | y=y, learn_y=self.learn_p, 75 | msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, 76 | encode_edge=self.conv_encode_edge, bond_encoder=True, 77 | norm=norm, mlp_layers=mlp_layers) 78 | else: 79 | raise Exception('Unknown Conv Type') 80 | self.gcns.append(gcn) 81 | self.norms.append(norm_layer(norm, hidden_channels)) 82 | 83 | self.atom_encoder = AtomEncoder(emb_dim=hidden_channels) 84 | 85 | if not self.conv_encode_edge: 86 | self.bond_encoder = BondEncoder(emb_dim=hidden_channels) 87 | 88 | if graph_pooling == "sum": 89 | self.pool = global_add_pool 90 | elif graph_pooling == "mean": 91 | self.pool = global_mean_pool 92 | elif graph_pooling == "max": 93 | self.pool = global_max_pool 94 | else: 95 | raise Exception('Unknown Pool Type') 96 | 97 | self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks) 98 | 99 | def forward(self, input_batch): 100 | 101 | x = input_batch.x 102 | edge_index = input_batch.edge_index 103 | edge_attr = input_batch.edge_attr 104 | batch = input_batch.batch 105 | 106 | h = self.atom_encoder(x) 107 | 108 | if self.add_virtual_node: 109 | virtualnode_embedding = self.virtualnode_embedding( 110 | torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) 111 | h = h + virtualnode_embedding[batch] 112 | 113 | if self.conv_encode_edge: 114 | edge_emb = edge_attr 115 | else: 116 | edge_emb = self.bond_encoder(edge_attr) 117 | 118 | if self.block == 'res+': 119 | 120 | h = self.gcns[0](h, edge_index, edge_emb) 121 | 122 | for layer in range(1, self.num_layers): 123 | h1 = self.norms[layer - 1](h) 124 | h2 = self.activation_func(h1) 125 | h2 = F.dropout(h2, p=self.dropout, training=self.training) 126 | 127 | if self.add_virtual_node: 128 | virtualnode_embedding_temp = global_add_pool(h2, batch) + virtualnode_embedding 129 | virtualnode_embedding = F.dropout( 130 | self.mlp_virtualnode_list[layer-1](virtualnode_embedding_temp), 131 | self.dropout, training=self.training) 132 | 133 | h2 = h2 + virtualnode_embedding[batch] 134 | 135 | h = self.gcns[layer](h2, edge_index, edge_emb) + h 136 | 137 | h = self.norms[self.num_layers - 1](h) 138 | h = F.dropout(h, p=self.dropout, training=self.training) 139 | 140 | elif self.block == 'res': 141 | 142 | h = self.activation_func(self.norms[0](self.gcns[0](h, edge_index, edge_emb))) 143 | h = F.dropout(h, p=self.dropout, training=self.training) 144 | 145 | for layer in range(1, self.num_layers): 146 | h1 = self.gcns[layer](h, edge_index, edge_emb) 147 | h2 = self.norms[layer](h1) 148 | h = self.activation_func(h2) + h 149 | h = F.dropout(h, p=self.dropout, training=self.training) 150 | 151 | elif self.block == 'dense': 152 | raise NotImplementedError('To be implemented') 153 | 154 | elif self.block == 'plain': 155 | 156 | h = self.activation_func(self.norms[0](self.gcns[0](h, edge_index, edge_emb))) 157 | h = F.dropout(h, p=self.dropout, training=self.training) 158 | 159 | for layer in range(1, self.num_layers): 160 | h1 = self.gcns[layer](h, edge_index, edge_emb) 161 | h2 = self.norms[layer](h1) 162 | if layer != (self.num_layers - 1): 163 | h = self.activation_func(h2) 164 | else: 165 | h = h2 166 | h = F.dropout(h, p=self.dropout, training=self.training) 167 | else: 168 | raise Exception('Unknown block Type') 169 | 170 | h_graph = self.pool(h, batch) # N, 256 171 | #print (h_graph.shape) 172 | #h_graph= self.dropout_fc(h_graph) 173 | return self.graph_pred_linear(h_graph) 174 | 175 | def print_params(self, epoch=None, final=False): 176 | 177 | if self.learn_t: 178 | ts = [] 179 | for gcn in self.gcns: 180 | ts.append(gcn.t.item()) 181 | if final: 182 | print('Final t {}'.format(ts)) 183 | else: 184 | logging.info('Epoch {}, t {}'.format(epoch, ts)) 185 | if self.learn_p: 186 | ps = [] 187 | for gcn in self.gcns: 188 | ps.append(gcn.p.item()) 189 | if final: 190 | print('Final p {}'.format(ps)) 191 | else: 192 | logging.info('Epoch {}, p {}'.format(epoch, ps)) 193 | if self.msg_norm: 194 | ss = [] 195 | for gcn in self.gcns: 196 | ss.append(gcn.msg_norm.msg_scale.item()) 197 | if final: 198 | print('Final s {}'.format(ss)) 199 | else: 200 | logging.info('Epoch {}, s {}'.format(epoch, ss)) 201 | -------------------------------------------------------------------------------- /model_att.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool 4 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder 5 | from gcn_lib.sparse.torch_vertex import GENConv 6 | from gcn_lib.sparse.torch_nn import norm_layer, MLP 7 | import logging 8 | 9 | 10 | class DeeperGCN(torch.nn.Module): 11 | def __init__(self, args): 12 | super(DeeperGCN, self).__init__() 13 | 14 | self.num_layers = args.num_layers 15 | self.dropout = args.dropout 16 | self.block = args.block 17 | self.conv_encode_edge = args.conv_encode_edge 18 | self.add_virtual_node = args.add_virtual_node 19 | 20 | hidden_channels = args.hidden_channels 21 | num_tasks = args.num_tasks 22 | conv = args.conv 23 | aggr = args.gcn_aggr 24 | t = args.t 25 | self.learn_t = args.learn_t 26 | p = args.p 27 | self.learn_p = args.learn_p 28 | y = args.y 29 | self.learn_y = args.learn_y 30 | 31 | self.beta = torch.nn.Parameter(torch.Tensor([0.5]), requires_grad=True) 32 | 33 | self.msg_norm = args.msg_norm 34 | learn_msg_scale = args.learn_msg_scale 35 | self.activation_func = F.relu if args.activations=='relu' else F.elu 36 | 37 | norm = args.norm 38 | mlp_layers = args.mlp_layers 39 | 40 | graph_pooling = args.graph_pooling 41 | 42 | print('The number of layers {}'.format(self.num_layers), 43 | 'Aggr aggregation method {}'.format(aggr), 44 | 'block: {}'.format(self.block)) 45 | if self.block == 'res+': 46 | print('LN/BN->ReLU->GraphConv->Res') 47 | elif self.block == 'res': 48 | print('GraphConv->LN/BN->ReLU->Res') 49 | elif self.block == 'dense': 50 | raise NotImplementedError('To be implemented') 51 | elif self.block == "plain": 52 | print('GraphConv->LN/BN->ReLU') 53 | else: 54 | raise Exception('Unknown block Type') 55 | 56 | self.gcns = torch.nn.ModuleList() 57 | self.norms = torch.nn.ModuleList() 58 | 59 | if self.add_virtual_node: 60 | self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels) 61 | torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) 62 | 63 | self.mlp_virtualnode_list = torch.nn.ModuleList() 64 | 65 | for layer in range(self.num_layers - 1): 66 | self.mlp_virtualnode_list.append(MLP([hidden_channels]*3, 67 | norm=norm)) 68 | 69 | for layer in range(self.num_layers): 70 | if conv == 'gen': 71 | gcn = GENConv(hidden_channels, hidden_channels, 72 | aggr=aggr, 73 | t=t, learn_t=self.learn_t, 74 | p=p, learn_p=self.learn_p, 75 | y=y, learn_y=self.learn_p, 76 | msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale, 77 | encode_edge=self.conv_encode_edge, bond_encoder=True, 78 | norm=norm, mlp_layers=mlp_layers) 79 | else: 80 | raise Exception('Unknown Conv Type') 81 | self.gcns.append(gcn) 82 | self.norms.append(norm_layer(norm, hidden_channels)) 83 | 84 | self.atom_encoder = AtomEncoder(emb_dim=hidden_channels) 85 | 86 | if not self.conv_encode_edge: 87 | self.bond_encoder = BondEncoder(emb_dim=hidden_channels) 88 | 89 | if graph_pooling == "sum": 90 | self.pool = global_add_pool 91 | elif graph_pooling == "mean": 92 | self.pool = global_mean_pool 93 | elif graph_pooling == "max": 94 | self.pool = global_max_pool 95 | else: 96 | raise Exception('Unknown Pool Type') 97 | 98 | self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks) 99 | 100 | def forward(self, input_batch, mode='train'): 101 | x = input_batch.x 102 | 103 | edge_index = input_batch.edge_index 104 | edge_attr = input_batch.edge_attr 105 | batch = input_batch.batch 106 | 107 | h = self.atom_encoder(x) 108 | 109 | if self.add_virtual_node: 110 | virtualnode_embedding = self.virtualnode_embedding( 111 | torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) 112 | h = h + virtualnode_embedding[batch] 113 | 114 | if self.conv_encode_edge: 115 | edge_emb = edge_attr 116 | else: 117 | edge_emb = self.bond_encoder(edge_attr) 118 | 119 | if self.block == 'res+': 120 | 121 | h = self.gcns[0](h, edge_index, edge_emb) 122 | 123 | for layer in range(1, self.num_layers): 124 | h1 = self.norms[layer - 1](h) 125 | h2 = self.activation_func(h1) 126 | h2 = F.dropout(h2, p=self.dropout, training=self.training) 127 | 128 | if self.add_virtual_node: 129 | virtualnode_embedding_temp = global_add_pool(h2, batch) + virtualnode_embedding 130 | virtualnode_embedding = F.dropout( 131 | self.mlp_virtualnode_list[layer-1](virtualnode_embedding_temp), 132 | self.dropout, training=self.training) 133 | 134 | h2 = h2 + virtualnode_embedding[batch] 135 | 136 | h = self.gcns[layer](h2, edge_index, edge_emb) + h 137 | 138 | h = self.norms[self.num_layers - 1](h) 139 | h = F.dropout(h, p=self.dropout, training=self.training) 140 | 141 | elif self.block == 'res': 142 | 143 | h = self.activation_func(self.norms[0](self.gcns[0](h, edge_index, edge_emb))) 144 | h = F.dropout(h, p=self.dropout, training=self.training) 145 | 146 | for layer in range(1, self.num_layers): 147 | h1 = self.gcns[layer](h, edge_index, edge_emb) 148 | h2 = self.norms[layer](h1) 149 | h = self.activation_func(h2) + h 150 | h = F.dropout(h, p=self.dropout, training=self.training) 151 | 152 | elif self.block == 'dense': 153 | raise NotImplementedError('To be implemented') 154 | 155 | elif self.block == 'plain': 156 | 157 | h = self.activation_func(self.norms[0](self.gcns[0](h, edge_index, edge_emb))) 158 | h = F.dropout(h, p=self.dropout, training=self.training) 159 | 160 | for layer in range(1, self.num_layers): 161 | h1 = self.gcns[layer](h, edge_index, edge_emb) 162 | h2 = self.norms[layer](h1) 163 | if layer != (self.num_layers - 1): 164 | h = self.activation_func(h2) 165 | else: 166 | h = h2 167 | h = F.dropout(h, p=self.dropout, training=self.training) 168 | else: 169 | raise Exception('Unknown block Type') 170 | 171 | h_graph = self.pool(h, batch) # N, 256 172 | 173 | dcn_pred = self.graph_pred_linear(h_graph) 174 | rf_pred = input_batch.y[:, 2] 175 | return (1-self.beta)*torch.sigmoid(dcn_pred).reshape(-1, 1) + (self.beta) * rf_pred.reshape(-1,1) 176 | 177 | 178 | def print_params(self, epoch=None, final=False): 179 | 180 | if self.learn_t: 181 | ts = [] 182 | for gcn in self.gcns: 183 | ts.append(gcn.t.item()) 184 | if final: 185 | print('Final t {}'.format(ts)) 186 | else: 187 | logging.info('Epoch {}, t {}'.format(epoch, ts)) 188 | if self.learn_p: 189 | ps = [] 190 | for gcn in self.gcns: 191 | ps.append(gcn.p.item()) 192 | if final: 193 | print('Final p {}'.format(ps)) 194 | else: 195 | logging.info('Epoch {}, p {}'.format(epoch, ps)) 196 | if self.msg_norm: 197 | ss = [] 198 | for gcn in self.gcns: 199 | ss.append(gcn.msg_norm.msg_scale.item()) 200 | if final: 201 | print('Final s {}'.format(ss)) 202 | else: 203 | logging.info('Epoch {}, s {}'.format(epoch, ss)) 204 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | from model_att import DeeperGCN 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | from args import ArgsInit 9 | from utils.ckpt_util import save_ckpt 10 | import logging 11 | import time, os 12 | import statistics 13 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 14 | import torch.nn.functional as F 15 | from utils.logger import create_exp_dir 16 | 17 | # for AUC margin loss 18 | from libauc.losses import AUCMLoss 19 | from libauc.optimizers import PESG 20 | 21 | 22 | def set_all_seeds(SEED): 23 | # REPRODUCIBILITY 24 | torch.manual_seed(SEED) 25 | np.random.seed(SEED) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | args = ArgsInit().save_exp() 30 | 31 | 32 | def train(model, device, loader, optimizer, task_type, grad_clip=0.): 33 | loss_list = [] 34 | model.train() 35 | 36 | for step, batch in enumerate(loader): 37 | batch = batch.to(device) 38 | 39 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 40 | pass 41 | else: 42 | optimizer.zero_grad() 43 | pred = model(batch) 44 | is_labeled = batch.y[:,0] == batch.y[:,0] 45 | loss = aucm_criterion(pred.to(torch.float32)[is_labeled].reshape(-1, 1), batch.y[:,0:1].to(torch.float32)[is_labeled].reshape(-1, 1)) 46 | loss.backward() 47 | 48 | if grad_clip > 0: 49 | torch.nn.utils.clip_grad_value_( 50 | model.parameters(), 51 | grad_clip) 52 | 53 | optimizer.step() 54 | 55 | loss_list.append(loss.item()) 56 | return statistics.mean(loss_list) 57 | 58 | 59 | @torch.no_grad() 60 | def eval(model, device, loader, evaluator): 61 | model.eval() 62 | y_true = [] 63 | y_pred = [] 64 | 65 | for step, batch in enumerate(loader): 66 | batch = batch.to(device) 67 | 68 | if batch.x.shape[0] == 1: 69 | pass 70 | else: 71 | pred = model(batch, mode='test') 72 | y_true.append(batch.y[:,0:1].view(pred.shape).detach().cpu()) # remove random forest pred 73 | y_pred.append(pred.detach().cpu()) 74 | 75 | y_true = torch.cat(y_true, dim=0).numpy() 76 | y_pred = torch.cat(y_pred, dim=0).numpy() 77 | 78 | input_dict = {"y_true": y_true, 79 | "y_pred": y_pred} 80 | 81 | return evaluator.eval(input_dict) 82 | 83 | 84 | def main(): 85 | 86 | if args.use_gpu: 87 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 88 | else: 89 | device = torch.device('cpu') 90 | 91 | sub_dir = 'BS_{}-NF_{}'.format(args.batch_size, args.feature) 92 | set_all_seeds(args.random_seed) 93 | dataset = PygGraphPropPredDataset(name=args.dataset) 94 | 95 | # Load RF predictions 96 | npy = os.listdir('rf_preds')[args.random_seed] 97 | rf_pred = np.load(os.path.join('rf_preds', npy)) 98 | print (npy) 99 | dataset.data.y = torch.cat((dataset.data.y, torch.from_numpy(rf_pred)), 1) 100 | 101 | args.num_tasks = dataset.num_tasks 102 | #logging.info('%s' % args) 103 | 104 | if args.feature == 'full': 105 | pass 106 | elif args.feature == 'simple': 107 | print('using simple feature') 108 | # only retain the top two node/edge features 109 | dataset.data.x = dataset.data.x[:, :2] 110 | dataset.data.edge_attr = dataset.data.edge_attr[:, :2] 111 | 112 | evaluator = Evaluator(args.dataset) 113 | split_idx = dataset.get_idx_split() 114 | 115 | set_all_seeds(args.random_seed) 116 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, 117 | num_workers=args.num_workers) 118 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, 119 | num_workers=args.num_workers) 120 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, 121 | num_workers=args.num_workers) 122 | 123 | 124 | set_all_seeds(args.random_seed) 125 | model = DeeperGCN(args).to(device) 126 | 127 | if True: 128 | checkpoint_path = './saved_models/EXP-B_res+-C_gen-L_14-F_256-DP_0.2-GA_softmax-T_1.0-LT_True-P_1.0-LP_False-Y_0.0-LY_False-MN_False-LS_False-RS_%s/model_ckpt/'%(args.random_seed) 129 | best_pth = sorted(os.listdir(checkpoint_path))[-1] 130 | args.model_load_path = os.path.join(checkpoint_path, best_pth) 131 | trained_stat_dict = torch.load(args.model_load_path)['model_state_dict'] 132 | #trained_stat_dict.pop('graph_pred_linear.weight', None) 133 | ##trained_stat_dict.pop('graph_pred_linear.bias', None) 134 | model.load_state_dict(trained_stat_dict, strict=False) 135 | 136 | num_params = sum(p.numel() for p in model.parameters()) 137 | print(f'#Params: {num_params}') 138 | 139 | optimizer = PESG(model, 140 | a=aucm_criterion.a, 141 | b=aucm_criterion.b, 142 | alpha=aucm_criterion.alpha, 143 | lr=args.lr, 144 | gamma=args.gamma, 145 | margin=args.margin, 146 | weight_decay=args.weight_decay) 147 | 148 | 149 | # get imbalance ratio from train set 150 | args.imratio = float((train_loader.dataset.data.y[:, 0].sum()/train_loader.dataset.data.y[:,0].shape[0]).numpy()) 151 | aucm_criterion.p = args.imratio 152 | print (aucm_criterion.p) 153 | 154 | # save 155 | datetime_now = '2021-10-09' 156 | pretrained_prefix = 'pre_' if args.pretrained else '' 157 | virtual_node_prefilx = '-vt' if args.add_virtual_node else '' 158 | args.configs = '[%s]Train_%s_im_%.4f_rd_%s_%s%s-FP_%s_%s_wd_%s_lr_%s_B_%s_E_%s_%s_%s_g_%s_m_%s'%(datetime_now, args.dataset, args.imratio, args.random_seed, pretrained_prefix, args.model_name, virtual_node_prefilx, args.activations, args.weight_decay, args.lr, args.batch_size, args.epochs, args.loss, args.optimizer, args.gamma, args.margin) 159 | logging.info(args.save) 160 | logging.info(args.configs) 161 | 162 | results = {'highest_valid': 0, 163 | 'final_train': 0, 164 | 'final_test': 0, 165 | 'highest_train': 0} 166 | 167 | start_time = time.time() 168 | start_time_local = time.time() 169 | for epoch in range(1, args.epochs + 1): 170 | 171 | if epoch in [int(args.epochs*0.33), int(args.epochs*0.66)] and args.loss!= 'ce': 172 | optimizer.update_regularizer(decay_factor=10) 173 | 174 | epoch_loss = train(model, device, train_loader, optimizer, dataset.task_type, grad_clip=args.grad_clip) 175 | 176 | #logging.info('Evaluating...') 177 | train_result = eval(model, device, train_loader, evaluator)[dataset.eval_metric] 178 | valid_result = eval(model, device, valid_loader, evaluator)[dataset.eval_metric] 179 | test_result = eval(model, device, test_loader, evaluator)[dataset.eval_metric] 180 | 181 | logging.info("Epoch:%s, train_auc:%.4f, valid_auc:%.4f, test_auc:%.4f, lr:%.4f, time:%.4f"%(epoch, train_result, valid_result, test_result, optimizer.lr, time.time()-start_time_local)) 182 | start_time_local = time.time() 183 | # model.print_params(epoch=epoch) 184 | 185 | if train_result > results['highest_train']: 186 | results['highest_train'] = train_result 187 | 188 | if valid_result > results['highest_valid']: 189 | results['highest_valid'] = valid_result 190 | results['final_train'] = train_result 191 | results['final_test'] = test_result 192 | 193 | save_ckpt(model, optimizer, 194 | round(epoch_loss, 4), epoch, 195 | args.model_save_path, 196 | sub_dir, name_post='valid_best_AUC-FP_E_%s_R%s'%(epoch, args.random_seed)) 197 | 198 | 199 | logging.info("%s" % results) 200 | 201 | end_time = time.time() 202 | total_time = end_time - start_time 203 | logging.info('Total time: {}'.format(time.strftime('%H:%M:%S', time.gmtime(total_time)))) 204 | 205 | 206 | if __name__ == "__main__": 207 | cls_criterion = torch.nn.BCEWithLogitsLoss() 208 | reg_criterion = torch.nn.MSELoss() 209 | 210 | # https://github.com/Optimization-AI/LibAUC 211 | aucm_criterion = AUCMLoss() 212 | main() 213 | -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | 6 | class RAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 9 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 10 | self.buffer = [[None, None, None] for ind in range(10)] 11 | super(RAdam, self).__init__(params, defaults) 12 | 13 | def __setstate__(self, state): 14 | super(RAdam, self).__setstate__(state) 15 | 16 | def step(self, closure=None): 17 | 18 | loss = None 19 | if closure is not None: 20 | loss = closure() 21 | 22 | for group in self.param_groups: 23 | 24 | for p in group['params']: 25 | if p.grad is None: 26 | continue 27 | grad = p.grad.data.float() 28 | if grad.is_sparse: 29 | raise RuntimeError('RAdam does not support sparse gradients') 30 | 31 | p_data_fp32 = p.data.float() 32 | 33 | state = self.state[p] 34 | 35 | if len(state) == 0: 36 | state['step'] = 0 37 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 38 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 39 | else: 40 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 41 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 42 | 43 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 44 | beta1, beta2 = group['betas'] 45 | 46 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 47 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 48 | 49 | state['step'] += 1 50 | buffered = self.buffer[int(state['step'] % 10)] 51 | if state['step'] == buffered[0]: 52 | N_sma, step_size = buffered[1], buffered[2] 53 | else: 54 | buffered[0] = state['step'] 55 | beta2_t = beta2 ** state['step'] 56 | N_sma_max = 2 / (1 - beta2) - 1 57 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 58 | buffered[1] = N_sma 59 | 60 | # more conservative since it's an approximated value 61 | if N_sma >= 5: 62 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 63 | else: 64 | step_size = group['lr'] / (1 - beta1 ** state['step']) 65 | buffered[2] = step_size 66 | 67 | if group['weight_decay'] != 0: 68 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 69 | 70 | # more conservative since it's an approximated value 71 | if N_sma >= 5: 72 | denom = exp_avg_sq.sqrt().add_(group['eps']) 73 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 74 | else: 75 | p_data_fp32.add_(-step_size, exp_avg) 76 | 77 | p.data.copy_(p_data_fp32) 78 | 79 | return loss 80 | 81 | class PlainRAdam(Optimizer): 82 | 83 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 84 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 85 | 86 | super(RAdam, self).__init__(params, defaults) 87 | 88 | def __setstate__(self, state): 89 | super(RAdam, self).__setstate__(state) 90 | 91 | def step(self, closure=None): 92 | 93 | loss = None 94 | if closure is not None: 95 | loss = closure() 96 | 97 | for group in self.param_groups: 98 | 99 | for p in group['params']: 100 | if p.grad is None: 101 | continue 102 | grad = p.grad.data.float() 103 | if grad.is_sparse: 104 | raise RuntimeError('RAdam does not support sparse gradients') 105 | 106 | p_data_fp32 = p.data.float() 107 | 108 | state = self.state[p] 109 | 110 | if len(state) == 0: 111 | state['step'] = 0 112 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 113 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 114 | else: 115 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 116 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 117 | 118 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 119 | beta1, beta2 = group['betas'] 120 | 121 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 122 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 123 | 124 | state['step'] += 1 125 | beta2_t = beta2 ** state['step'] 126 | N_sma_max = 2 / (1 - beta2) - 1 127 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 128 | 129 | if group['weight_decay'] != 0: 130 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 131 | 132 | # more conservative since it's an approximated value 133 | if N_sma >= 5: 134 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 135 | denom = exp_avg_sq.sqrt().add_(group['eps']) 136 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 137 | else: 138 | step_size = group['lr'] / (1 - beta1 ** state['step']) 139 | p_data_fp32.add_(-step_size, exp_avg) 140 | 141 | p.data.copy_(p_data_fp32) 142 | 143 | return loss 144 | 145 | 146 | class AdamW(Optimizer): 147 | 148 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 149 | defaults = dict(lr=lr, betas=betas, eps=eps, 150 | weight_decay=weight_decay, amsgrad=amsgrad, use_variance=True, warmup = warmup) 151 | super(AdamW, self).__init__(params, defaults) 152 | 153 | def __setstate__(self, state): 154 | super(AdamW, self).__setstate__(state) 155 | 156 | def step(self, closure=None): 157 | loss = None 158 | if closure is not None: 159 | loss = closure() 160 | 161 | for group in self.param_groups: 162 | 163 | for p in group['params']: 164 | if p.grad is None: 165 | continue 166 | grad = p.grad.data.float() 167 | if grad.is_sparse: 168 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 169 | 170 | p_data_fp32 = p.data.float() 171 | 172 | state = self.state[p] 173 | 174 | if len(state) == 0: 175 | state['step'] = 0 176 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 177 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 178 | else: 179 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 180 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 181 | 182 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 183 | beta1, beta2 = group['betas'] 184 | 185 | state['step'] += 1 186 | 187 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 188 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 189 | 190 | denom = exp_avg_sq.sqrt().add_(group['eps']) 191 | bias_correction1 = 1 - beta1 ** state['step'] 192 | bias_correction2 = 1 - beta2 ** state['step'] 193 | 194 | if group['warmup'] > state['step']: 195 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 196 | else: 197 | scheduled_lr = group['lr'] 198 | 199 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 200 | 201 | if group['weight_decay'] != 0: 202 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 203 | 204 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 205 | 206 | p.data.copy_(p_data_fp32) 207 | 208 | return loss -------------------------------------------------------------------------------- /utils/pc_viz.py: -------------------------------------------------------------------------------- 1 | import vtk 2 | import numpy as np 3 | import random 4 | import os 5 | 6 | print('Using', vtk.vtkVersion.GetVTKSourceVersion()) 7 | 8 | 9 | class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): 10 | def __init__(self, parent, pointcloud): 11 | self.parent = parent 12 | self.pointcloud = pointcloud 13 | self.AddObserver("KeyPressEvent", self.keyPressEvent) 14 | 15 | def keyPressEvent(self, obj, event): 16 | key = self.parent.GetKeySym() 17 | if key == '+': 18 | point_size = self.pointcloud.vtkActor.GetProperty().GetPointSize() 19 | self.pointcloud.vtkActor.GetProperty().SetPointSize(point_size + 1) 20 | print(str(point_size) + " " + key) 21 | return 22 | 23 | 24 | class VtkPointCloud: 25 | 26 | def __init__(self, point_size=18, maxNumPoints=1e8): 27 | self.maxNumPoints = maxNumPoints 28 | self.vtkPolyData = vtk.vtkPolyData() 29 | self.clear_points() 30 | 31 | self.colors = vtk.vtkUnsignedCharArray() 32 | self.colors.SetNumberOfComponents(3) 33 | self.colors.SetName("Colors") 34 | 35 | mapper = vtk.vtkPolyDataMapper() 36 | mapper.SetInputData(self.vtkPolyData) 37 | 38 | self.vtkActor = vtk.vtkActor() 39 | self.vtkActor.SetMapper(mapper) 40 | self.vtkActor.GetProperty().SetPointSize(point_size) 41 | 42 | def add_point(self, point, color): 43 | if self.vtkPoints.GetNumberOfPoints() < self.maxNumPoints: 44 | pointId = self.vtkPoints.InsertNextPoint(point[:]) 45 | self.colors.InsertNextTuple(color) 46 | self.vtkDepth.InsertNextValue(point[2]) 47 | self.vtkCells.InsertNextCell(1) 48 | self.vtkCells.InsertCellPoint(pointId) 49 | else: 50 | print("VIZ: Reached max number of points!") 51 | r = random.randint(0, self.maxNumPoints) 52 | self.vtkPoints.SetPoint(r, point[:]) 53 | self.vtkPolyData.GetPointData().SetScalars(self.colors) 54 | self.vtkCells.Modified() 55 | self.vtkPoints.Modified() 56 | self.vtkDepth.Modified() 57 | 58 | def clear_points(self): 59 | self.vtkPoints = vtk.vtkPoints() 60 | self.vtkCells = vtk.vtkCellArray() 61 | self.vtkDepth = vtk.vtkDoubleArray() 62 | self.vtkDepth.SetName('DepthArray') 63 | self.vtkPolyData.SetPoints(self.vtkPoints) 64 | self.vtkPolyData.SetVerts(self.vtkCells) 65 | self.vtkPolyData.GetPointData().SetScalars(self.vtkDepth) 66 | self.vtkPolyData.GetPointData().SetActiveScalars('DepthArray') 67 | 68 | 69 | def getActorCircle(radius_inner=100, radius_outer=99, color=(1, 0, 0)): 70 | """""" 71 | # create source 72 | source = vtk.vtkDiskSource() 73 | source.SetInnerRadius(radius_inner) 74 | source.SetOuterRadius(radius_outer) 75 | source.SetRadialResolution(100) 76 | source.SetCircumferentialResolution(100) 77 | 78 | # Transformer 79 | transform = vtk.vtkTransform() 80 | transform.RotateWXYZ(90, 1, 0, 0) 81 | transformFilter = vtk.vtkTransformPolyDataFilter() 82 | transformFilter.SetTransform(transform) 83 | transformFilter.SetInputConnection(source.GetOutputPort()) 84 | transformFilter.Update() 85 | 86 | # mapper 87 | mapper = vtk.vtkPolyDataMapper() 88 | mapper.SetInputConnection(transformFilter.GetOutputPort()) 89 | 90 | # actor 91 | actor = vtk.vtkActor() 92 | actor.GetProperty().SetColor(color) 93 | actor.SetMapper(mapper) 94 | 95 | return actor 96 | 97 | 98 | def show_pointclouds(points, colors, text=[], title="Default", png_path="", interactive=True, orientation='horizontal'): 99 | """ 100 | Show multiple point clouds specified as lists. First clouds at the bottom. 101 | :param points: list of pointclouds, item: numpy (N x 3) XYZ 102 | :param colors: list of corresponding colors, item: numpy (N x 3) RGB [0..255] 103 | :param title: window title 104 | :param text: text per point cloud 105 | :param png_path: where to save png image 106 | :param interactive: wether to display window or not, useful if you only want to take screenshot 107 | :return: nothing 108 | """ 109 | 110 | # make sure pointclouds is a list 111 | assert isinstance(points, type([])), \ 112 | "Pointclouds argument must be a list" 113 | 114 | # make sure colors is a list 115 | assert isinstance(colors, type([])), \ 116 | "Colors argument must be a list" 117 | 118 | # make sure number of pointclouds and colors are the same 119 | assert len(points) == len(colors), \ 120 | "Number of pointclouds (%d) is different then number of colors (%d)" % (len(points), len(colors)) 121 | 122 | while len(text) < len(points): 123 | text.append("") 124 | 125 | # Number of pointclouds to be displayed in this window 126 | num_pointclouds = len(points) 127 | 128 | point_size = 10 129 | pointclouds = [VtkPointCloud(point_size) for _ in range(num_pointclouds)] 130 | renderers = [vtk.vtkRenderer() for _ in range(num_pointclouds)] 131 | 132 | height = 1.0 / max(num_pointclouds, 1) 133 | viewports = [(i*height, (i+1)*height) for i in range(num_pointclouds)] 134 | #print(viewports) 135 | 136 | # iterate over all point clouds 137 | for i, pc in enumerate(points): 138 | pc = pc.squeeze() 139 | co = colors[i].squeeze() 140 | assert pc.shape[0] == co.shape[0], \ 141 | "expected same number of points (%d) then colors (%d), cloud index = %d" % (pc.shape[0], co.shape[0], i) 142 | assert pc.shape[1] == 3, "expected points to be N x 3, got N x %d" % pc.shape[1] 143 | assert co.shape[1] == 3, "expected colors to be N x 3, got N x %d" % co.shape[1] 144 | 145 | # for each point cloud iterate over all points 146 | for j in range(pc.shape[0]): 147 | point = pc[j, :] 148 | color = co[j, :] 149 | pointclouds[i].add_point(point, color) 150 | 151 | renderers[i].AddActor(pointclouds[i].vtkActor) 152 | # renderers[i].AddActor(vtk.vtkAxesActor()) 153 | renderers[i].SetBackground(1.0, 1.0, 1.0) 154 | if orientation == 'horizontal': 155 | print(viewports[i][0]) 156 | renderers[i].SetViewport(viewports[i][0], 0.0, viewports[i][1], 1.0) 157 | elif orientation == 'vertical': 158 | renderers[i].SetViewport(0.0, viewports[i][0], 1.0, viewports[i][1]) 159 | else: 160 | raise Exception('Not a valid orientation!') 161 | renderers[i].ResetCamera() 162 | 163 | # Add circle to first render 164 | renderers[0].AddActor(getActorCircle()) 165 | renderers[0].AddActor(getActorCircle(50, 49, color=(0, 1, 0))) 166 | 167 | # Text actors 168 | text_actors = [vtk.vtkTextActor() for _ in text] 169 | for i, ta in enumerate(text_actors): 170 | if orientation == 'horizontal': 171 | ta.SetInput(' ' + text[i]) 172 | elif orientation == 'vertical': 173 | ta.SetInput(text[i] + '\n\n\n\n\n\n') 174 | else: 175 | raise Exception('Not a valid orientation!') 176 | txtprop = ta.GetTextProperty() 177 | txtprop.SetFontFamilyToArial() 178 | txtprop.SetFontSize(0) 179 | txtprop.SetColor(0, 0, 0) 180 | # txtprop.SetJustificationToCentered() 181 | # ta.SetDisplayPosition(500, 10) 182 | # ta.SetAlignmentPoint() 183 | renderers[i].AddActor(ta) 184 | 185 | # Render Window 186 | render_window = vtk.vtkRenderWindow() 187 | for renderer in renderers: 188 | render_window.AddRenderer(renderer) 189 | 190 | render_window_interactor = vtk.vtkRenderWindowInteractor() 191 | render_window_interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera()) 192 | render_window_interactor.SetRenderWindow(render_window) 193 | 194 | [center_x, center_y, center_z] = np.mean(points[0].squeeze(), axis=0) 195 | camera = vtk.vtkCamera() 196 | # d = 10 197 | # camera.SetViewUp(0, -1, 0) 198 | 199 | # camera.SetPosition(center_x + d, center_y + d, center_z + d / 2) 200 | # camera.SetFocalPoint(center_x, center_y, center_z) 201 | # camera.SetFocalPoint(0, 0, 0) 202 | 203 | camera.SetViewUp(0, 0, 1) 204 | if orientation == 'horizontal': 205 | camera.SetPosition(3, -10, 2) 206 | camera.SetFocalPoint(3, 1.5, 1.5) 207 | elif orientation == 'vertical': 208 | camera.SetPosition(1.5, -6, 2) 209 | camera.SetFocalPoint(1.5, 1.5, 1.5) 210 | else: 211 | raise Exception('Not a valid orientation!') 212 | 213 | camera.SetClippingRange(0.002, 1000) 214 | for renderer in renderers: 215 | renderer.SetActiveCamera(camera) 216 | 217 | # Begin Interaction 218 | render_window.Render() 219 | render_window.SetWindowName(title) 220 | if orientation == 'horizontal': 221 | render_window.SetSize(1940, 720) 222 | elif orientation == 'vertical': 223 | render_window.SetSize(600, 1388) 224 | else: 225 | raise Exception('Not a valid orientation!') 226 | 227 | if interactive: 228 | render_window_interactor.Start() 229 | 230 | if png_path: 231 | # screenshot code: 232 | w2if = vtk.vtkWindowToImageFilter() 233 | w2if.SetInput(render_window) 234 | w2if.Update() 235 | 236 | writer = vtk.vtkPNGWriter() 237 | writer.SetFileName(png_path) 238 | writer.SetInputConnection(w2if.GetOutputPort()) 239 | writer.Write() 240 | 241 | 242 | def get_points_colors_from_obj(filename, limit=1): 243 | points = [] 244 | colors = [] 245 | with open(filename) as f: 246 | for line in f: 247 | parts = line.strip().split() 248 | points.append(np.array([float(parts[1]), float(parts[2]), float(parts[3])])) 249 | colors.append(np.array([float(parts[4]), float(parts[5]), float(parts[6])])) 250 | points = np.array(points) 251 | colors = np.array(colors) 252 | idx = points[:, 1] >= limit 253 | return points[idx, :], colors[idx, :] 254 | 255 | 256 | def visualize_part_seg(file_name_pred, file_name_gt, comparison_folder_list, limit=1, text=[], png_path="", 257 | interactive=True, orientation='horizontal'): 258 | # load base point cloud 259 | gt_points, gt_colors = get_points_colors_from_obj(os.path.join(comparison_folder_list[0], file_name_gt), limit) 260 | 261 | idx_gt = gt_points[:, 1] >= limit 262 | 263 | all_points = [gt_points[idx_gt, :3]] 264 | all_colors = [gt_colors[idx_gt, :3]] 265 | 266 | for folder in comparison_folder_list: 267 | pts, col = get_points_colors_from_obj(os.path.join(folder, file_name_pred), limit=limit) 268 | 269 | all_points.append(pts) 270 | all_colors.append(col) 271 | 272 | print(np.asarray(all_points).shape) 273 | show_pointclouds(all_points, all_colors, text=text, png_path=png_path, interactive=interactive, 274 | orientation=orientation) 275 | 276 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_vertex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torch_geometric as tg 5 | from .torch_nn import MLP, act_layer, norm_layer, BondEncoder 6 | from .torch_edge import DilatedKnnGraph 7 | from .torch_message import GenMessagePassing, MsgNorm 8 | from torch_geometric.utils import remove_self_loops, add_self_loops 9 | 10 | 11 | class GENConv(GenMessagePassing): 12 | """ 13 | GENeralized Graph Convolution (GENConv): https://arxiv.org/pdf/2006.07739.pdf 14 | SoftMax & PowerMean Aggregation 15 | """ 16 | def __init__(self, in_dim, emb_dim, 17 | aggr='softmax', 18 | t=1.0, learn_t=False, 19 | p=1.0, learn_p=False, 20 | y=0.0, learn_y=False, 21 | msg_norm=False, learn_msg_scale=True, 22 | encode_edge=False, bond_encoder=False, 23 | edge_feat_dim=None, 24 | norm='batch', mlp_layers=2, 25 | eps=1e-7): 26 | 27 | super(GENConv, self).__init__(aggr=aggr, 28 | t=t, learn_t=learn_t, 29 | p=p, learn_p=learn_p, 30 | y=y, learn_y=learn_y) 31 | 32 | channels_list = [in_dim] 33 | 34 | for i in range(mlp_layers-1): 35 | channels_list.append(in_dim*2) 36 | 37 | channels_list.append(emb_dim) 38 | 39 | self.mlp = MLP(channels=channels_list, 40 | norm=norm, 41 | last_lin=True) 42 | 43 | self.msg_encoder = torch.nn.ReLU() 44 | self.eps = eps 45 | 46 | self.msg_norm = msg_norm 47 | self.encode_edge = encode_edge 48 | self.bond_encoder = bond_encoder 49 | 50 | if msg_norm: 51 | self.msg_norm = MsgNorm(learn_msg_scale=learn_msg_scale) 52 | else: 53 | self.msg_norm = None 54 | 55 | if self.encode_edge: 56 | if self.bond_encoder: 57 | self.edge_encoder = BondEncoder(emb_dim=in_dim) 58 | else: 59 | self.edge_encoder = torch.nn.Linear(edge_feat_dim, in_dim) 60 | 61 | def forward(self, x, edge_index, edge_attr=None): 62 | if self.encode_edge and edge_attr is not None: 63 | edge_emb = self.edge_encoder(edge_attr) 64 | else: 65 | edge_emb = edge_attr 66 | 67 | m = self.propagate(edge_index, x=x, edge_attr=edge_emb) 68 | 69 | if self.msg_norm is not None: 70 | m = self.msg_norm(x, m) 71 | 72 | h = x + m 73 | out = self.mlp(h) 74 | 75 | return out 76 | 77 | def message(self, x_j, edge_attr=None): 78 | 79 | if edge_attr is not None: 80 | msg = x_j + edge_attr 81 | else: 82 | msg = x_j 83 | 84 | return self.msg_encoder(msg) + self.eps 85 | 86 | def update(self, aggr_out): 87 | return aggr_out 88 | 89 | 90 | class MRConv(nn.Module): 91 | """ 92 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) 93 | """ 94 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'): 95 | super(MRConv, self).__init__() 96 | self.nn = MLP([in_channels*2, out_channels], act, norm, bias) 97 | self.aggr = aggr 98 | 99 | def forward(self, x, edge_index): 100 | """""" 101 | x_j = tg.utils.scatter_(self.aggr, torch.index_select(x, 0, edge_index[0]) - torch.index_select(x, 0, edge_index[1]), edge_index[1], dim_size=x.shape[0]) 102 | return self.nn(torch.cat([x, x_j], dim=1)) 103 | 104 | 105 | class EdgConv(tg.nn.EdgeConv): 106 | """ 107 | Edge convolution layer (with activation, batch normalization) 108 | """ 109 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'): 110 | super(EdgConv, self).__init__(MLP([in_channels*2, out_channels], act, norm, bias), aggr) 111 | 112 | def forward(self, x, edge_index): 113 | return super(EdgConv, self).forward(x, edge_index) 114 | 115 | 116 | class GATConv(nn.Module): 117 | """ 118 | Graph Attention Convolution layer (with activation, batch normalization) 119 | """ 120 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, heads=8): 121 | super(GATConv, self).__init__() 122 | self.gconv = tg.nn.GATConv(in_channels, out_channels, heads, bias=bias) 123 | m =[] 124 | if act: 125 | m.append(act_layer(act)) 126 | if norm: 127 | m.append(norm_layer(norm, out_channels)) 128 | self.unlinear = nn.Sequential(*m) 129 | 130 | def forward(self, x, edge_index): 131 | out = self.unlinear(self.gconv(x, edge_index)) 132 | return out 133 | 134 | 135 | class SAGEConv(tg.nn.SAGEConv): 136 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on 137 | Large Graphs" `_ paper 138 | 139 | .. math:: 140 | \mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot 141 | \mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j) 142 | 143 | \mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i} 144 | {\| \mathbf{\hat{x}}_i \|_2}. 145 | 146 | Args: 147 | in_channels (int): Size of each input sample. 148 | out_channels (int): Size of each output sample. 149 | normalize (bool, optional): If set to :obj:`False`, output features 150 | will not be :math:`\ell_2`-normalized. (default: :obj:`True`) 151 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 152 | an additive bias. (default: :obj:`True`) 153 | **kwargs (optional): Additional arguments of 154 | :class:`torch_geometric.nn.conv.MessagePassing`. 155 | """ 156 | 157 | def __init__(self, 158 | in_channels, 159 | out_channels, 160 | nn, 161 | norm=True, 162 | bias=True, 163 | relative=False, 164 | **kwargs): 165 | self.relative = relative 166 | if norm is not None: 167 | super(SAGEConv, self).__init__(in_channels, out_channels, True, bias, **kwargs) 168 | else: 169 | super(SAGEConv, self).__init__(in_channels, out_channels, False, bias, **kwargs) 170 | self.nn = nn 171 | 172 | def forward(self, x, edge_index, size=None): 173 | """""" 174 | if size is None: 175 | edge_index, _ = remove_self_loops(edge_index) 176 | edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 177 | 178 | x = x.unsqueeze(-1) if x.dim() == 1 else x 179 | return self.propagate(edge_index, size=size, x=x) 180 | 181 | def message(self, x_i, x_j): 182 | if self.relative: 183 | x = torch.matmul(x_j - x_i, self.weight) 184 | else: 185 | x = torch.matmul(x_j, self.weight) 186 | return x 187 | 188 | def update(self, aggr_out, x): 189 | out = self.nn(torch.cat((x, aggr_out), dim=1)) 190 | if self.bias is not None: 191 | out = out + self.bias 192 | if self.normalize: 193 | out = F.normalize(out, p=2, dim=-1) 194 | return out 195 | 196 | 197 | class RSAGEConv(SAGEConv): 198 | """ 199 | Residual SAGE convolution layer (with activation, batch normalization) 200 | """ 201 | 202 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, relative=False): 203 | nn = MLP([out_channels + in_channels, out_channels], act, norm, bias) 204 | super(RSAGEConv, self).__init__(in_channels, out_channels, nn, norm, bias, relative) 205 | 206 | 207 | class SemiGCNConv(nn.Module): 208 | """ 209 | SemiGCN convolution layer (with activation, batch normalization) 210 | """ 211 | 212 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 213 | super(SemiGCNConv, self).__init__() 214 | self.gconv = tg.nn.GCNConv(in_channels, out_channels, bias=bias) 215 | m = [] 216 | if act: 217 | m.append(act_layer(act)) 218 | if norm: 219 | m.append(norm_layer(norm, out_channels)) 220 | self.unlinear = nn.Sequential(*m) 221 | 222 | def forward(self, x, edge_index): 223 | out = self.unlinear(self.gconv(x, edge_index)) 224 | return out 225 | 226 | 227 | class GinConv(tg.nn.GINConv): 228 | """ 229 | GINConv layer (with activation, batch normalization) 230 | """ 231 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='add'): 232 | super(GinConv, self).__init__(MLP([in_channels, out_channels], act, norm, bias)) 233 | 234 | def forward(self, x, edge_index): 235 | return super(GinConv, self).forward(x, edge_index) 236 | 237 | 238 | class GraphConv(nn.Module): 239 | """ 240 | Static graph convolution layer 241 | """ 242 | def __init__(self, in_channels, out_channels, conv='edge', 243 | act='relu', norm=None, bias=True, heads=8): 244 | super(GraphConv, self).__init__() 245 | if conv.lower() == 'edge': 246 | self.gconv = EdgConv(in_channels, out_channels, act, norm, bias) 247 | elif conv.lower() == 'mr': 248 | self.gconv = MRConv(in_channels, out_channels, act, norm, bias) 249 | elif conv.lower() == 'gat': 250 | self.gconv = GATConv(in_channels, out_channels//heads, act, norm, bias, heads) 251 | elif conv.lower() == 'gcn': 252 | self.gconv = SemiGCNConv(in_channels, out_channels, act, norm, bias) 253 | elif conv.lower() == 'gin': 254 | self.gconv = GinConv(in_channels, out_channels, act, norm, bias) 255 | elif conv.lower() == 'sage': 256 | self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, False) 257 | elif conv.lower() == 'rsage': 258 | self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, True) 259 | else: 260 | raise NotImplementedError('conv {} is not implemented'.format(conv)) 261 | 262 | def forward(self, x, edge_index): 263 | return self.gconv(x, edge_index) 264 | 265 | 266 | class DynConv(GraphConv): 267 | """ 268 | Dynamic graph convolution layer 269 | """ 270 | def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', 271 | norm=None, bias=True, heads=8, **kwargs): 272 | super(DynConv, self).__init__(in_channels, out_channels, conv, act, norm, bias, heads) 273 | self.k = kernel_size 274 | self.d = dilation 275 | self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, **kwargs) 276 | 277 | def forward(self, x, batch=None): 278 | edge_index = self.dilated_knn_graph(x, batch) 279 | return super(DynConv, self).forward(x, edge_index) 280 | 281 | 282 | class PlainDynBlock(nn.Module): 283 | """ 284 | Plain Dynamic graph convolution block 285 | """ 286 | def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 287 | bias=True, res_scale=1, **kwargs): 288 | super(PlainDynBlock, self).__init__() 289 | self.body = DynConv(channels, channels, kernel_size, dilation, conv, 290 | act, norm, bias, **kwargs) 291 | self.res_scale = res_scale 292 | 293 | def forward(self, x, batch=None): 294 | return self.body(x, batch), batch 295 | 296 | 297 | class ResDynBlock(nn.Module): 298 | """ 299 | Residual Dynamic graph convolution block 300 | """ 301 | def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 302 | bias=True, res_scale=1, **kwargs): 303 | super(ResDynBlock, self).__init__() 304 | self.body = DynConv(channels, channels, kernel_size, dilation, conv, 305 | act, norm, bias, **kwargs) 306 | self.res_scale = res_scale 307 | 308 | def forward(self, x, batch=None): 309 | return self.body(x, batch) + x*self.res_scale, batch 310 | 311 | 312 | class DenseDynBlock(nn.Module): 313 | """ 314 | Dense Dynamic graph convolution block 315 | """ 316 | def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, **kwargs): 317 | super(DenseDynBlock, self).__init__() 318 | self.body = DynConv(in_channels, out_channels, kernel_size, dilation, conv, 319 | act, norm, bias, **kwargs) 320 | 321 | def forward(self, x, batch=None): 322 | dense = self.body(x, batch) 323 | return torch.cat((x, dense), 1), batch 324 | 325 | 326 | class ResGraphBlock(nn.Module): 327 | """ 328 | Residual Static graph convolution block 329 | """ 330 | def __init__(self, channels, conv='edge', act='relu', norm=None, bias=True, heads=8, res_scale=1): 331 | super(ResGraphBlock, self).__init__() 332 | self.body = GraphConv(channels, channels, conv, act, norm, bias, heads) 333 | self.res_scale = res_scale 334 | 335 | def forward(self, x, edge_index): 336 | return self.body(x, edge_index) + x*self.res_scale, edge_index 337 | 338 | 339 | class DenseGraphBlock(nn.Module): 340 | """ 341 | Dense Static graph convolution block 342 | """ 343 | def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True, heads=8): 344 | super(DenseGraphBlock, self).__init__() 345 | self.body = GraphConv(in_channels, out_channels, conv, act, norm, bias, heads) 346 | 347 | def forward(self, x, edge_index): 348 | dense = self.body(x, edge_index) 349 | return torch.cat((x, dense), 1), edge_index 350 | 351 | -------------------------------------------------------------------------------- /utils/data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import os 4 | import os.path as osp 5 | import shutil 6 | from glob import glob 7 | import torch 8 | from torch_scatter import scatter 9 | from torch_geometric.data import InMemoryDataset, Data, extract_zip 10 | from tqdm import tqdm 11 | import torch_geometric as tg 12 | 13 | 14 | def intersection(lst1, lst2): 15 | return list(set(lst1) & set(lst2)) 16 | 17 | 18 | def process_indexes(idx_list): 19 | idx_dict = {} 20 | for i, idx in enumerate(idx_list): 21 | idx_dict[idx] = i 22 | 23 | return [idx_dict[i] for i in sorted(idx_dict.keys())] 24 | 25 | 26 | def add_zeros(data): 27 | data.x = torch.zeros(data.num_nodes, dtype=torch.long) 28 | return data 29 | 30 | 31 | def extract_node_feature(data, reduce='add'): 32 | if reduce in ['mean', 'max', 'add']: 33 | data.x = scatter(data.edge_attr, 34 | data.edge_index[0], 35 | dim=0, 36 | dim_size=data.num_nodes, 37 | reduce=reduce) 38 | else: 39 | raise Exception('Unknown Aggregation Type') 40 | return data 41 | 42 | # random partition graph 43 | def random_partition_graph(num_nodes, cluster_number=10): 44 | parts = np.random.randint(cluster_number, size=num_nodes) 45 | return parts 46 | 47 | 48 | def generate_sub_graphs(adj, parts, cluster_number=10, batch_size=1): 49 | # convert sparse tensor to scipy csr 50 | adj = adj.to_scipy(layout='csr') 51 | 52 | num_batches = cluster_number // batch_size 53 | 54 | sg_nodes = [[] for _ in range(num_batches)] 55 | sg_edges = [[] for _ in range(num_batches)] 56 | 57 | for cluster in range(num_batches): 58 | sg_nodes[cluster] = np.where(parts == cluster)[0] 59 | sg_edges[cluster] = tg.utils.from_scipy_sparse_matrix(adj[sg_nodes[cluster], :][:, sg_nodes[cluster]])[0] 60 | 61 | return sg_nodes, sg_edges 62 | 63 | def random_rotate(points): 64 | theta = np.random.uniform(0, np.pi * 2) 65 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 66 | rotation_matrix = torch.from_numpy(rotation_matrix).float() 67 | points[:, 0:2] = torch.matmul(points[:, [0, 1]].transpose(1, 3), rotation_matrix).transpose(1, 3) 68 | return points 69 | 70 | 71 | def random_translate(points, mean=0, std=0.02): 72 | points += torch.randn(points.shape)*std + mean 73 | return points 74 | 75 | 76 | def random_points_augmentation(points, rotate=False, translate=False, **kwargs): 77 | if rotate: 78 | points = random_rotate(points) 79 | if translate: 80 | points = random_translate(points, **kwargs) 81 | 82 | return points 83 | 84 | 85 | def scale_translate_pointcloud(pointcloud, shift=[-0.2, 0.2], scale=[2. / 3., 3. /2.]): 86 | """ 87 | for scaling and shifting the point cloud 88 | :param pointcloud: 89 | :return: 90 | """ 91 | B, C, N = pointcloud.shape[0:3] 92 | scale = scale[0] + torch.rand([B, C, 1, 1])*(scale[1]-scale[0]) 93 | shift = shift[0] + torch.rand([B, C, 1, 1]) * (shift[1]-shift[0]) 94 | translated_pointcloud = torch.mul(pointcloud, scale) + shift 95 | return translated_pointcloud 96 | 97 | 98 | class PartNet(InMemoryDataset): 99 | r"""The PartNet dataset from 100 | the `"PartNet: A Large-scale Benchmark for Fine-grained and Hierarchical Part-level 3D Object Understanding" 101 | `_ 102 | paper, containing 3D objects annotated with fine-grained, instance-level, and hierarchical 3D part information. 103 | 104 | Args: 105 | root (string): Root directory where the dataset should be saved. 106 | dataset (str, optional): Which dataset to use (ins_seg_h5, or sem_seg_h5). 107 | (default: :obj:`sem_seg_h5`) 108 | obj_category (str, optional): which category to load. 109 | (default: :obj:`Bed`) 110 | level (str, optional): Which level of part semantic segmentation to use. 111 | (default: :obj:`3`) 112 | phase (str, optional): If :obj:`test`, loads the testing dataset, 113 | If :obj:`val`, loads the validation dataset, 114 | otherwise the training dataset. (default: :obj:`train`) 115 | transform (callable, optional): A function/transform that takes in an 116 | :obj:`torch_geometric.data.Data` object and returns a transformed 117 | version. The data object will be transformed before every access. 118 | (default: :obj:`None`) 119 | pre_transform (callable, optional): A function/transform that takes in 120 | an :obj:`torch_geometric.data.Data` object and returns a 121 | transformed version. The data object will be transformed before 122 | being saved to disk. (default: :obj:`None`) 123 | pre_filter (callable, optional): A function that takes in an 124 | :obj:`torch_geometric.data.Data` object and returns a boolean 125 | value, indicating whether the data object should be included in the 126 | final dataset. (default: :obj:`None`) 127 | """ 128 | # the dataset we use for our paper is pre-released version 129 | def __init__(self, 130 | root, 131 | dataset='sem_seg_h5', 132 | obj_category='Bed', 133 | level=3, 134 | phase='train', 135 | transform=None, 136 | pre_transform=None, 137 | pre_filter=None): 138 | self.dataset = dataset 139 | self.level = level 140 | self.obj_category = obj_category 141 | self.object = '-'.join([self.obj_category, str(self.level)]) 142 | self.level_folder = 'level_'+str(self.level) 143 | self.processed_file_folder = osp.join(self.dataset, self.level_folder, self.object) 144 | super(PartNet, self).__init__(root, transform, pre_transform, pre_filter) 145 | if phase == 'test': 146 | path = self.processed_paths[1] 147 | elif phase == 'val': 148 | path = self.processed_paths[2] 149 | else: 150 | path = self.processed_paths[0] 151 | self.data, self.slices = torch.load(path) 152 | 153 | @property 154 | def raw_file_names(self): 155 | return [self.dataset] 156 | 157 | @property 158 | def processed_file_names(self): 159 | return osp.join(self.processed_file_folder, 'train.pt'), osp.join(self.processed_file_folder, 'test.pt'), \ 160 | osp.join(self.processed_file_folder, 'val.pt') 161 | 162 | def download(self): 163 | path = osp.join(self.raw_dir, self.dataset) 164 | if not osp.exists(path): 165 | raise FileExistsError('PartNet can only downloaded via application. ' 166 | 'See details in https://cs.stanford.edu/~kaichun/partnet/') 167 | # path = download_url(self.url, self.root) 168 | extract_zip(path, self.root) 169 | os.unlink(path) 170 | shutil.rmtree(self.raw_dir) 171 | name = self.url.split(os.sep)[-1].split('.')[0] 172 | os.rename(osp.join(self.root, name), self.raw_dir) 173 | 174 | def process(self): 175 | # save to processed_paths 176 | processed_path = osp.join(self.processed_dir, self.processed_file_folder) 177 | if not osp.exists(processed_path): 178 | os.makedirs(osp.join(processed_path)) 179 | torch.save(self.process_set('train'), self.processed_paths[0]) 180 | torch.save(self.process_set('test'), self.processed_paths[1]) 181 | torch.save(self.process_set('val'), self.processed_paths[2]) 182 | 183 | def process_set(self, dataset): 184 | if self.dataset == 'ins_seg_h5': 185 | raw_path = osp.join(self.raw_dir, 'ins_seg_h5_for_sgpn', self.dataset) 186 | categories = glob(osp.join(raw_path, '*')) 187 | categories = sorted([x.split(os.sep)[-1] for x in categories]) 188 | 189 | data_list = [] 190 | for target, category in enumerate(tqdm(categories)): 191 | folder = osp.join(raw_path, category) 192 | paths = glob('{}/{}-*.h5'.format(folder, dataset)) 193 | labels, nors, opacitys, pts, rgbs = [], [], [], [], [] 194 | for path in paths: 195 | f = h5py.File(path) 196 | pts += torch.from_numpy(f['pts'][:]).unbind(0) 197 | labels += torch.from_numpy(f['label'][:]).to(torch.long).unbind(0) 198 | nors += torch.from_numpy(f['nor'][:]).unbind(0) 199 | opacitys += torch.from_numpy(f['opacity'][:]).unbind(0) 200 | rgbs += torch.from_numpy(f['rgb'][:]).to(torch.float32).unbind(0) 201 | 202 | for i, (pt, label, nor, opacity, rgb) in enumerate(zip(pts, labels, nors, opacitys, rgbs)): 203 | data = Data(pos=pt[:, :3], y=label, norm=nor[:, :3], x=torch.cat((opacity.unsqueeze(-1), rgb/255.), 1)) 204 | 205 | if self.pre_filter is not None and not self.pre_filter(data): 206 | continue 207 | if self.pre_transform is not None: 208 | data = self.pre_transform(data) 209 | data_list.append(data) 210 | else: 211 | raw_path = osp.join(self.raw_dir, self.dataset) 212 | categories = glob(osp.join(raw_path, self.object)) 213 | categories = sorted([x.split(os.sep)[-1] for x in categories]) 214 | data_list = [] 215 | # class_name = [] 216 | for target, category in enumerate(tqdm(categories)): 217 | folder = osp.join(raw_path, category) 218 | paths = glob('{}/{}-*.h5'.format(folder, dataset)) 219 | labels, pts = [], [] 220 | # clss = category.split('-')[0] 221 | 222 | for path in paths: 223 | f = h5py.File(path) 224 | pts += torch.from_numpy(f['data'][:].astype(np.float32)).unbind(0) 225 | labels += torch.from_numpy(f['label_seg'][:].astype(np.float32)).to(torch.long).unbind(0) 226 | for i, (pt, label) in enumerate(zip(pts, labels)): 227 | data = Data(pos=pt[:, :3], y=label) 228 | # data = PartData(pos=pt[:, :3], y=label, clss=clss) 229 | if self.pre_filter is not None and not self.pre_filter(data): 230 | continue 231 | if self.pre_transform is not None: 232 | data = self.pre_transform(data) 233 | data_list.append(data) 234 | return self.collate(data_list) 235 | 236 | 237 | class PartData(Data): 238 | def __init__(self, 239 | y=None, 240 | pos=None, 241 | clss=None): 242 | super(PartData).__init__(pos=pos, y=y) 243 | self.clss = clss 244 | 245 | 246 | # allowable multiple choice node and edge features 247 | # code from https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py 248 | allowable_features = { 249 | 'possible_atomic_num_list' : list(range(1, 119)) + ['misc'], 250 | 'possible_chirality_list' : [ 251 | 'CHI_UNSPECIFIED', 252 | 'CHI_TETRAHEDRAL_CW', 253 | 'CHI_TETRAHEDRAL_CCW', 254 | 'CHI_OTHER' 255 | ], 256 | 'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], 257 | 'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], 258 | 'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], 259 | 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], 260 | 'possible_hybridization_list' : [ 261 | 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' 262 | ], 263 | 'possible_is_aromatic_list': [False, True], 264 | 'possible_is_in_ring_list': [False, True], 265 | 'possible_bond_type_list' : [ 266 | 'SINGLE', 267 | 'DOUBLE', 268 | 'TRIPLE', 269 | 'AROMATIC', 270 | 'misc' 271 | ], 272 | 'possible_bond_stereo_list': [ 273 | 'STEREONONE', 274 | 'STEREOZ', 275 | 'STEREOE', 276 | 'STEREOCIS', 277 | 'STEREOTRANS', 278 | 'STEREOANY', 279 | ], 280 | 'possible_is_conjugated_list': [False, True], 281 | } 282 | 283 | 284 | def safe_index(l, e): 285 | """ 286 | Return index of element e in list l. If e is not present, return the last index 287 | """ 288 | try: 289 | return l.index(e) 290 | except: 291 | return len(l) - 1 292 | 293 | 294 | def atom_to_feature_vector(atom): 295 | """ 296 | Converts rdkit atom object to feature list of indices 297 | :param mol: rdkit atom object 298 | :return: list 299 | """ 300 | atom_feature = [ 301 | safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), 302 | allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())), 303 | safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), 304 | safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), 305 | safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), 306 | safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), 307 | safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())), 308 | allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()), 309 | allowable_features['possible_is_in_ring_list'].index(atom.IsInRing()), 310 | ] 311 | return atom_feature 312 | 313 | 314 | def get_atom_feature_dims(): 315 | return list(map(len, [ 316 | allowable_features['possible_atomic_num_list'], 317 | allowable_features['possible_chirality_list'], 318 | allowable_features['possible_degree_list'], 319 | allowable_features['possible_formal_charge_list'], 320 | allowable_features['possible_numH_list'], 321 | allowable_features['possible_number_radical_e_list'], 322 | allowable_features['possible_hybridization_list'], 323 | allowable_features['possible_is_aromatic_list'], 324 | allowable_features['possible_is_in_ring_list'] 325 | ])) 326 | 327 | 328 | def bond_to_feature_vector(bond): 329 | """ 330 | Converts rdkit bond object to feature list of indices 331 | :param mol: rdkit bond object 332 | :return: list 333 | """ 334 | bond_feature = [ 335 | safe_index(allowable_features['possible_bond_type_list'], str(bond.GetBondType())), 336 | allowable_features['possible_bond_stereo_list'].index(str(bond.GetStereo())), 337 | allowable_features['possible_is_conjugated_list'].index(bond.GetIsConjugated()), 338 | ] 339 | return bond_feature 340 | 341 | 342 | def get_bond_feature_dims(): 343 | return list(map(len, [ 344 | allowable_features['possible_bond_type_list'], 345 | allowable_features['possible_bond_stereo_list'], 346 | allowable_features['possible_is_conjugated_list'] 347 | ])) 348 | 349 | 350 | def atom_feature_vector_to_dict(atom_feature): 351 | [atomic_num_idx, 352 | chirality_idx, 353 | degree_idx, 354 | formal_charge_idx, 355 | num_h_idx, 356 | number_radical_e_idx, 357 | hybridization_idx, 358 | is_aromatic_idx, 359 | is_in_ring_idx] = atom_feature 360 | 361 | feature_dict = { 362 | 'atomic_num': allowable_features['possible_atomic_num_list'][atomic_num_idx], 363 | 'chirality': allowable_features['possible_chirality_list'][chirality_idx], 364 | 'degree': allowable_features['possible_degree_list'][degree_idx], 365 | 'formal_charge': allowable_features['possible_formal_charge_list'][formal_charge_idx], 366 | 'num_h': allowable_features['possible_numH_list'][num_h_idx], 367 | 'num_rad_e': allowable_features['possible_number_radical_e_list'][number_radical_e_idx], 368 | 'hybridization': allowable_features['possible_hybridization_list'][hybridization_idx], 369 | 'is_aromatic': allowable_features['possible_is_aromatic_list'][is_aromatic_idx], 370 | 'is_in_ring': allowable_features['possible_is_in_ring_list'][is_in_ring_idx] 371 | } 372 | 373 | return feature_dict 374 | 375 | 376 | def bond_feature_vector_to_dict(bond_feature): 377 | [bond_type_idx, 378 | bond_stereo_idx, 379 | is_conjugated_idx] = bond_feature 380 | 381 | feature_dict = { 382 | 'bond_type': allowable_features['possible_bond_type_list'][bond_type_idx], 383 | 'bond_stereo': allowable_features['possible_bond_stereo_list'][bond_stereo_idx], 384 | 'is_conjugated': allowable_features['possible_is_conjugated_list'][is_conjugated_idx] 385 | } 386 | 387 | return feature_dict 388 | --------------------------------------------------------------------------------